diff --git a/.gitattributes b/.gitattributes index 0ca74e648af6fe669e87215f0b78e6b0947029c2..460c60cde5ea001bdbb4e378afbd7ff25f93580c 100644 --- a/.gitattributes +++ b/.gitattributes @@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text +custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text +custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text diff --git a/__init__.py b/__init__.py index b4719805b60c1030e0ca7d8517d24f493b3081ff..33e7a7f594ef441479257c788e4c0d6e08657fc8 100644 --- a/__init__.py +++ b/__init__.py @@ -1,214 +1 @@ -import sys, os -from .utils import here, define_preprocessor_inputs, INPUT -from pathlib import Path -import traceback -import importlib -from .log import log, blue_text, cyan_text, get_summary, get_label -from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS -from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS -#Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17 -sys.path.insert(0, str(Path(here, "src").resolve())) -for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]: - sys.path.append(str(Path(here, "src", pkg_name).resolve())) - -#Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out -#https://github.com/pytorch/pytorch/issues/77764 -#https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485 -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1') - - -def load_nodes(): - shorted_errors = [] - full_error_messages = [] - node_class_mappings = {} - node_display_name_mappings = {} - - for filename in (here / "node_wrappers").iterdir(): - module_name = filename.stem - if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store)) - try: - module = importlib.import_module( - f".node_wrappers.{module_name}", package=__package__ - ) - node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS")) - if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"): - node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS")) - - log.debug(f"Imported {module_name} nodes") - - except AttributeError: - pass # wip nodes - except Exception: - error_message = traceback.format_exc() - full_error_messages.append(error_message) - error_message = error_message.splitlines()[-1] - shorted_errors.append( - f"Failed to import module {module_name} because {error_message}" - ) - - if len(shorted_errors) > 0: - full_err_log = '\n\n'.join(full_error_messages) - print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n") - log.info( - f"Some nodes failed to load:\n\t" - + "\n\t".join(shorted_errors) - + "\n\n" - + "Check that you properly installed the dependencies.\n" - + "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)" - ) - return node_class_mappings, node_display_name_mappings - -AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes() - -#For nodes not mapping image to image or has special requirements -AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"] -AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"] -AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"] - -def preprocessor_options(): - auxs = list(AUX_NODE_MAPPINGS.keys()) - auxs.insert(0, "none") - for name in AIO_NOT_SUPPORTED: - if name in auxs: - auxs.remove(name) - return auxs - - -PREPROCESSOR_OPTIONS = preprocessor_options() - -class AIO_Preprocessor: - @classmethod - def INPUT_TYPES(s): - return define_preprocessor_inputs( - preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"), - resolution=INPUT.RESOLUTION() - ) - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - - CATEGORY = "ControlNet Preprocessors" - - def execute(self, preprocessor, image, resolution=512): - if preprocessor == "none": - return (image, ) - else: - aux_class = AUX_NODE_MAPPINGS[preprocessor] - input_types = aux_class.INPUT_TYPES() - input_types = { - **input_types["required"], - **(input_types["optional"] if "optional" in input_types else {}) - } - params = {} - for name, input_type in input_types.items(): - if name == "image": - params[name] = image - continue - - if name == "resolution": - params[name] = resolution - continue - - if len(input_type) == 2 and ("default" in input_type[1]): - params[name] = input_type[1]["default"] - continue - - default_values = { "INT": 0, "FLOAT": 0.0 } - if input_type[0] in default_values: - params[name] = default_values[input_type[0]] - - return getattr(aux_class(), aux_class.FUNCTION)(**params) - -class ControlNetAuxSimpleAddText: - @classmethod - def INPUT_TYPES(s): - return dict( - required=dict(image=INPUT.IMAGE(), text=INPUT.STRING()) - ) - - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - CATEGORY = "ControlNet Preprocessors" - def execute(self, image, text): - from PIL import Image, ImageDraw, ImageFont - import numpy as np - import torch - - font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40) - img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8)) - ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font) - return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,) - -class ExecuteAllControlNetPreprocessors: - @classmethod - def INPUT_TYPES(s): - return define_preprocessor_inputs(resolution=INPUT.RESOLUTION()) - RETURN_TYPES = ("IMAGE",) - FUNCTION = "execute" - - CATEGORY = "ControlNet Preprocessors" - - def execute(self, image, resolution=512): - try: - from comfy_execution.graph_utils import GraphBuilder - except: - raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature") - - graph = GraphBuilder() - curr_outputs = [] - for preprocc in PREPROCESSOR_OPTIONS: - preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution) - hint_img = preprocc_node.out(0) - add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc) - curr_outputs.append(add_text_node.out(0)) - - while len(curr_outputs) > 1: - _outputs = [] - for i in range(0, len(curr_outputs), 2): - if i+1 < len(curr_outputs): - image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1]) - _outputs.append(image_batch.out(0)) - else: - _outputs.append(curr_outputs[i]) - curr_outputs = _outputs - - return { - "result": (curr_outputs[0],), - "expand": graph.finalize(), - } - -class ControlNetPreprocessorSelector: - @classmethod - def INPUT_TYPES(s): - return { - "required": { - "preprocessor": (PREPROCESSOR_OPTIONS,), - } - } - - RETURN_TYPES = (PREPROCESSOR_OPTIONS,) - RETURN_NAMES = ("preprocessor",) - FUNCTION = "get_preprocessor" - - CATEGORY = "ControlNet Preprocessors" - - def get_preprocessor(self, preprocessor: str): - return (preprocessor,) - - -NODE_CLASS_MAPPINGS = { - **AUX_NODE_MAPPINGS, - "AIO_Preprocessor": AIO_Preprocessor, - "ControlNetPreprocessorSelector": ControlNetPreprocessorSelector, - **HIE_NODE_CLASS_MAPPINGS, - "ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors, - "ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText -} - -NODE_DISPLAY_NAME_MAPPINGS = { - **AUX_DISPLAY_NAME_MAPPINGS, - "AIO_Preprocessor": "AIO Aux Preprocessor", - "ControlNetPreprocessorSelector": "Preprocessor Selector", - **HIE_NODE_DISPLAY_NAME_MAPPINGS, - "ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors" -} +#Dummy file ensuring this package will be recognized \ No newline at end of file diff --git a/custom_albumentations/LICENSE b/custom_albumentations/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..68e16af191003d844016fc83dafbc1fd9dd65104 --- /dev/null +++ b/custom_albumentations/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2017 Buslaev Alexander, Alexander Parinov, Vladimir Iglovikov + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/custom_albumentations/__init__.py b/custom_albumentations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e9d0b803bd01f421cdc41a575c1668985331d5a --- /dev/null +++ b/custom_albumentations/__init__.py @@ -0,0 +1,15 @@ +from __future__ import absolute_import + +__version__ = "1.3.1" + +from .augmentations import * +from .core.composition import * +from .core.serialization import * +from .core.transforms_interface import * + +try: + from .imgaug.transforms import * # type: ignore +except ImportError: + # imgaug is not installed by default, so we import stubs. + # Run `pip install -U albumentations[imgaug] if you need augmentations from imgaug.` + from .imgaug.stubs import * # type: ignore diff --git a/custom_albumentations/augmentations/__init__.py b/custom_albumentations/augmentations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..79da4b649f7e0f1afad439c3abef5b1389983095 --- /dev/null +++ b/custom_albumentations/augmentations/__init__.py @@ -0,0 +1,21 @@ +# Common classes +from .blur.functional import * +from .blur.transforms import * +from .crops.functional import * +from .crops.transforms import * + +# New transformations goes to individual files listed below +from .domain_adaptation import * +from .dropout.channel_dropout import * +from .dropout.coarse_dropout import * +from .dropout.cutout import * +from .dropout.functional import * +from .dropout.grid_dropout import * +from .dropout.mask_dropout import * +from .functional import * +from .geometric.functional import * +from .geometric.resize import * +from .geometric.rotate import * +from .geometric.transforms import * +from .transforms import * +from .utils import * diff --git a/custom_albumentations/augmentations/blur/__init__.py b/custom_albumentations/augmentations/blur/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d447f1d7e2a97edbeaf95d461eeb66f7a0bae95 --- /dev/null +++ b/custom_albumentations/augmentations/blur/__init__.py @@ -0,0 +1,2 @@ +from .functional import * +from .transforms import * diff --git a/custom_albumentations/augmentations/blur/functional.py b/custom_albumentations/augmentations/blur/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..755732341bd97279c69350773853074513455488 --- /dev/null +++ b/custom_albumentations/augmentations/blur/functional.py @@ -0,0 +1,106 @@ +from itertools import product +from math import ceil +from typing import Sequence, Union + +import cv2 +import numpy as np + +from custom_albumentations.augmentations.functional import convolve +from custom_albumentations.augmentations.geometric.functional import scale +from custom_albumentations.augmentations.utils import ( + _maybe_process_in_chunks, + clipped, + preserve_shape, +) + +__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"] + + +@preserve_shape +def blur(img: np.ndarray, ksize: int) -> np.ndarray: + blur_fn = _maybe_process_in_chunks(cv2.blur, ksize=(ksize, ksize)) + return blur_fn(img) + + +@preserve_shape +def median_blur(img: np.ndarray, ksize: int) -> np.ndarray: + if img.dtype == np.float32 and ksize not in {3, 5}: + raise ValueError(f"Invalid ksize value {ksize}. For a float32 image the only valid ksize values are 3 and 5") + + blur_fn = _maybe_process_in_chunks(cv2.medianBlur, ksize=ksize) + return blur_fn(img) + + +@preserve_shape +def gaussian_blur(img: np.ndarray, ksize: int, sigma: float = 0) -> np.ndarray: + # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8` + blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma) + return blur_fn(img) + + +@preserve_shape +def glass_blur( + img: np.ndarray, sigma: float, max_delta: int, iterations: int, dxy: np.ndarray, mode: str +) -> np.ndarray: + x = cv2.GaussianBlur(np.array(img), sigmaX=sigma, ksize=(0, 0)) + + if mode == "fast": + hs = np.arange(img.shape[0] - max_delta, max_delta, -1) + ws = np.arange(img.shape[1] - max_delta, max_delta, -1) + h: Union[int, np.ndarray] = np.tile(hs, ws.shape[0]) + w: Union[int, np.ndarray] = np.repeat(ws, hs.shape[0]) + + for i in range(iterations): + dy = dxy[:, i, 0] + dx = dxy[:, i, 1] + x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w] + + elif mode == "exact": + for ind, (i, h, w) in enumerate( + product( + range(iterations), + range(img.shape[0] - max_delta, max_delta, -1), + range(img.shape[1] - max_delta, max_delta, -1), + ) + ): + ind = ind if ind < len(dxy) else ind % len(dxy) + dy = dxy[ind, i, 0] + dx = dxy[ind, i, 1] + x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w] + else: + ValueError(f"Unsupported mode `{mode}`. Supports only `fast` and `exact`.") + + return cv2.GaussianBlur(x, sigmaX=sigma, ksize=(0, 0)) + + +def defocus(img: np.ndarray, radius: int, alias_blur: float) -> np.ndarray: + length = np.arange(-max(8, radius), max(8, radius) + 1) + ksize = 3 if radius <= 8 else 5 + + x, y = np.meshgrid(length, length) + aliased_disk = np.array((x**2 + y**2) <= radius**2, dtype=np.float32) + aliased_disk /= np.sum(aliased_disk) + + kernel = gaussian_blur(aliased_disk, ksize, sigma=alias_blur) + return convolve(img, kernel=kernel) + + +def central_zoom(img: np.ndarray, zoom_factor: int) -> np.ndarray: + h, w = img.shape[:2] + h_ch, w_ch = ceil(h / zoom_factor), ceil(w / zoom_factor) + h_top, w_top = (h - h_ch) // 2, (w - w_ch) // 2 + + img = scale(img[h_top : h_top + h_ch, w_top : w_top + w_ch], zoom_factor, cv2.INTER_LINEAR) + h_trim_top, w_trim_top = (img.shape[0] - h) // 2, (img.shape[1] - w) // 2 + return img[h_trim_top : h_trim_top + h, w_trim_top : w_trim_top + w] + + +@clipped +def zoom_blur(img: np.ndarray, zoom_factors: Union[np.ndarray, Sequence[int]]) -> np.ndarray: + out = np.zeros_like(img, dtype=np.float32) + for zoom_factor in zoom_factors: + out += central_zoom(img, zoom_factor) + + img = ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype) + + return img diff --git a/custom_albumentations/augmentations/blur/transforms.py b/custom_albumentations/augmentations/blur/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..50194e8961154d060354e5d5acc8e73d28cb6261 --- /dev/null +++ b/custom_albumentations/augmentations/blur/transforms.py @@ -0,0 +1,486 @@ +import random +import warnings +from typing import Any, Dict, List, Sequence, Tuple + +import cv2 +import numpy as np + +from custom_albumentations import random_utils +from custom_albumentations.augmentations import functional as FMain +from custom_albumentations.augmentations.blur import functional as F +from custom_albumentations.core.transforms_interface import ( + ImageOnlyTransform, + ScaleFloatType, + ScaleIntType, + to_tuple, +) + +__all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"] + + +class Blur(ImageOnlyTransform): + """Blur the input image using a random-sized kernel. + + Args: + blur_limit (int, (int, int)): maximum kernel size for blurring the input image. + Should be in range [3, inf). Default: (3, 7). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5): + super().__init__(always_apply, p) + self.blur_limit = to_tuple(blur_limit, 3) + + def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray: + return F.blur(img, ksize) + + def get_params(self) -> Dict[str, Any]: + return {"ksize": int(random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2))))} + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ("blur_limit",) + + +class MotionBlur(Blur): + """Apply motion blur to the input image using a random-sized kernel. + + Args: + blur_limit (int): maximum kernel size for blurring the input image. + Should be in range [3, inf). Default: (3, 7). + allow_shifted (bool): if set to true creates non shifted kernels only, + otherwise creates randomly shifted kernels. Default: True. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + blur_limit: ScaleIntType = 7, + allow_shifted: bool = True, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(blur_limit=blur_limit, always_apply=always_apply, p=p) + self.allow_shifted = allow_shifted + + if not allow_shifted and self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1: + raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}") + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return super().get_transform_init_args_names() + ("allow_shifted",) + + def apply(self, img: np.ndarray, kernel: np.ndarray = None, **params) -> np.ndarray: # type: ignore + return FMain.convolve(img, kernel=kernel) + + def get_params(self) -> Dict[str, Any]: + ksize = random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2))) + if ksize <= 2: + raise ValueError("ksize must be > 2. Got: {}".format(ksize)) + kernel = np.zeros((ksize, ksize), dtype=np.uint8) + x1, x2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1) + if x1 == x2: + y1, y2 = random.sample(range(ksize), 2) + else: + y1, y2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1) + + def make_odd_val(v1, v2): + len_v = abs(v1 - v2) + 1 + if len_v % 2 != 1: + if v2 > v1: + v2 -= 1 + else: + v1 -= 1 + return v1, v2 + + if not self.allow_shifted: + x1, x2 = make_odd_val(x1, x2) + y1, y2 = make_odd_val(y1, y2) + + xc = (x1 + x2) / 2 + yc = (y1 + y2) / 2 + + center = ksize / 2 - 0.5 + dx = xc - center + dy = yc - center + x1, x2 = [int(i - dx) for i in [x1, x2]] + y1, y2 = [int(i - dy) for i in [y1, y2]] + + cv2.line(kernel, (x1, y1), (x2, y2), 1, thickness=1) + + # Normalize kernel + return {"kernel": kernel.astype(np.float32) / np.sum(kernel)} + + +class MedianBlur(Blur): + """Blur the input image using a median filter with a random aperture linear size. + + Args: + blur_limit (int): maximum aperture linear size for blurring the input image. + Must be odd and in range [3, inf). Default: (3, 7). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5): + super().__init__(blur_limit, always_apply, p) + + if self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1: + raise ValueError("MedianBlur supports only odd blur limits.") + + def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray: + return F.median_blur(img, ksize) + + +class GaussianBlur(ImageOnlyTransform): + """Blur the input image using a Gaussian filter with a random kernel size. + + Args: + blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image. + Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma + as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. + If set single value `blur_limit` will be in range (0, blur_limit). + Default: (3, 7). + sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf). + If set single value `sigma_limit` will be in range (0, sigma_limit). + If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + blur_limit: ScaleIntType = (3, 7), + sigma_limit: ScaleFloatType = 0, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + self.blur_limit = to_tuple(blur_limit, 0) + self.sigma_limit = to_tuple(sigma_limit if sigma_limit is not None else 0, 0) + + if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0: + self.blur_limit = 3, max(3, self.blur_limit[1]) + warnings.warn( + "blur_limit and sigma_limit minimum value can not be both equal to 0. " + "blur_limit minimum value changed to 3." + ) + + if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( + self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 + ): + raise ValueError("GaussianBlur supports only odd blur limits.") + + def apply(self, img: np.ndarray, ksize: int = 3, sigma: float = 0, **params) -> np.ndarray: + return F.gaussian_blur(img, ksize, sigma=sigma) + + def get_params(self) -> Dict[str, float]: + ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1) + if ksize != 0 and ksize % 2 != 1: + ksize = (ksize + 1) % (self.blur_limit[1] + 1) + + return {"ksize": ksize, "sigma": random.uniform(*self.sigma_limit)} + + def get_transform_init_args_names(self) -> Tuple[str, str]: + return ("blur_limit", "sigma_limit") + + +class GlassBlur(Blur): + """Apply glass noise to the input image. + + Args: + sigma (float): standard deviation for Gaussian kernel. + max_delta (int): max distance between pixels which are swapped. + iterations (int): number of repeats. + Should be in range [1, inf). Default: (2). + mode (str): mode of computation: fast or exact. Default: "fast". + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Reference: + | https://arxiv.org/abs/1903.12261 + | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py + """ + + def __init__( + self, + sigma: float = 0.7, + max_delta: int = 4, + iterations: int = 2, + always_apply: bool = False, + mode: str = "fast", + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + if iterations < 1: + raise ValueError(f"Iterations should be more or equal to 1, but we got {iterations}") + + if mode not in ["fast", "exact"]: + raise ValueError(f"Mode should be 'fast' or 'exact', but we got {mode}") + + self.sigma = sigma + self.max_delta = max_delta + self.iterations = iterations + self.mode = mode + + def apply(self, img: np.ndarray, dxy: np.ndarray = None, **params) -> np.ndarray: # type: ignore + assert dxy is not None + return F.glass_blur(img, self.sigma, self.max_delta, self.iterations, dxy, self.mode) + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, np.ndarray]: + img = params["image"] + + # generate array containing all necessary values for transformations + width_pixels = img.shape[0] - self.max_delta * 2 + height_pixels = img.shape[1] - self.max_delta * 2 + total_pixels = width_pixels * height_pixels + dxy = random_utils.randint(-self.max_delta, self.max_delta, size=(total_pixels, self.iterations, 2)) + + return {"dxy": dxy} + + def get_transform_init_args_names(self) -> Tuple[str, str, str]: + return ("sigma", "max_delta", "iterations") + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + +class AdvancedBlur(ImageOnlyTransform): + """Blur the input image using a Generalized Normal filter with a randomly selected parameters. + This transform also adds multiplicative noise to generated kernel before convolution. + + Args: + blur_limit: maximum Gaussian kernel size for blurring the input image. + Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma + as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. + If set single value `blur_limit` will be in range (0, blur_limit). + Default: (3, 7). + sigmaX_limit: Gaussian kernel standard deviation. Must be in range [0, inf). + If set single value `sigmaX_limit` will be in range (0, sigma_limit). + If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0. + sigmaY_limit: Same as `sigmaY_limit` for another dimension. + rotate_limit: Range from which a random angle used to rotate Gaussian kernel is picked. + If limit is a single int an angle is picked from (-rotate_limit, rotate_limit). Default: (-90, 90). + beta_limit: Distribution shape parameter, 1 is the normal distribution. Values below 1.0 make distribution + tails heavier than normal, values above 1.0 make it lighter than normal. Default: (0.5, 8.0). + noise_limit: Multiplicative factor that control strength of kernel noise. Must be positive and preferably + centered around 1.0. If set single value `noise_limit` will be in range (0, noise_limit). + Default: (0.75, 1.25). + p (float): probability of applying the transform. Default: 0.5. + + Reference: + https://arxiv.org/abs/2107.10833 + + Targets: + image + Image types: + uint8, float32 + """ + + def __init__( + self, + blur_limit: ScaleIntType = (3, 7), + sigmaX_limit: ScaleFloatType = (0.2, 1.0), + sigmaY_limit: ScaleFloatType = (0.2, 1.0), + rotate_limit: ScaleIntType = 90, + beta_limit: ScaleFloatType = (0.5, 8.0), + noise_limit: ScaleFloatType = (0.9, 1.1), + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + self.blur_limit = to_tuple(blur_limit, 3) + self.sigmaX_limit = self.__check_values(to_tuple(sigmaX_limit, 0.0), name="sigmaX_limit") + self.sigmaY_limit = self.__check_values(to_tuple(sigmaY_limit, 0.0), name="sigmaY_limit") + self.rotate_limit = to_tuple(rotate_limit) + self.beta_limit = to_tuple(beta_limit, low=0.0) + self.noise_limit = self.__check_values(to_tuple(noise_limit, 0.0), name="noise_limit") + + if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( + self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 + ): + raise ValueError("AdvancedBlur supports only odd blur limits.") + + if self.sigmaX_limit[0] == 0 and self.sigmaY_limit[0] == 0: + raise ValueError("sigmaX_limit and sigmaY_limit minimum value can not be both equal to 0.") + + if not (self.beta_limit[0] < 1.0 < self.beta_limit[1]): + raise ValueError("Beta limit is expected to include 1.0") + + @staticmethod + def __check_values( + value: Sequence[float], name: str, bounds: Tuple[float, float] = (0, float("inf")) + ) -> Sequence[float]: + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError(f"{name} values should be between {bounds}") + return value + + def apply(self, img: np.ndarray, kernel: np.ndarray = np.array(None), **params) -> np.ndarray: + return FMain.convolve(img, kernel=kernel) + + def get_params(self) -> Dict[str, np.ndarray]: + ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2) + sigmaX = random.uniform(*self.sigmaX_limit) + sigmaY = random.uniform(*self.sigmaY_limit) + angle = np.deg2rad(random.uniform(*self.rotate_limit)) + + # Split into 2 cases to avoid selection of narrow kernels (beta > 1) too often. + if random.random() < 0.5: + beta = random.uniform(self.beta_limit[0], 1) + else: + beta = random.uniform(1, self.beta_limit[1]) + + noise_matrix = random_utils.uniform(self.noise_limit[0], self.noise_limit[1], size=[ksize, ksize]) + + # Generate mesh grid centered at zero. + ax = np.arange(-ksize // 2 + 1.0, ksize // 2 + 1.0) + # Shape (ksize, ksize, 2) + grid = np.stack(np.meshgrid(ax, ax), axis=-1) + + # Calculate rotated sigma matrix + d_matrix = np.array([[sigmaX**2, 0], [0, sigmaY**2]]) + u_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]]) + sigma_matrix = np.dot(u_matrix, np.dot(d_matrix, u_matrix.T)) + + inverse_sigma = np.linalg.inv(sigma_matrix) + # Described in "Parameter Estimation For Multivariate Generalized Gaussian Distributions" + kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta)) + # Add noise + kernel = kernel * noise_matrix + + # Normalize kernel + kernel = kernel.astype(np.float32) / np.sum(kernel) + return {"kernel": kernel} + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]: + return ( + "blur_limit", + "sigmaX_limit", + "sigmaY_limit", + "rotate_limit", + "beta_limit", + "noise_limit", + ) + + +class Defocus(ImageOnlyTransform): + """ + Apply defocus transform. See https://arxiv.org/abs/1903.12261. + + Args: + radius ((int, int) or int): range for radius of defocusing. + If limit is a single int, the range will be [1, limit]. Default: (3, 10). + alias_blur ((float, float) or float): range for alias_blur of defocusing (sigma of gaussian blur). + If limit is a single float, the range will be (0, limit). Default: (0.1, 0.5). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + Any + """ + + def __init__( + self, + radius: ScaleIntType = (3, 10), + alias_blur: ScaleFloatType = (0.1, 0.5), + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + self.radius = to_tuple(radius, low=1) + self.alias_blur = to_tuple(alias_blur, low=0) + + if self.radius[0] <= 0: + raise ValueError("Parameter radius must be positive") + + if self.alias_blur[0] < 0: + raise ValueError("Parameter alias_blur must be non-negative") + + def apply(self, img: np.ndarray, radius: int = 3, alias_blur: float = 0.5, **params) -> np.ndarray: + return F.defocus(img, radius, alias_blur) + + def get_params(self) -> Dict[str, Any]: + return { + "radius": random_utils.randint(self.radius[0], self.radius[1] + 1), + "alias_blur": random_utils.uniform(self.alias_blur[0], self.alias_blur[1]), + } + + def get_transform_init_args_names(self) -> Tuple[str, str]: + return ("radius", "alias_blur") + + +class ZoomBlur(ImageOnlyTransform): + """ + Apply zoom blur transform. See https://arxiv.org/abs/1903.12261. + + Args: + max_factor ((float, float) or float): range for max factor for blurring. + If max_factor is a single float, the range will be (1, limit). Default: (1, 1.31). + All max_factor values should be larger than 1. + step_factor ((float, float) or float): If single float will be used as step parameter for np.arange. + If tuple of float step_factor will be in range `[step_factor[0], step_factor[1])`. Default: (0.01, 0.03). + All step_factor values should be positive. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + Any + """ + + def __init__( + self, + max_factor: ScaleFloatType = 1.31, + step_factor: ScaleFloatType = (0.01, 0.03), + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + self.max_factor = to_tuple(max_factor, low=1.0) + self.step_factor = to_tuple(step_factor, step_factor) + + if self.max_factor[0] < 1: + raise ValueError("Max factor must be larger or equal 1") + if self.step_factor[0] <= 0: + raise ValueError("Step factor must be positive") + + def apply(self, img: np.ndarray, zoom_factors: np.ndarray = np.array(None), **params) -> np.ndarray: + assert zoom_factors is not None + return F.zoom_blur(img, zoom_factors) + + def get_params(self) -> Dict[str, Any]: + max_factor = random.uniform(self.max_factor[0], self.max_factor[1]) + step_factor = random.uniform(self.step_factor[0], self.step_factor[1]) + return {"zoom_factors": np.arange(1.0, max_factor, step_factor)} + + def get_transform_init_args_names(self) -> Tuple[str, str]: + return ("max_factor", "step_factor") diff --git a/custom_albumentations/augmentations/crops/__init__.py b/custom_albumentations/augmentations/crops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d447f1d7e2a97edbeaf95d461eeb66f7a0bae95 --- /dev/null +++ b/custom_albumentations/augmentations/crops/__init__.py @@ -0,0 +1,2 @@ +from .functional import * +from .transforms import * diff --git a/custom_albumentations/augmentations/crops/functional.py b/custom_albumentations/augmentations/crops/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..d3b458015604a4da040f5ef605fdc7b41a418218 --- /dev/null +++ b/custom_albumentations/augmentations/crops/functional.py @@ -0,0 +1,317 @@ +from typing import Optional, Sequence, Tuple + +import cv2 +import numpy as np + +from custom_albumentations.augmentations.utils import ( + _maybe_process_in_chunks, + preserve_channel_dim, +) + +from ...core.bbox_utils import denormalize_bbox, normalize_bbox +from ...core.transforms_interface import BoxInternalType, KeypointInternalType +from ..geometric import functional as FGeometric + +__all__ = [ + "get_random_crop_coords", + "random_crop", + "crop_bbox_by_coords", + "bbox_random_crop", + "crop_keypoint_by_coords", + "keypoint_random_crop", + "get_center_crop_coords", + "center_crop", + "bbox_center_crop", + "keypoint_center_crop", + "crop", + "bbox_crop", + "clamping_crop", + "crop_and_pad", + "crop_and_pad_bbox", + "crop_and_pad_keypoint", +] + + +def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, w_start: float): + # h_start is [0, 1) and should map to [0, (height - crop_height)] (note inclusive) + # This is conceptually equivalent to mapping onto `range(0, (height - crop_height + 1))` + # See: https://github.com/albumentations-team/albumentations/pull/1080 + y1 = int((height - crop_height + 1) * h_start) + y2 = y1 + crop_height + x1 = int((width - crop_width + 1) * w_start) + x2 = x1 + crop_width + return x1, y1, x2, y2 + + +def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float): + height, width = img.shape[:2] + if height < crop_height or width < crop_width: + raise ValueError( + "Requested crop size ({crop_height}, {crop_width}) is " + "larger than the image size ({height}, {width})".format( + crop_height=crop_height, crop_width=crop_width, height=height, width=width + ) + ) + x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start) + img = img[y1:y2, x1:x2] + return img + + +def crop_bbox_by_coords( + bbox: BoxInternalType, + crop_coords: Tuple[int, int, int, int], + crop_height: int, + crop_width: int, + rows: int, + cols: int, +): + """Crop a bounding box using the provided coordinates of bottom-left and top-right corners in pixels and the + required height and width of the crop. + + Args: + bbox (tuple): A cropped box `(x_min, y_min, x_max, y_max)`. + crop_coords (tuple): Crop coordinates `(x1, y1, x2, y2)`. + crop_height (int): + crop_width (int): + rows (int): Image rows. + cols (int): Image cols. + + Returns: + tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`. + + """ + bbox = denormalize_bbox(bbox, rows, cols) + x_min, y_min, x_max, y_max = bbox[:4] + x1, y1, _, _ = crop_coords + cropped_bbox = x_min - x1, y_min - y1, x_max - x1, y_max - y1 + return normalize_bbox(cropped_bbox, crop_height, crop_width) + + +def bbox_random_crop( + bbox: BoxInternalType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int +): + crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start) + return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols) + + +def crop_keypoint_by_coords( + keypoint: KeypointInternalType, crop_coords: Tuple[int, int, int, int] +): # skipcq: PYL-W0613 + """Crop a keypoint using the provided coordinates of bottom-left and top-right corners in pixels and the + required height and width of the crop. + + Args: + keypoint (tuple): A keypoint `(x, y, angle, scale)`. + crop_coords (tuple): Crop box coords `(x1, x2, y1, y2)`. + + Returns: + A keypoint `(x, y, angle, scale)`. + + """ + x, y, angle, scale = keypoint[:4] + x1, y1, _, _ = crop_coords + return x - x1, y - y1, angle, scale + + +def keypoint_random_crop( + keypoint: KeypointInternalType, + crop_height: int, + crop_width: int, + h_start: float, + w_start: float, + rows: int, + cols: int, +): + """Keypoint random crop. + + Args: + keypoint: (tuple): A keypoint `(x, y, angle, scale)`. + crop_height (int): Crop height. + crop_width (int): Crop width. + h_start (int): Crop height start. + w_start (int): Crop width start. + rows (int): Image height. + cols (int): Image width. + + Returns: + A keypoint `(x, y, angle, scale)`. + + """ + crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start) + return crop_keypoint_by_coords(keypoint, crop_coords) + + +def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int): + y1 = (height - crop_height) // 2 + y2 = y1 + crop_height + x1 = (width - crop_width) // 2 + x2 = x1 + crop_width + return x1, y1, x2, y2 + + +def center_crop(img: np.ndarray, crop_height: int, crop_width: int): + height, width = img.shape[:2] + if height < crop_height or width < crop_width: + raise ValueError( + "Requested crop size ({crop_height}, {crop_width}) is " + "larger than the image size ({height}, {width})".format( + crop_height=crop_height, crop_width=crop_width, height=height, width=width + ) + ) + x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width) + img = img[y1:y2, x1:x2] + return img + + +def bbox_center_crop(bbox: BoxInternalType, crop_height: int, crop_width: int, rows: int, cols: int): + crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width) + return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols) + + +def keypoint_center_crop(keypoint: KeypointInternalType, crop_height: int, crop_width: int, rows: int, cols: int): + """Keypoint center crop. + + Args: + keypoint (tuple): A keypoint `(x, y, angle, scale)`. + crop_height (int): Crop height. + crop_width (int): Crop width. + rows (int): Image height. + cols (int): Image width. + + Returns: + tuple: A keypoint `(x, y, angle, scale)`. + + """ + crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width) + return crop_keypoint_by_coords(keypoint, crop_coords) + + +def crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int): + height, width = img.shape[:2] + if x_max <= x_min or y_max <= y_min: + raise ValueError( + "We should have x_min < x_max and y_min < y_max. But we got" + " (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})".format( + x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max + ) + ) + + if x_min < 0 or x_max > width or y_min < 0 or y_max > height: + raise ValueError( + "Values for crop should be non negative and equal or smaller than image sizes" + "(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}, " + "height = {height}, width = {width})".format( + x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, height=height, width=width + ) + ) + + return img[y_min:y_max, x_min:x_max] + + +def bbox_crop(bbox: BoxInternalType, x_min: int, y_min: int, x_max: int, y_max: int, rows: int, cols: int): + """Crop a bounding box. + + Args: + bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`. + x_min (int): + y_min (int): + x_max (int): + y_max (int): + rows (int): Image rows. + cols (int): Image cols. + + Returns: + tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`. + + """ + crop_coords = x_min, y_min, x_max, y_max + crop_height = y_max - y_min + crop_width = x_max - x_min + return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols) + + +def clamping_crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int): + h, w = img.shape[:2] + if x_min < 0: + x_min = 0 + if y_min < 0: + y_min = 0 + if y_max >= h: + y_max = h - 1 + if x_max >= w: + x_max = w - 1 + return img[int(y_min) : int(y_max), int(x_min) : int(x_max)] + + +@preserve_channel_dim +def crop_and_pad( + img: np.ndarray, + crop_params: Optional[Sequence[int]], + pad_params: Optional[Sequence[int]], + pad_value: Optional[float], + rows: int, + cols: int, + interpolation: int, + pad_mode: int, + keep_size: bool, +) -> np.ndarray: + if crop_params is not None and any(i != 0 for i in crop_params): + img = crop(img, *crop_params) + if pad_params is not None and any(i != 0 for i in pad_params): + img = FGeometric.pad_with_params( + img, pad_params[0], pad_params[1], pad_params[2], pad_params[3], border_mode=pad_mode, value=pad_value + ) + + if keep_size: + resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(cols, rows), interpolation=interpolation) + img = resize_fn(img) + + return img + + +def crop_and_pad_bbox( + bbox: BoxInternalType, + crop_params: Optional[Sequence[int]], + pad_params: Optional[Sequence[int]], + rows, + cols, + result_rows, + result_cols, +) -> BoxInternalType: + x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4] + + if crop_params is not None: + crop_x, crop_y = crop_params[:2] + x1, y1, x2, y2 = x1 - crop_x, y1 - crop_y, x2 - crop_x, y2 - crop_y + if pad_params is not None: + top, bottom, left, right = pad_params + x1, y1, x2, y2 = x1 + left, y1 + top, x2 + left, y2 + top + + return normalize_bbox((x1, y1, x2, y2), result_rows, result_cols) + + +def crop_and_pad_keypoint( + keypoint: KeypointInternalType, + crop_params: Optional[Sequence[int]], + pad_params: Optional[Sequence[int]], + rows: int, + cols: int, + result_rows: int, + result_cols: int, + keep_size: bool, +) -> KeypointInternalType: + x, y, angle, scale = keypoint[:4] + + if crop_params is not None: + crop_x1, crop_y1, crop_x2, crop_y2 = crop_params + x, y = x - crop_x1, y - crop_y1 + if pad_params is not None: + top, bottom, left, right = pad_params + x, y = x + left, y + top + + if keep_size and (result_cols != cols or result_rows != rows): + scale_x = cols / result_cols + scale_y = rows / result_rows + return FGeometric.keypoint_scale((x, y, angle, scale), scale_x, scale_y) + + return x, y, angle, scale diff --git a/custom_albumentations/augmentations/crops/transforms.py b/custom_albumentations/augmentations/crops/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f4c4af379edfb3b6ea7945ed547f749f401702b9 --- /dev/null +++ b/custom_albumentations/augmentations/crops/transforms.py @@ -0,0 +1,943 @@ +import math +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np + +from custom_albumentations.core.bbox_utils import union_of_bboxes + +from ...core.transforms_interface import ( + BoxInternalType, + DualTransform, + KeypointInternalType, + to_tuple, +) +from ..geometric import functional as FGeometric +from . import functional as F + +__all__ = [ + "RandomCrop", + "CenterCrop", + "Crop", + "CropNonEmptyMaskIfExists", + "RandomSizedCrop", + "RandomResizedCrop", + "RandomCropNearBBox", + "RandomSizedBBoxSafeCrop", + "CropAndPad", + "RandomCropFromBorders", + "BBoxSafeRandomCrop", +] + + +class RandomCrop(DualTransform): + """Crop a random part of the input. + + Args: + height (int): height of the crop. + width (int): width of the crop. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, height, width, always_apply=False, p=1.0): + super().__init__(always_apply, p) + self.height = height + self.width = width + + def apply(self, img, h_start=0, w_start=0, **params): + return F.random_crop(img, self.height, self.width, h_start, w_start) + + def get_params(self): + return {"h_start": random.random(), "w_start": random.random()} + + def apply_to_bbox(self, bbox, **params): + return F.bbox_random_crop(bbox, self.height, self.width, **params) + + def apply_to_keypoint(self, keypoint, **params): + return F.keypoint_random_crop(keypoint, self.height, self.width, **params) + + def get_transform_init_args_names(self): + return ("height", "width") + + +class CenterCrop(DualTransform): + """Crop the central part of the input. + + Args: + height (int): height of the crop. + width (int): width of the crop. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + + Note: + It is recommended to use uint8 images as input. + Otherwise the operation will require internal conversion + float32 -> uint8 -> float32 that causes worse performance. + """ + + def __init__(self, height, width, always_apply=False, p=1.0): + super(CenterCrop, self).__init__(always_apply, p) + self.height = height + self.width = width + + def apply(self, img, **params): + return F.center_crop(img, self.height, self.width) + + def apply_to_bbox(self, bbox, **params): + return F.bbox_center_crop(bbox, self.height, self.width, **params) + + def apply_to_keypoint(self, keypoint, **params): + return F.keypoint_center_crop(keypoint, self.height, self.width, **params) + + def get_transform_init_args_names(self): + return ("height", "width") + + +class Crop(DualTransform): + """Crop region from image. + + Args: + x_min (int): Minimum upper left x coordinate. + y_min (int): Minimum upper left y coordinate. + x_max (int): Maximum lower right x coordinate. + y_max (int): Maximum lower right y coordinate. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, x_min=0, y_min=0, x_max=1024, y_max=1024, always_apply=False, p=1.0): + super(Crop, self).__init__(always_apply, p) + self.x_min = x_min + self.y_min = y_min + self.x_max = x_max + self.y_max = y_max + + def apply(self, img, **params): + return F.crop(img, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max) + + def apply_to_bbox(self, bbox, **params): + return F.bbox_crop(bbox, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max, **params) + + def apply_to_keypoint(self, keypoint, **params): + return F.crop_keypoint_by_coords(keypoint, crop_coords=(self.x_min, self.y_min, self.x_max, self.y_max)) + + def get_transform_init_args_names(self): + return ("x_min", "y_min", "x_max", "y_max") + + +class CropNonEmptyMaskIfExists(DualTransform): + """Crop area with mask if mask is non-empty, else make random crop. + + Args: + height (int): vertical size of crop in pixels + width (int): horizontal size of crop in pixels + ignore_values (list of int): values to ignore in mask, `0` values are always ignored + (e.g. if background value is 5 set `ignore_values=[5]` to ignore) + ignore_channels (list of int): channels to ignore in mask + (e.g. if background is a first channel set `ignore_channels=[0]` to ignore) + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, height, width, ignore_values=None, ignore_channels=None, always_apply=False, p=1.0): + super(CropNonEmptyMaskIfExists, self).__init__(always_apply, p) + + if ignore_values is not None and not isinstance(ignore_values, list): + raise ValueError("Expected `ignore_values` of type `list`, got `{}`".format(type(ignore_values))) + if ignore_channels is not None and not isinstance(ignore_channels, list): + raise ValueError("Expected `ignore_channels` of type `list`, got `{}`".format(type(ignore_channels))) + + self.height = height + self.width = width + self.ignore_values = ignore_values + self.ignore_channels = ignore_channels + + def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.crop(img, x_min, y_min, x_max, y_max) + + def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.bbox_crop( + bbox, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, rows=params["rows"], cols=params["cols"] + ) + + def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max)) + + def _preprocess_mask(self, mask): + mask_height, mask_width = mask.shape[:2] + + if self.ignore_values is not None: + ignore_values_np = np.array(self.ignore_values) + mask = np.where(np.isin(mask, ignore_values_np), 0, mask) + + if mask.ndim == 3 and self.ignore_channels is not None: + target_channels = np.array([ch for ch in range(mask.shape[-1]) if ch not in self.ignore_channels]) + mask = np.take(mask, target_channels, axis=-1) + + if self.height > mask_height or self.width > mask_width: + raise ValueError( + "Crop size ({},{}) is larger than image ({},{})".format( + self.height, self.width, mask_height, mask_width + ) + ) + + return mask + + def update_params(self, params, **kwargs): + super().update_params(params, **kwargs) + if "mask" in kwargs: + mask = self._preprocess_mask(kwargs["mask"]) + elif "masks" in kwargs and len(kwargs["masks"]): + masks = kwargs["masks"] + mask = self._preprocess_mask(np.copy(masks[0])) # need copy as we perform in-place mod afterwards + for m in masks[1:]: + mask |= self._preprocess_mask(m) + else: + raise RuntimeError("Can not find mask for CropNonEmptyMaskIfExists") + + mask_height, mask_width = mask.shape[:2] + + if mask.any(): + mask = mask.sum(axis=-1) if mask.ndim == 3 else mask + non_zero_yx = np.argwhere(mask) + y, x = random.choice(non_zero_yx) + x_min = x - random.randint(0, self.width - 1) + y_min = y - random.randint(0, self.height - 1) + x_min = np.clip(x_min, 0, mask_width - self.width) + y_min = np.clip(y_min, 0, mask_height - self.height) + else: + x_min = random.randint(0, mask_width - self.width) + y_min = random.randint(0, mask_height - self.height) + + x_max = x_min + self.width + y_max = y_min + self.height + + params.update({"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}) + return params + + def get_transform_init_args_names(self): + return ("height", "width", "ignore_values", "ignore_channels") + + +class _BaseRandomSizedCrop(DualTransform): + # Base class for RandomSizedCrop and RandomResizedCrop + + def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0): + super(_BaseRandomSizedCrop, self).__init__(always_apply, p) + self.height = height + self.width = width + self.interpolation = interpolation + + def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params): + crop = F.random_crop(img, crop_height, crop_width, h_start, w_start) + return FGeometric.resize(crop, self.height, self.width, interpolation) + + def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params): + return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols) + + def apply_to_keypoint(self, keypoint, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params): + keypoint = F.keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, rows, cols) + scale_x = self.width / crop_width + scale_y = self.height / crop_height + keypoint = FGeometric.keypoint_scale(keypoint, scale_x, scale_y) + return keypoint + + +class RandomSizedCrop(_BaseRandomSizedCrop): + """Crop a random part of the input and rescale it to some size. + + Args: + min_max_height ((int, int)): crop size limits. + height (int): height after crop and resize. + width (int): width after crop and resize. + w2h_ratio (float): aspect ratio of crop. + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, min_max_height, height, width, w2h_ratio=1.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0 + ): + super(RandomSizedCrop, self).__init__( + height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p + ) + self.min_max_height = min_max_height + self.w2h_ratio = w2h_ratio + + def get_params(self): + crop_height = random.randint(self.min_max_height[0], self.min_max_height[1]) + return { + "h_start": random.random(), + "w_start": random.random(), + "crop_height": crop_height, + "crop_width": int(crop_height * self.w2h_ratio), + } + + def get_transform_init_args_names(self): + return "min_max_height", "height", "width", "w2h_ratio", "interpolation" + + +class RandomResizedCrop(_BaseRandomSizedCrop): + """Torchvision's variant of crop a random part of the input and rescale it to some size. + + Args: + height (int): height after crop and resize. + width (int): width after crop and resize. + scale ((float, float)): range of size of the origin size cropped + ratio ((float, float)): range of aspect ratio of the origin aspect ratio cropped + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + height, + width, + scale=(0.08, 1.0), + ratio=(0.75, 1.3333333333333333), + interpolation=cv2.INTER_LINEAR, + always_apply=False, + p=1.0, + ): + super(RandomResizedCrop, self).__init__( + height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p + ) + self.scale = scale + self.ratio = ratio + + def get_params_dependent_on_targets(self, params): + img = params["image"] + area = img.shape[0] * img.shape[1] + + for _attempt in range(10): + target_area = random.uniform(*self.scale) * area + log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) + aspect_ratio = math.exp(random.uniform(*log_ratio)) + + w = int(round(math.sqrt(target_area * aspect_ratio))) # skipcq: PTC-W0028 + h = int(round(math.sqrt(target_area / aspect_ratio))) # skipcq: PTC-W0028 + + if 0 < w <= img.shape[1] and 0 < h <= img.shape[0]: + i = random.randint(0, img.shape[0] - h) + j = random.randint(0, img.shape[1] - w) + return { + "crop_height": h, + "crop_width": w, + "h_start": i * 1.0 / (img.shape[0] - h + 1e-10), + "w_start": j * 1.0 / (img.shape[1] - w + 1e-10), + } + + # Fallback to central crop + in_ratio = img.shape[1] / img.shape[0] + if in_ratio < min(self.ratio): + w = img.shape[1] + h = int(round(w / min(self.ratio))) + elif in_ratio > max(self.ratio): + h = img.shape[0] + w = int(round(h * max(self.ratio))) + else: # whole image + w = img.shape[1] + h = img.shape[0] + i = (img.shape[0] - h) // 2 + j = (img.shape[1] - w) // 2 + return { + "crop_height": h, + "crop_width": w, + "h_start": i * 1.0 / (img.shape[0] - h + 1e-10), + "w_start": j * 1.0 / (img.shape[1] - w + 1e-10), + } + + def get_params(self): + return {} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return "height", "width", "scale", "ratio", "interpolation" + + +class RandomCropNearBBox(DualTransform): + """Crop bbox from image with random shift by x,y coordinates + + Args: + max_part_shift (float, (float, float)): Max shift in `height` and `width` dimensions relative + to `cropping_bbox` dimension. + If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift). + Default (0.3, 0.3). + cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox` + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + + Examples: + >>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')], + >>> bbox_params=BboxParams("pascal_voc")) + >>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20]) + + """ + + def __init__( + self, + max_part_shift: Union[float, Tuple[float, float]] = (0.3, 0.3), + cropping_box_key: str = "cropping_bbox", + always_apply: bool = False, + p: float = 1.0, + ): + super(RandomCropNearBBox, self).__init__(always_apply, p) + self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift) + self.cropping_bbox_key = cropping_box_key + + if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1: + raise ValueError("Invalid max_part_shift. Got: {}".format(max_part_shift)) + + def apply( + self, img: np.ndarray, x_min: int = 0, x_max: int = 0, y_min: int = 0, y_max: int = 0, **params + ) -> np.ndarray: + return F.clamping_crop(img, x_min, y_min, x_max, y_max) + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, int]: + bbox = params[self.cropping_bbox_key] + h_max_shift = round((bbox[3] - bbox[1]) * self.max_part_shift[0]) + w_max_shift = round((bbox[2] - bbox[0]) * self.max_part_shift[1]) + + x_min = bbox[0] - random.randint(-w_max_shift, w_max_shift) + x_max = bbox[2] + random.randint(-w_max_shift, w_max_shift) + + y_min = bbox[1] - random.randint(-h_max_shift, h_max_shift) + y_max = bbox[3] + random.randint(-h_max_shift, h_max_shift) + + x_min = max(0, x_min) + y_min = max(0, y_min) + + return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max} + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return F.bbox_crop(bbox, **params) + + def apply_to_keypoint( + self, + keypoint: Tuple[float, float, float, float], + x_min: int = 0, + x_max: int = 0, + y_min: int = 0, + y_max: int = 0, + **params + ) -> Tuple[float, float, float, float]: + return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max)) + + @property + def targets_as_params(self) -> List[str]: + return [self.cropping_bbox_key] + + def get_transform_init_args_names(self) -> Tuple[str]: + return ("max_part_shift",) + + +class BBoxSafeRandomCrop(DualTransform): + """Crop a random part of the input without loss of bboxes. + Args: + erosion_rate (float): erosion rate applied on input image height before crop. + p (float): probability of applying the transform. Default: 1. + Targets: + image, mask, bboxes + Image types: + uint8, float32 + """ + + def __init__(self, erosion_rate=0.0, always_apply=False, p=1.0): + super(BBoxSafeRandomCrop, self).__init__(always_apply, p) + self.erosion_rate = erosion_rate + + def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, **params): + return F.random_crop(img, crop_height, crop_width, h_start, w_start) + + def get_params_dependent_on_targets(self, params): + img_h, img_w = params["image"].shape[:2] + if len(params["bboxes"]) == 0: # less likely, this class is for use with bboxes. + erosive_h = int(img_h * (1.0 - self.erosion_rate)) + crop_height = img_h if erosive_h >= img_h else random.randint(erosive_h, img_h) + return { + "h_start": random.random(), + "w_start": random.random(), + "crop_height": crop_height, + "crop_width": int(crop_height * img_w / img_h), + } + # get union of all bboxes + x, y, x2, y2 = union_of_bboxes( + width=img_w, height=img_h, bboxes=params["bboxes"], erosion_rate=self.erosion_rate + ) + # find bigger region + bx, by = x * random.random(), y * random.random() + bx2, by2 = x2 + (1 - x2) * random.random(), y2 + (1 - y2) * random.random() + bw, bh = bx2 - bx, by2 - by + crop_height = img_h if bh >= 1.0 else int(img_h * bh) + crop_width = img_w if bw >= 1.0 else int(img_w * bw) + h_start = np.clip(0.0 if bh >= 1.0 else by / (1.0 - bh), 0.0, 1.0) + w_start = np.clip(0.0 if bw >= 1.0 else bx / (1.0 - bw), 0.0, 1.0) + return {"h_start": h_start, "w_start": w_start, "crop_height": crop_height, "crop_width": crop_width} + + def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params): + return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols) + + @property + def targets_as_params(self): + return ["image", "bboxes"] + + def get_transform_init_args_names(self): + return ("erosion_rate",) + + +class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop): + """Crop a random part of the input and rescale it to some size without loss of bboxes. + Args: + height (int): height after crop and resize. + width (int): width after crop and resize. + erosion_rate (float): erosion rate applied on input image height before crop. + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + Targets: + image, mask, bboxes + Image types: + uint8, float32 + """ + + def __init__(self, height, width, erosion_rate=0.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0): + super(RandomSizedBBoxSafeCrop, self).__init__(erosion_rate, always_apply, p) + self.height = height + self.width = width + self.interpolation = interpolation + + def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params): + crop = F.random_crop(img, crop_height, crop_width, h_start, w_start) + return FGeometric.resize(crop, self.height, self.width, interpolation) + + def get_transform_init_args_names(self): + return super().get_transform_init_args_names() + ("height", "width", "interpolation") + + +class CropAndPad(DualTransform): + """Crop and pad images by pixel amounts or fractions of image sizes. + Cropping removes pixels at the sides (i.e. extracts a subimage from a given full image). + Padding adds pixels to the sides (e.g. black pixels). + This transformation will never crop images below a height or width of ``1``. + + Note: + This transformation automatically resizes images back to their original size. To deactivate this, add the + parameter ``keep_size=False``. + + Args: + px (int or tuple): + The number of pixels to crop (negative values) or pad (positive values) + on each side of the image. Either this or the parameter `percent` may + be set, not both at the same time. + * If ``None``, then pixel-based cropping/padding will not be used. + * If ``int``, then that exact number of pixels will always be cropped/padded. + * If a ``tuple`` of two ``int`` s with values ``a`` and ``b``, + then each side will be cropped/padded by a random amount sampled + uniformly per image and side from the interval ``[a, b]``. If + however `sample_independently` is set to ``False``, only one + value will be sampled per image and used for all sides. + * If a ``tuple`` of four entries, then the entries represent top, + right, bottom, left. Each entry may be a single ``int`` (always + crop/pad by exactly that value), a ``tuple`` of two ``int`` s + ``a`` and ``b`` (crop/pad by an amount within ``[a, b]``), a + ``list`` of ``int`` s (crop/pad by a random value that is + contained in the ``list``). + percent (float or tuple): + The number of pixels to crop (negative values) or pad (positive values) + on each side of the image given as a *fraction* of the image + height/width. E.g. if this is set to ``-0.1``, the transformation will + always crop away ``10%`` of the image's height at both the top and the + bottom (both ``10%`` each), as well as ``10%`` of the width at the + right and left. + Expected value range is ``(-1.0, inf)``. + Either this or the parameter `px` may be set, not both + at the same time. + * If ``None``, then fraction-based cropping/padding will not be + used. + * If ``float``, then that fraction will always be cropped/padded. + * If a ``tuple`` of two ``float`` s with values ``a`` and ``b``, + then each side will be cropped/padded by a random fraction + sampled uniformly per image and side from the interval + ``[a, b]``. If however `sample_independently` is set to + ``False``, only one value will be sampled per image and used for + all sides. + * If a ``tuple`` of four entries, then the entries represent top, + right, bottom, left. Each entry may be a single ``float`` + (always crop/pad by exactly that percent value), a ``tuple`` of + two ``float`` s ``a`` and ``b`` (crop/pad by a fraction from + ``[a, b]``), a ``list`` of ``float`` s (crop/pad by a random + value that is contained in the list). + pad_mode (int): OpenCV border mode. + pad_cval (number, Sequence[number]): + The constant value to use if the pad mode is ``BORDER_CONSTANT``. + * If ``number``, then that value will be used. + * If a ``tuple`` of two ``number`` s and at least one of them is + a ``float``, then a random number will be uniformly sampled per + image from the continuous interval ``[a, b]`` and used as the + value. If both ``number`` s are ``int`` s, the interval is + discrete. + * If a ``list`` of ``number``, then a random value will be chosen + from the elements of the ``list`` and used as the value. + pad_cval_mask (number, Sequence[number]): Same as pad_cval but only for masks. + keep_size (bool): + After cropping and padding, the result image will usually have a + different height/width compared to the original input image. If this + parameter is set to ``True``, then the cropped/padded image will be + resized to the input image's size, i.e. the output shape is always identical to the input shape. + sample_independently (bool): + If ``False`` *and* the values for `px`/`percent` result in exactly + *one* probability distribution for all image sides, only one single + value will be sampled from that probability distribution and used for + all sides. I.e. the crop/pad amount then is the same for all sides. + If ``True``, four values will be sampled independently, one per side. + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + + Targets: + image, mask, bboxes, keypoints + + Image types: + any + """ + + def __init__( + self, + px: Optional[Union[int, Sequence[float], Sequence[Tuple]]] = None, + percent: Optional[Union[float, Sequence[float], Sequence[Tuple]]] = None, + pad_mode: int = cv2.BORDER_CONSTANT, + pad_cval: Union[float, Sequence[float]] = 0, + pad_cval_mask: Union[float, Sequence[float]] = 0, + keep_size: bool = True, + sample_independently: bool = True, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 1.0, + ): + super().__init__(always_apply, p) + + if px is None and percent is None: + raise ValueError("px and percent are empty!") + if px is not None and percent is not None: + raise ValueError("Only px or percent may be set!") + + self.px = px + self.percent = percent + + self.pad_mode = pad_mode + self.pad_cval = pad_cval + self.pad_cval_mask = pad_cval_mask + + self.keep_size = keep_size + self.sample_independently = sample_independently + + self.interpolation = interpolation + + def apply( + self, + img: np.ndarray, + crop_params: Sequence[int] = (), + pad_params: Sequence[int] = (), + pad_value: Union[int, float] = 0, + rows: int = 0, + cols: int = 0, + interpolation: int = cv2.INTER_LINEAR, + **params + ) -> np.ndarray: + return F.crop_and_pad( + img, crop_params, pad_params, pad_value, rows, cols, interpolation, self.pad_mode, self.keep_size + ) + + def apply_to_mask( + self, + img: np.ndarray, + crop_params: Optional[Sequence[int]] = None, + pad_params: Optional[Sequence[int]] = None, + pad_value_mask: Optional[float] = None, + rows: int = 0, + cols: int = 0, + interpolation: int = cv2.INTER_NEAREST, + **params + ) -> np.ndarray: + return F.crop_and_pad( + img, crop_params, pad_params, pad_value_mask, rows, cols, interpolation, self.pad_mode, self.keep_size + ) + + def apply_to_bbox( + self, + bbox: BoxInternalType, + crop_params: Optional[Sequence[int]] = None, + pad_params: Optional[Sequence[int]] = None, + rows: int = 0, + cols: int = 0, + result_rows: int = 0, + result_cols: int = 0, + **params + ) -> BoxInternalType: + return F.crop_and_pad_bbox(bbox, crop_params, pad_params, rows, cols, result_rows, result_cols) + + def apply_to_keypoint( + self, + keypoint: KeypointInternalType, + crop_params: Optional[Sequence[int]] = None, + pad_params: Optional[Sequence[int]] = None, + rows: int = 0, + cols: int = 0, + result_rows: int = 0, + result_cols: int = 0, + **params + ) -> KeypointInternalType: + return F.crop_and_pad_keypoint( + keypoint, crop_params, pad_params, rows, cols, result_rows, result_cols, self.keep_size + ) + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + @staticmethod + def __prevent_zero(val1: int, val2: int, max_val: int) -> Tuple[int, int]: + regain = abs(max_val) + 1 + regain1 = regain // 2 + regain2 = regain // 2 + if regain1 + regain2 < regain: + regain1 += 1 + + if regain1 > val1: + diff = regain1 - val1 + regain1 = val1 + regain2 += diff + elif regain2 > val2: + diff = regain2 - val2 + regain2 = val2 + regain1 += diff + + val1 = val1 - regain1 + val2 = val2 - regain2 + + return val1, val2 + + @staticmethod + def _prevent_zero(crop_params: List[int], height: int, width: int) -> Sequence[int]: + top, right, bottom, left = crop_params + + remaining_height = height - (top + bottom) + remaining_width = width - (left + right) + + if remaining_height < 1: + top, bottom = CropAndPad.__prevent_zero(top, bottom, height) + if remaining_width < 1: + left, right = CropAndPad.__prevent_zero(left, right, width) + + return [max(top, 0), max(right, 0), max(bottom, 0), max(left, 0)] + + def get_params_dependent_on_targets(self, params) -> dict: + height, width = params["image"].shape[:2] + + if self.px is not None: + params = self._get_px_params() + else: + params = self._get_percent_params() + params[0] = int(params[0] * height) + params[1] = int(params[1] * width) + params[2] = int(params[2] * height) + params[3] = int(params[3] * width) + + pad_params = [max(i, 0) for i in params] + + crop_params = self._prevent_zero([-min(i, 0) for i in params], height, width) + + top, right, bottom, left = crop_params + crop_params = [left, top, width - right, height - bottom] + result_rows = crop_params[3] - crop_params[1] + result_cols = crop_params[2] - crop_params[0] + if result_cols == width and result_rows == height: + crop_params = [] + + top, right, bottom, left = pad_params + pad_params = [top, bottom, left, right] + if any(pad_params): + result_rows += top + bottom + result_cols += left + right + else: + pad_params = [] + + return { + "crop_params": crop_params or None, + "pad_params": pad_params or None, + "pad_value": None if pad_params is None else self._get_pad_value(self.pad_cval), + "pad_value_mask": None if pad_params is None else self._get_pad_value(self.pad_cval_mask), + "result_rows": result_rows, + "result_cols": result_cols, + } + + def _get_px_params(self) -> List[int]: + if self.px is None: + raise ValueError("px is not set") + + if isinstance(self.px, int): + params = [self.px] * 4 + elif len(self.px) == 2: + if self.sample_independently: + params = [random.randrange(*self.px) for _ in range(4)] + else: + px = random.randrange(*self.px) + params = [px] * 4 + else: + params = [i if isinstance(i, int) else random.randrange(*i) for i in self.px] # type: ignore + + return params # [top, right, bottom, left] + + def _get_percent_params(self) -> List[float]: + if self.percent is None: + raise ValueError("percent is not set") + + if isinstance(self.percent, float): + params = [self.percent] * 4 + elif len(self.percent) == 2: + if self.sample_independently: + params = [random.uniform(*self.percent) for _ in range(4)] + else: + px = random.uniform(*self.percent) + params = [px] * 4 + else: + params = [i if isinstance(i, (int, float)) else random.uniform(*i) for i in self.percent] + + return params # params = [top, right, bottom, left] + + @staticmethod + def _get_pad_value(pad_value: Union[float, Sequence[float]]) -> Union[int, float]: + if isinstance(pad_value, (int, float)): + return pad_value + + if len(pad_value) == 2: + a, b = pad_value + if isinstance(a, int) and isinstance(b, int): + return random.randint(a, b) + + return random.uniform(a, b) + + return random.choice(pad_value) + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ( + "px", + "percent", + "pad_mode", + "pad_cval", + "pad_cval_mask", + "keep_size", + "sample_independently", + "interpolation", + ) + + +class RandomCropFromBorders(DualTransform): + """Crop bbox from image randomly cut parts from borders without resize at the end + + Args: + crop_left (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut + from left side in range [0, crop_left * width) + crop_right (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut + from right side in range [(1 - crop_right) * width, width) + crop_top (float): singlefloat value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut + from top side in range [0, crop_top * height) + crop_bottom (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut + from bottom side in range [(1 - crop_bottom) * height, height) + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + crop_left=0.1, + crop_right=0.1, + crop_top=0.1, + crop_bottom=0.1, + always_apply=False, + p=1.0, + ): + super(RandomCropFromBorders, self).__init__(always_apply, p) + self.crop_left = crop_left + self.crop_right = crop_right + self.crop_top = crop_top + self.crop_bottom = crop_bottom + + def get_params_dependent_on_targets(self, params): + img = params["image"] + x_min = random.randint(0, int(self.crop_left * img.shape[1])) + x_max = random.randint(max(x_min + 1, int((1 - self.crop_right) * img.shape[1])), img.shape[1]) + y_min = random.randint(0, int(self.crop_top * img.shape[0])) + y_max = random.randint(max(y_min + 1, int((1 - self.crop_bottom) * img.shape[0])), img.shape[0]) + return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max} + + def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.clamping_crop(img, x_min, y_min, x_max, y_max) + + def apply_to_mask(self, mask, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.clamping_crop(mask, x_min, y_min, x_max, y_max) + + def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params): + rows, cols = params["rows"], params["cols"] + return F.bbox_crop(bbox, x_min, y_min, x_max, y_max, rows, cols) + + def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params): + return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max)) + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return "crop_left", "crop_right", "crop_top", "crop_bottom" diff --git a/custom_albumentations/augmentations/domain_adaptation.py b/custom_albumentations/augmentations/domain_adaptation.py new file mode 100644 index 0000000000000000000000000000000000000000..f40ba31ced7280612880d4f0c0fd701efed87f84 --- /dev/null +++ b/custom_albumentations/augmentations/domain_adaptation.py @@ -0,0 +1,337 @@ +import random +from typing import Any, Callable, Literal, Sequence, Tuple + +import cv2 +import numpy as np +from custom_qudida import DomainAdapter +from skimage.exposure import match_histograms +from sklearn.decomposition import PCA +from sklearn.preprocessing import MinMaxScaler, StandardScaler + +from custom_albumentations.augmentations.utils import ( + clipped, + get_opencv_dtype_from_numpy, + is_grayscale_image, + is_multispectral_image, + preserve_shape, + read_rgb_image, +) + +from ..core.transforms_interface import ImageOnlyTransform, ScaleFloatType, to_tuple + +__all__ = [ + "HistogramMatching", + "FDA", + "PixelDistributionAdaptation", + "fourier_domain_adaptation", + "apply_histogram", + "adapt_pixel_distribution", +] + + +@clipped +@preserve_shape +def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray: + """ + Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA + + Args: + img: source image + target_img: target image for domain adaptation + beta: coefficient from source paper + + Returns: + transformed image + + """ + + img = np.squeeze(img) + target_img = np.squeeze(target_img) + + if target_img.shape != img.shape: + raise ValueError( + "The source and target images must have the same shape," + " but got {} and {} respectively.".format(img.shape, target_img.shape) + ) + + # get fft of both source and target + fft_src = np.fft.fft2(img.astype(np.float32), axes=(0, 1)) + fft_trg = np.fft.fft2(target_img.astype(np.float32), axes=(0, 1)) + + # extract amplitude and phase of both fft-s + amplitude_src, phase_src = np.abs(fft_src), np.angle(fft_src) + amplitude_trg = np.abs(fft_trg) + + # mutate the amplitude part of source with target + amplitude_src = np.fft.fftshift(amplitude_src, axes=(0, 1)) + amplitude_trg = np.fft.fftshift(amplitude_trg, axes=(0, 1)) + height, width = amplitude_src.shape[:2] + border = np.floor(min(height, width) * beta).astype(int) + center_y, center_x = np.floor([height / 2.0, width / 2.0]).astype(int) + + y1, y2 = center_y - border, center_y + border + 1 + x1, x2 = center_x - border, center_x + border + 1 + + amplitude_src[y1:y2, x1:x2] = amplitude_trg[y1:y2, x1:x2] + amplitude_src = np.fft.ifftshift(amplitude_src, axes=(0, 1)) + + # get mutated image + src_image_transformed = np.fft.ifft2(amplitude_src * np.exp(1j * phase_src), axes=(0, 1)) + src_image_transformed = np.real(src_image_transformed) + + return src_image_transformed + + +@preserve_shape +def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray: + if img.dtype != reference_image.dtype: + raise RuntimeError( + f"Dtype of image and reference image must be the same. Got {img.dtype} and {reference_image.dtype}" + ) + if img.shape[:2] != reference_image.shape[:2]: + reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0])) + + img, reference_image = np.squeeze(img), np.squeeze(reference_image) + + try: + matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == 3 else None) + except TypeError: + matched = match_histograms(img, reference_image, multichannel=True) # case for scikit-image<0.19.1 + img = cv2.addWeighted( + matched, + blend_ratio, + img, + 1 - blend_ratio, + 0, + dtype=get_opencv_dtype_from_numpy(img.dtype), + ) + return img + + +@preserve_shape +def adapt_pixel_distribution( + img: np.ndarray, ref: np.ndarray, transform_type: str = "pca", weight: float = 0.5 +) -> np.ndarray: + initial_type = img.dtype + transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]() + adapter = DomainAdapter(transformer=transformer, ref_img=ref) + result = adapter(img).astype("float32") + blended = (img.astype("float32") * (1 - weight) + result * weight).astype(initial_type) + return blended + + +class HistogramMatching(ImageOnlyTransform): + """ + Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches + the histogram of the reference image. If the images have multiple channels, the matching is done independently + for each channel, as long as the number of channels is equal in the input image and the reference. + + Histogram matching can be used as a lightweight normalisation for image processing, + such as feature matching, especially in circumstances where the images have been taken from different + sources or in different conditions (i.e. lighting). + + See: + https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_histogram_matching.html + + Args: + reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, + it expects a sequence of paths to images. + blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original + with random blend factor for increased diversity of generated images. + read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` + and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image + + Image types: + uint8, uint16, float32 + """ + + def __init__( + self, + reference_images: Sequence[Any], + blend_ratio: Tuple[float, float] = (0.5, 1.0), + read_fn: Callable[[Any], np.ndarray] = read_rgb_image, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + self.reference_images = reference_images + self.read_fn = read_fn + self.blend_ratio = blend_ratio + + def apply(self, img, reference_image=None, blend_ratio=0.5, **params): + return apply_histogram(img, reference_image, blend_ratio) + + def get_params(self): + return { + "reference_image": self.read_fn(random.choice(self.reference_images)), + "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]), + } + + def get_transform_init_args_names(self): + return ("reference_images", "blend_ratio", "read_fn") + + def _to_dict(self): + raise NotImplementedError("HistogramMatching can not be serialized.") + + +class FDA(ImageOnlyTransform): + """ + Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA + Simple "style transfer". + + Args: + reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, + it expects a sequence of paths to images. + beta_limit (float or tuple of float): coefficient beta from paper. Recommended less 0.3. + read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` + and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. + + Targets: + image + + Image types: + uint8, float32 + + Reference: + https://github.com/YanchaoYang/FDA + https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf + + Example: + >>> import numpy as np + >>> import custom_albumentations as albumentations as A + >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8) + >>> target_image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8) + >>> aug = A.Compose([A.FDA([target_image], p=1, read_fn=lambda x: x)]) + >>> result = aug(image=image) + + """ + + def __init__( + self, + reference_images: Sequence[Any], + beta_limit: ScaleFloatType = 0.1, + read_fn: Callable[[Any], np.ndarray] = read_rgb_image, + always_apply: bool = False, + p: float = 0.5, + ): + super(FDA, self).__init__(always_apply=always_apply, p=p) + self.reference_images = reference_images + self.read_fn = read_fn + self.beta_limit = to_tuple(beta_limit, low=0) + + def apply(self, img, target_image=None, beta=0.1, **params): + return fourier_domain_adaptation(img=img, target_img=target_image, beta=beta) + + def get_params_dependent_on_targets(self, params): + img = params["image"] + target_img = self.read_fn(random.choice(self.reference_images)) + target_img = cv2.resize(target_img, dsize=(img.shape[1], img.shape[0])) + + return {"target_image": target_img} + + def get_params(self): + return {"beta": random.uniform(self.beta_limit[0], self.beta_limit[1])} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return ("reference_images", "beta_limit", "read_fn") + + def _to_dict(self): + raise NotImplementedError("FDA can not be serialized.") + + +class PixelDistributionAdaptation(ImageOnlyTransform): + """ + Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler + or MinMaxScaler) on both original and reference image, transforms original image with transform trained on this + image and then performs inverse transformation using transform fitted on reference image. + + Args: + reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, + it expects a sequence of paths to images. + blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original + with random blend factor for increased diversity of generated images. + read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` + and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. + transform_type (str): type of transform; "pca", "standard", "minmax" are allowed. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image + + Image types: + uint8, float32 + + See also: https://github.com/arsenyinfo/qudida + """ + + def __init__( + self, + reference_images: Sequence[Any], + blend_ratio: Tuple[float, float] = (0.25, 1.0), + read_fn: Callable[[Any], np.ndarray] = read_rgb_image, + transform_type: Literal["pca", "standard", "minmax"] = "pca", + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + self.reference_images = reference_images + self.read_fn = read_fn + self.blend_ratio = blend_ratio + expected_transformers = ("pca", "standard", "minmax") + if transform_type not in expected_transformers: + raise ValueError(f"Got unexpected transform_type {transform_type}. Expected one of {expected_transformers}") + self.transform_type = transform_type + + @staticmethod + def _validate_shape(img: np.ndarray): + if is_grayscale_image(img) or is_multispectral_image(img): + raise ValueError( + f"Unexpected image shape: expected 3 dimensions, got {len(img.shape)}." + f"Is it a grayscale or multispectral image? It's not supported for now." + ) + + def ensure_uint8(self, img: np.ndarray) -> Tuple[np.ndarray, bool]: + if img.dtype == np.float32: + if img.min() < 0 or img.max() > 1: + message = ( + "PixelDistributionAdaptation uses uint8 under the hood, so float32 should be converted," + "Can not do it automatically when the image is out of [0..1] range." + ) + raise TypeError(message) + return (img * 255).astype("uint8"), True + return img, False + + def apply(self, img, reference_image, blend_ratio, **params): + self._validate_shape(img) + reference_image, _ = self.ensure_uint8(reference_image) + img, needs_reconvert = self.ensure_uint8(img) + + adapted = adapt_pixel_distribution( + img=img, + ref=reference_image, + weight=blend_ratio, + transform_type=self.transform_type, + ) + if needs_reconvert: + adapted = adapted.astype("float32") * (1 / 255) + return adapted + + def get_params(self): + return { + "reference_image": self.read_fn(random.choice(self.reference_images)), + "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]), + } + + def get_transform_init_args_names(self): + return ("reference_images", "blend_ratio", "read_fn", "transform_type") + + def _to_dict(self): + raise NotImplementedError("PixelDistributionAdaptation can not be serialized.") diff --git a/custom_albumentations/augmentations/dropout/__init__.py b/custom_albumentations/augmentations/dropout/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aa29f452ab1729e9f87d8bf5174f877844a7b25e --- /dev/null +++ b/custom_albumentations/augmentations/dropout/__init__.py @@ -0,0 +1,5 @@ +from .channel_dropout import * +from .coarse_dropout import * +from .cutout import * +from .grid_dropout import * +from .mask_dropout import * diff --git a/custom_albumentations/augmentations/dropout/channel_dropout.py b/custom_albumentations/augmentations/dropout/channel_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..b3cdbc2bf77dbb12e26060f287ff5ebcb0cea225 --- /dev/null +++ b/custom_albumentations/augmentations/dropout/channel_dropout.py @@ -0,0 +1,72 @@ +import random +from typing import Any, Mapping, Tuple, Union + +import numpy as np + +from custom_albumentations.core.transforms_interface import ImageOnlyTransform + +from .functional import channel_dropout + +__all__ = ["ChannelDropout"] + + +class ChannelDropout(ImageOnlyTransform): + """Randomly Drop Channels in the input Image. + + Args: + channel_drop_range (int, int): range from which we choose the number of channels to drop. + fill_value (int, float): pixel value for the dropped channel. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, uint16, unit32, float32 + """ + + def __init__( + self, + channel_drop_range: Tuple[int, int] = (1, 1), + fill_value: Union[int, float] = 0, + always_apply: bool = False, + p: float = 0.5, + ): + super(ChannelDropout, self).__init__(always_apply, p) + + self.channel_drop_range = channel_drop_range + + self.min_channels = channel_drop_range[0] + self.max_channels = channel_drop_range[1] + + if not 1 <= self.min_channels <= self.max_channels: + raise ValueError("Invalid channel_drop_range. Got: {}".format(channel_drop_range)) + + self.fill_value = fill_value + + def apply(self, img: np.ndarray, channels_to_drop: Tuple[int, ...] = (0,), **params) -> np.ndarray: + return channel_dropout(img, channels_to_drop, self.fill_value) + + def get_params_dependent_on_targets(self, params: Mapping[str, Any]): + img = params["image"] + + num_channels = img.shape[-1] + + if len(img.shape) == 2 or num_channels == 1: + raise NotImplementedError("Images has one channel. ChannelDropout is not defined.") + + if self.max_channels >= num_channels: + raise ValueError("Can not drop all channels in ChannelDropout.") + + num_drop_channels = random.randint(self.min_channels, self.max_channels) + + channels_to_drop = random.sample(range(num_channels), k=num_drop_channels) + + return {"channels_to_drop": channels_to_drop} + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return "channel_drop_range", "fill_value" + + @property + def targets_as_params(self): + return ["image"] diff --git a/custom_albumentations/augmentations/dropout/coarse_dropout.py b/custom_albumentations/augmentations/dropout/coarse_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..64017d70c93cdf3b02307c7d27a8bb0ccc4e78c8 --- /dev/null +++ b/custom_albumentations/augmentations/dropout/coarse_dropout.py @@ -0,0 +1,187 @@ +import random +from typing import Iterable, List, Optional, Sequence, Tuple, Union + +import numpy as np + +from ...core.transforms_interface import DualTransform, KeypointType +from .functional import cutout + +__all__ = ["CoarseDropout"] + + +class CoarseDropout(DualTransform): + """CoarseDropout of the rectangular regions in the image. + + Args: + max_holes (int): Maximum number of regions to zero out. + max_height (int, float): Maximum height of the hole. + If float, it is calculated as a fraction of the image height. + max_width (int, float): Maximum width of the hole. + If float, it is calculated as a fraction of the image width. + min_holes (int): Minimum number of regions to zero out. If `None`, + `min_holes` is be set to `max_holes`. Default: `None`. + min_height (int, float): Minimum height of the hole. Default: None. If `None`, + `min_height` is set to `max_height`. Default: `None`. + If float, it is calculated as a fraction of the image height. + min_width (int, float): Minimum width of the hole. If `None`, `min_height` is + set to `max_width`. Default: `None`. + If float, it is calculated as a fraction of the image width. + + fill_value (int, float, list of int, list of float): value for dropped pixels. + mask_fill_value (int, float, list of int, list of float): fill value for dropped pixels + in mask. If `None` - mask is not affected. Default: `None`. + + Targets: + image, mask, keypoints + + Image types: + uint8, float32 + + Reference: + | https://arxiv.org/abs/1708.04552 + | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py + | https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py + """ + + def __init__( + self, + max_holes: int = 8, + max_height: int = 8, + max_width: int = 8, + min_holes: Optional[int] = None, + min_height: Optional[int] = None, + min_width: Optional[int] = None, + fill_value: int = 0, + mask_fill_value: Optional[int] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(CoarseDropout, self).__init__(always_apply, p) + self.max_holes = max_holes + self.max_height = max_height + self.max_width = max_width + self.min_holes = min_holes if min_holes is not None else max_holes + self.min_height = min_height if min_height is not None else max_height + self.min_width = min_width if min_width is not None else max_width + self.fill_value = fill_value + self.mask_fill_value = mask_fill_value + if not 0 < self.min_holes <= self.max_holes: + raise ValueError("Invalid combination of min_holes and max_holes. Got: {}".format([min_holes, max_holes])) + + self.check_range(self.max_height) + self.check_range(self.min_height) + self.check_range(self.max_width) + self.check_range(self.min_width) + + if not 0 < self.min_height <= self.max_height: + raise ValueError( + "Invalid combination of min_height and max_height. Got: {}".format([min_height, max_height]) + ) + if not 0 < self.min_width <= self.max_width: + raise ValueError("Invalid combination of min_width and max_width. Got: {}".format([min_width, max_width])) + + def check_range(self, dimension): + if isinstance(dimension, float) and not 0 <= dimension < 1.0: + raise ValueError( + "Invalid value {}. If using floats, the value should be in the range [0.0, 1.0)".format(dimension) + ) + + def apply( + self, + img: np.ndarray, + fill_value: Union[int, float] = 0, + holes: Iterable[Tuple[int, int, int, int]] = (), + **params + ) -> np.ndarray: + return cutout(img, holes, fill_value) + + def apply_to_mask( + self, + img: np.ndarray, + mask_fill_value: Union[int, float] = 0, + holes: Iterable[Tuple[int, int, int, int]] = (), + **params + ) -> np.ndarray: + if mask_fill_value is None: + return img + return cutout(img, holes, mask_fill_value) + + def get_params_dependent_on_targets(self, params): + img = params["image"] + height, width = img.shape[:2] + + holes = [] + for _n in range(random.randint(self.min_holes, self.max_holes)): + if all( + [ + isinstance(self.min_height, int), + isinstance(self.min_width, int), + isinstance(self.max_height, int), + isinstance(self.max_width, int), + ] + ): + hole_height = random.randint(self.min_height, self.max_height) + hole_width = random.randint(self.min_width, self.max_width) + elif all( + [ + isinstance(self.min_height, float), + isinstance(self.min_width, float), + isinstance(self.max_height, float), + isinstance(self.max_width, float), + ] + ): + hole_height = int(height * random.uniform(self.min_height, self.max_height)) + hole_width = int(width * random.uniform(self.min_width, self.max_width)) + else: + raise ValueError( + "Min width, max width, \ + min height and max height \ + should all either be ints or floats. \ + Got: {} respectively".format( + [ + type(self.min_width), + type(self.max_width), + type(self.min_height), + type(self.max_height), + ] + ) + ) + + y1 = random.randint(0, height - hole_height) + x1 = random.randint(0, width - hole_width) + y2 = y1 + hole_height + x2 = x1 + hole_width + holes.append((x1, y1, x2, y2)) + + return {"holes": holes} + + @property + def targets_as_params(self): + return ["image"] + + def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool: + x1, y1, x2, y2 = hole + x, y = keypoint[:2] + return x1 <= x < x2 and y1 <= y < y2 + + def apply_to_keypoints( + self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params + ) -> List[KeypointType]: + result = set(keypoints) + for hole in holes: + for kp in keypoints: + if self._keypoint_in_hole(kp, hole): + result.discard(kp) + return list(result) + + def get_transform_init_args_names(self): + return ( + "max_holes", + "max_height", + "max_width", + "min_holes", + "min_height", + "min_width", + "fill_value", + "mask_fill_value", + ) diff --git a/custom_albumentations/augmentations/dropout/cutout.py b/custom_albumentations/augmentations/dropout/cutout.py new file mode 100644 index 0000000000000000000000000000000000000000..1dcba8b400bfa4655d4c0adbd6d5570dfaf39301 --- /dev/null +++ b/custom_albumentations/augmentations/dropout/cutout.py @@ -0,0 +1,79 @@ +import random +import warnings +from typing import Any, Dict, Tuple, Union + +import numpy as np + +from custom_albumentations.core.transforms_interface import ImageOnlyTransform + +from .functional import cutout + +__all__ = ["Cutout"] + + +class Cutout(ImageOnlyTransform): + """CoarseDropout of the square regions in the image. + + Args: + num_holes (int): number of regions to zero out + max_h_size (int): maximum height of the hole + max_w_size (int): maximum width of the hole + fill_value (int, float, list of int, list of float): value for dropped pixels. + + Targets: + image + + Image types: + uint8, float32 + + Reference: + | https://arxiv.org/abs/1708.04552 + | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py + | https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py + """ + + def __init__( + self, + num_holes: int = 8, + max_h_size: int = 8, + max_w_size: int = 8, + fill_value: Union[int, float] = 0, + always_apply: bool = False, + p: float = 0.5, + ): + super(Cutout, self).__init__(always_apply, p) + self.num_holes = num_holes + self.max_h_size = max_h_size + self.max_w_size = max_w_size + self.fill_value = fill_value + warnings.warn( + f"{self.__class__.__name__} has been deprecated. Please use CoarseDropout", + FutureWarning, + ) + + def apply(self, img: np.ndarray, fill_value: Union[int, float] = 0, holes=(), **params): + return cutout(img, holes, fill_value) + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + img = params["image"] + height, width = img.shape[:2] + + holes = [] + for _n in range(self.num_holes): + y = random.randint(0, height) + x = random.randint(0, width) + + y1 = np.clip(y - self.max_h_size // 2, 0, height) + y2 = np.clip(y1 + self.max_h_size, 0, height) + x1 = np.clip(x - self.max_w_size // 2, 0, width) + x2 = np.clip(x1 + self.max_w_size, 0, width) + holes.append((x1, y1, x2, y2)) + + return {"holes": holes} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ("num_holes", "max_h_size", "max_w_size") diff --git a/custom_albumentations/augmentations/dropout/functional.py b/custom_albumentations/augmentations/dropout/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..f9bddcae0ddac9e75f4355bc982ee0fd9bb4b5ca --- /dev/null +++ b/custom_albumentations/augmentations/dropout/functional.py @@ -0,0 +1,29 @@ +from typing import Iterable, List, Tuple, Union + +import numpy as np + +from custom_albumentations.augmentations.utils import preserve_shape + +__all__ = ["cutout", "channel_dropout"] + + +@preserve_shape +def channel_dropout( + img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: Union[int, float] = 0 +) -> np.ndarray: + if len(img.shape) == 2 or img.shape[2] == 1: + raise NotImplementedError("Only one channel. ChannelDropout is not defined.") + + img = img.copy() + img[..., channels_to_drop] = fill_value + return img + + +def cutout( + img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: Union[int, float] = 0 +) -> np.ndarray: + # Make a copy of the input image since we don't want to modify it directly + img = img.copy() + for x1, y1, x2, y2 in holes: + img[y1:y2, x1:x2] = fill_value + return img diff --git a/custom_albumentations/augmentations/dropout/grid_dropout.py b/custom_albumentations/augmentations/dropout/grid_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..094b989aa2e0298df3036de01bf6b96f784f620a --- /dev/null +++ b/custom_albumentations/augmentations/dropout/grid_dropout.py @@ -0,0 +1,155 @@ +import random +from typing import Iterable, Optional, Tuple + +import numpy as np + +from ...core.transforms_interface import DualTransform +from . import functional as F + +__all__ = ["GridDropout"] + + +class GridDropout(DualTransform): + """GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion. + + Args: + ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions). + Must be between 0 and 1. Default: 0.5. + unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge. + If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`. + unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge. + If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`. + holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2. + If 'None', grid unit width is set as image_width//10. Default: `None`. + holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2. + If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller. + shift_x (int): offsets of the grid start in x direction from (0,0) coordinate. + Clipped between 0 and grid unit_width - hole_width. Default: 0. + shift_y (int): offsets of the grid start in y direction from (0,0) coordinate. + Clipped between 0 and grid unit height - hole_height. Default: 0. + random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size + If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`. + fill_value (int): value for the dropped pixels. Default = 0 + mask_fill_value (int): value for the dropped pixels in mask. + If `None`, transformation is not applied to the mask. Default: `None`. + + Targets: + image, mask + + Image types: + uint8, float32 + + References: + https://arxiv.org/abs/2001.04086 + + """ + + def __init__( + self, + ratio: float = 0.5, + unit_size_min: Optional[int] = None, + unit_size_max: Optional[int] = None, + holes_number_x: Optional[int] = None, + holes_number_y: Optional[int] = None, + shift_x: int = 0, + shift_y: int = 0, + random_offset: bool = False, + fill_value: int = 0, + mask_fill_value: Optional[int] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(GridDropout, self).__init__(always_apply, p) + self.ratio = ratio + self.unit_size_min = unit_size_min + self.unit_size_max = unit_size_max + self.holes_number_x = holes_number_x + self.holes_number_y = holes_number_y + self.shift_x = shift_x + self.shift_y = shift_y + self.random_offset = random_offset + self.fill_value = fill_value + self.mask_fill_value = mask_fill_value + if not 0 < self.ratio <= 1: + raise ValueError("ratio must be between 0 and 1.") + + def apply(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray: + return F.cutout(img, holes, self.fill_value) + + def apply_to_mask(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray: + if self.mask_fill_value is None: + return img + + return F.cutout(img, holes, self.mask_fill_value) + + def get_params_dependent_on_targets(self, params): + img = params["image"] + height, width = img.shape[:2] + # set grid using unit size limits + if self.unit_size_min and self.unit_size_max: + if not 2 <= self.unit_size_min <= self.unit_size_max: + raise ValueError("Max unit size should be >= min size, both at least 2 pixels.") + if self.unit_size_max > min(height, width): + raise ValueError("Grid size limits must be within the shortest image edge.") + unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1) + unit_height = unit_width + else: + # set grid using holes numbers + if self.holes_number_x is None: + unit_width = max(2, width // 10) + else: + if not 1 <= self.holes_number_x <= width // 2: + raise ValueError("The hole_number_x must be between 1 and image width//2.") + unit_width = width // self.holes_number_x + if self.holes_number_y is None: + unit_height = max(min(unit_width, height), 2) + else: + if not 1 <= self.holes_number_y <= height // 2: + raise ValueError("The hole_number_y must be between 1 and image height//2.") + unit_height = height // self.holes_number_y + + hole_width = int(unit_width * self.ratio) + hole_height = int(unit_height * self.ratio) + # min 1 pixel and max unit length - 1 + hole_width = min(max(hole_width, 1), unit_width - 1) + hole_height = min(max(hole_height, 1), unit_height - 1) + # set offset of the grid + if self.shift_x is None: + shift_x = 0 + else: + shift_x = min(max(0, self.shift_x), unit_width - hole_width) + if self.shift_y is None: + shift_y = 0 + else: + shift_y = min(max(0, self.shift_y), unit_height - hole_height) + if self.random_offset: + shift_x = random.randint(0, unit_width - hole_width) + shift_y = random.randint(0, unit_height - hole_height) + holes = [] + for i in range(width // unit_width + 1): + for j in range(height // unit_height + 1): + x1 = min(shift_x + unit_width * i, width) + y1 = min(shift_y + unit_height * j, height) + x2 = min(x1 + hole_width, width) + y2 = min(y1 + hole_height, height) + holes.append((x1, y1, x2, y2)) + + return {"holes": holes} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return ( + "ratio", + "unit_size_min", + "unit_size_max", + "holes_number_x", + "holes_number_y", + "shift_x", + "shift_y", + "random_offset", + "fill_value", + "mask_fill_value", + ) diff --git a/custom_albumentations/augmentations/dropout/mask_dropout.py b/custom_albumentations/augmentations/dropout/mask_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..fe67a65401a2be4cf3943e0076675c6536051bb6 --- /dev/null +++ b/custom_albumentations/augmentations/dropout/mask_dropout.py @@ -0,0 +1,99 @@ +import random +from typing import Any, Dict, Optional, Tuple, Union + +import cv2 +import numpy as np +from skimage.measure import label + +from ...core.transforms_interface import DualTransform, to_tuple + +__all__ = ["MaskDropout"] + + +class MaskDropout(DualTransform): + """ + Image & mask augmentation that zero out mask and image regions corresponding + to randomly chosen object instance from mask. + + Mask must be single-channel image, zero values treated as background. + Image can be any number of channels. + + Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254 + + Args: + max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max] + image_fill_value: Fill value to use when filling image. + Can be 'inpaint' to apply inpaining (works only for 3-chahnel images) + mask_fill_value: Fill value to use when filling mask. + + Targets: + image, mask + + Image types: + uint8, float32 + """ + + def __init__( + self, + max_objects: int = 1, + image_fill_value: Union[int, float, str] = 0, + mask_fill_value: Union[int, float] = 0, + always_apply: bool = False, + p: float = 0.5, + ): + super(MaskDropout, self).__init__(always_apply, p) + self.max_objects = to_tuple(max_objects, 1) + self.image_fill_value = image_fill_value + self.mask_fill_value = mask_fill_value + + @property + def targets_as_params(self): + return ["mask"] + + def get_params_dependent_on_targets(self, params) -> Dict[str, Any]: + mask = params["mask"] + + label_image, num_labels = label(mask, return_num=True) + + if num_labels == 0: + dropout_mask = None + else: + objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1])) + objects_to_drop = min(num_labels, objects_to_drop) + + if objects_to_drop == num_labels: + dropout_mask = mask > 0 + else: + labels_index = random.sample(range(1, num_labels + 1), objects_to_drop) + dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool) + for label_index in labels_index: + dropout_mask |= label_image == label_index + + params.update({"dropout_mask": dropout_mask}) + return params + + def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray: + if dropout_mask is None: + return img + + if self.image_fill_value == "inpaint": + dropout_mask = dropout_mask.astype(np.uint8) + _, _, w, h = cv2.boundingRect(dropout_mask) + radius = min(3, max(w, h) // 2) + img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS) + else: + img = img.copy() + img[dropout_mask] = self.image_fill_value + + return img + + def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray: + if dropout_mask is None: + return img + + img = img.copy() + img[dropout_mask] = self.mask_fill_value + return img + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return "max_objects", "image_fill_value", "mask_fill_value" diff --git a/custom_albumentations/augmentations/functional.py b/custom_albumentations/augmentations/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..9852dc749559bb4a0092270dbed72bf44b543ee4 --- /dev/null +++ b/custom_albumentations/augmentations/functional.py @@ -0,0 +1,1380 @@ +from __future__ import division + +from typing import Optional, Sequence, Union +from warnings import warn + +import cv2 +import numpy as np +import skimage + +from custom_albumentations import random_utils +from custom_albumentations.augmentations.utils import ( + MAX_VALUES_BY_DTYPE, + _maybe_process_in_chunks, + clip, + clipped, + ensure_contiguous, + is_grayscale_image, + is_rgb_image, + non_rgb_warning, + preserve_channel_dim, + preserve_shape, +) + +__all__ = [ + "add_fog", + "add_rain", + "add_shadow", + "add_gravel", + "add_snow", + "add_sun_flare", + "add_weighted", + "adjust_brightness_torchvision", + "adjust_contrast_torchvision", + "adjust_hue_torchvision", + "adjust_saturation_torchvision", + "brightness_contrast_adjust", + "channel_shuffle", + "clahe", + "convolve", + "downscale", + "equalize", + "fancy_pca", + "from_float", + "gamma_transform", + "gauss_noise", + "image_compression", + "invert", + "iso_noise", + "linear_transformation_rgb", + "move_tone_curve", + "multiply", + "noop", + "normalize", + "posterize", + "shift_hsv", + "shift_rgb", + "solarize", + "superpixels", + "swap_tiles_on_image", + "to_float", + "to_gray", + "gray_to_rgb", + "unsharp_mask", +] + + +def normalize_cv2(img, mean, denominator): + if mean.shape and len(mean) != 4 and mean.shape != img.shape: + mean = np.array(mean.tolist() + [0] * (4 - len(mean)), dtype=np.float64) + if not denominator.shape: + denominator = np.array([denominator.tolist()] * 4, dtype=np.float64) + elif len(denominator) != 4 and denominator.shape != img.shape: + denominator = np.array(denominator.tolist() + [1] * (4 - len(denominator)), dtype=np.float64) + + img = np.ascontiguousarray(img.astype("float32")) + cv2.subtract(img, mean.astype(np.float64), img) + cv2.multiply(img, denominator.astype(np.float64), img) + return img + + +def normalize_numpy(img, mean, denominator): + img = img.astype(np.float32) + img -= mean + img *= denominator + return img + + +def normalize(img, mean, std, max_pixel_value=255.0): + mean = np.array(mean, dtype=np.float32) + mean *= max_pixel_value + + std = np.array(std, dtype=np.float32) + std *= max_pixel_value + + denominator = np.reciprocal(std, dtype=np.float32) + + if img.ndim == 3 and img.shape[-1] == 3: + return normalize_cv2(img, mean, denominator) + return normalize_numpy(img, mean, denominator) + + +def _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift): + dtype = img.dtype + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + hue, sat, val = cv2.split(img) + + if hue_shift != 0: + lut_hue = np.arange(0, 256, dtype=np.int16) + lut_hue = np.mod(lut_hue + hue_shift, 180).astype(dtype) + hue = cv2.LUT(hue, lut_hue) + + if sat_shift != 0: + lut_sat = np.arange(0, 256, dtype=np.int16) + lut_sat = np.clip(lut_sat + sat_shift, 0, 255).astype(dtype) + sat = cv2.LUT(sat, lut_sat) + + if val_shift != 0: + lut_val = np.arange(0, 256, dtype=np.int16) + lut_val = np.clip(lut_val + val_shift, 0, 255).astype(dtype) + val = cv2.LUT(val, lut_val) + + img = cv2.merge((hue, sat, val)).astype(dtype) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img + + +def _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift): + dtype = img.dtype + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + hue, sat, val = cv2.split(img) + + if hue_shift != 0: + hue = cv2.add(hue, hue_shift) + hue = np.mod(hue, 360) # OpenCV fails with negative values + + if sat_shift != 0: + sat = clip(cv2.add(sat, sat_shift), dtype, 1.0) + + if val_shift != 0: + val = clip(cv2.add(val, val_shift), dtype, 1.0) + + img = cv2.merge((hue, sat, val)) + img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + return img + + +@preserve_shape +def shift_hsv(img, hue_shift, sat_shift, val_shift): + if hue_shift == 0 and sat_shift == 0 and val_shift == 0: + return img + + is_gray = is_grayscale_image(img) + if is_gray: + if hue_shift != 0 or sat_shift != 0: + hue_shift = 0 + sat_shift = 0 + warn( + "HueSaturationValue: hue_shift and sat_shift are not applicable to grayscale image. " + "Set them to 0 or use RGB image" + ) + img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + + if img.dtype == np.uint8: + img = _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift) + else: + img = _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift) + + if is_gray: + img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + return img + + +def solarize(img, threshold=128): + """Invert all pixel values above a threshold. + + Args: + img (numpy.ndarray): The image to solarize. + threshold (int): All pixels above this greyscale level are inverted. + + Returns: + numpy.ndarray: Solarized image. + + """ + dtype = img.dtype + max_val = MAX_VALUES_BY_DTYPE[dtype] + + if dtype == np.dtype("uint8"): + lut = [(i if i < threshold else max_val - i) for i in range(max_val + 1)] + + prev_shape = img.shape + img = cv2.LUT(img, np.array(lut, dtype=dtype)) + + if len(prev_shape) != len(img.shape): + img = np.expand_dims(img, -1) + return img + + result_img = img.copy() + cond = img >= threshold + result_img[cond] = max_val - result_img[cond] + return result_img + + +@preserve_shape +def posterize(img, bits): + """Reduce the number of bits for each color channel. + + Args: + img (numpy.ndarray): image to posterize. + bits (int): number of high bits. Must be in range [0, 8] + + Returns: + numpy.ndarray: Image with reduced color channels. + + """ + bits = np.uint8(bits) + + if img.dtype != np.uint8: + raise TypeError("Image must have uint8 channel type") + if np.any((bits < 0) | (bits > 8)): + raise ValueError("bits must be in range [0, 8]") + + if not bits.shape or len(bits) == 1: + if bits == 0: + return np.zeros_like(img) + if bits == 8: + return img.copy() + + lut = np.arange(0, 256, dtype=np.uint8) + mask = ~np.uint8(2 ** (8 - bits) - 1) + lut &= mask + + return cv2.LUT(img, lut) + + if not is_rgb_image(img): + raise TypeError("If bits is iterable image must be RGB") + + result_img = np.empty_like(img) + for i, channel_bits in enumerate(bits): + if channel_bits == 0: + result_img[..., i] = np.zeros_like(img[..., i]) + elif channel_bits == 8: + result_img[..., i] = img[..., i].copy() + else: + lut = np.arange(0, 256, dtype=np.uint8) + mask = ~np.uint8(2 ** (8 - channel_bits) - 1) + lut &= mask + + result_img[..., i] = cv2.LUT(img[..., i], lut) + + return result_img + + +def _equalize_pil(img, mask=None): + histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel() + h = [_f for _f in histogram if _f] + + if len(h) <= 1: + return img.copy() + + step = np.sum(h[:-1]) // 255 + if not step: + return img.copy() + + lut = np.empty(256, dtype=np.uint8) + n = step // 2 + for i in range(256): + lut[i] = min(n // step, 255) + n += histogram[i] + + return cv2.LUT(img, np.array(lut)) + + +def _equalize_cv(img, mask=None): + if mask is None: + return cv2.equalizeHist(img) + + histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel() + i = 0 + for val in histogram: + if val > 0: + break + i += 1 + i = min(i, 255) + + total = np.sum(histogram) + if histogram[i] == total: + return np.full_like(img, i) + + scale = 255.0 / (total - histogram[i]) + _sum = 0 + + lut = np.zeros(256, dtype=np.uint8) + i += 1 + for i in range(i, len(histogram)): + _sum += histogram[i] + lut[i] = clip(round(_sum * scale), np.dtype("uint8"), 255) + + return cv2.LUT(img, lut) + + +@preserve_channel_dim +def equalize(img, mask=None, mode="cv", by_channels=True): + """Equalize the image histogram. + + Args: + img (numpy.ndarray): RGB or grayscale image. + mask (numpy.ndarray): An optional mask. If given, only the pixels selected by + the mask are included in the analysis. Maybe 1 channel or 3 channel array. + mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method. + by_channels (bool): If True, use equalization by channels separately, + else convert image to YCbCr representation and use equalization by `Y` channel. + + Returns: + numpy.ndarray: Equalized image. + + """ + if img.dtype != np.uint8: + raise TypeError("Image must have uint8 channel type") + + modes = ["cv", "pil"] + + if mode not in modes: + raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode)) + if mask is not None: + if is_rgb_image(mask) and is_grayscale_image(img): + raise ValueError("Wrong mask shape. Image shape: {}. " "Mask shape: {}".format(img.shape, mask.shape)) + if not by_channels and not is_grayscale_image(mask): + raise ValueError( + "When by_channels=False only 1-channel mask supports. " "Mask shape: {}".format(mask.shape) + ) + + if mode == "pil": + function = _equalize_pil + else: + function = _equalize_cv + + if mask is not None: + mask = mask.astype(np.uint8) + + if is_grayscale_image(img): + return function(img, mask) + + if not by_channels: + result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb) + result_img[..., 0] = function(result_img[..., 0], mask) + return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB) + + result_img = np.empty_like(img) + for i in range(3): + if mask is None: + _mask = None + elif is_grayscale_image(mask): + _mask = mask + else: + _mask = mask[..., i] + + result_img[..., i] = function(img[..., i], _mask) + + return result_img + + +@preserve_shape +def move_tone_curve(img, low_y, high_y): + """Rescales the relationship between bright and dark areas of the image by manipulating its tone curve. + + Args: + img (numpy.ndarray): RGB or grayscale image. + low_y (float): y-position of a Bezier control point used + to adjust the tone curve, must be in range [0, 1] + high_y (float): y-position of a Bezier control point used + to adjust image tone curve, must be in range [0, 1] + """ + input_dtype = img.dtype + + if low_y < 0 or low_y > 1: + raise ValueError("low_shift must be in range [0, 1]") + if high_y < 0 or high_y > 1: + raise ValueError("high_shift must be in range [0, 1]") + + if input_dtype != np.uint8: + raise ValueError("Unsupported image type {}".format(input_dtype)) + + t = np.linspace(0.0, 1.0, 256) + + # Defines responze of a four-point bezier curve + def evaluate_bez(t): + return 3 * (1 - t) ** 2 * t * low_y + 3 * (1 - t) * t**2 * high_y + t**3 + + evaluate_bez = np.vectorize(evaluate_bez) + remapping = np.rint(evaluate_bez(t) * 255).astype(np.uint8) + + lut_fn = _maybe_process_in_chunks(cv2.LUT, lut=remapping) + img = lut_fn(img) + return img + + +@clipped +def _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift): + if r_shift == g_shift == b_shift: + return img + r_shift + + result_img = np.empty_like(img) + shifts = [r_shift, g_shift, b_shift] + for i, shift in enumerate(shifts): + result_img[..., i] = img[..., i] + shift + + return result_img + + +def _shift_image_uint8(img, value): + max_value = MAX_VALUES_BY_DTYPE[img.dtype] + + lut = np.arange(0, max_value + 1).astype("float32") + lut += value + + lut = np.clip(lut, 0, max_value).astype(img.dtype) + return cv2.LUT(img, lut) + + +@preserve_shape +def _shift_rgb_uint8(img, r_shift, g_shift, b_shift): + if r_shift == g_shift == b_shift: + h, w, c = img.shape + img = img.reshape([h, w * c]) + + return _shift_image_uint8(img, r_shift) + + result_img = np.empty_like(img) + shifts = [r_shift, g_shift, b_shift] + for i, shift in enumerate(shifts): + result_img[..., i] = _shift_image_uint8(img[..., i], shift) + + return result_img + + +def shift_rgb(img, r_shift, g_shift, b_shift): + if img.dtype == np.uint8: + return _shift_rgb_uint8(img, r_shift, g_shift, b_shift) + + return _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift) + + +@clipped +def linear_transformation_rgb(img, transformation_matrix): + result_img = cv2.transform(img, transformation_matrix) + + return result_img + + +@preserve_channel_dim +def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)): + if img.dtype != np.uint8: + raise TypeError("clahe supports only uint8 inputs") + + clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size) + + if len(img.shape) == 2 or img.shape[2] == 1: + img = clahe_mat.apply(img) + else: + img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) + img[:, :, 0] = clahe_mat.apply(img[:, :, 0]) + img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB) + + return img + + +@preserve_shape +def convolve(img, kernel): + conv_fn = _maybe_process_in_chunks(cv2.filter2D, ddepth=-1, kernel=kernel) + return conv_fn(img) + + +@preserve_shape +def image_compression(img, quality, image_type): + if image_type in [".jpeg", ".jpg"]: + quality_flag = cv2.IMWRITE_JPEG_QUALITY + elif image_type == ".webp": + quality_flag = cv2.IMWRITE_WEBP_QUALITY + else: + NotImplementedError("Only '.jpg' and '.webp' compression transforms are implemented. ") + + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + warn( + "Image compression augmentation " + "is most effective with uint8 inputs, " + "{} is used as input.".format(input_dtype), + UserWarning, + ) + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for image augmentation".format(input_dtype)) + + _, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality)) + img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED) + + if needs_float: + img = to_float(img, max_value=255) + return img + + +@preserve_shape +def add_snow(img, snow_point, brightness_coeff): + """Bleaches out pixels, imitation snow. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): Image. + snow_point: Number of show points. + brightness_coeff: Brightness coefficient. + + Returns: + numpy.ndarray: Image. + + """ + non_rgb_warning(img) + + input_dtype = img.dtype + needs_float = False + + snow_point *= 127.5 # = 255 / 2 + snow_point += 85 # = 255 / 3 + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for RandomSnow augmentation".format(input_dtype)) + + image_HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) + image_HLS = np.array(image_HLS, dtype=np.float32) + + image_HLS[:, :, 1][image_HLS[:, :, 1] < snow_point] *= brightness_coeff + + image_HLS[:, :, 1] = clip(image_HLS[:, :, 1], np.uint8, 255) + + image_HLS = np.array(image_HLS, dtype=np.uint8) + + image_RGB = cv2.cvtColor(image_HLS, cv2.COLOR_HLS2RGB) + + if needs_float: + image_RGB = to_float(image_RGB, max_value=255) + + return image_RGB + + +@preserve_shape +def add_rain( + img, + slant, + drop_length, + drop_width, + drop_color, + blur_value, + brightness_coefficient, + rain_drops, +): + """ + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): Image. + slant (int): + drop_length: + drop_width: + drop_color: + blur_value (int): Rainy view are blurry. + brightness_coefficient (float): Rainy days are usually shady. + rain_drops: + + Returns: + numpy.ndarray: Image. + + """ + non_rgb_warning(img) + + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for RandomRain augmentation".format(input_dtype)) + + image = img.copy() + + for rain_drop_x0, rain_drop_y0 in rain_drops: + rain_drop_x1 = rain_drop_x0 + slant + rain_drop_y1 = rain_drop_y0 + drop_length + + cv2.line( + image, + (rain_drop_x0, rain_drop_y0), + (rain_drop_x1, rain_drop_y1), + drop_color, + drop_width, + ) + + image = cv2.blur(image, (blur_value, blur_value)) # rainy view are blurry + image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32) + image_hsv[:, :, 2] *= brightness_coefficient + + image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB) + + if needs_float: + image_rgb = to_float(image_rgb, max_value=255) + + return image_rgb + + +@preserve_shape +def add_fog(img, fog_coef, alpha_coef, haze_list): + """Add fog to the image. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): Image. + fog_coef (float): Fog coefficient. + alpha_coef (float): Alpha coefficient. + haze_list (list): + + Returns: + numpy.ndarray: Image. + + """ + non_rgb_warning(img) + + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for RandomFog augmentation".format(input_dtype)) + + width = img.shape[1] + + hw = max(int(width // 3 * fog_coef), 10) + + for haze_points in haze_list: + x, y = haze_points + overlay = img.copy() + output = img.copy() + alpha = alpha_coef * fog_coef + rad = hw // 2 + point = (x + hw // 2, y + hw // 2) + cv2.circle(overlay, point, int(rad), (255, 255, 255), -1) + cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output) + + img = output.copy() + + image_rgb = cv2.blur(img, (hw // 10, hw // 10)) + + if needs_float: + image_rgb = to_float(image_rgb, max_value=255) + + return image_rgb + + +@preserve_shape +def add_sun_flare(img, flare_center_x, flare_center_y, src_radius, src_color, circles): + """Add sun flare. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): + flare_center_x (float): + flare_center_y (float): + src_radius: + src_color (int, int, int): + circles (list): + + Returns: + numpy.ndarray: + + """ + non_rgb_warning(img) + + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for RandomSunFlareaugmentation".format(input_dtype)) + + overlay = img.copy() + output = img.copy() + + for alpha, (x, y), rad3, (r_color, g_color, b_color) in circles: + cv2.circle(overlay, (x, y), rad3, (r_color, g_color, b_color), -1) + + cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output) + + point = (int(flare_center_x), int(flare_center_y)) + + overlay = output.copy() + num_times = src_radius // 10 + alpha = np.linspace(0.0, 1, num=num_times) + rad = np.linspace(1, src_radius, num=num_times) + for i in range(num_times): + cv2.circle(overlay, point, int(rad[i]), src_color, -1) + alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1] + cv2.addWeighted(overlay, alp, output, 1 - alp, 0, output) + + image_rgb = output + + if needs_float: + image_rgb = to_float(image_rgb, max_value=255) + + return image_rgb + + +@ensure_contiguous +@preserve_shape +def add_shadow(img, vertices_list): + """Add shadows to the image. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): + vertices_list (list): + + Returns: + numpy.ndarray: + + """ + non_rgb_warning(img) + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for RandomShadow augmentation".format(input_dtype)) + + image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) + mask = np.zeros_like(img) + + # adding all shadow polygons on empty mask, single 255 denotes only red channel + for vertices in vertices_list: + cv2.fillPoly(mask, vertices, 255) + + # if red channel is hot, image's "Lightness" channel's brightness is lowered + red_max_value_ind = mask[:, :, 0] == 255 + image_hls[:, :, 1][red_max_value_ind] = image_hls[:, :, 1][red_max_value_ind] * 0.5 + + image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB) + + if needs_float: + image_rgb = to_float(image_rgb, max_value=255) + + return image_rgb + + +@ensure_contiguous +@preserve_shape +def add_gravel(img: np.ndarray, gravels: list): + """Add gravel to the image. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + img (numpy.ndarray): image to add gravel to + gravels (list): list of gravel parameters. (float, float, float, float): + (top-left x, top-left y, bottom-right x, bottom right y) + + Returns: + numpy.ndarray: + """ + non_rgb_warning(img) + input_dtype = img.dtype + needs_float = False + + if input_dtype == np.float32: + img = from_float(img, dtype=np.dtype("uint8")) + needs_float = True + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for AddGravel augmentation".format(input_dtype)) + + image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS) + + for gravel in gravels: + y1, y2, x1, x2, sat = gravel + image_hls[x1:x2, y1:y2, 1] = sat + + image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB) + + if needs_float: + image_rgb = to_float(image_rgb, max_value=255) + + return image_rgb + + +def invert(img: np.ndarray) -> np.ndarray: + # Supports all the valid dtypes + # clips the img to avoid unexpected behaviour. + return MAX_VALUES_BY_DTYPE[img.dtype] - img + + +def channel_shuffle(img, channels_shuffled): + img = img[..., channels_shuffled] + return img + + +@preserve_shape +def gamma_transform(img, gamma): + if img.dtype == np.uint8: + table = (np.arange(0, 256.0 / 255, 1.0 / 255) ** gamma) * 255 + img = cv2.LUT(img, table.astype(np.uint8)) + else: + img = np.power(img, gamma) + + return img + + +@clipped +def gauss_noise(image, gauss): + image = image.astype("float32") + return image + gauss + + +@clipped +def _brightness_contrast_adjust_non_uint(img, alpha=1, beta=0, beta_by_max=False): + dtype = img.dtype + img = img.astype("float32") + + if alpha != 1: + img *= alpha + if beta != 0: + if beta_by_max: + max_value = MAX_VALUES_BY_DTYPE[dtype] + img += beta * max_value + else: + img += beta * np.mean(img) + return img + + +@preserve_shape +def _brightness_contrast_adjust_uint(img, alpha=1, beta=0, beta_by_max=False): + dtype = np.dtype("uint8") + + max_value = MAX_VALUES_BY_DTYPE[dtype] + + lut = np.arange(0, max_value + 1).astype("float32") + + if alpha != 1: + lut *= alpha + if beta != 0: + if beta_by_max: + lut += beta * max_value + else: + lut += (alpha * beta) * np.mean(img) + + lut = np.clip(lut, 0, max_value).astype(dtype) + img = cv2.LUT(img, lut) + return img + + +def brightness_contrast_adjust(img, alpha=1, beta=0, beta_by_max=False): + if img.dtype == np.uint8: + return _brightness_contrast_adjust_uint(img, alpha, beta, beta_by_max) + + return _brightness_contrast_adjust_non_uint(img, alpha, beta, beta_by_max) + + +@clipped +def iso_noise(image, color_shift=0.05, intensity=0.5, random_state=None, **kwargs): + """ + Apply poisson noise to image to simulate camera sensor noise. + + Args: + image (numpy.ndarray): Input image, currently, only RGB, uint8 images are supported. + color_shift (float): + intensity (float): Multiplication factor for noise values. Values of ~0.5 are produce noticeable, + yet acceptable level of noise. + random_state: + **kwargs: + + Returns: + numpy.ndarray: Noised image + + """ + if image.dtype != np.uint8: + raise TypeError("Image must have uint8 channel type") + if not is_rgb_image(image): + raise TypeError("Image must be RGB") + + one_over_255 = float(1.0 / 255.0) + image = np.multiply(image, one_over_255, dtype=np.float32) + hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS) + _, stddev = cv2.meanStdDev(hls) + + luminance_noise = random_utils.poisson(stddev[1] * intensity * 255, size=hls.shape[:2], random_state=random_state) + color_noise = random_utils.normal(0, color_shift * 360 * intensity, size=hls.shape[:2], random_state=random_state) + + hue = hls[..., 0] + hue += color_noise + hue[hue < 0] += 360 + hue[hue > 360] -= 360 + + luminance = hls[..., 1] + luminance += (luminance_noise / 255) * (1.0 - luminance) + + image = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB) * 255 + return image.astype(np.uint8) + + +def to_gray(img): + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + return cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) + + +def gray_to_rgb(img): + return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB) + + +@preserve_shape +def downscale(img, scale, down_interpolation=cv2.INTER_AREA, up_interpolation=cv2.INTER_LINEAR): + h, w = img.shape[:2] + + need_cast = ( + up_interpolation != cv2.INTER_NEAREST or down_interpolation != cv2.INTER_NEAREST + ) and img.dtype == np.uint8 + if need_cast: + img = to_float(img) + downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=down_interpolation) + upscaled = cv2.resize(downscaled, (w, h), interpolation=up_interpolation) + if need_cast: + upscaled = from_float(np.clip(upscaled, 0, 1), dtype=np.dtype("uint8")) + return upscaled + + +def to_float(img, max_value=None): + if max_value is None: + try: + max_value = MAX_VALUES_BY_DTYPE[img.dtype] + except KeyError: + raise RuntimeError( + "Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by " + "passing the max_value argument".format(img.dtype) + ) + return img.astype("float32") / max_value + + +def from_float(img, dtype, max_value=None): + if max_value is None: + try: + max_value = MAX_VALUES_BY_DTYPE[dtype] + except KeyError: + raise RuntimeError( + "Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by " + "passing the max_value argument".format(dtype) + ) + return (img * max_value).astype(dtype) + + +def noop(input_obj, **params): # skipcq: PYL-W0613 + return input_obj + + +def swap_tiles_on_image(image, tiles): + """ + Swap tiles on image. + + Args: + image (np.ndarray): Input image. + tiles (np.ndarray): array of tuples( + current_left_up_corner_row, current_left_up_corner_col, + old_left_up_corner_row, old_left_up_corner_col, + height_tile, width_tile) + + Returns: + np.ndarray: Output image. + + """ + new_image = image.copy() + + for tile in tiles: + new_image[tile[0] : tile[0] + tile[4], tile[1] : tile[1] + tile[5]] = image[ + tile[2] : tile[2] + tile[4], tile[3] : tile[3] + tile[5] + ] + + return new_image + + +@clipped +def _multiply_uint8(img, multiplier): + img = img.astype(np.float32) + return np.multiply(img, multiplier) + + +@preserve_shape +def _multiply_uint8_optimized(img, multiplier): + if is_grayscale_image(img) or len(multiplier) == 1: + multiplier = multiplier[0] + lut = np.arange(0, 256, dtype=np.float32) + lut *= multiplier + lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype]) + func = _maybe_process_in_chunks(cv2.LUT, lut=lut) + return func(img) + + channels = img.shape[-1] + lut = [np.arange(0, 256, dtype=np.float32)] * channels + lut = np.stack(lut, axis=-1) + + lut *= multiplier + lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype]) + + images = [] + for i in range(channels): + func = _maybe_process_in_chunks(cv2.LUT, lut=lut[:, i]) + images.append(func(img[:, :, i])) + return np.stack(images, axis=-1) + + +@clipped +def _multiply_non_uint8(img, multiplier): + return img * multiplier + + +def multiply(img, multiplier): + """ + Args: + img (numpy.ndarray): Image. + multiplier (numpy.ndarray): Multiplier coefficient. + + Returns: + numpy.ndarray: Image multiplied by `multiplier` coefficient. + + """ + if img.dtype == np.uint8: + if len(multiplier.shape) == 1: + return _multiply_uint8_optimized(img, multiplier) + + return _multiply_uint8(img, multiplier) + + return _multiply_non_uint8(img, multiplier) + + +def bbox_from_mask(mask): + """Create bounding box from binary mask (fast version) + + Args: + mask (numpy.ndarray): binary mask. + + Returns: + tuple: A bounding box tuple `(x_min, y_min, x_max, y_max)`. + + """ + rows = np.any(mask, axis=1) + if not rows.any(): + return -1, -1, -1, -1 + cols = np.any(mask, axis=0) + y_min, y_max = np.where(rows)[0][[0, -1]] + x_min, x_max = np.where(cols)[0][[0, -1]] + return x_min, y_min, x_max + 1, y_max + 1 + + +def mask_from_bbox(img, bbox): + """Create binary mask from bounding box + + Args: + img (numpy.ndarray): input image + bbox: A bounding box tuple `(x_min, y_min, x_max, y_max)` + + Returns: + mask (numpy.ndarray): binary mask + + """ + + mask = np.zeros(img.shape[:2], dtype=np.uint8) + x_min, y_min, x_max, y_max = bbox + mask[y_min:y_max, x_min:x_max] = 1 + return mask + + +def fancy_pca(img, alpha=0.1): + """Perform 'Fancy PCA' augmentation from: + http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf + + Args: + img (numpy.ndarray): numpy array with (h, w, rgb) shape, as ints between 0-255 + alpha (float): how much to perturb/scale the eigen vecs and vals + the paper used std=0.1 + + Returns: + numpy.ndarray: numpy image-like array as uint8 range(0, 255) + + """ + if not is_rgb_image(img) or img.dtype != np.uint8: + raise TypeError("Image must be RGB image in uint8 format.") + + orig_img = img.astype(float).copy() + + img = img / 255.0 # rescale to 0 to 1 range + + # flatten image to columns of RGB + img_rs = img.reshape(-1, 3) + # img_rs shape (640000, 3) + + # center mean + img_centered = img_rs - np.mean(img_rs, axis=0) + + # paper says 3x3 covariance matrix + img_cov = np.cov(img_centered, rowvar=False) + + # eigen values and eigen vectors + eig_vals, eig_vecs = np.linalg.eigh(img_cov) + + # sort values and vector + sort_perm = eig_vals[::-1].argsort() + eig_vals[::-1].sort() + eig_vecs = eig_vecs[:, sort_perm] + + # get [p1, p2, p3] + m1 = np.column_stack((eig_vecs)) + + # get 3x1 matrix of eigen values multiplied by random variable draw from normal + # distribution with mean of 0 and standard deviation of 0.1 + m2 = np.zeros((3, 1)) + # according to the paper alpha should only be draw once per augmentation (not once per channel) + # alpha = np.random.normal(0, alpha_std) + + # broad cast to speed things up + m2[:, 0] = alpha * eig_vals[:] + + # this is the vector that we're going to add to each pixel in a moment + add_vect = np.matrix(m1) * np.matrix(m2) + + for idx in range(3): # RGB + orig_img[..., idx] += add_vect[idx] * 255 + + # for image processing it was found that working with float 0.0 to 1.0 + # was easier than integers between 0-255 + # orig_img /= 255.0 + orig_img = np.clip(orig_img, 0.0, 255.0) + + # orig_img *= 255 + orig_img = orig_img.astype(np.uint8) + + return orig_img + + +def _adjust_brightness_torchvision_uint8(img, factor): + lut = np.arange(0, 256) * factor + lut = np.clip(lut, 0, 255).astype(np.uint8) + return cv2.LUT(img, lut) + + +@preserve_shape +def adjust_brightness_torchvision(img, factor): + if factor == 0: + return np.zeros_like(img) + elif factor == 1: + return img + + if img.dtype == np.uint8: + return _adjust_brightness_torchvision_uint8(img, factor) + + return clip(img * factor, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype]) + + +def _adjust_contrast_torchvision_uint8(img, factor, mean): + lut = np.arange(0, 256) * factor + lut = lut + mean * (1 - factor) + lut = clip(lut, img.dtype, 255) + + return cv2.LUT(img, lut) + + +@preserve_shape +def adjust_contrast_torchvision(img, factor): + if factor == 1: + return img + + if is_grayscale_image(img): + mean = img.mean() + else: + mean = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean() + + if factor == 0: + if img.dtype != np.float32: + mean = int(mean + 0.5) + return np.full_like(img, mean, dtype=img.dtype) + + if img.dtype == np.uint8: + return _adjust_contrast_torchvision_uint8(img, factor, mean) + + return clip( + img.astype(np.float32) * factor + mean * (1 - factor), + img.dtype, + MAX_VALUES_BY_DTYPE[img.dtype], + ) + + +@preserve_shape +def adjust_saturation_torchvision(img, factor, gamma=0): + if factor == 1: + return img + + if is_grayscale_image(img): + gray = img + return gray + else: + gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB) + + if factor == 0: + return gray + + result = cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma) + if img.dtype == np.uint8: + return result + + # OpenCV does not clip values for float dtype + return clip(result, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype]) + + +def _adjust_hue_torchvision_uint8(img, factor): + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + + lut = np.arange(0, 256, dtype=np.int16) + lut = np.mod(lut + 180 * factor, 180).astype(np.uint8) + img[..., 0] = cv2.LUT(img[..., 0], lut) + + return cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + + +def adjust_hue_torchvision(img, factor): + if is_grayscale_image(img): + return img + + if factor == 0: + return img + + if img.dtype == np.uint8: + return _adjust_hue_torchvision_uint8(img, factor) + + img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV) + img[..., 0] = np.mod(img[..., 0] + factor * 360, 360) + return cv2.cvtColor(img, cv2.COLOR_HSV2RGB) + + +@preserve_shape +def superpixels( + image: np.ndarray, n_segments: int, replace_samples: Sequence[bool], max_size: Optional[int], interpolation: int +) -> np.ndarray: + if not np.any(replace_samples): + return image + + orig_shape = image.shape + if max_size is not None: + size = max(image.shape[:2]) + if size > max_size: + scale = max_size / size + height, width = image.shape[:2] + new_height, new_width = int(height * scale), int(width * scale) + resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(new_width, new_height), interpolation=interpolation) + image = resize_fn(image) + + segments = skimage.segmentation.slic( + image, n_segments=n_segments, compactness=10, channel_axis=-1 if image.ndim > 2 else None + ) + + min_value = 0 + max_value = MAX_VALUES_BY_DTYPE[image.dtype] + image = np.copy(image) + if image.ndim == 2: + image = image.reshape(*image.shape, 1) + nb_channels = image.shape[2] + for c in range(nb_channels): + # segments+1 here because otherwise regionprops always misses the last label + regions = skimage.measure.regionprops(segments + 1, intensity_image=image[..., c]) + for ridx, region in enumerate(regions): + # with mod here, because slic can sometimes create more superpixel than requested. + # replace_samples then does not have enough values, so we just start over with the first one again. + if replace_samples[ridx % len(replace_samples)]: + mean_intensity = region.mean_intensity + image_sp_c = image[..., c] + + if image_sp_c.dtype.kind in ["i", "u", "b"]: + # After rounding the value can end up slightly outside of the value_range. Hence, we need to clip. + # We do clip via min(max(...)) instead of np.clip because + # the latter one does not seem to keep dtypes for dtypes with large itemsizes (e.g. uint64). + value: Union[int, float] + value = int(np.round(mean_intensity)) + value = min(max(value, min_value), max_value) + else: + value = mean_intensity + + image_sp_c[segments == ridx] = value + + if orig_shape != image.shape: + resize_fn = _maybe_process_in_chunks( + cv2.resize, dsize=(orig_shape[1], orig_shape[0]), interpolation=interpolation + ) + image = resize_fn(image) + + return image + + +@clipped +def add_weighted(img1, alpha, img2, beta): + return img1.astype(float) * alpha + img2.astype(float) * beta + + +@clipped +@preserve_shape +def unsharp_mask(image: np.ndarray, ksize: int, sigma: float = 0.0, alpha: float = 0.2, threshold: int = 10): + blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma) + + input_dtype = image.dtype + if input_dtype == np.uint8: + image = to_float(image) + elif input_dtype not in (np.uint8, np.float32): + raise ValueError("Unexpected dtype {} for UnsharpMask augmentation".format(input_dtype)) + + blur = blur_fn(image) + residual = image - blur + + # Do not sharpen noise + mask = np.abs(residual) * 255 > threshold + mask = mask.astype("float32") + + sharp = image + alpha * residual + # Avoid color noise artefacts. + sharp = np.clip(sharp, 0, 1) + + soft_mask = blur_fn(mask) + output = soft_mask * sharp + (1 - soft_mask) * image + return from_float(output, dtype=input_dtype) + + +@preserve_shape +def pixel_dropout(image: np.ndarray, drop_mask: np.ndarray, drop_value: Union[float, Sequence[float]]) -> np.ndarray: + if isinstance(drop_value, (int, float)) and drop_value == 0: + drop_values = np.zeros_like(image) + else: + drop_values = np.full_like(image, drop_value) # type: ignore + return np.where(drop_mask, drop_values, image) + + +@clipped +@preserve_shape +def spatter( + img: np.ndarray, + non_mud: Optional[np.ndarray], + mud: Optional[np.ndarray], + rain: Optional[np.ndarray], + mode: str, +) -> np.ndarray: + non_rgb_warning(img) + + coef = MAX_VALUES_BY_DTYPE[img.dtype] + img = img.astype(np.float32) * (1 / coef) + + if mode == "rain": + assert rain is not None + img = img + rain + elif mode == "mud": + assert non_mud is not None and mud is not None + img = img * non_mud + mud + else: + raise ValueError("Unsupported spatter mode: " + str(mode)) + + return img * 255 diff --git a/custom_albumentations/augmentations/geometric/__init__.py b/custom_albumentations/augmentations/geometric/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b18da9e260418c267c73a7d3d56acadbdc3b3636 --- /dev/null +++ b/custom_albumentations/augmentations/geometric/__init__.py @@ -0,0 +1,4 @@ +from .functional import * +from .resize import * +from .rotate import * +from .transforms import * diff --git a/custom_albumentations/augmentations/geometric/functional.py b/custom_albumentations/augmentations/geometric/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..13d98f54f1d3089d690640b9fdea1efa1c0a9299 --- /dev/null +++ b/custom_albumentations/augmentations/geometric/functional.py @@ -0,0 +1,1300 @@ +import math +from typing import List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +import skimage.transform +from scipy.ndimage import gaussian_filter + +from custom_albumentations.augmentations.utils import ( + _maybe_process_in_chunks, + angle_2pi_range, + clipped, + preserve_channel_dim, + preserve_shape, +) + +from ... import random_utils +from ...core.bbox_utils import denormalize_bbox, normalize_bbox +from ...core.transforms_interface import ( + BoxInternalType, + FillValueType, + ImageColorType, + KeypointInternalType, +) + +__all__ = [ + "optical_distortion", + "elastic_transform_approx", + "grid_distortion", + "pad", + "pad_with_params", + "bbox_rot90", + "keypoint_rot90", + "rotate", + "bbox_rotate", + "keypoint_rotate", + "shift_scale_rotate", + "keypoint_shift_scale_rotate", + "bbox_shift_scale_rotate", + "elastic_transform", + "resize", + "scale", + "keypoint_scale", + "py3round", + "_func_max_size", + "longest_max_size", + "smallest_max_size", + "perspective", + "perspective_bbox", + "rotation2DMatrixToEulerAngles", + "perspective_keypoint", + "_is_identity_matrix", + "warp_affine", + "keypoint_affine", + "bbox_affine", + "safe_rotate", + "bbox_safe_rotate", + "keypoint_safe_rotate", + "piecewise_affine", + "to_distance_maps", + "from_distance_maps", + "keypoint_piecewise_affine", + "bbox_piecewise_affine", + "bbox_flip", + "bbox_hflip", + "bbox_transpose", + "bbox_vflip", + "hflip", + "hflip_cv2", + "transpose", + "keypoint_flip", + "keypoint_hflip", + "keypoint_transpose", + "keypoint_vflip", +] + + +def bbox_rot90(bbox: BoxInternalType, factor: int, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613 + """Rotates a bounding box by 90 degrees CCW (see np.rot90) + + Args: + bbox: A bounding box tuple (x_min, y_min, x_max, y_max). + factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90. + rows: Image rows. + cols: Image cols. + + Returns: + tuple: A bounding box tuple (x_min, y_min, x_max, y_max). + + """ + if factor not in {0, 1, 2, 3}: + raise ValueError("Parameter n must be in set {0, 1, 2, 3}") + x_min, y_min, x_max, y_max = bbox[:4] + if factor == 1: + bbox = y_min, 1 - x_max, y_max, 1 - x_min + elif factor == 2: + bbox = 1 - x_max, 1 - y_max, 1 - x_min, 1 - y_min + elif factor == 3: + bbox = 1 - y_max, x_min, 1 - y_min, x_max + return bbox + + +@angle_2pi_range +def keypoint_rot90(keypoint: KeypointInternalType, factor: int, rows: int, cols: int, **params) -> KeypointInternalType: + """Rotates a keypoint by 90 degrees CCW (see np.rot90) + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + factor: Number of CCW rotations. Must be in range [0;3] See np.rot90. + rows: Image height. + cols: Image width. + + Returns: + tuple: A keypoint `(x, y, angle, scale)`. + + Raises: + ValueError: if factor not in set {0, 1, 2, 3} + + """ + x, y, angle, scale = keypoint[:4] + + if factor not in {0, 1, 2, 3}: + raise ValueError("Parameter n must be in set {0, 1, 2, 3}") + + if factor == 1: + x, y, angle = y, (cols - 1) - x, angle - math.pi / 2 + elif factor == 2: + x, y, angle = (cols - 1) - x, (rows - 1) - y, angle - math.pi + elif factor == 3: + x, y, angle = (rows - 1) - y, x, angle + math.pi / 2 + + return x, y, angle, scale + + +@preserve_channel_dim +def rotate( + img: np.ndarray, + angle: float, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, +): + height, width = img.shape[:2] + # for images we use additional shifts of (0.5, 0.5) as otherwise + # we get an ugly black border for 90deg rotations + matrix = cv2.getRotationMatrix2D((width / 2 - 0.5, height / 2 - 0.5), angle, 1.0) + + warp_fn = _maybe_process_in_chunks( + cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value + ) + return warp_fn(img) + + +def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, cols: int) -> BoxInternalType: + """Rotates a bounding box by angle degrees. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + angle: Angle of rotation in degrees. + method: Rotation method used. Should be one of: "largest_box", "ellipse". Default: "largest_box". + rows: Image rows. + cols: Image cols. + + Returns: + A bounding box `(x_min, y_min, x_max, y_max)`. + + References: + https://arxiv.org/abs/2109.13488 + + """ + x_min, y_min, x_max, y_max = bbox[:4] + scale = cols / float(rows) + if method == "largest_box": + x = np.array([x_min, x_max, x_max, x_min]) - 0.5 + y = np.array([y_min, y_min, y_max, y_max]) - 0.5 + elif method == "ellipse": + w = (x_max - x_min) / 2 + h = (y_max - y_min) / 2 + data = np.arange(0, 360, dtype=np.float32) + x = w * np.sin(np.radians(data)) + (w + x_min - 0.5) + y = h * np.cos(np.radians(data)) + (h + y_min - 0.5) + else: + raise ValueError(f"Method {method} is not a valid rotation method.") + angle = np.deg2rad(angle) + x_t = (np.cos(angle) * x * scale + np.sin(angle) * y) / scale + y_t = -np.sin(angle) * x * scale + np.cos(angle) * y + x_t = x_t + 0.5 + y_t = y_t + 0.5 + + x_min, x_max = min(x_t), max(x_t) + y_min, y_max = min(y_t), max(y_t) + + return x_min, y_min, x_max, y_max + + +@angle_2pi_range +def keypoint_rotate(keypoint, angle, rows, cols, **params): + """Rotate a keypoint by angle. + + Args: + keypoint (tuple): A keypoint `(x, y, angle, scale)`. + angle (float): Rotation angle. + rows (int): Image height. + cols (int): Image width. + + Returns: + tuple: A keypoint `(x, y, angle, scale)`. + + """ + center = (cols - 1) * 0.5, (rows - 1) * 0.5 + matrix = cv2.getRotationMatrix2D(center, angle, 1.0) + x, y, a, s = keypoint[:4] + x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze() + return x, y, a + math.radians(angle), s + + +@preserve_channel_dim +def shift_scale_rotate( + img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=None +): + height, width = img.shape[:2] + # for images we use additional shifts of (0.5, 0.5) as otherwise + # we get an ugly black border for 90deg rotations + center = (width / 2 - 0.5, height / 2 - 0.5) + matrix = cv2.getRotationMatrix2D(center, angle, scale) + matrix[0, 2] += dx * width + matrix[1, 2] += dy * height + + warp_affine_fn = _maybe_process_in_chunks( + cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value + ) + return warp_affine_fn(img) + + +@angle_2pi_range +def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params): + ( + x, + y, + a, + s, + ) = keypoint[:4] + height, width = rows, cols + center = (cols - 1) * 0.5, (rows - 1) * 0.5 + matrix = cv2.getRotationMatrix2D(center, angle, scale) + matrix[0, 2] += dx * width + matrix[1, 2] += dy * height + + x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze() + angle = a + math.radians(angle) + scale = s * scale + + return x, y, angle, scale + + +def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rotate_method, rows, cols, **kwargs): # skipcq: PYL-W0613 + """Rotates, shifts and scales a bounding box. Rotation is made by angle degrees, + scaling is made by scale factor and shifting is made by dx and dy. + + + Args: + bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`. + angle (int): Angle of rotation in degrees. + scale (int): Scale factor. + dx (int): Shift along x-axis in pixel units. + dy (int): Shift along y-axis in pixel units. + rotate_method(str): Rotation method used. Should be one of: "largest_box", "ellipse". + Default: "largest_box". + rows (int): Image rows. + cols (int): Image cols. + + Returns: + A bounding box `(x_min, y_min, x_max, y_max)`. + + """ + height, width = rows, cols + center = (width / 2, height / 2) + if rotate_method == "ellipse": + x_min, y_min, x_max, y_max = bbox_rotate(bbox, angle, rotate_method, rows, cols) + matrix = cv2.getRotationMatrix2D(center, 0, scale) + else: + x_min, y_min, x_max, y_max = bbox[:4] + matrix = cv2.getRotationMatrix2D(center, angle, scale) + matrix[0, 2] += dx * width + matrix[1, 2] += dy * height + x = np.array([x_min, x_max, x_max, x_min]) + y = np.array([y_min, y_min, y_max, y_max]) + ones = np.ones(shape=(len(x))) + points_ones = np.vstack([x, y, ones]).transpose() + points_ones[:, 0] *= width + points_ones[:, 1] *= height + tr_points = matrix.dot(points_ones.T).T + tr_points[:, 0] /= width + tr_points[:, 1] /= height + + x_min, x_max = min(tr_points[:, 0]), max(tr_points[:, 0]) + y_min, y_max = min(tr_points[:, 1]), max(tr_points[:, 1]) + + return x_min, y_min, x_max, y_max + + +@preserve_shape +def elastic_transform( + img: np.ndarray, + alpha: float, + sigma: float, + alpha_affine: float, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, + random_state: Optional[np.random.RandomState] = None, + approximate: bool = False, + same_dxdy: bool = False, +): + """Elastic deformation of images as described in [Simard2003]_ (with modifications). + Based on https://gist.github.com/ernestum/601cdf56d2b424757de5 + + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + height, width = img.shape[:2] + + # Random affine + center_square = np.array((height, width), dtype=np.float32) // 2 + square_size = min((height, width)) // 3 + alpha = float(alpha) + sigma = float(sigma) + alpha_affine = float(alpha_affine) + + pts1 = np.array( + [ + center_square + square_size, + [center_square[0] + square_size, center_square[1] - square_size], + center_square - square_size, + ], + dtype=np.float32, + ) + pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype( + np.float32 + ) + matrix = cv2.getAffineTransform(pts1, pts2) + + warp_fn = _maybe_process_in_chunks( + cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value + ) + img = warp_fn(img) + + if approximate: + # Approximate computation smooth displacement map with a large enough kernel. + # On large images (512+) this is approximately 2X times faster + dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1 + cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx) + dx *= alpha + if same_dxdy: + # Speed up even more + dy = dx + else: + dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1 + cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy) + dy *= alpha + else: + dx = np.float32( + gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha + ) + if same_dxdy: + # Speed up + dy = dx + else: + dy = np.float32( + gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha + ) + + x, y = np.meshgrid(np.arange(width), np.arange(height)) + + map_x = np.float32(x + dx) + map_y = np.float32(y + dy) + + remap_fn = _maybe_process_in_chunks( + cv2.remap, map1=map_x, map2=map_y, interpolation=interpolation, borderMode=border_mode, borderValue=value + ) + return remap_fn(img) + + +@preserve_channel_dim +def resize(img, height, width, interpolation=cv2.INTER_LINEAR): + img_height, img_width = img.shape[:2] + if height == img_height and width == img_width: + return img + resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation) + return resize_fn(img) + + +@preserve_channel_dim +def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray: + height, width = img.shape[:2] + new_height, new_width = int(height * scale), int(width * scale) + return resize(img, new_height, new_width, interpolation) + + +def keypoint_scale(keypoint: KeypointInternalType, scale_x: float, scale_y: float) -> KeypointInternalType: + """Scales a keypoint by scale_x and scale_y. + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + scale_x: Scale coefficient x-axis. + scale_y: Scale coefficient y-axis. + + Returns: + A keypoint `(x, y, angle, scale)`. + + """ + x, y, angle, scale = keypoint[:4] + return x * scale_x, y * scale_y, angle, scale * max(scale_x, scale_y) + + +def py3round(number): + """Unified rounding in all python versions.""" + if abs(round(number) - number) == 0.5: + return int(2.0 * round(number / 2.0)) + + return int(round(number)) + + +def _func_max_size(img, max_size, interpolation, func): + height, width = img.shape[:2] + + scale = max_size / float(func(width, height)) + + if scale != 1.0: + new_height, new_width = tuple(py3round(dim * scale) for dim in (height, width)) + img = resize(img, height=new_height, width=new_width, interpolation=interpolation) + return img + + +@preserve_channel_dim +def longest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray: + return _func_max_size(img, max_size, interpolation, max) + + +@preserve_channel_dim +def smallest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray: + return _func_max_size(img, max_size, interpolation, min) + + +@preserve_channel_dim +def perspective( + img: np.ndarray, + matrix: np.ndarray, + max_width: int, + max_height: int, + border_val: Union[int, float, List[int], List[float], np.ndarray], + border_mode: int, + keep_size: bool, + interpolation: int, +): + h, w = img.shape[:2] + perspective_func = _maybe_process_in_chunks( + cv2.warpPerspective, + M=matrix, + dsize=(max_width, max_height), + borderMode=border_mode, + borderValue=border_val, + flags=interpolation, + ) + warped = perspective_func(img) + + if keep_size: + return resize(warped, h, w, interpolation=interpolation) + + return warped + + +def perspective_bbox( + bbox: BoxInternalType, + height: int, + width: int, + matrix: np.ndarray, + max_width: int, + max_height: int, + keep_size: bool, +) -> BoxInternalType: + x1, y1, x2, y2 = denormalize_bbox(bbox, height, width)[:4] + + points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32) + + x1, y1, x2, y2 = float("inf"), float("inf"), 0, 0 + for pt in points: + pt = perspective_keypoint(pt.tolist() + [0, 0], height, width, matrix, max_width, max_height, keep_size) + x, y = pt[:2] + x1 = min(x1, x) + x2 = max(x2, x) + y1 = min(y1, y) + y2 = max(y2, y) + + return normalize_bbox((x1, y1, x2, y2), height if keep_size else max_height, width if keep_size else max_width) + + +def rotation2DMatrixToEulerAngles(matrix: np.ndarray, y_up: bool = False) -> float: + """ + Args: + matrix (np.ndarray): Rotation matrix + y_up (bool): is Y axis looks up or down + """ + if y_up: + return np.arctan2(matrix[1, 0], matrix[0, 0]) + return np.arctan2(-matrix[1, 0], matrix[0, 0]) + + +@angle_2pi_range +def perspective_keypoint( + keypoint: KeypointInternalType, + height: int, + width: int, + matrix: np.ndarray, + max_width: int, + max_height: int, + keep_size: bool, +) -> KeypointInternalType: + x, y, angle, scale = keypoint + + keypoint_vector = np.array([x, y], dtype=np.float32).reshape([1, 1, 2]) + + x, y = cv2.perspectiveTransform(keypoint_vector, matrix)[0, 0] + angle += rotation2DMatrixToEulerAngles(matrix[:2, :2], y_up=True) + + scale_x = np.sign(matrix[0, 0]) * np.sqrt(matrix[0, 0] ** 2 + matrix[0, 1] ** 2) + scale_y = np.sign(matrix[1, 1]) * np.sqrt(matrix[1, 0] ** 2 + matrix[1, 1] ** 2) + scale *= max(scale_x, scale_y) + + if keep_size: + scale_x = width / max_width + scale_y = height / max_height + return keypoint_scale((x, y, angle, scale), scale_x, scale_y) + + return x, y, angle, scale + + +def _is_identity_matrix(matrix: skimage.transform.ProjectiveTransform) -> bool: + return np.allclose(matrix.params, np.eye(3, dtype=np.float32)) + + +@preserve_channel_dim +def warp_affine( + image: np.ndarray, + matrix: skimage.transform.ProjectiveTransform, + interpolation: int, + cval: Union[int, float, Sequence[int], Sequence[float]], + mode: int, + output_shape: Sequence[int], +) -> np.ndarray: + if _is_identity_matrix(matrix): + return image + + dsize = int(np.round(output_shape[1])), int(np.round(output_shape[0])) + warp_fn = _maybe_process_in_chunks( + cv2.warpAffine, M=matrix.params[:2], dsize=dsize, flags=interpolation, borderMode=mode, borderValue=cval + ) + tmp = warp_fn(image) + return tmp + + +@angle_2pi_range +def keypoint_affine( + keypoint: KeypointInternalType, + matrix: skimage.transform.ProjectiveTransform, + scale: dict, +) -> KeypointInternalType: + if _is_identity_matrix(matrix): + return keypoint + + x, y, a, s = keypoint[:4] + x, y = cv2.transform(np.array([[[x, y]]]), matrix.params[:2]).squeeze() + a += rotation2DMatrixToEulerAngles(matrix.params[:2]) + s *= np.max([scale["x"], scale["y"]]) + return x, y, a, s + + +def bbox_affine( + bbox: BoxInternalType, + matrix: skimage.transform.ProjectiveTransform, + rotate_method: str, + rows: int, + cols: int, + output_shape: Sequence[int], +) -> BoxInternalType: + if _is_identity_matrix(matrix): + return bbox + x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4] + if rotate_method == "largest_box": + points = np.array( + [ + [x_min, y_min], + [x_max, y_min], + [x_max, y_max], + [x_min, y_max], + ] + ) + elif rotate_method == "ellipse": + w = (x_max - x_min) / 2 + h = (y_max - y_min) / 2 + data = np.arange(0, 360, dtype=np.float32) + x = w * np.sin(np.radians(data)) + (w + x_min - 0.5) + y = h * np.cos(np.radians(data)) + (h + y_min - 0.5) + points = np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)]) + else: + raise ValueError(f"Method {rotate_method} is not a valid rotation method.") + points = skimage.transform.matrix_transform(points, matrix.params) + x_min = np.min(points[:, 0]) + x_max = np.max(points[:, 0]) + y_min = np.min(points[:, 1]) + y_max = np.max(points[:, 1]) + + return normalize_bbox((x_min, y_min, x_max, y_max), output_shape[0], output_shape[1]) + + +@preserve_channel_dim +def safe_rotate( + img: np.ndarray, + matrix: np.ndarray, + interpolation: int, + value: FillValueType = None, + border_mode: int = cv2.BORDER_REFLECT_101, +) -> np.ndarray: + h, w = img.shape[:2] + warp_fn = _maybe_process_in_chunks( + cv2.warpAffine, + M=matrix, + dsize=(w, h), + flags=interpolation, + borderMode=border_mode, + borderValue=value, + ) + return warp_fn(img) + + +def bbox_safe_rotate(bbox: BoxInternalType, matrix: np.ndarray, cols: int, rows: int) -> BoxInternalType: + x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4] + points = np.array( + [ + [x1, y1, 1], + [x2, y1, 1], + [x2, y2, 1], + [x1, y2, 1], + ] + ) + points = points @ matrix.T + x1 = points[:, 0].min() + x2 = points[:, 0].max() + y1 = points[:, 1].min() + y2 = points[:, 1].max() + + def fix_point(pt1: float, pt2: float, max_val: float) -> Tuple[float, float]: + # In my opinion, these errors should be very low, around 1-2 pixels. + if pt1 < 0: + return 0, pt2 + pt1 + if pt2 > max_val: + return pt1 - (pt2 - max_val), max_val + return pt1, pt2 + + x1, x2 = fix_point(x1, x2, cols) + y1, y2 = fix_point(y1, y2, rows) + + return normalize_bbox((x1, y1, x2, y2), rows, cols) + + +def keypoint_safe_rotate( + keypoint: KeypointInternalType, + matrix: np.ndarray, + angle: float, + scale_x: float, + scale_y: float, + cols: int, + rows: int, +) -> KeypointInternalType: + x, y, a, s = keypoint[:4] + point = np.array([[x, y, 1]]) + x, y = (point @ matrix.T)[0] + + # To avoid problems with float errors + x = np.clip(x, 0, cols - 1) + y = np.clip(y, 0, rows - 1) + + a += angle + s *= max(scale_x, scale_y) + return x, y, a, s + + +@clipped +def piecewise_affine( + img: np.ndarray, + matrix: Optional[skimage.transform.PiecewiseAffineTransform], + interpolation: int, + mode: str, + cval: float, +) -> np.ndarray: + if matrix is None: + return img + return skimage.transform.warp( + img, matrix, order=interpolation, mode=mode, cval=cval, preserve_range=True, output_shape=img.shape + ) + + +def to_distance_maps( + keypoints: Sequence[Tuple[float, float]], height: int, width: int, inverted: bool = False +) -> np.ndarray: + """Generate a ``(H,W,N)`` array of distance maps for ``N`` keypoints. + + The ``n``-th distance map contains at every location ``(y, x)`` the + euclidean distance to the ``n``-th keypoint. + + This function can be used as a helper when augmenting keypoints with a + method that only supports the augmentation of images. + + Args: + keypoint: keypoint coordinates + height: image height + width: image width + inverted (bool): If ``True``, inverted distance maps are returned where each + distance value d is replaced by ``d/(d+1)``, i.e. the distance + maps have values in the range ``(0.0, 1.0]`` with ``1.0`` denoting + exactly the position of the respective keypoint. + + Returns: + (H, W, N) ndarray + A ``float32`` array containing ``N`` distance maps for ``N`` + keypoints. Each location ``(y, x, n)`` in the array denotes the + euclidean distance at ``(y, x)`` to the ``n``-th keypoint. + If `inverted` is ``True``, the distance ``d`` is replaced + by ``d/(d+1)``. The height and width of the array match the + height and width in ``KeypointsOnImage.shape``. + """ + distance_maps = np.zeros((height, width, len(keypoints)), dtype=np.float32) + + yy = np.arange(0, height) + xx = np.arange(0, width) + grid_xx, grid_yy = np.meshgrid(xx, yy) + + for i, (x, y) in enumerate(keypoints): + distance_maps[:, :, i] = (grid_xx - x) ** 2 + (grid_yy - y) ** 2 + + distance_maps = np.sqrt(distance_maps) + if inverted: + return 1 / (distance_maps + 1) + return distance_maps + + +def from_distance_maps( + distance_maps: np.ndarray, + inverted: bool, + if_not_found_coords: Optional[Union[Sequence[int], dict]], + threshold: Optional[float] = None, +) -> List[Tuple[float, float]]: + """Convert outputs of ``to_distance_maps()`` to ``KeypointsOnImage``. + This is the inverse of `to_distance_maps`. + + Args: + distance_maps (np.ndarray): The distance maps. ``N`` is the number of keypoints. + inverted (bool): Whether the given distance maps were generated in inverted mode + (i.e. :func:`KeypointsOnImage.to_distance_maps` was called with ``inverted=True``) or in non-inverted mode. + if_not_found_coords (tuple, list, dict or None, optional): + Coordinates to use for keypoints that cannot be found in `distance_maps`. + + * If this is a ``list``/``tuple``, it must contain two ``int`` values. + * If it is a ``dict``, it must contain the keys ``x`` and ``y`` with each containing one ``int`` value. + * If this is ``None``, then the keypoint will not be added. + threshold (float): The search for keypoints works by searching for the + argmin (non-inverted) or argmax (inverted) in each channel. This + parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit + as a keypoint. Use ``None`` to use no min/max. + nb_channels (None, int): Number of channels of the image on which the keypoints are placed. + Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set + to ``(height, width)``, otherwise ``(height, width, nb_channels)``. + """ + if distance_maps.ndim != 3: + raise ValueError( + f"Expected three-dimensional input, " + f"got {distance_maps.ndim} dimensions and shape {distance_maps.shape}." + ) + height, width, nb_keypoints = distance_maps.shape + + drop_if_not_found = False + if if_not_found_coords is None: + drop_if_not_found = True + if_not_found_x = -1 + if_not_found_y = -1 + elif isinstance(if_not_found_coords, (tuple, list)): + if len(if_not_found_coords) != 2: + raise ValueError( + f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, " + f"got {len(if_not_found_coords)}." + ) + if_not_found_x = if_not_found_coords[0] + if_not_found_y = if_not_found_coords[1] + elif isinstance(if_not_found_coords, dict): + if_not_found_x = if_not_found_coords["x"] + if_not_found_y = if_not_found_coords["y"] + else: + raise ValueError( + f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}." + ) + + keypoints = [] + for i in range(nb_keypoints): + if inverted: + hitidx_flat = np.argmax(distance_maps[..., i]) + else: + hitidx_flat = np.argmin(distance_maps[..., i]) + hitidx_ndim = np.unravel_index(hitidx_flat, (height, width)) + if not inverted and threshold is not None: + found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] < threshold + elif inverted and threshold is not None: + found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] >= threshold + else: + found = True + if found: + keypoints.append((float(hitidx_ndim[1]), float(hitidx_ndim[0]))) + else: + if not drop_if_not_found: + keypoints.append((if_not_found_x, if_not_found_y)) + + return keypoints + + +def keypoint_piecewise_affine( + keypoint: KeypointInternalType, + matrix: Optional[skimage.transform.PiecewiseAffineTransform], + h: int, + w: int, + keypoints_threshold: float, +) -> KeypointInternalType: + if matrix is None: + return keypoint + x, y, a, s = keypoint[:4] + dist_maps = to_distance_maps([(x, y)], h, w, True) + dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) + x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)[0] + return x, y, a, s + + +def bbox_piecewise_affine( + bbox: BoxInternalType, + matrix: Optional[skimage.transform.PiecewiseAffineTransform], + h: int, + w: int, + keypoints_threshold: float, +) -> BoxInternalType: + if matrix is None: + return bbox + x1, y1, x2, y2 = denormalize_bbox(bbox, h, w)[:4] + keypoints = [ + (x1, y1), + (x2, y1), + (x2, y2), + (x1, y2), + ] + dist_maps = to_distance_maps(keypoints, h, w, True) + dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0) + keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold) + keypoints = [i for i in keypoints if 0 <= i[0] < w and 0 <= i[1] < h] + keypoints_arr = np.array(keypoints) + x1 = keypoints_arr[:, 0].min() + y1 = keypoints_arr[:, 1].min() + x2 = keypoints_arr[:, 0].max() + y2 = keypoints_arr[:, 1].max() + return normalize_bbox((x1, y1, x2, y2), h, w) + + +def vflip(img: np.ndarray) -> np.ndarray: + return np.ascontiguousarray(img[::-1, ...]) + + +def hflip(img: np.ndarray) -> np.ndarray: + return np.ascontiguousarray(img[:, ::-1, ...]) + + +def hflip_cv2(img: np.ndarray) -> np.ndarray: + return cv2.flip(img, 1) + + +@preserve_shape +def random_flip(img: np.ndarray, code: int) -> np.ndarray: + return cv2.flip(img, code) + + +def transpose(img: np.ndarray) -> np.ndarray: + return img.transpose(1, 0, 2) if len(img.shape) > 2 else img.transpose(1, 0) + + +def rot90(img: np.ndarray, factor: int) -> np.ndarray: + img = np.rot90(img, factor) + return np.ascontiguousarray(img) + + +def bbox_vflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613 + """Flip a bounding box vertically around the x-axis. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image rows. + cols: Image cols. + + Returns: + tuple: A bounding box `(x_min, y_min, x_max, y_max)`. + + """ + x_min, y_min, x_max, y_max = bbox[:4] + return x_min, 1 - y_max, x_max, 1 - y_min + + +def bbox_hflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613 + """Flip a bounding box horizontally around the y-axis. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image rows. + cols: Image cols. + + Returns: + A bounding box `(x_min, y_min, x_max, y_max)`. + + """ + x_min, y_min, x_max, y_max = bbox[:4] + return 1 - x_max, y_min, 1 - x_min, y_max + + +def bbox_flip(bbox: BoxInternalType, d: int, rows: int, cols: int) -> BoxInternalType: + """Flip a bounding box either vertically, horizontally or both depending on the value of `d`. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + d: dimension. 0 for vertical flip, 1 for horizontal, -1 for transpose + rows: Image rows. + cols: Image cols. + + Returns: + A bounding box `(x_min, y_min, x_max, y_max)`. + + Raises: + ValueError: if value of `d` is not -1, 0 or 1. + + """ + if d == 0: + bbox = bbox_vflip(bbox, rows, cols) + elif d == 1: + bbox = bbox_hflip(bbox, rows, cols) + elif d == -1: + bbox = bbox_hflip(bbox, rows, cols) + bbox = bbox_vflip(bbox, rows, cols) + else: + raise ValueError("Invalid d value {}. Valid values are -1, 0 and 1".format(d)) + return bbox + + +def bbox_transpose( + bbox: KeypointInternalType, axis: int, rows: int, cols: int +) -> KeypointInternalType: # skipcq: PYL-W0613 + """Transposes a bounding box along given axis. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + axis: 0 - main axis, 1 - secondary axis. + rows: Image rows. + cols: Image cols. + + Returns: + A bounding box tuple `(x_min, y_min, x_max, y_max)`. + + Raises: + ValueError: If axis not equal to 0 or 1. + + """ + x_min, y_min, x_max, y_max = bbox[:4] + if axis not in {0, 1}: + raise ValueError("Axis must be either 0 or 1.") + if axis == 0: + bbox = (y_min, x_min, y_max, x_max) + if axis == 1: + bbox = (1 - y_max, 1 - x_max, 1 - y_min, 1 - x_min) + return bbox + + +@angle_2pi_range +def keypoint_vflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType: + """Flip a keypoint vertically around the x-axis. + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + rows: Image height. + cols: Image width. + + Returns: + tuple: A keypoint `(x, y, angle, scale)`. + + """ + x, y, angle, scale = keypoint[:4] + angle = -angle + return x, (rows - 1) - y, angle, scale + + +@angle_2pi_range +def keypoint_hflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType: + """Flip a keypoint horizontally around the y-axis. + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + rows: Image height. + cols: Image width. + + Returns: + A keypoint `(x, y, angle, scale)`. + + """ + x, y, angle, scale = keypoint[:4] + angle = math.pi - angle + return (cols - 1) - x, y, angle, scale + + +def keypoint_flip(keypoint: KeypointInternalType, d: int, rows: int, cols: int) -> KeypointInternalType: + """Flip a keypoint either vertically, horizontally or both depending on the value of `d`. + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + d: Number of flip. Must be -1, 0 or 1: + * 0 - vertical flip, + * 1 - horizontal flip, + * -1 - vertical and horizontal flip. + rows: Image height. + cols: Image width. + + Returns: + A keypoint `(x, y, angle, scale)`. + + Raises: + ValueError: if value of `d` is not -1, 0 or 1. + + """ + if d == 0: + keypoint = keypoint_vflip(keypoint, rows, cols) + elif d == 1: + keypoint = keypoint_hflip(keypoint, rows, cols) + elif d == -1: + keypoint = keypoint_hflip(keypoint, rows, cols) + keypoint = keypoint_vflip(keypoint, rows, cols) + else: + raise ValueError(f"Invalid d value {d}. Valid values are -1, 0 and 1") + return keypoint + + +def keypoint_transpose(keypoint: KeypointInternalType) -> KeypointInternalType: + """Rotate a keypoint by angle. + + Args: + keypoint: A keypoint `(x, y, angle, scale)`. + + Returns: + A keypoint `(x, y, angle, scale)`. + + """ + x, y, angle, scale = keypoint[:4] + + if angle <= np.pi: + angle = np.pi - angle + else: + angle = 3 * np.pi - angle + + return y, x, angle, scale + + +@preserve_channel_dim +def pad( + img: np.ndarray, + min_height: int, + min_width: int, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, +) -> np.ndarray: + height, width = img.shape[:2] + + if height < min_height: + h_pad_top = int((min_height - height) / 2.0) + h_pad_bottom = min_height - height - h_pad_top + else: + h_pad_top = 0 + h_pad_bottom = 0 + + if width < min_width: + w_pad_left = int((min_width - width) / 2.0) + w_pad_right = min_width - width - w_pad_left + else: + w_pad_left = 0 + w_pad_right = 0 + + img = pad_with_params(img, h_pad_top, h_pad_bottom, w_pad_left, w_pad_right, border_mode, value) + + if img.shape[:2] != (max(min_height, height), max(min_width, width)): + raise RuntimeError( + "Invalid result shape. Got: {}. Expected: {}".format( + img.shape[:2], (max(min_height, height), max(min_width, width)) + ) + ) + + return img + + +@preserve_channel_dim +def pad_with_params( + img: np.ndarray, + h_pad_top: int, + h_pad_bottom: int, + w_pad_left: int, + w_pad_right: int, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, +) -> np.ndarray: + pad_fn = _maybe_process_in_chunks( + cv2.copyMakeBorder, + top=h_pad_top, + bottom=h_pad_bottom, + left=w_pad_left, + right=w_pad_right, + borderType=border_mode, + value=value, + ) + return pad_fn(img) + + +@preserve_shape +def optical_distortion( + img: np.ndarray, + k: int = 0, + dx: int = 0, + dy: int = 0, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, +) -> np.ndarray: + """Barrel / pincushion distortion. Unconventional augment. + + Reference: + | https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion + | https://stackoverflow.com/questions/10364201/image-transformation-in-opencv + | https://stackoverflow.com/questions/2477774/correcting-fisheye-distortion-programmatically + | http://www.coldvision.io/2017/03/02/advanced-lane-finding-using-opencv/ + """ + height, width = img.shape[:2] + + fx = width + fy = height + + cx = width * 0.5 + dx + cy = height * 0.5 + dy + + camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) + + distortion = np.array([k, k, 0, 0, 0], dtype=np.float32) + map1, map2 = cv2.initUndistortRectifyMap( + camera_matrix, distortion, None, None, (width, height), cv2.CV_32FC1 # type: ignore[attr-defined] + ) + return cv2.remap(img, map1, map2, interpolation=interpolation, borderMode=border_mode, borderValue=value) + + +@preserve_shape +def grid_distortion( + img: np.ndarray, + num_steps: int = 10, + xsteps: Tuple = (), + ysteps: Tuple = (), + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, +) -> np.ndarray: + """Perform a grid distortion of an input image. + + Reference: + http://pythology.blogspot.sg/2014/03/interpolation-on-regular-distorted-grid.html + """ + height, width = img.shape[:2] + + x_step = width // num_steps + xx = np.zeros(width, np.float32) + prev = 0 + for idx in range(num_steps + 1): + x = idx * x_step + start = int(x) + end = int(x) + x_step + if end > width: + end = width + cur = width + else: + cur = prev + x_step * xsteps[idx] + + xx[start:end] = np.linspace(prev, cur, end - start) + prev = cur + + y_step = height // num_steps + yy = np.zeros(height, np.float32) + prev = 0 + for idx in range(num_steps + 1): + y = idx * y_step + start = int(y) + end = int(y) + y_step + if end > height: + end = height + cur = height + else: + cur = prev + y_step * ysteps[idx] + + yy[start:end] = np.linspace(prev, cur, end - start) + prev = cur + + map_x, map_y = np.meshgrid(xx, yy) + map_x = map_x.astype(np.float32) + map_y = map_y.astype(np.float32) + + remap_fn = _maybe_process_in_chunks( + cv2.remap, + map1=map_x, + map2=map_y, + interpolation=interpolation, + borderMode=border_mode, + borderValue=value, + ) + return remap_fn(img) + + +@preserve_shape +def elastic_transform_approx( + img: np.ndarray, + alpha: float, + sigma: float, + alpha_affine: float, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, + random_state: Optional[np.random.RandomState] = None, +) -> np.ndarray: + """Elastic deformation of images as described in [Simard2003]_ (with modifications for speed). + Based on https://gist.github.com/ernestum/601cdf56d2b424757de5 + + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + """ + height, width = img.shape[:2] + + # Random affine + center_square = np.array((height, width), dtype=np.float32) // 2 + square_size = min((height, width)) // 3 + alpha = float(alpha) + sigma = float(sigma) + alpha_affine = float(alpha_affine) + + pts1 = np.array( + [ + center_square + square_size, + [center_square[0] + square_size, center_square[1] - square_size], + center_square - square_size, + ], + dtype=np.float32, + ) + pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype( + np.float32 + ) + matrix = cv2.getAffineTransform(pts1, pts2) + + warp_fn = _maybe_process_in_chunks( + cv2.warpAffine, + M=matrix, + dsize=(width, height), + flags=interpolation, + borderMode=border_mode, + borderValue=value, + ) + img = warp_fn(img) + + dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1 + cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx) + dx *= alpha + + dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1 + cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy) + dy *= alpha + + x, y = np.meshgrid(np.arange(width), np.arange(height)) + + map_x = np.float32(x + dx) + map_y = np.float32(y + dy) + + remap_fn = _maybe_process_in_chunks( + cv2.remap, + map1=map_x, + map2=map_y, + interpolation=interpolation, + borderMode=border_mode, + borderValue=value, + ) + return remap_fn(img) diff --git a/custom_albumentations/augmentations/geometric/resize.py b/custom_albumentations/augmentations/geometric/resize.py new file mode 100644 index 0000000000000000000000000000000000000000..913c08429916289e2dfc41bcfcc561aa41f28aae --- /dev/null +++ b/custom_albumentations/augmentations/geometric/resize.py @@ -0,0 +1,198 @@ +import random +from typing import Dict, Sequence, Tuple, Union + +import cv2 +import numpy as np + +from ...core.transforms_interface import ( + BoxInternalType, + DualTransform, + KeypointInternalType, + to_tuple, +) +from . import functional as F + +__all__ = ["RandomScale", "LongestMaxSize", "SmallestMaxSize", "Resize"] + + +class RandomScale(DualTransform): + """Randomly resize the input. Output image size is different from the input image size. + + Args: + scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the + range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1. + If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high). + Default: (-0.1, 0.1). + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, scale_limit=0.1, interpolation=cv2.INTER_LINEAR, always_apply=False, p=0.5): + super(RandomScale, self).__init__(always_apply, p) + self.scale_limit = to_tuple(scale_limit, bias=1.0) + self.interpolation = interpolation + + def get_params(self): + return {"scale": random.uniform(self.scale_limit[0], self.scale_limit[1])} + + def apply(self, img, scale=0, interpolation=cv2.INTER_LINEAR, **params): + return F.scale(img, scale, interpolation) + + def apply_to_bbox(self, bbox, **params): + # Bounding box coordinates are scale invariant + return bbox + + def apply_to_keypoint(self, keypoint, scale=0, **params): + return F.keypoint_scale(keypoint, scale, scale) + + def get_transform_init_args(self): + return {"interpolation": self.interpolation, "scale_limit": to_tuple(self.scale_limit, bias=-1.0)} + + +class LongestMaxSize(DualTransform): + """Rescale an image so that maximum side is equal to max_size, keeping the aspect ratio of the initial image. + + Args: + max_size (int, list of int): maximum size of the image after the transformation. When using a list, max size + will be randomly selected from the values in the list. + interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + max_size: Union[int, Sequence[int]] = 1024, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 1, + ): + super(LongestMaxSize, self).__init__(always_apply, p) + self.interpolation = interpolation + self.max_size = max_size + + def apply( + self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params + ) -> np.ndarray: + return F.longest_max_size(img, max_size=max_size, interpolation=interpolation) + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + # Bounding box coordinates are scale invariant + return bbox + + def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType: + height = params["rows"] + width = params["cols"] + + scale = max_size / max([height, width]) + return F.keypoint_scale(keypoint, scale, scale) + + def get_params(self) -> Dict[str, int]: + return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)} + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ("max_size", "interpolation") + + +class SmallestMaxSize(DualTransform): + """Rescale an image so that minimum side is equal to max_size, keeping the aspect ratio of the initial image. + + Args: + max_size (int, list of int): maximum size of smallest side of the image after the transformation. When using a + list, max size will be randomly selected from the values in the list. + interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + max_size: Union[int, Sequence[int]] = 1024, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 1, + ): + super(SmallestMaxSize, self).__init__(always_apply, p) + self.interpolation = interpolation + self.max_size = max_size + + def apply( + self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params + ) -> np.ndarray: + return F.smallest_max_size(img, max_size=max_size, interpolation=interpolation) + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return bbox + + def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType: + height = params["rows"] + width = params["cols"] + + scale = max_size / min([height, width]) + return F.keypoint_scale(keypoint, scale, scale) + + def get_params(self) -> Dict[str, int]: + return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)} + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + return ("max_size", "interpolation") + + +class Resize(DualTransform): + """Resize the input to the given height and width. + + Args: + height (int): desired height of the output. + width (int): desired width of the output. + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 1. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1): + super(Resize, self).__init__(always_apply, p) + self.height = height + self.width = width + self.interpolation = interpolation + + def apply(self, img, interpolation=cv2.INTER_LINEAR, **params): + return F.resize(img, height=self.height, width=self.width, interpolation=interpolation) + + def apply_to_bbox(self, bbox, **params): + # Bounding box coordinates are scale invariant + return bbox + + def apply_to_keypoint(self, keypoint, **params): + height = params["rows"] + width = params["cols"] + scale_x = self.width / width + scale_y = self.height / height + return F.keypoint_scale(keypoint, scale_x, scale_y) + + def get_transform_init_args_names(self): + return ("height", "width", "interpolation") diff --git a/custom_albumentations/augmentations/geometric/rotate.py b/custom_albumentations/augmentations/geometric/rotate.py new file mode 100644 index 0000000000000000000000000000000000000000..7578e60eafb416738104f308aa716592b3aa60bf --- /dev/null +++ b/custom_albumentations/augmentations/geometric/rotate.py @@ -0,0 +1,294 @@ +import math +import random +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np + +from ...core.transforms_interface import ( + BoxInternalType, + DualTransform, + FillValueType, + KeypointInternalType, + to_tuple, +) +from ..crops import functional as FCrops +from . import functional as F + +__all__ = ["Rotate", "RandomRotate90", "SafeRotate"] + + +class RandomRotate90(DualTransform): + """Randomly rotate the input by 90 degrees zero or more times. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def apply(self, img, factor=0, **params): + """ + Args: + factor (int): number of times the input will be rotated by 90 degrees. + """ + return np.ascontiguousarray(np.rot90(img, factor)) + + def get_params(self): + # Random int in the range [0, 3] + return {"factor": random.randint(0, 3)} + + def apply_to_bbox(self, bbox, factor=0, **params): + return F.bbox_rot90(bbox, factor, **params) + + def apply_to_keypoint(self, keypoint, factor=0, **params): + return F.keypoint_rot90(keypoint, factor, **params) + + def get_transform_init_args_names(self): + return () + + +class Rotate(DualTransform): + """Rotate the input by an angle selected randomly from the uniform distribution. + + Args: + limit ((int, int) or int): range from which a random angle is picked. If limit is a single int + an angle is picked from (-limit, limit). Default: (-90, 90) + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of ints, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse". + Default: "largest_box" + crop_border (bool): If True would make a largest possible crop within rotated image + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + limit=90, + interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_REFLECT_101, + value=None, + mask_value=None, + rotate_method="largest_box", + crop_border=False, + always_apply=False, + p=0.5, + ): + super(Rotate, self).__init__(always_apply, p) + self.limit = to_tuple(limit) + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + self.rotate_method = rotate_method + self.crop_border = crop_border + + if rotate_method not in ["largest_box", "ellipse"]: + raise ValueError(f"Rotation method {self.rotate_method} is not valid.") + + def apply( + self, img, angle=0, interpolation=cv2.INTER_LINEAR, x_min=None, x_max=None, y_min=None, y_max=None, **params + ): + img_out = F.rotate(img, angle, interpolation, self.border_mode, self.value) + if self.crop_border: + img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max) + return img_out + + def apply_to_mask(self, img, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, **params): + img_out = F.rotate(img, angle, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + if self.crop_border: + img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max) + return img_out + + def apply_to_bbox(self, bbox, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params): + bbox_out = F.bbox_rotate(bbox, angle, self.rotate_method, rows, cols) + if self.crop_border: + bbox_out = FCrops.bbox_crop(bbox_out, x_min, y_min, x_max, y_max, rows, cols) + return bbox_out + + def apply_to_keypoint( + self, keypoint, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params + ): + keypoint_out = F.keypoint_rotate(keypoint, angle, rows, cols, **params) + if self.crop_border: + keypoint_out = FCrops.crop_keypoint_by_coords(keypoint_out, (x_min, y_min, x_max, y_max)) + return keypoint_out + + @staticmethod + def _rotated_rect_with_max_area(h, w, angle): + """ + Given a rectangle of size wxh that has been rotated by 'angle' (in + degrees), computes the width and height of the largest possible + axis-aligned rectangle (maximal area) within the rotated rectangle. + + Code from: https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders + """ + + angle = math.radians(angle) + width_is_longer = w >= h + side_long, side_short = (w, h) if width_is_longer else (h, w) + + # since the solutions for angle, -angle and 180-angle are all the same, + # it is sufficient to look at the first quadrant and the absolute values of sin,cos: + sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle)) + if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10: + # half constrained case: two crop corners touch the longer side, + # the other two corners are on the mid-line parallel to the longer line + x = 0.5 * side_short + wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a) + else: + # fully constrained case: crop touches all 4 sides + cos_2a = cos_a * cos_a - sin_a * sin_a + wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a + + return dict( + x_min=max(0, int(w / 2 - wr / 2)), + x_max=min(w, int(w / 2 + wr / 2)), + y_min=max(0, int(h / 2 - hr / 2)), + y_max=min(h, int(h / 2 + hr / 2)), + ) + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + out_params = {"angle": random.uniform(self.limit[0], self.limit[1])} + if self.crop_border: + h, w = params["image"].shape[:2] + out_params.update(self._rotated_rect_with_max_area(h, w, out_params["angle"])) + return out_params + + def get_transform_init_args_names(self): + return ("limit", "interpolation", "border_mode", "value", "mask_value", "rotate_method", "crop_border") + + +class SafeRotate(DualTransform): + """Rotate the input inside the input's frame by an angle selected randomly from the uniform distribution. + + The resulting image may have artifacts in it. After rotation, the image may have a different aspect ratio, and + after resizing, it returns to its original shape with the original aspect ratio of the image. For these reason we + may see some artifacts. + + Args: + limit ((int, int) or int): range from which a random angle is picked. If limit is a single int + an angle is picked from (-limit, limit). Default: (-90, 90) + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of ints, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + limit: Union[float, Tuple[float, float]] = 90, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: FillValueType = None, + mask_value: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(SafeRotate, self).__init__(always_apply, p) + self.limit = to_tuple(limit) + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + + def apply(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray: + return F.safe_rotate(img, matrix, self.interpolation, self.value, self.border_mode) + + def apply_to_mask(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray: + return F.safe_rotate(img, matrix, cv2.INTER_NEAREST, self.mask_value, self.border_mode) + + def apply_to_bbox(self, bbox: BoxInternalType, cols: int = 0, rows: int = 0, **params) -> BoxInternalType: + return F.bbox_safe_rotate(bbox, params["matrix"], cols, rows) + + def apply_to_keypoint( + self, + keypoint: KeypointInternalType, + angle: float = 0, + scale_x: float = 0, + scale_y: float = 0, + cols: int = 0, + rows: int = 0, + **params + ) -> KeypointInternalType: + return F.keypoint_safe_rotate(keypoint, params["matrix"], angle, scale_x, scale_y, cols, rows) + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + angle = random.uniform(self.limit[0], self.limit[1]) + + image = params["image"] + h, w = image.shape[:2] + + # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides + image_center = (w / 2, h / 2) + + # Rotation Matrix + rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0) + + # rotation calculates the cos and sin, taking absolutes of those. + abs_cos = abs(rotation_mat[0, 0]) + abs_sin = abs(rotation_mat[0, 1]) + + # find the new width and height bounds + new_w = math.ceil(h * abs_sin + w * abs_cos) + new_h = math.ceil(h * abs_cos + w * abs_sin) + + scale_x = w / new_w + scale_y = h / new_h + + # Shift the image to create padding + rotation_mat[0, 2] += new_w / 2 - image_center[0] + rotation_mat[1, 2] += new_h / 2 - image_center[1] + + # Rescale to original size + scale_mat = np.diag(np.ones(3)) + scale_mat[0, 0] *= scale_x + scale_mat[1, 1] *= scale_y + _tmp = np.diag(np.ones(3)) + _tmp[:2] = rotation_mat + _tmp = scale_mat @ _tmp + rotation_mat = _tmp[:2] + + return {"matrix": rotation_mat, "angle": angle, "scale_x": scale_x, "scale_y": scale_y} + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str]: + return ("limit", "interpolation", "border_mode", "value", "mask_value") diff --git a/custom_albumentations/augmentations/geometric/transforms.py b/custom_albumentations/augmentations/geometric/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..9d0b561b0e40c20d72f7ae94b4a73d8a34645118 --- /dev/null +++ b/custom_albumentations/augmentations/geometric/transforms.py @@ -0,0 +1,1499 @@ +import math +import random +from enum import Enum +from typing import Dict, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +import skimage.transform + +from custom_albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox + +from ... import random_utils +from ...core.transforms_interface import ( + BoxInternalType, + DualTransform, + ImageColorType, + KeypointInternalType, + ScaleFloatType, + to_tuple, +) +from ..functional import bbox_from_mask +from . import functional as F + +__all__ = [ + "ShiftScaleRotate", + "ElasticTransform", + "Perspective", + "Affine", + "PiecewiseAffine", + "VerticalFlip", + "HorizontalFlip", + "Flip", + "Transpose", + "OpticalDistortion", + "GridDistortion", + "PadIfNeeded", +] + + +class ShiftScaleRotate(DualTransform): + """Randomly apply affine transforms: translate, scale and rotate the input. + + Args: + shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit + is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and + upper bounds should lie in range [0, 1]. Default: (-0.0625, 0.0625). + scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the + range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1. + If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high). + Default: (-0.1, 0.1). + rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the + range will be (-rotate_limit, rotate_limit). Default: (-45, 45). + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of int, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value + instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value, + the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in + the range [0, 1]. Default: None. + shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value + instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value, + the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie + in the range [0, 1]. Default: None. + rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse". + Default: "largest_box" + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, keypoints + + Image types: + uint8, float32 + """ + + def __init__( + self, + shift_limit=0.0625, + scale_limit=0.1, + rotate_limit=45, + interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_REFLECT_101, + value=None, + mask_value=None, + shift_limit_x=None, + shift_limit_y=None, + rotate_method="largest_box", + always_apply=False, + p=0.5, + ): + super(ShiftScaleRotate, self).__init__(always_apply, p) + self.shift_limit_x = to_tuple(shift_limit_x if shift_limit_x is not None else shift_limit) + self.shift_limit_y = to_tuple(shift_limit_y if shift_limit_y is not None else shift_limit) + self.scale_limit = to_tuple(scale_limit, bias=1.0) + self.rotate_limit = to_tuple(rotate_limit) + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + self.rotate_method = rotate_method + + if self.rotate_method not in ["largest_box", "ellipse"]: + raise ValueError(f"Rotation method {self.rotate_method} is not valid.") + + def apply(self, img, angle=0, scale=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params): + return F.shift_scale_rotate(img, angle, scale, dx, dy, interpolation, self.border_mode, self.value) + + def apply_to_mask(self, img, angle=0, scale=0, dx=0, dy=0, **params): + return F.shift_scale_rotate(img, angle, scale, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + + def apply_to_keypoint(self, keypoint, angle=0, scale=0, dx=0, dy=0, rows=0, cols=0, **params): + return F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols) + + def get_params(self): + return { + "angle": random.uniform(self.rotate_limit[0], self.rotate_limit[1]), + "scale": random.uniform(self.scale_limit[0], self.scale_limit[1]), + "dx": random.uniform(self.shift_limit_x[0], self.shift_limit_x[1]), + "dy": random.uniform(self.shift_limit_y[0], self.shift_limit_y[1]), + } + + def apply_to_bbox(self, bbox, angle, scale, dx, dy, **params): + return F.bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, self.rotate_method, **params) + + def get_transform_init_args(self): + return { + "shift_limit_x": self.shift_limit_x, + "shift_limit_y": self.shift_limit_y, + "scale_limit": to_tuple(self.scale_limit, bias=-1.0), + "rotate_limit": self.rotate_limit, + "interpolation": self.interpolation, + "border_mode": self.border_mode, + "value": self.value, + "mask_value": self.mask_value, + "rotate_method": self.rotate_method, + } + + +class ElasticTransform(DualTransform): + """Elastic deformation of images as described in [Simard2003]_ (with modifications). + Based on https://gist.github.com/ernestum/601cdf56d2b424757de5 + + .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for + Convolutional Neural Networks applied to Visual Document Analysis", in + Proc. of the International Conference on Document Analysis and + Recognition, 2003. + + Args: + alpha (float): + sigma (float): Gaussian filter parameter. + alpha_affine (float): The range will be (-alpha_affine, alpha_affine) + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of ints, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + approximate (boolean): Whether to smooth displacement map with fixed kernel size. + Enabling this option gives ~2X speedup on large images. + same_dxdy (boolean): Whether to use same random generated shift for x and y. + Enabling this option gives ~2X speedup. + + Targets: + image, mask, bbox + + Image types: + uint8, float32 + """ + + def __init__( + self, + alpha=1, + sigma=50, + alpha_affine=50, + interpolation=cv2.INTER_LINEAR, + border_mode=cv2.BORDER_REFLECT_101, + value=None, + mask_value=None, + always_apply=False, + approximate=False, + same_dxdy=False, + p=0.5, + ): + super(ElasticTransform, self).__init__(always_apply, p) + self.alpha = alpha + self.alpha_affine = alpha_affine + self.sigma = sigma + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + self.approximate = approximate + self.same_dxdy = same_dxdy + + def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params): + return F.elastic_transform( + img, + self.alpha, + self.sigma, + self.alpha_affine, + interpolation, + self.border_mode, + self.value, + np.random.RandomState(random_state), + self.approximate, + self.same_dxdy, + ) + + def apply_to_mask(self, img, random_state=None, **params): + return F.elastic_transform( + img, + self.alpha, + self.sigma, + self.alpha_affine, + cv2.INTER_NEAREST, + self.border_mode, + self.mask_value, + np.random.RandomState(random_state), + self.approximate, + self.same_dxdy, + ) + + def apply_to_bbox(self, bbox, random_state=None, **params): + rows, cols = params["rows"], params["cols"] + mask = np.zeros((rows, cols), dtype=np.uint8) + bbox_denorm = F.denormalize_bbox(bbox, rows, cols) + x_min, y_min, x_max, y_max = bbox_denorm[:4] + x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max) + mask[y_min:y_max, x_min:x_max] = 1 + mask = F.elastic_transform( + mask, + self.alpha, + self.sigma, + self.alpha_affine, + cv2.INTER_NEAREST, + self.border_mode, + self.mask_value, + np.random.RandomState(random_state), + self.approximate, + ) + bbox_returned = bbox_from_mask(mask) + bbox_returned = F.normalize_bbox(bbox_returned, rows, cols) + return bbox_returned + + def get_params(self): + return {"random_state": random.randint(0, 10000)} + + def get_transform_init_args_names(self): + return ( + "alpha", + "sigma", + "alpha_affine", + "interpolation", + "border_mode", + "value", + "mask_value", + "approximate", + "same_dxdy", + ) + + +class Perspective(DualTransform): + """Perform a random four point perspective transform of the input. + + Args: + scale (float or (float, float)): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. + If scale is a single float value, the range will be (0, scale). Default: (0.05, 0.1). + keep_size (bool): Whether to resize image’s back to their original size after applying the perspective + transform. If set to False, the resulting images may end up having different shapes + and will always be a list, never an array. Default: True + pad_mode (OpenCV flag): OpenCV border mode. + pad_val (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + Default: 0 + mask_pad_val (int, float, list of int, list of float): padding value for mask + if border_mode is cv2.BORDER_CONSTANT. Default: 0 + fit_output (bool): If True, the image plane size and position will be adjusted to still capture + the whole image after perspective transformation. (Followed by image resizing if keep_size is set to True.) + Otherwise, parts of the transformed image may be outside of the image plane. + This setting should not be set to True when using large scale values as it could lead to very large images. + Default: False + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, keypoints, bboxes + + Image types: + uint8, float32 + """ + + def __init__( + self, + scale=(0.05, 0.1), + keep_size=True, + pad_mode=cv2.BORDER_CONSTANT, + pad_val=0, + mask_pad_val=0, + fit_output=False, + interpolation=cv2.INTER_LINEAR, + always_apply=False, + p=0.5, + ): + super().__init__(always_apply, p) + self.scale = to_tuple(scale, 0) + self.keep_size = keep_size + self.pad_mode = pad_mode + self.pad_val = pad_val + self.mask_pad_val = mask_pad_val + self.fit_output = fit_output + self.interpolation = interpolation + + def apply(self, img, matrix=None, max_height=None, max_width=None, **params): + return F.perspective( + img, matrix, max_width, max_height, self.pad_val, self.pad_mode, self.keep_size, params["interpolation"] + ) + + def apply_to_bbox(self, bbox, matrix=None, max_height=None, max_width=None, **params): + return F.perspective_bbox(bbox, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size) + + def apply_to_keypoint(self, keypoint, matrix=None, max_height=None, max_width=None, **params): + return F.perspective_keypoint( + keypoint, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size + ) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + h, w = params["image"].shape[:2] + + scale = random_utils.uniform(*self.scale) + points = random_utils.normal(0, scale, [4, 2]) + points = np.mod(np.abs(points), 0.32) + + # top left -- no changes needed, just use jitter + # top right + points[1, 0] = 1.0 - points[1, 0] # w = 1.0 - jitter + # bottom right + points[2] = 1.0 - points[2] # w = 1.0 - jitt + # bottom left + points[3, 1] = 1.0 - points[3, 1] # h = 1.0 - jitter + + points[:, 0] *= w + points[:, 1] *= h + + # Obtain a consistent order of the points and unpack them individually. + # Warning: don't just do (tl, tr, br, bl) = _order_points(...) + # here, because the reordered points is used further below. + points = self._order_points(points) + tl, tr, br, bl = points + + # compute the width of the new image, which will be the + # maximum distance between bottom-right and bottom-left + # x-coordiates or the top-right and top-left x-coordinates + min_width = None + max_width = None + while min_width is None or min_width < 2: + width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) + width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) + max_width = int(max(width_top, width_bottom)) + min_width = int(min(width_top, width_bottom)) + if min_width < 2: + step_size = (2 - min_width) / 2 + tl[0] -= step_size + tr[0] += step_size + bl[0] -= step_size + br[0] += step_size + + # compute the height of the new image, which will be the maximum distance between the top-right + # and bottom-right y-coordinates or the top-left and bottom-left y-coordinates + min_height = None + max_height = None + while min_height is None or min_height < 2: + height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) + height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) + max_height = int(max(height_right, height_left)) + min_height = int(min(height_right, height_left)) + if min_height < 2: + step_size = (2 - min_height) / 2 + tl[1] -= step_size + tr[1] -= step_size + bl[1] += step_size + br[1] += step_size + + # now that we have the dimensions of the new image, construct + # the set of destination points to obtain a "birds eye view", + # (i.e. top-down view) of the image, again specifying points + # in the top-left, top-right, bottom-right, and bottom-left order + # do not use width-1 or height-1 here, as for e.g. width=3, height=2 + # the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0) + dst = np.array([[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]], dtype=np.float32) + + # compute the perspective transform matrix and then apply it + m = cv2.getPerspectiveTransform(points, dst) + + if self.fit_output: + m, max_width, max_height = self._expand_transform(m, (h, w)) + + return {"matrix": m, "max_height": max_height, "max_width": max_width, "interpolation": self.interpolation} + + @classmethod + def _expand_transform(cls, matrix, shape): + height, width = shape + # do not use width-1 or height-1 here, as for e.g. width=3, height=2, max_height + # the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0) + rect = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32) + dst = cv2.perspectiveTransform(np.array([rect]), matrix)[0] + + # get min x, y over transformed 4 points + # then modify target points by subtracting these minima => shift to (0, 0) + dst -= dst.min(axis=0, keepdims=True) + dst = np.around(dst, decimals=0) + + matrix_expanded = cv2.getPerspectiveTransform(rect, dst) + max_width, max_height = dst.max(axis=0) + return matrix_expanded, int(max_width), int(max_height) + + @staticmethod + def _order_points(pts: np.ndarray) -> np.ndarray: + pts = np.array(sorted(pts, key=lambda x: x[0])) + left = pts[:2] # points with smallest x coordinate - left points + right = pts[2:] # points with greatest x coordinate - right points + + if left[0][1] < left[1][1]: + tl, bl = left + else: + bl, tl = left + + if right[0][1] < right[1][1]: + tr, br = right + else: + br, tr = right + + return np.array([tl, tr, br, bl], dtype=np.float32) + + def get_transform_init_args_names(self): + return "scale", "keep_size", "pad_mode", "pad_val", "mask_pad_val", "fit_output", "interpolation" + + +class Affine(DualTransform): + """Augmentation to apply affine transformations to images. + This is mostly a wrapper around the corresponding classes and functions in OpenCV. + + Affine transformations involve: + + - Translation ("move" image on the x-/y-axis) + - Rotation + - Scaling ("zoom" in/out) + - Shear (move one side of the image, turning a square into a trapezoid) + + All such transformations can create "new" pixels in the image without a defined content, e.g. + if the image is translated to the left, pixels are created on the right. + A method has to be defined to deal with these pixel values. + The parameters `cval` and `mode` of this class deal with this. + + Some transformations involve interpolations between several pixels + of the input image to generate output pixel values. The parameters `interpolation` and + `mask_interpolation` deals with the method of interpolation used for this. + + Args: + scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and + ``0.5`` is zoomed out to ``50`` percent of the original size. + * If a single number, then that value will be used for all images. + * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``. + That the same range will be used for both x- and y-axis. To keep the aspect ratio, set + ``keep_ratio=True``, then the same value will be used for both x- and y-axis. + * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``. + Each of these keys can have the same values as described above. + Using a dictionary allows to set different values for the two axis and sampling will then happen + *independently* per axis, resulting in samples that differ between the axes. Note that when + the ``keep_ratio=True``, the x- and y-axis ranges should be the same. + translate_percent (None, number, tuple of number or dict): Translation as a fraction of the image height/width + (x-translation, y-translation), where ``0`` denotes "no change" + and ``0.5`` denotes "half of the axis size". + * If ``None`` then equivalent to ``0.0`` unless `translate_px` has a value other than ``None``. + * If a single number, then that value will be used for all images. + * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``. + That sampled fraction value will be used identically for both x- and y-axis. + * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``. + Each of these keys can have the same values as described above. + Using a dictionary allows to set different values for the two axis and sampling will then happen + *independently* per axis, resulting in samples that differ between the axes. + translate_px (None, int, tuple of int or dict): Translation in pixels. + * If ``None`` then equivalent to ``0`` unless `translate_percent` has a value other than ``None``. + * If a single int, then that value will be used for all images. + * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from + the discrete interval ``[a..b]``. That number will be used identically for both x- and y-axis. + * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``. + Each of these keys can have the same values as described above. + Using a dictionary allows to set different values for the two axis and sampling will then happen + *independently* per axis, resulting in samples that differ between the axes. + rotate (number or tuple of number): Rotation in degrees (**NOT** radians), i.e. expected value range is + around ``[-360, 360]``. Rotation happens around the *center* of the image, + not the top left corner as in some other frameworks. + * If a number, then that value will be used for all images. + * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]`` + and used as the rotation value. + shear (number, tuple of number or dict): Shear in degrees (**NOT** radians), i.e. expected value range is + around ``[-360, 360]``, with reasonable values being in the range of ``[-45, 45]``. + * If a number, then that value will be used for all images as + the shear on the x-axis (no shear on the y-axis will be done). + * If a tuple ``(a, b)``, then two value will be uniformly sampled per image + from the interval ``[a, b]`` and be used as the x- and y-shear value. + * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``. + Each of these keys can have the same values as described above. + Using a dictionary allows to set different values for the two axis and sampling will then happen + *independently* per axis, resulting in samples that differ between the axes. + interpolation (int): OpenCV interpolation flag. + mask_interpolation (int): OpenCV interpolation flag. + cval (number or sequence of number): The constant value to use when filling in newly created pixels. + (E.g. translating by 1px to the right will create a new 1px-wide column of pixels + on the left of the image). + The value is only used when `mode=constant`. The expected value range is ``[0, 255]`` for ``uint8`` images. + cval_mask (number or tuple of number): Same as cval but only for masks. + mode (int): OpenCV border flag. + fit_output (bool): If True, the image plane size and position will be adjusted to tightly capture + the whole image after affine transformation (`translate_percent` and `translate_px` are ignored). + Otherwise (``False``), parts of the transformed image may end up outside the image plane. + Fitting the output shape can be useful to avoid corners of the image being outside the image plane + after applying rotations. Default: False + keep_ratio (bool): When True, the original aspect ratio will be kept when the random scale is applied. + Default: False. + rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or + "ellipse"[1]. + Default: "largest_box" + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, keypoints, bboxes + + Image types: + uint8, float32 + + Reference: + [1] https://arxiv.org/abs/2109.13488 + """ + + def __init__( + self, + scale: Optional[Union[float, Sequence[float], dict]] = None, + translate_percent: Optional[Union[float, Sequence[float], dict]] = None, + translate_px: Optional[Union[int, Sequence[int], dict]] = None, + rotate: Optional[Union[float, Sequence[float]]] = None, + shear: Optional[Union[float, Sequence[float], dict]] = None, + interpolation: int = cv2.INTER_LINEAR, + mask_interpolation: int = cv2.INTER_NEAREST, + cval: Union[int, float, Sequence[int], Sequence[float]] = 0, + cval_mask: Union[int, float, Sequence[int], Sequence[float]] = 0, + mode: int = cv2.BORDER_CONSTANT, + fit_output: bool = False, + keep_ratio: bool = False, + rotate_method: str = "largest_box", + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + + params = [scale, translate_percent, translate_px, rotate, shear] + if all([p is None for p in params]): + scale = {"x": (0.9, 1.1), "y": (0.9, 1.1)} + translate_percent = {"x": (-0.1, 0.1), "y": (-0.1, 0.1)} + rotate = (-15, 15) + shear = {"x": (-10, 10), "y": (-10, 10)} + else: + scale = scale if scale is not None else 1.0 + rotate = rotate if rotate is not None else 0.0 + shear = shear if shear is not None else 0.0 + + self.interpolation = interpolation + self.mask_interpolation = mask_interpolation + self.cval = cval + self.cval_mask = cval_mask + self.mode = mode + self.scale = self._handle_dict_arg(scale, "scale") + self.translate_percent, self.translate_px = self._handle_translate_arg(translate_px, translate_percent) + self.rotate = to_tuple(rotate, rotate) + self.fit_output = fit_output + self.shear = self._handle_dict_arg(shear, "shear") + self.keep_ratio = keep_ratio + self.rotate_method = rotate_method + + if self.keep_ratio and self.scale["x"] != self.scale["y"]: + raise ValueError( + "When keep_ratio is True, the x and y scale range should be identical. got {}".format(self.scale) + ) + + def get_transform_init_args_names(self): + return ( + "interpolation", + "mask_interpolation", + "cval", + "mode", + "scale", + "translate_percent", + "translate_px", + "rotate", + "fit_output", + "shear", + "cval_mask", + "keep_ratio", + "rotate_method", + ) + + @staticmethod + def _handle_dict_arg(val: Union[float, Sequence[float], dict], name: str, default: float = 1.0): + if isinstance(val, dict): + if "x" not in val and "y" not in val: + raise ValueError( + f'Expected {name} dictionary to contain at least key "x" or ' 'key "y". Found neither of them.' + ) + x = val.get("x", default) + y = val.get("y", default) + return {"x": to_tuple(x, x), "y": to_tuple(y, y)} + return {"x": to_tuple(val, val), "y": to_tuple(val, val)} + + @classmethod + def _handle_translate_arg( + cls, + translate_px: Optional[Union[float, Sequence[float], dict]], + translate_percent: Optional[Union[float, Sequence[float], dict]], + ): + if translate_percent is None and translate_px is None: + translate_px = 0 + + if translate_percent is not None and translate_px is not None: + raise ValueError( + "Expected either translate_percent or translate_px to be " "provided, " "but neither of them was." + ) + + if translate_percent is not None: + # translate by percent + return cls._handle_dict_arg(translate_percent, "translate_percent", default=0.0), translate_px + + if translate_px is None: + raise ValueError("translate_px is None.") + # translate by pixels + return translate_percent, cls._handle_dict_arg(translate_px, "translate_px") + + def apply( + self, + img: np.ndarray, + matrix: skimage.transform.ProjectiveTransform = None, + output_shape: Sequence[int] = (), + **params + ) -> np.ndarray: + return F.warp_affine( + img, + matrix, + interpolation=self.interpolation, + cval=self.cval, + mode=self.mode, + output_shape=output_shape, + ) + + def apply_to_mask( + self, + img: np.ndarray, + matrix: skimage.transform.ProjectiveTransform = None, + output_shape: Sequence[int] = (), + **params + ) -> np.ndarray: + return F.warp_affine( + img, + matrix, + interpolation=self.mask_interpolation, + cval=self.cval_mask, + mode=self.mode, + output_shape=output_shape, + ) + + def apply_to_bbox( + self, + bbox: BoxInternalType, + matrix: skimage.transform.ProjectiveTransform = None, + rows: int = 0, + cols: int = 0, + output_shape: Sequence[int] = (), + **params + ) -> BoxInternalType: + return F.bbox_affine(bbox, matrix, self.rotate_method, rows, cols, output_shape) + + def apply_to_keypoint( + self, + keypoint: KeypointInternalType, + matrix: Optional[skimage.transform.ProjectiveTransform] = None, + scale: Optional[dict] = None, + **params + ) -> KeypointInternalType: + assert scale is not None and matrix is not None + return F.keypoint_affine(keypoint, matrix=matrix, scale=scale) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params: dict) -> dict: + h, w = params["image"].shape[:2] + + translate: Dict[str, Union[int, float]] + if self.translate_px is not None: + translate = {key: random.randint(*value) for key, value in self.translate_px.items()} + elif self.translate_percent is not None: + translate = {key: random.uniform(*value) for key, value in self.translate_percent.items()} + translate["x"] = translate["x"] * w + translate["y"] = translate["y"] * h + else: + translate = {"x": 0, "y": 0} + + # Look to issue https://github.com/albumentations-team/albumentations/issues/1079 + shear = {key: -random.uniform(*value) for key, value in self.shear.items()} + scale = {key: random.uniform(*value) for key, value in self.scale.items()} + if self.keep_ratio: + scale["y"] = scale["x"] + + # Look to issue https://github.com/albumentations-team/albumentations/issues/1079 + rotate = -random.uniform(*self.rotate) + + # for images we use additional shifts of (0.5, 0.5) as otherwise + # we get an ugly black border for 90deg rotations + shift_x = w / 2 - 0.5 + shift_y = h / 2 - 0.5 + + matrix_to_topleft = skimage.transform.SimilarityTransform(translation=[-shift_x, -shift_y]) + matrix_shear_y_rot = skimage.transform.AffineTransform(rotation=-np.pi / 2) + matrix_shear_y = skimage.transform.AffineTransform(shear=np.deg2rad(shear["y"])) + matrix_shear_y_rot_inv = skimage.transform.AffineTransform(rotation=np.pi / 2) + matrix_transforms = skimage.transform.AffineTransform( + scale=(scale["x"], scale["y"]), + translation=(translate["x"], translate["y"]), + rotation=np.deg2rad(rotate), + shear=np.deg2rad(shear["x"]), + ) + matrix_to_center = skimage.transform.SimilarityTransform(translation=[shift_x, shift_y]) + matrix = ( + matrix_to_topleft + + matrix_shear_y_rot + + matrix_shear_y + + matrix_shear_y_rot_inv + + matrix_transforms + + matrix_to_center + ) + if self.fit_output: + matrix, output_shape = self._compute_affine_warp_output_shape(matrix, params["image"].shape) + else: + output_shape = params["image"].shape + + return { + "rotate": rotate, + "scale": scale, + "matrix": matrix, + "output_shape": output_shape, + } + + @staticmethod + def _compute_affine_warp_output_shape( + matrix: skimage.transform.ProjectiveTransform, input_shape: Sequence[int] + ) -> Tuple[skimage.transform.ProjectiveTransform, Sequence[int]]: + height, width = input_shape[:2] + + if height == 0 or width == 0: + return matrix, input_shape + + # determine shape of output image + corners = np.array([[0, 0], [0, height - 1], [width - 1, height - 1], [width - 1, 0]]) + corners = matrix(corners) + minc = corners[:, 0].min() + minr = corners[:, 1].min() + maxc = corners[:, 0].max() + maxr = corners[:, 1].max() + out_height = maxr - minr + 1 + out_width = maxc - minc + 1 + if len(input_shape) == 3: + output_shape = np.ceil((out_height, out_width, input_shape[2])) + else: + output_shape = np.ceil((out_height, out_width)) + output_shape_tuple = tuple([int(v) for v in output_shape.tolist()]) + # fit output image in new shape + translation = (-minc, -minr) + matrix_to_fit = skimage.transform.SimilarityTransform(translation=translation) + matrix = matrix + matrix_to_fit + return matrix, output_shape_tuple + + +class PiecewiseAffine(DualTransform): + """Apply affine transformations that differ between local neighbourhoods. + This augmentation places a regular grid of points on an image and randomly moves the neighbourhood of these point + around via affine transformations. This leads to local distortions. + + This is mostly a wrapper around scikit-image's ``PiecewiseAffine``. + See also ``Affine`` for a similar technique. + + Note: + This augmenter is very slow. Try to use ``ElasticTransformation`` instead, which is at least 10x faster. + + Note: + For coordinate-based inputs (keypoints, bounding boxes, polygons, ...), + this augmenter still has to perform an image-based augmentation, + which will make it significantly slower and not fully correct for such inputs than other transforms. + + Args: + scale (float, tuple of float): Each point on the regular grid is moved around via a normal distribution. + This scale factor is equivalent to the normal distribution's sigma. + Note that the jitter (how far each point is moved in which direction) is multiplied by the height/width of + the image if ``absolute_scale=False`` (default), so this scale can be the same for different sized images. + Recommended values are in the range ``0.01`` to ``0.05`` (weak to strong augmentations). + * If a single ``float``, then that value will always be used as the scale. + * If a tuple ``(a, b)`` of ``float`` s, then a random value will + be uniformly sampled per image from the interval ``[a, b]``. + nb_rows (int, tuple of int): Number of rows of points that the regular grid should have. + Must be at least ``2``. For large images, you might want to pick a higher value than ``4``. + You might have to then adjust scale to lower values. + * If a single ``int``, then that value will always be used as the number of rows. + * If a tuple ``(a, b)``, then a value from the discrete interval + ``[a..b]`` will be uniformly sampled per image. + nb_cols (int, tuple of int): Number of columns. Analogous to `nb_rows`. + interpolation (int): The order of interpolation. The order has to be in the range 0-5: + - 0: Nearest-neighbor + - 1: Bi-linear (default) + - 2: Bi-quadratic + - 3: Bi-cubic + - 4: Bi-quartic + - 5: Bi-quintic + mask_interpolation (int): same as interpolation but for mask. + cval (number): The constant value to use when filling in newly created pixels. + cval_mask (number): Same as cval but only for masks. + mode (str): {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}, optional + Points outside the boundaries of the input are filled according + to the given mode. Modes match the behaviour of `numpy.pad`. + absolute_scale (bool): Take `scale` as an absolute value rather than a relative value. + keypoints_threshold (float): Used as threshold in conversion from distance maps to keypoints. + The search for keypoints works by searching for the + argmin (non-inverted) or argmax (inverted) in each channel. This + parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit + as a keypoint. Use ``None`` to use no min/max. Default: 0.01 + + Targets: + image, mask, keypoints, bboxes + + Image types: + uint8, float32 + + """ + + def __init__( + self, + scale: ScaleFloatType = (0.03, 0.05), + nb_rows: Union[int, Sequence[int]] = 4, + nb_cols: Union[int, Sequence[int]] = 4, + interpolation: int = 1, + mask_interpolation: int = 0, + cval: int = 0, + cval_mask: int = 0, + mode: str = "constant", + absolute_scale: bool = False, + always_apply: bool = False, + keypoints_threshold: float = 0.01, + p: float = 0.5, + ): + super(PiecewiseAffine, self).__init__(always_apply, p) + + self.scale = to_tuple(scale, scale) + self.nb_rows = to_tuple(nb_rows, nb_rows) + self.nb_cols = to_tuple(nb_cols, nb_cols) + self.interpolation = interpolation + self.mask_interpolation = mask_interpolation + self.cval = cval + self.cval_mask = cval_mask + self.mode = mode + self.absolute_scale = absolute_scale + self.keypoints_threshold = keypoints_threshold + + def get_transform_init_args_names(self): + return ( + "scale", + "nb_rows", + "nb_cols", + "interpolation", + "mask_interpolation", + "cval", + "cval_mask", + "mode", + "absolute_scale", + "keypoints_threshold", + ) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params) -> dict: + h, w = params["image"].shape[:2] + + nb_rows = np.clip(random.randint(*self.nb_rows), 2, None) + nb_cols = np.clip(random.randint(*self.nb_cols), 2, None) + nb_cells = nb_cols * nb_rows + scale = random.uniform(*self.scale) + + jitter: np.ndarray = random_utils.normal(0, scale, (nb_cells, 2)) + if not np.any(jitter > 0): + for i in range(10): # See: https://github.com/albumentations-team/albumentations/issues/1442 + jitter = random_utils.normal(0, scale, (nb_cells, 2)) + if np.any(jitter > 0): + break + if not np.any(jitter > 0): + return {"matrix": None} + + y = np.linspace(0, h, nb_rows) + x = np.linspace(0, w, nb_cols) + + # (H, W) and (H, W) for H=rows, W=cols + xx_src, yy_src = np.meshgrid(x, y) + + # (1, HW, 2) => (HW, 2) for H=rows, W=cols + points_src = np.dstack([yy_src.flat, xx_src.flat])[0] + + if self.absolute_scale: + jitter[:, 0] = jitter[:, 0] / h if h > 0 else 0.0 + jitter[:, 1] = jitter[:, 1] / w if w > 0 else 0.0 + + jitter[:, 0] = jitter[:, 0] * h + jitter[:, 1] = jitter[:, 1] * w + + points_dest = np.copy(points_src) + points_dest[:, 0] = points_dest[:, 0] + jitter[:, 0] + points_dest[:, 1] = points_dest[:, 1] + jitter[:, 1] + + # Restrict all destination points to be inside the image plane. + # This is necessary, as otherwise keypoints could be augmented + # outside of the image plane and these would be replaced by + # (-1, -1), which would not conform with the behaviour of the other augmenters. + points_dest[:, 0] = np.clip(points_dest[:, 0], 0, h - 1) + points_dest[:, 1] = np.clip(points_dest[:, 1], 0, w - 1) + + matrix = skimage.transform.PiecewiseAffineTransform() + matrix.estimate(points_src[:, ::-1], points_dest[:, ::-1]) + + return { + "matrix": matrix, + } + + def apply( + self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params + ) -> np.ndarray: + return F.piecewise_affine(img, matrix, self.interpolation, self.mode, self.cval) + + def apply_to_mask( + self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params + ) -> np.ndarray: + return F.piecewise_affine(img, matrix, self.mask_interpolation, self.mode, self.cval_mask) + + def apply_to_bbox( + self, + bbox: BoxInternalType, + rows: int = 0, + cols: int = 0, + matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, + **params + ) -> BoxInternalType: + return F.bbox_piecewise_affine(bbox, matrix, rows, cols, self.keypoints_threshold) + + def apply_to_keypoint( + self, + keypoint: KeypointInternalType, + rows: int = 0, + cols: int = 0, + matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, + **params + ): + return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols, self.keypoints_threshold) + + +class PadIfNeeded(DualTransform): + """Pad side of the image / max if side is less than desired number. + + Args: + min_height (int): minimal result image height. + min_width (int): minimal result image width. + pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument. + pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument. + position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or + PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT. + or PositionType.RANDOM. Default: PositionType.CENTER. + border_mode (OpenCV flag): OpenCV border mode. + value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of int, + list of float): padding value for mask if border_mode is cv2.BORDER_CONSTANT. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image, mask, bbox, keypoints + + Image types: + uint8, float32 + """ + + class PositionType(Enum): + CENTER = "center" + TOP_LEFT = "top_left" + TOP_RIGHT = "top_right" + BOTTOM_LEFT = "bottom_left" + BOTTOM_RIGHT = "bottom_right" + RANDOM = "random" + + def __init__( + self, + min_height: Optional[int] = 1024, + min_width: Optional[int] = 1024, + pad_height_divisor: Optional[int] = None, + pad_width_divisor: Optional[int] = None, + position: Union[PositionType, str] = PositionType.CENTER, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, + mask_value: Optional[ImageColorType] = None, + always_apply: bool = False, + p: float = 1.0, + ): + if (min_height is None) == (pad_height_divisor is None): + raise ValueError("Only one of 'min_height' and 'pad_height_divisor' parameters must be set") + + if (min_width is None) == (pad_width_divisor is None): + raise ValueError("Only one of 'min_width' and 'pad_width_divisor' parameters must be set") + + super(PadIfNeeded, self).__init__(always_apply, p) + self.min_height = min_height + self.min_width = min_width + self.pad_width_divisor = pad_width_divisor + self.pad_height_divisor = pad_height_divisor + self.position = PadIfNeeded.PositionType(position) + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + + def update_params(self, params, **kwargs): + params = super(PadIfNeeded, self).update_params(params, **kwargs) + rows = params["rows"] + cols = params["cols"] + + if self.min_height is not None: + if rows < self.min_height: + h_pad_top = int((self.min_height - rows) / 2.0) + h_pad_bottom = self.min_height - rows - h_pad_top + else: + h_pad_top = 0 + h_pad_bottom = 0 + else: + pad_remained = rows % self.pad_height_divisor + pad_rows = self.pad_height_divisor - pad_remained if pad_remained > 0 else 0 + + h_pad_top = pad_rows // 2 + h_pad_bottom = pad_rows - h_pad_top + + if self.min_width is not None: + if cols < self.min_width: + w_pad_left = int((self.min_width - cols) / 2.0) + w_pad_right = self.min_width - cols - w_pad_left + else: + w_pad_left = 0 + w_pad_right = 0 + else: + pad_remainder = cols % self.pad_width_divisor + pad_cols = self.pad_width_divisor - pad_remainder if pad_remainder > 0 else 0 + + w_pad_left = pad_cols // 2 + w_pad_right = pad_cols - w_pad_left + + h_pad_top, h_pad_bottom, w_pad_left, w_pad_right = self.__update_position_params( + h_top=h_pad_top, h_bottom=h_pad_bottom, w_left=w_pad_left, w_right=w_pad_right + ) + + params.update( + { + "pad_top": h_pad_top, + "pad_bottom": h_pad_bottom, + "pad_left": w_pad_left, + "pad_right": w_pad_right, + } + ) + return params + + def apply( + self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params + ) -> np.ndarray: + return F.pad_with_params( + img, + pad_top, + pad_bottom, + pad_left, + pad_right, + border_mode=self.border_mode, + value=self.value, + ) + + def apply_to_mask( + self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params + ) -> np.ndarray: + return F.pad_with_params( + img, + pad_top, + pad_bottom, + pad_left, + pad_right, + border_mode=self.border_mode, + value=self.mask_value, + ) + + def apply_to_bbox( + self, + bbox: BoxInternalType, + pad_top: int = 0, + pad_bottom: int = 0, + pad_left: int = 0, + pad_right: int = 0, + rows: int = 0, + cols: int = 0, + **params + ) -> BoxInternalType: + x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4] + bbox = x_min + pad_left, y_min + pad_top, x_max + pad_left, y_max + pad_top + return normalize_bbox(bbox, rows + pad_top + pad_bottom, cols + pad_left + pad_right) + + def apply_to_keypoint( + self, + keypoint: KeypointInternalType, + pad_top: int = 0, + pad_bottom: int = 0, + pad_left: int = 0, + pad_right: int = 0, + **params + ) -> KeypointInternalType: + x, y, angle, scale = keypoint[:4] + return x + pad_left, y + pad_top, angle, scale + + def get_transform_init_args_names(self): + return ( + "min_height", + "min_width", + "pad_height_divisor", + "pad_width_divisor", + "border_mode", + "value", + "mask_value", + ) + + def __update_position_params( + self, h_top: int, h_bottom: int, w_left: int, w_right: int + ) -> Tuple[int, int, int, int]: + if self.position == PadIfNeeded.PositionType.TOP_LEFT: + h_bottom += h_top + w_right += w_left + h_top = 0 + w_left = 0 + + elif self.position == PadIfNeeded.PositionType.TOP_RIGHT: + h_bottom += h_top + w_left += w_right + h_top = 0 + w_right = 0 + + elif self.position == PadIfNeeded.PositionType.BOTTOM_LEFT: + h_top += h_bottom + w_right += w_left + h_bottom = 0 + w_left = 0 + + elif self.position == PadIfNeeded.PositionType.BOTTOM_RIGHT: + h_top += h_bottom + w_left += w_right + h_bottom = 0 + w_right = 0 + + elif self.position == PadIfNeeded.PositionType.RANDOM: + h_pad = h_top + h_bottom + w_pad = w_left + w_right + h_top = random.randint(0, h_pad) + h_bottom = h_pad - h_top + w_left = random.randint(0, w_pad) + w_right = w_pad - w_left + + return h_top, h_bottom, w_left, w_right + + +class VerticalFlip(DualTransform): + """Flip the input vertically around the x-axis. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + return F.vflip(img) + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return F.bbox_vflip(bbox, **params) + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + return F.keypoint_vflip(keypoint, **params) + + def get_transform_init_args_names(self): + return () + + +class HorizontalFlip(DualTransform): + """Flip the input horizontally around the y-axis. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + if img.ndim == 3 and img.shape[2] > 1 and img.dtype == np.uint8: + # Opencv is faster than numpy only in case of + # non-gray scale 8bits images + return F.hflip_cv2(img) + + return F.hflip(img) + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return F.bbox_hflip(bbox, **params) + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + return F.keypoint_hflip(keypoint, **params) + + def get_transform_init_args_names(self): + return () + + +class Flip(DualTransform): + """Flip the input either horizontally, vertically or both horizontally and vertically. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def apply(self, img: np.ndarray, d: int = 0, **params) -> np.ndarray: + """Args: + d (int): code that specifies how to flip the input. 0 for vertical flipping, 1 for horizontal flipping, + -1 for both vertical and horizontal flipping (which is also could be seen as rotating the input by + 180 degrees). + """ + return F.random_flip(img, d) + + def get_params(self): + # Random int in the range [-1, 1] + return {"d": random.randint(-1, 1)} + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return F.bbox_flip(bbox, **params) + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + return F.keypoint_flip(keypoint, **params) + + def get_transform_init_args_names(self): + return () + + +class Transpose(DualTransform): + """Transpose the input by swapping rows and columns. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask, bboxes, keypoints + + Image types: + uint8, float32 + """ + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + return F.transpose(img) + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return F.bbox_transpose(bbox, 0, **params) + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + return F.keypoint_transpose(keypoint) + + def get_transform_init_args_names(self): + return () + + +class OpticalDistortion(DualTransform): + """ + Args: + distort_limit (float, (float, float)): If distort_limit is a single float, the range + will be (-distort_limit, distort_limit). Default: (-0.05, 0.05). + shift_limit (float, (float, float))): If shift_limit is a single float, the range + will be (-shift_limit, shift_limit). Default: (-0.05, 0.05). + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of ints, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + + Targets: + image, mask, bbox + + Image types: + uint8, float32 + """ + + def __init__( + self, + distort_limit: ScaleFloatType = 0.05, + shift_limit: ScaleFloatType = 0.05, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, + mask_value: Optional[ImageColorType] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(OpticalDistortion, self).__init__(always_apply, p) + self.shift_limit = to_tuple(shift_limit) + self.distort_limit = to_tuple(distort_limit) + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + + def apply( + self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, interpolation: int = cv2.INTER_LINEAR, **params + ) -> np.ndarray: + return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value) + + def apply_to_mask(self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, **params) -> np.ndarray: + return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + + def apply_to_bbox(self, bbox: BoxInternalType, k: int = 0, dx: int = 0, dy: int = 0, **params) -> BoxInternalType: + rows, cols = params["rows"], params["cols"] + mask = np.zeros((rows, cols), dtype=np.uint8) + bbox_denorm = F.denormalize_bbox(bbox, rows, cols) + x_min, y_min, x_max, y_max = bbox_denorm[:4] + x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max) + mask[y_min:y_max, x_min:x_max] = 1 + mask = F.optical_distortion(mask, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value) + bbox_returned = bbox_from_mask(mask) + bbox_returned = F.normalize_bbox(bbox_returned, rows, cols) + return bbox_returned + + def get_params(self): + return { + "k": random.uniform(self.distort_limit[0], self.distort_limit[1]), + "dx": round(random.uniform(self.shift_limit[0], self.shift_limit[1])), + "dy": round(random.uniform(self.shift_limit[0], self.shift_limit[1])), + } + + def get_transform_init_args_names(self): + return ( + "distort_limit", + "shift_limit", + "interpolation", + "border_mode", + "value", + "mask_value", + ) + + +class GridDistortion(DualTransform): + """ + Args: + num_steps (int): count of grid cells on each side. + distort_limit (float, (float, float)): If distort_limit is a single float, the range + will be (-distort_limit, distort_limit). Default: (-0.03, 0.03). + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of: + cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101. + Default: cv2.BORDER_REFLECT_101 + value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT. + mask_value (int, float, + list of ints, + list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks. + normalized (bool): if true, distortion will be normalized to do not go outside the image. Default: False + See for more information: https://github.com/albumentations-team/albumentations/pull/722 + + Targets: + image, mask + + Image types: + uint8, float32 + """ + + def __init__( + self, + num_steps: int = 5, + distort_limit: ScaleFloatType = 0.3, + interpolation: int = cv2.INTER_LINEAR, + border_mode: int = cv2.BORDER_REFLECT_101, + value: Optional[ImageColorType] = None, + mask_value: Optional[ImageColorType] = None, + normalized: bool = False, + always_apply: bool = False, + p: float = 0.5, + ): + super(GridDistortion, self).__init__(always_apply, p) + self.num_steps = num_steps + self.distort_limit = to_tuple(distort_limit) + self.interpolation = interpolation + self.border_mode = border_mode + self.value = value + self.mask_value = mask_value + self.normalized = normalized + + def apply( + self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), interpolation: int = cv2.INTER_LINEAR, **params + ) -> np.ndarray: + return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value) + + def apply_to_mask(self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> np.ndarray: + return F.grid_distortion( + img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value + ) + + def apply_to_bbox(self, bbox: BoxInternalType, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> BoxInternalType: + rows, cols = params["rows"], params["cols"] + mask = np.zeros((rows, cols), dtype=np.uint8) + bbox_denorm = F.denormalize_bbox(bbox, rows, cols) + x_min, y_min, x_max, y_max = bbox_denorm[:4] + x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max) + mask[y_min:y_max, x_min:x_max] = 1 + mask = F.grid_distortion( + mask, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value + ) + bbox_returned = bbox_from_mask(mask) + bbox_returned = F.normalize_bbox(bbox_returned, rows, cols) + return bbox_returned + + def _normalize(self, h, w, xsteps, ysteps): + # compensate for smaller last steps in source image. + x_step = w // self.num_steps + last_x_step = min(w, ((self.num_steps + 1) * x_step)) - (self.num_steps * x_step) + xsteps[-1] *= last_x_step / x_step + + y_step = h // self.num_steps + last_y_step = min(h, ((self.num_steps + 1) * y_step)) - (self.num_steps * y_step) + ysteps[-1] *= last_y_step / y_step + + # now normalize such that distortion never leaves image bounds. + tx = w / math.floor(w / self.num_steps) + ty = h / math.floor(h / self.num_steps) + xsteps = np.array(xsteps) * (tx / np.sum(xsteps)) + ysteps = np.array(ysteps) * (ty / np.sum(ysteps)) + + return {"stepsx": xsteps, "stepsy": ysteps} + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + h, w = params["image"].shape[:2] + + stepsx = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)] + stepsy = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)] + + if self.normalized: + return self._normalize(h, w, stepsx, stepsy) + + return {"stepsx": stepsx, "stepsy": stepsy} + + def get_transform_init_args_names(self): + return "num_steps", "distort_limit", "interpolation", "border_mode", "value", "mask_value", "normalized" diff --git a/custom_albumentations/augmentations/transforms.py b/custom_albumentations/augmentations/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..892f32933d45aff91bde8fc485a818ad7a7ae86c --- /dev/null +++ b/custom_albumentations/augmentations/transforms.py @@ -0,0 +1,2667 @@ +from __future__ import absolute_import, division + +import math +import numbers +import random +import warnings +from enum import IntEnum +from types import LambdaType +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import cv2 +import numpy as np +from scipy import special +from scipy.ndimage import gaussian_filter + +from custom_albumentations import random_utils +from custom_albumentations.augmentations.blur.functional import blur +from custom_albumentations.augmentations.utils import ( + get_num_channels, + is_grayscale_image, + is_rgb_image, +) + +from ..core.transforms_interface import ( + DualTransform, + ImageOnlyTransform, + NoOp, + ScaleFloatType, + to_tuple, +) +from ..core.utils import format_args +from . import functional as F + +__all__ = [ + "Normalize", + "RandomGamma", + "RandomGridShuffle", + "HueSaturationValue", + "RGBShift", + "RandomBrightness", + "RandomContrast", + "GaussNoise", + "CLAHE", + "ChannelShuffle", + "InvertImg", + "ToGray", + "ToRGB", + "ToSepia", + "JpegCompression", + "ImageCompression", + "ToFloat", + "FromFloat", + "RandomBrightnessContrast", + "RandomSnow", + "RandomGravel", + "RandomRain", + "RandomFog", + "RandomSunFlare", + "RandomShadow", + "RandomToneCurve", + "Lambda", + "ISONoise", + "Solarize", + "Equalize", + "Posterize", + "Downscale", + "MultiplicativeNoise", + "FancyPCA", + "ColorJitter", + "Sharpen", + "Emboss", + "Superpixels", + "TemplateTransform", + "RingingOvershoot", + "UnsharpMask", + "PixelDropout", + "Spatter", +] + + +class RandomGridShuffle(DualTransform): + """ + Random shuffle grid's cells on image. + + Args: + grid ((int, int)): size of grid for splitting image. + + Targets: + image, mask, keypoints + + Image types: + uint8, float32 + """ + + def __init__(self, grid: Tuple[int, int] = (3, 3), always_apply: bool = False, p: float = 0.5): + super(RandomGridShuffle, self).__init__(always_apply, p) + self.grid = grid + + def apply(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params): + return F.swap_tiles_on_image(img, tiles) + + def apply_to_mask(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params): + return F.swap_tiles_on_image(img, tiles) + + def apply_to_keypoint( + self, keypoint: Tuple[float, ...], tiles: np.ndarray = np.array(None), rows: int = 0, cols: int = 0, **params + ): + for ( + current_left_up_corner_row, + current_left_up_corner_col, + old_left_up_corner_row, + old_left_up_corner_col, + height_tile, + width_tile, + ) in tiles: + x, y = keypoint[:2] + + if (old_left_up_corner_row <= y < (old_left_up_corner_row + height_tile)) and ( + old_left_up_corner_col <= x < (old_left_up_corner_col + width_tile) + ): + x = x - old_left_up_corner_col + current_left_up_corner_col + y = y - old_left_up_corner_row + current_left_up_corner_row + keypoint = (x, y) + tuple(keypoint[2:]) + break + + return keypoint + + def get_params_dependent_on_targets(self, params): + height, width = params["image"].shape[:2] + n, m = self.grid + + if n <= 0 or m <= 0: + raise ValueError("Grid's values must be positive. Current grid [%s, %s]" % (n, m)) + + if n > height // 2 or m > width // 2: + raise ValueError("Incorrect size cell of grid. Just shuffle pixels of image") + + height_split = np.linspace(0, height, n + 1, dtype=np.int32) + width_split = np.linspace(0, width, m + 1, dtype=np.int32) + + height_matrix, width_matrix = np.meshgrid(height_split, width_split, indexing="ij") + + index_height_matrix = height_matrix[:-1, :-1] + index_width_matrix = width_matrix[:-1, :-1] + + shifted_index_height_matrix = height_matrix[1:, 1:] + shifted_index_width_matrix = width_matrix[1:, 1:] + + height_tile_sizes = shifted_index_height_matrix - index_height_matrix + width_tile_sizes = shifted_index_width_matrix - index_width_matrix + + tiles_sizes = np.stack((height_tile_sizes, width_tile_sizes), axis=2) + + index_matrix = np.indices((n, m)) + new_index_matrix = np.stack(index_matrix, axis=2) + + for bbox_size in np.unique(tiles_sizes.reshape(-1, 2), axis=0): + eq_mat = np.all(tiles_sizes == bbox_size, axis=2) + new_index_matrix[eq_mat] = random_utils.permutation(new_index_matrix[eq_mat]) + + new_index_matrix = np.split(new_index_matrix, 2, axis=2) + + old_x = index_height_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1) + old_y = index_width_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1) + + shift_x = height_tile_sizes.reshape(-1) + shift_y = width_tile_sizes.reshape(-1) + + curr_x = index_height_matrix.reshape(-1) + curr_y = index_width_matrix.reshape(-1) + + tiles = np.stack([curr_x, curr_y, old_x, old_y, shift_x, shift_y], axis=1) + + return {"tiles": tiles} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return ("grid",) + + +class Normalize(ImageOnlyTransform): + """Normalization is applied by the formula: `img = (img - mean * max_pixel_value) / (std * max_pixel_value)` + + Args: + mean (float, list of float): mean values + std (float, list of float): std values + max_pixel_value (float): maximum possible pixel value + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + mean=(0.485, 0.456, 0.406), + std=(0.229, 0.224, 0.225), + max_pixel_value=255.0, + always_apply=False, + p=1.0, + ): + super(Normalize, self).__init__(always_apply, p) + self.mean = mean + self.std = std + self.max_pixel_value = max_pixel_value + + def apply(self, image, **params): + return F.normalize(image, self.mean, self.std, self.max_pixel_value) + + def get_transform_init_args_names(self): + return ("mean", "std", "max_pixel_value") + + +class ImageCompression(ImageOnlyTransform): + """Decreases image quality by Jpeg, WebP compression of an image. + + Args: + quality_lower (float): lower bound on the image quality. + Should be in [0, 100] range for jpeg and [1, 100] for webp. + quality_upper (float): upper bound on the image quality. + Should be in [0, 100] range for jpeg and [1, 100] for webp. + compression_type (ImageCompressionType): should be ImageCompressionType.JPEG or ImageCompressionType.WEBP. + Default: ImageCompressionType.JPEG + + Targets: + image + + Image types: + uint8, float32 + """ + + class ImageCompressionType(IntEnum): + JPEG = 0 + WEBP = 1 + + def __init__( + self, + quality_lower=99, + quality_upper=100, + compression_type=ImageCompressionType.JPEG, + always_apply=False, + p=0.5, + ): + super(ImageCompression, self).__init__(always_apply, p) + + self.compression_type = ImageCompression.ImageCompressionType(compression_type) + low_thresh_quality_assert = 0 + + if self.compression_type == ImageCompression.ImageCompressionType.WEBP: + low_thresh_quality_assert = 1 + + if not low_thresh_quality_assert <= quality_lower <= 100: + raise ValueError("Invalid quality_lower. Got: {}".format(quality_lower)) + if not low_thresh_quality_assert <= quality_upper <= 100: + raise ValueError("Invalid quality_upper. Got: {}".format(quality_upper)) + + self.quality_lower = quality_lower + self.quality_upper = quality_upper + + def apply(self, image, quality=100, image_type=".jpg", **params): + if not image.ndim == 2 and image.shape[-1] not in (1, 3, 4): + raise TypeError("ImageCompression transformation expects 1, 3 or 4 channel images.") + return F.image_compression(image, quality, image_type) + + def get_params(self): + image_type = ".jpg" + + if self.compression_type == ImageCompression.ImageCompressionType.WEBP: + image_type = ".webp" + + return { + "quality": random.randint(self.quality_lower, self.quality_upper), + "image_type": image_type, + } + + def get_transform_init_args(self): + return { + "quality_lower": self.quality_lower, + "quality_upper": self.quality_upper, + "compression_type": self.compression_type.value, + } + + +class JpegCompression(ImageCompression): + """Decreases image quality by Jpeg compression of an image. + + Args: + quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range + quality_upper (float): upper bound on the jpeg quality. Should be in [0, 100] range + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, quality_lower=99, quality_upper=100, always_apply=False, p=0.5): + super(JpegCompression, self).__init__( + quality_lower=quality_lower, + quality_upper=quality_upper, + compression_type=ImageCompression.ImageCompressionType.JPEG, + always_apply=always_apply, + p=p, + ) + warnings.warn( + f"{self.__class__.__name__} has been deprecated. Please use ImageCompression", + FutureWarning, + ) + + def get_transform_init_args(self): + return { + "quality_lower": self.quality_lower, + "quality_upper": self.quality_upper, + } + + +class RandomSnow(ImageOnlyTransform): + """Bleach out some pixel values simulating snow. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + snow_point_lower (float): lower_bond of the amount of snow. Should be in [0, 1] range + snow_point_upper (float): upper_bond of the amount of snow. Should be in [0, 1] range + brightness_coeff (float): larger number will lead to a more snow on the image. Should be >= 0 + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + snow_point_lower=0.1, + snow_point_upper=0.3, + brightness_coeff=2.5, + always_apply=False, + p=0.5, + ): + super(RandomSnow, self).__init__(always_apply, p) + + if not 0 <= snow_point_lower <= snow_point_upper <= 1: + raise ValueError( + "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format( + (snow_point_lower, snow_point_upper) + ) + ) + if brightness_coeff < 0: + raise ValueError("brightness_coeff must be greater than 0. Got: {}".format(brightness_coeff)) + + self.snow_point_lower = snow_point_lower + self.snow_point_upper = snow_point_upper + self.brightness_coeff = brightness_coeff + + def apply(self, image, snow_point=0.1, **params): + return F.add_snow(image, snow_point, self.brightness_coeff) + + def get_params(self): + return {"snow_point": random.uniform(self.snow_point_lower, self.snow_point_upper)} + + def get_transform_init_args_names(self): + return ("snow_point_lower", "snow_point_upper", "brightness_coeff") + + +class RandomGravel(ImageOnlyTransform): + """Add gravels. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + gravel_roi (float, float, float, float): (top-left x, top-left y, + bottom-right x, bottom right y). Should be in [0, 1] range + number_of_patches (int): no. of gravel patches required + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + gravel_roi: tuple = (0.1, 0.4, 0.9, 0.9), + number_of_patches: int = 2, + always_apply: bool = False, + p: float = 0.5, + ): + super(RandomGravel, self).__init__(always_apply, p) + + (gravel_lower_x, gravel_lower_y, gravel_upper_x, gravel_upper_y) = gravel_roi + + if not 0 <= gravel_lower_x < gravel_upper_x <= 1 or not 0 <= gravel_lower_y < gravel_upper_y <= 1: + raise ValueError("Invalid gravel_roi. Got: %s." % gravel_roi) + if number_of_patches < 1: + raise ValueError("Invalid gravel number_of_patches. Got: %s." % number_of_patches) + + self.gravel_roi = gravel_roi + self.number_of_patches = number_of_patches + + def generate_gravel_patch(self, rectangular_roi): + x1, y1, x2, y2 = rectangular_roi + gravels = [] + area = abs((x2 - x1) * (y2 - y1)) + count = area // 10 + gravels = np.empty([count, 2], dtype=np.int64) + gravels[:, 0] = random_utils.randint(x1, x2, count) + gravels[:, 1] = random_utils.randint(y1, y2, count) + return gravels + + def apply(self, image, gravels_infos=(), **params): + return F.add_gravel(image, gravels_infos) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + img = params["image"] + height, width = img.shape[:2] + + x_min, y_min, x_max, y_max = self.gravel_roi + x_min = int(x_min * width) + x_max = int(x_max * width) + y_min = int(y_min * height) + y_max = int(y_max * height) + + max_height = 200 + max_width = 30 + + rectangular_rois = np.zeros([self.number_of_patches, 4], dtype=np.int64) + xx1 = random_utils.randint(x_min + 1, x_max, self.number_of_patches) # xmax + xx2 = random_utils.randint(x_min, xx1) # xmin + yy1 = random_utils.randint(y_min + 1, y_max, self.number_of_patches) # ymax + yy2 = random_utils.randint(y_min, yy1) # ymin + + rectangular_rois[:, 0] = xx2 + rectangular_rois[:, 1] = yy2 + rectangular_rois[:, 2] = [min(tup) for tup in zip(xx1, xx2 + max_height)] + rectangular_rois[:, 3] = [min(tup) for tup in zip(yy1, yy2 + max_width)] + + minx = [] + maxx = [] + miny = [] + maxy = [] + val = [] + for roi in rectangular_rois: + gravels = self.generate_gravel_patch(roi) + x = gravels[:, 0] + y = gravels[:, 1] + r = random_utils.randint(1, 4, len(gravels)) + sat = random_utils.randint(0, 255, len(gravels)) + miny.append(np.maximum(y - r, 0)) + maxy.append(np.minimum(y + r, y)) + minx.append(np.maximum(x - r, 0)) + maxx.append(np.minimum(x + r, x)) + val.append(sat) + + return { + "gravels_infos": np.stack( + [ + np.concatenate(miny), + np.concatenate(maxy), + np.concatenate(minx), + np.concatenate(maxx), + np.concatenate(val), + ], + 1, + ) + } + + def get_transform_init_args_names(self): + return {"gravel_roi": self.gravel_roi, "number_of_patches": self.number_of_patches} + + +class RandomRain(ImageOnlyTransform): + """Adds rain effects. + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + slant_lower: should be in range [-20, 20]. + slant_upper: should be in range [-20, 20]. + drop_length: should be in range [0, 100]. + drop_width: should be in range [1, 5]. + drop_color (list of (r, g, b)): rain lines color. + blur_value (int): rainy view are blurry + brightness_coefficient (float): rainy days are usually shady. Should be in range [0, 1]. + rain_type: One of [None, "drizzle", "heavy", "torrential"] + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + slant_lower=-10, + slant_upper=10, + drop_length=20, + drop_width=1, + drop_color=(200, 200, 200), + blur_value=7, + brightness_coefficient=0.7, + rain_type=None, + always_apply=False, + p=0.5, + ): + super(RandomRain, self).__init__(always_apply, p) + + if rain_type not in ["drizzle", "heavy", "torrential", None]: + raise ValueError( + "raint_type must be one of ({}). Got: {}".format(["drizzle", "heavy", "torrential", None], rain_type) + ) + if not -20 <= slant_lower <= slant_upper <= 20: + raise ValueError( + "Invalid combination of slant_lower and slant_upper. Got: {}".format((slant_lower, slant_upper)) + ) + if not 1 <= drop_width <= 5: + raise ValueError("drop_width must be in range [1, 5]. Got: {}".format(drop_width)) + if not 0 <= drop_length <= 100: + raise ValueError("drop_length must be in range [0, 100]. Got: {}".format(drop_length)) + if not 0 <= brightness_coefficient <= 1: + raise ValueError("brightness_coefficient must be in range [0, 1]. Got: {}".format(brightness_coefficient)) + + self.slant_lower = slant_lower + self.slant_upper = slant_upper + + self.drop_length = drop_length + self.drop_width = drop_width + self.drop_color = drop_color + self.blur_value = blur_value + self.brightness_coefficient = brightness_coefficient + self.rain_type = rain_type + + def apply(self, image, slant=10, drop_length=20, rain_drops=(), **params): + return F.add_rain( + image, + slant, + drop_length, + self.drop_width, + self.drop_color, + self.blur_value, + self.brightness_coefficient, + rain_drops, + ) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + img = params["image"] + slant = int(random.uniform(self.slant_lower, self.slant_upper)) + + height, width = img.shape[:2] + area = height * width + + if self.rain_type == "drizzle": + num_drops = area // 770 + drop_length = 10 + elif self.rain_type == "heavy": + num_drops = width * height // 600 + drop_length = 30 + elif self.rain_type == "torrential": + num_drops = area // 500 + drop_length = 60 + else: + drop_length = self.drop_length + num_drops = area // 600 + + rain_drops = [] + + for _i in range(num_drops): # If You want heavy rain, try increasing this + if slant < 0: + x = random.randint(slant, width) + else: + x = random.randint(0, width - slant) + + y = random.randint(0, height - drop_length) + + rain_drops.append((x, y)) + + return {"drop_length": drop_length, "slant": slant, "rain_drops": rain_drops} + + def get_transform_init_args_names(self): + return ( + "slant_lower", + "slant_upper", + "drop_length", + "drop_width", + "drop_color", + "blur_value", + "brightness_coefficient", + "rain_type", + ) + + +class RandomFog(ImageOnlyTransform): + """Simulates fog for the image + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + fog_coef_lower (float): lower limit for fog intensity coefficient. Should be in [0, 1] range. + fog_coef_upper (float): upper limit for fog intensity coefficient. Should be in [0, 1] range. + alpha_coef (float): transparency of the fog circles. Should be in [0, 1] range. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + fog_coef_lower=0.3, + fog_coef_upper=1, + alpha_coef=0.08, + always_apply=False, + p=0.5, + ): + super(RandomFog, self).__init__(always_apply, p) + + if not 0 <= fog_coef_lower <= fog_coef_upper <= 1: + raise ValueError( + "Invalid combination if fog_coef_lower and fog_coef_upper. Got: {}".format( + (fog_coef_lower, fog_coef_upper) + ) + ) + if not 0 <= alpha_coef <= 1: + raise ValueError("alpha_coef must be in range [0, 1]. Got: {}".format(alpha_coef)) + + self.fog_coef_lower = fog_coef_lower + self.fog_coef_upper = fog_coef_upper + self.alpha_coef = alpha_coef + + def apply(self, image, fog_coef=0.1, haze_list=(), **params): + return F.add_fog(image, fog_coef, self.alpha_coef, haze_list) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + img = params["image"] + fog_coef = random.uniform(self.fog_coef_lower, self.fog_coef_upper) + + height, width = imshape = img.shape[:2] + + hw = max(1, int(width // 3 * fog_coef)) + + haze_list = [] + midx = width // 2 - 2 * hw + midy = height // 2 - hw + index = 1 + + while midx > -hw or midy > -hw: + for _i in range(hw // 10 * index): + x = random.randint(midx, width - midx - hw) + y = random.randint(midy, height - midy - hw) + haze_list.append((x, y)) + + midx -= 3 * hw * width // sum(imshape) + midy -= 3 * hw * height // sum(imshape) + index += 1 + + return {"haze_list": haze_list, "fog_coef": fog_coef} + + def get_transform_init_args_names(self): + return ("fog_coef_lower", "fog_coef_upper", "alpha_coef") + + +class RandomSunFlare(ImageOnlyTransform): + """Simulates Sun Flare for the image + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + flare_roi (float, float, float, float): region of the image where flare will + appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1]. + angle_lower (float): should be in range [0, `angle_upper`]. + angle_upper (float): should be in range [`angle_lower`, 1]. + num_flare_circles_lower (int): lower limit for the number of flare circles. + Should be in range [0, `num_flare_circles_upper`]. + num_flare_circles_upper (int): upper limit for the number of flare circles. + Should be in range [`num_flare_circles_lower`, inf]. + src_radius (int): + src_color ((int, int, int)): color of the flare + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + flare_roi=(0, 0, 1, 0.5), + angle_lower=0, + angle_upper=1, + num_flare_circles_lower=6, + num_flare_circles_upper=10, + src_radius=400, + src_color=(255, 255, 255), + always_apply=False, + p=0.5, + ): + super(RandomSunFlare, self).__init__(always_apply, p) + + ( + flare_center_lower_x, + flare_center_lower_y, + flare_center_upper_x, + flare_center_upper_y, + ) = flare_roi + + if ( + not 0 <= flare_center_lower_x < flare_center_upper_x <= 1 + or not 0 <= flare_center_lower_y < flare_center_upper_y <= 1 + ): + raise ValueError("Invalid flare_roi. Got: {}".format(flare_roi)) + if not 0 <= angle_lower < angle_upper <= 1: + raise ValueError( + "Invalid combination of angle_lower nad angle_upper. Got: {}".format((angle_lower, angle_upper)) + ) + if not 0 <= num_flare_circles_lower < num_flare_circles_upper: + raise ValueError( + "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format( + (num_flare_circles_lower, num_flare_circles_upper) + ) + ) + + self.flare_center_lower_x = flare_center_lower_x + self.flare_center_upper_x = flare_center_upper_x + + self.flare_center_lower_y = flare_center_lower_y + self.flare_center_upper_y = flare_center_upper_y + + self.angle_lower = angle_lower + self.angle_upper = angle_upper + self.num_flare_circles_lower = num_flare_circles_lower + self.num_flare_circles_upper = num_flare_circles_upper + + self.src_radius = src_radius + self.src_color = src_color + + def apply(self, image, flare_center_x=0.5, flare_center_y=0.5, circles=(), **params): + return F.add_sun_flare( + image, + flare_center_x, + flare_center_y, + self.src_radius, + self.src_color, + circles, + ) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + img = params["image"] + height, width = img.shape[:2] + + angle = 2 * math.pi * random.uniform(self.angle_lower, self.angle_upper) + + flare_center_x = random.uniform(self.flare_center_lower_x, self.flare_center_upper_x) + flare_center_y = random.uniform(self.flare_center_lower_y, self.flare_center_upper_y) + + flare_center_x = int(width * flare_center_x) + flare_center_y = int(height * flare_center_y) + + num_circles = random.randint(self.num_flare_circles_lower, self.num_flare_circles_upper) + + circles = [] + + x = [] + y = [] + + def line(t): + return (flare_center_x + t * math.cos(angle), flare_center_y + t * math.sin(angle)) + + for t_val in range(-flare_center_x, width - flare_center_x, 10): + rand_x, rand_y = line(t_val) + x.append(rand_x) + y.append(rand_y) + + for _i in range(num_circles): + alpha = random.uniform(0.05, 0.2) + r = random.randint(0, len(x) - 1) + rad = random.randint(1, max(height // 100 - 2, 2)) + + r_color = random.randint(max(self.src_color[0] - 50, 0), self.src_color[0]) + g_color = random.randint(max(self.src_color[1] - 50, 0), self.src_color[1]) + b_color = random.randint(max(self.src_color[2] - 50, 0), self.src_color[2]) + + circles += [ + ( + alpha, + (int(x[r]), int(y[r])), + pow(rad, 3), + (r_color, g_color, b_color), + ) + ] + + return { + "circles": circles, + "flare_center_x": flare_center_x, + "flare_center_y": flare_center_y, + } + + def get_transform_init_args(self): + return { + "flare_roi": ( + self.flare_center_lower_x, + self.flare_center_lower_y, + self.flare_center_upper_x, + self.flare_center_upper_y, + ), + "angle_lower": self.angle_lower, + "angle_upper": self.angle_upper, + "num_flare_circles_lower": self.num_flare_circles_lower, + "num_flare_circles_upper": self.num_flare_circles_upper, + "src_radius": self.src_radius, + "src_color": self.src_color, + } + + +class RandomShadow(ImageOnlyTransform): + """Simulates shadows for the image + + From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library + + Args: + shadow_roi (float, float, float, float): region of the image where shadows + will appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1]. + num_shadows_lower (int): Lower limit for the possible number of shadows. + Should be in range [0, `num_shadows_upper`]. + num_shadows_upper (int): Lower limit for the possible number of shadows. + Should be in range [`num_shadows_lower`, inf]. + shadow_dimension (int): number of edges in the shadow polygons + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + shadow_roi=(0, 0.5, 1, 1), + num_shadows_lower=1, + num_shadows_upper=2, + shadow_dimension=5, + always_apply=False, + p=0.5, + ): + super(RandomShadow, self).__init__(always_apply, p) + + (shadow_lower_x, shadow_lower_y, shadow_upper_x, shadow_upper_y) = shadow_roi + + if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1: + raise ValueError("Invalid shadow_roi. Got: {}".format(shadow_roi)) + if not 0 <= num_shadows_lower <= num_shadows_upper: + raise ValueError( + "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format( + (num_shadows_lower, num_shadows_upper) + ) + ) + + self.shadow_roi = shadow_roi + + self.num_shadows_lower = num_shadows_lower + self.num_shadows_upper = num_shadows_upper + + self.shadow_dimension = shadow_dimension + + def apply(self, image, vertices_list=(), **params): + return F.add_shadow(image, vertices_list) + + @property + def targets_as_params(self): + return ["image"] + + def get_params_dependent_on_targets(self, params): + img = params["image"] + height, width = img.shape[:2] + + num_shadows = random.randint(self.num_shadows_lower, self.num_shadows_upper) + + x_min, y_min, x_max, y_max = self.shadow_roi + + x_min = int(x_min * width) + x_max = int(x_max * width) + y_min = int(y_min * height) + y_max = int(y_max * height) + + vertices_list = [] + + for _index in range(num_shadows): + vertex = [] + for _dimension in range(self.shadow_dimension): + vertex.append((random.randint(x_min, x_max), random.randint(y_min, y_max))) + + vertices = np.array([vertex], dtype=np.int32) + vertices_list.append(vertices) + + return {"vertices_list": vertices_list} + + def get_transform_init_args_names(self): + return ( + "shadow_roi", + "num_shadows_lower", + "num_shadows_upper", + "shadow_dimension", + ) + + +class RandomToneCurve(ImageOnlyTransform): + """Randomly change the relationship between bright and dark areas of the image by manipulating its tone curve. + + Args: + scale (float): standard deviation of the normal distribution. + Used to sample random distances to move two control points that modify the image's curve. + Values should be in range [0, 1]. Default: 0.1 + + + Targets: + image + + Image types: + uint8 + """ + + def __init__( + self, + scale=0.1, + always_apply=False, + p=0.5, + ): + super(RandomToneCurve, self).__init__(always_apply, p) + self.scale = scale + + def apply(self, image, low_y, high_y, **params): + return F.move_tone_curve(image, low_y, high_y) + + def get_params(self): + return { + "low_y": np.clip(random_utils.normal(loc=0.25, scale=self.scale), 0, 1), + "high_y": np.clip(random_utils.normal(loc=0.75, scale=self.scale), 0, 1), + } + + def get_transform_init_args_names(self): + return ("scale",) + + +class HueSaturationValue(ImageOnlyTransform): + """Randomly change hue, saturation and value of the input image. + + Args: + hue_shift_limit ((int, int) or int): range for changing hue. If hue_shift_limit is a single int, the range + will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20). + sat_shift_limit ((int, int) or int): range for changing saturation. If sat_shift_limit is a single int, + the range will be (-sat_shift_limit, sat_shift_limit). Default: (-30, 30). + val_shift_limit ((int, int) or int): range for changing value. If val_shift_limit is a single int, the range + will be (-val_shift_limit, val_shift_limit). Default: (-20, 20). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + hue_shift_limit=20, + sat_shift_limit=30, + val_shift_limit=20, + always_apply=False, + p=0.5, + ): + super(HueSaturationValue, self).__init__(always_apply, p) + self.hue_shift_limit = to_tuple(hue_shift_limit) + self.sat_shift_limit = to_tuple(sat_shift_limit) + self.val_shift_limit = to_tuple(val_shift_limit) + + def apply(self, image, hue_shift=0, sat_shift=0, val_shift=0, **params): + if not is_rgb_image(image) and not is_grayscale_image(image): + raise TypeError("HueSaturationValue transformation expects 1-channel or 3-channel images.") + return F.shift_hsv(image, hue_shift, sat_shift, val_shift) + + def get_params(self): + return { + "hue_shift": random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1]), + "sat_shift": random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]), + "val_shift": random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]), + } + + def get_transform_init_args_names(self): + return ("hue_shift_limit", "sat_shift_limit", "val_shift_limit") + + +class Solarize(ImageOnlyTransform): + """Invert all pixel values above a threshold. + + Args: + threshold ((int, int) or int, or (float, float) or float): range for solarizing threshold. + If threshold is a single value, the range will be [threshold, threshold]. Default: 128. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + any + """ + + def __init__(self, threshold=128, always_apply=False, p=0.5): + super(Solarize, self).__init__(always_apply, p) + + if isinstance(threshold, (int, float)): + self.threshold = to_tuple(threshold, low=threshold) + else: + self.threshold = to_tuple(threshold, low=0) + + def apply(self, image, threshold=0, **params): + return F.solarize(image, threshold) + + def get_params(self): + return {"threshold": random.uniform(self.threshold[0], self.threshold[1])} + + def get_transform_init_args_names(self): + return ("threshold",) + + +class Posterize(ImageOnlyTransform): + """Reduce the number of bits for each color channel. + + Args: + num_bits ((int, int) or int, + or list of ints [r, g, b], + or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits. + If num_bits is a single value, the range will be [num_bits, num_bits]. + Must be in range [0, 8]. Default: 4. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8 + """ + + def __init__(self, num_bits=4, always_apply=False, p=0.5): + super(Posterize, self).__init__(always_apply, p) + + if isinstance(num_bits, (list, tuple)): + if len(num_bits) == 3: + self.num_bits = [to_tuple(i, 0) for i in num_bits] + else: + self.num_bits = to_tuple(num_bits, 0) + else: + self.num_bits = to_tuple(num_bits, num_bits) + + def apply(self, image, num_bits=1, **params): + return F.posterize(image, num_bits) + + def get_params(self): + if len(self.num_bits) == 3: + return {"num_bits": [random.randint(i[0], i[1]) for i in self.num_bits]} + return {"num_bits": random.randint(self.num_bits[0], self.num_bits[1])} + + def get_transform_init_args_names(self): + return ("num_bits",) + + +class Equalize(ImageOnlyTransform): + """Equalize the image histogram. + + Args: + mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method. + by_channels (bool): If True, use equalization by channels separately, + else convert image to YCbCr representation and use equalization by `Y` channel. + mask (np.ndarray, callable): If given, only the pixels selected by + the mask are included in the analysis. Maybe 1 channel or 3 channel array or callable. + Function signature must include `image` argument. + mask_params (list of str): Params for mask function. + + Targets: + image + + Image types: + uint8 + """ + + def __init__( + self, + mode="cv", + by_channels=True, + mask=None, + mask_params=(), + always_apply=False, + p=0.5, + ): + modes = ["cv", "pil"] + if mode not in modes: + raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode)) + + super(Equalize, self).__init__(always_apply, p) + self.mode = mode + self.by_channels = by_channels + self.mask = mask + self.mask_params = mask_params + + def apply(self, image, mask=None, **params): + return F.equalize(image, mode=self.mode, by_channels=self.by_channels, mask=mask) + + def get_params_dependent_on_targets(self, params): + if not callable(self.mask): + return {"mask": self.mask} + + return {"mask": self.mask(**params)} + + @property + def targets_as_params(self): + return ["image"] + list(self.mask_params) + + def get_transform_init_args_names(self): + return ("mode", "by_channels") + + +class RGBShift(ImageOnlyTransform): + """Randomly shift values for each channel of the input RGB image. + + Args: + r_shift_limit ((int, int) or int): range for changing values for the red channel. If r_shift_limit is a single + int, the range will be (-r_shift_limit, r_shift_limit). Default: (-20, 20). + g_shift_limit ((int, int) or int): range for changing values for the green channel. If g_shift_limit is a + single int, the range will be (-g_shift_limit, g_shift_limit). Default: (-20, 20). + b_shift_limit ((int, int) or int): range for changing values for the blue channel. If b_shift_limit is a single + int, the range will be (-b_shift_limit, b_shift_limit). Default: (-20, 20). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + r_shift_limit=20, + g_shift_limit=20, + b_shift_limit=20, + always_apply=False, + p=0.5, + ): + super(RGBShift, self).__init__(always_apply, p) + self.r_shift_limit = to_tuple(r_shift_limit) + self.g_shift_limit = to_tuple(g_shift_limit) + self.b_shift_limit = to_tuple(b_shift_limit) + + def apply(self, image, r_shift=0, g_shift=0, b_shift=0, **params): + if not is_rgb_image(image): + raise TypeError("RGBShift transformation expects 3-channel images.") + return F.shift_rgb(image, r_shift, g_shift, b_shift) + + def get_params(self): + return { + "r_shift": random.uniform(self.r_shift_limit[0], self.r_shift_limit[1]), + "g_shift": random.uniform(self.g_shift_limit[0], self.g_shift_limit[1]), + "b_shift": random.uniform(self.b_shift_limit[0], self.b_shift_limit[1]), + } + + def get_transform_init_args_names(self): + return ("r_shift_limit", "g_shift_limit", "b_shift_limit") + + +class RandomBrightnessContrast(ImageOnlyTransform): + """Randomly change brightness and contrast of the input image. + + Args: + brightness_limit ((float, float) or float): factor range for changing brightness. + If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2). + contrast_limit ((float, float) or float): factor range for changing contrast. + If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2). + brightness_by_max (Boolean): If True adjust contrast by image dtype maximum, + else adjust contrast by image mean. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__( + self, + brightness_limit=0.2, + contrast_limit=0.2, + brightness_by_max=True, + always_apply=False, + p=0.5, + ): + super(RandomBrightnessContrast, self).__init__(always_apply, p) + self.brightness_limit = to_tuple(brightness_limit) + self.contrast_limit = to_tuple(contrast_limit) + self.brightness_by_max = brightness_by_max + + def apply(self, img, alpha=1.0, beta=0.0, **params): + return F.brightness_contrast_adjust(img, alpha, beta, self.brightness_by_max) + + def get_params(self): + return { + "alpha": 1.0 + random.uniform(self.contrast_limit[0], self.contrast_limit[1]), + "beta": 0.0 + random.uniform(self.brightness_limit[0], self.brightness_limit[1]), + } + + def get_transform_init_args_names(self): + return ("brightness_limit", "contrast_limit", "brightness_by_max") + + +class RandomBrightness(RandomBrightnessContrast): + """Randomly change brightness of the input image. + + Args: + limit ((float, float) or float): factor range for changing brightness. + If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, limit=0.2, always_apply=False, p=0.5): + super(RandomBrightness, self).__init__(brightness_limit=limit, contrast_limit=0, always_apply=always_apply, p=p) + warnings.warn( + "This class has been deprecated. Please use RandomBrightnessContrast", + FutureWarning, + ) + + def get_transform_init_args(self): + return {"limit": self.brightness_limit} + + +class RandomContrast(RandomBrightnessContrast): + """Randomly change contrast of the input image. + + Args: + limit ((float, float) or float): factor range for changing contrast. + If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, limit=0.2, always_apply=False, p=0.5): + super(RandomContrast, self).__init__(brightness_limit=0, contrast_limit=limit, always_apply=always_apply, p=p) + warnings.warn( + f"{self.__class__.__name__} has been deprecated. Please use RandomBrightnessContrast", + FutureWarning, + ) + + def get_transform_init_args(self): + return {"limit": self.contrast_limit} + + +class GaussNoise(ImageOnlyTransform): + """Apply gaussian noise to the input image. + + Args: + var_limit ((float, float) or float): variance range for noise. If var_limit is a single float, the range + will be (0, var_limit). Default: (10.0, 50.0). + mean (float): mean of the noise. Default: 0 + per_channel (bool): if set to True, noise will be sampled for each channel independently. + Otherwise, the noise will be sampled once for all channels. Default: True + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, var_limit=(10.0, 50.0), mean=0, per_channel=True, always_apply=False, p=0.5): + super(GaussNoise, self).__init__(always_apply, p) + if isinstance(var_limit, (tuple, list)): + if var_limit[0] < 0: + raise ValueError("Lower var_limit should be non negative.") + if var_limit[1] < 0: + raise ValueError("Upper var_limit should be non negative.") + self.var_limit = var_limit + elif isinstance(var_limit, (int, float)): + if var_limit < 0: + raise ValueError("var_limit should be non negative.") + + self.var_limit = (0, var_limit) + else: + raise TypeError( + "Expected var_limit type to be one of (int, float, tuple, list), got {}".format(type(var_limit)) + ) + + self.mean = mean + self.per_channel = per_channel + + def apply(self, img, gauss=None, **params): + return F.gauss_noise(img, gauss=gauss) + + def get_params_dependent_on_targets(self, params): + image = params["image"] + var = random.uniform(self.var_limit[0], self.var_limit[1]) + sigma = var**0.5 + + if self.per_channel: + gauss = random_utils.normal(self.mean, sigma, image.shape) + else: + gauss = random_utils.normal(self.mean, sigma, image.shape[:2]) + if len(image.shape) == 3: + gauss = np.expand_dims(gauss, -1) + + return {"gauss": gauss} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return ("var_limit", "per_channel", "mean") + + +class ISONoise(ImageOnlyTransform): + """ + Apply camera sensor noise. + + Args: + color_shift (float, float): variance range for color hue change. + Measured as a fraction of 360 degree Hue angle in HLS colorspace. + intensity ((float, float): Multiplicative factor that control strength + of color and luminace noise. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8 + """ + + def __init__(self, color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=0.5): + super(ISONoise, self).__init__(always_apply, p) + self.intensity = intensity + self.color_shift = color_shift + + def apply(self, img, color_shift=0.05, intensity=1.0, random_state=None, **params): + return F.iso_noise(img, color_shift, intensity, np.random.RandomState(random_state)) + + def get_params(self): + return { + "color_shift": random.uniform(self.color_shift[0], self.color_shift[1]), + "intensity": random.uniform(self.intensity[0], self.intensity[1]), + "random_state": random.randint(0, 65536), + } + + def get_transform_init_args_names(self): + return ("intensity", "color_shift") + + +class CLAHE(ImageOnlyTransform): + """Apply Contrast Limited Adaptive Histogram Equalization to the input image. + + Args: + clip_limit (float or (float, float)): upper threshold value for contrast limiting. + If clip_limit is a single float value, the range will be (1, clip_limit). Default: (1, 4). + tile_grid_size ((int, int)): size of grid for histogram equalization. Default: (8, 8). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8 + """ + + def __init__(self, clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5): + super(CLAHE, self).__init__(always_apply, p) + self.clip_limit = to_tuple(clip_limit, 1) + self.tile_grid_size = tuple(tile_grid_size) + + def apply(self, img, clip_limit=2, **params): + if not is_rgb_image(img) and not is_grayscale_image(img): + raise TypeError("CLAHE transformation expects 1-channel or 3-channel images.") + + return F.clahe(img, clip_limit, self.tile_grid_size) + + def get_params(self): + return {"clip_limit": random.uniform(self.clip_limit[0], self.clip_limit[1])} + + def get_transform_init_args_names(self): + return ("clip_limit", "tile_grid_size") + + +class ChannelShuffle(ImageOnlyTransform): + """Randomly rearrange channels of the input RGB image. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + @property + def targets_as_params(self): + return ["image"] + + def apply(self, img, channels_shuffled=(0, 1, 2), **params): + return F.channel_shuffle(img, channels_shuffled) + + def get_params_dependent_on_targets(self, params): + img = params["image"] + ch_arr = list(range(img.shape[2])) + random.shuffle(ch_arr) + return {"channels_shuffled": ch_arr} + + def get_transform_init_args_names(self): + return () + + +class InvertImg(ImageOnlyTransform): + """Invert the input image by subtracting pixel values from max values of the image types, + i.e., 255 for uint8 and 1.0 for float32. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def apply(self, img, **params): + return F.invert(img) + + def get_transform_init_args_names(self): + return () + + +class RandomGamma(ImageOnlyTransform): + """ + Args: + gamma_limit (float or (float, float)): If gamma_limit is a single float value, + the range will be (-gamma_limit, gamma_limit). Default: (80, 120). + eps: Deprecated. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5): + super(RandomGamma, self).__init__(always_apply, p) + self.gamma_limit = to_tuple(gamma_limit) + self.eps = eps + + def apply(self, img, gamma=1, **params): + return F.gamma_transform(img, gamma=gamma) + + def get_params(self): + return {"gamma": random.uniform(self.gamma_limit[0], self.gamma_limit[1]) / 100.0} + + def get_transform_init_args_names(self): + return ("gamma_limit", "eps") + + +class ToGray(ImageOnlyTransform): + """Convert the input RGB image to grayscale. If the mean pixel value for the resulting image is greater + than 127, invert the resulting grayscale image. + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def apply(self, img, **params): + if is_grayscale_image(img): + warnings.warn("The image is already gray.") + return img + if not is_rgb_image(img): + raise TypeError("ToGray transformation expects 3-channel images.") + + return F.to_gray(img) + + def get_transform_init_args_names(self): + return () + + +class ToRGB(ImageOnlyTransform): + """Convert the input grayscale image to RGB. + + Args: + p (float): probability of applying the transform. Default: 1. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, always_apply=True, p=1.0): + super(ToRGB, self).__init__(always_apply=always_apply, p=p) + + def apply(self, img, **params): + if is_rgb_image(img): + warnings.warn("The image is already an RGB.") + return img + if not is_grayscale_image(img): + raise TypeError("ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1.") + + return F.gray_to_rgb(img) + + def get_transform_init_args_names(self): + return () + + +class ToSepia(ImageOnlyTransform): + """Applies sepia filter to the input RGB image + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + """ + + def __init__(self, always_apply=False, p=0.5): + super(ToSepia, self).__init__(always_apply, p) + self.sepia_transformation_matrix = np.array( + [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]] + ) + + def apply(self, image, **params): + if not is_rgb_image(image): + raise TypeError("ToSepia transformation expects 3-channel images.") + return F.linear_transformation_rgb(image, self.sepia_transformation_matrix) + + def get_transform_init_args_names(self): + return () + + +class ToFloat(ImageOnlyTransform): + """Divide pixel values by `max_value` to get a float32 output array where all values lie in the range [0, 1.0]. + If `max_value` is None the transform will try to infer the maximum value by inspecting the data type of the input + image. + + See Also: + :class:`~albumentations.augmentations.transforms.FromFloat` + + Args: + max_value (float): maximum possible input value. Default: None. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image + + Image types: + any type + + """ + + def __init__(self, max_value=None, always_apply=False, p=1.0): + super(ToFloat, self).__init__(always_apply, p) + self.max_value = max_value + + def apply(self, img, **params): + return F.to_float(img, self.max_value) + + def get_transform_init_args_names(self): + return ("max_value",) + + +class FromFloat(ImageOnlyTransform): + """Take an input array where all values should lie in the range [0, 1.0], multiply them by `max_value` and then + cast the resulted value to a type specified by `dtype`. If `max_value` is None the transform will try to infer + the maximum value for the data type from the `dtype` argument. + + This is the inverse transform for :class:`~albumentations.augmentations.transforms.ToFloat`. + + Args: + max_value (float): maximum possible input value. Default: None. + dtype (string or numpy data type): data type of the output. See the `'Data types' page from the NumPy docs`_. + Default: 'uint16'. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image + + Image types: + float32 + + .. _'Data types' page from the NumPy docs: + https://docs.scipy.org/doc/numpy/user/basics.types.html + """ + + def __init__(self, dtype="uint16", max_value=None, always_apply=False, p=1.0): + super(FromFloat, self).__init__(always_apply, p) + self.dtype = np.dtype(dtype) + self.max_value = max_value + + def apply(self, img, **params): + return F.from_float(img, self.dtype, self.max_value) + + def get_transform_init_args(self): + return {"dtype": self.dtype.name, "max_value": self.max_value} + + +class Downscale(ImageOnlyTransform): + """Decreases image quality by downscaling and upscaling back. + + Args: + scale_min (float): lower bound on the image scale. Should be < 1. + scale_max (float): lower bound on the image scale. Should be . + interpolation: cv2 interpolation method. Could be: + - single cv2 interpolation flag - selected method will be used for downscale and upscale. + - dict(downscale=flag, upscale=flag) + - Downscale.Interpolation(downscale=flag, upscale=flag) - + Default: Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST) + + Targets: + image + + Image types: + uint8, float32 + """ + + class Interpolation: + def __init__(self, *, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_NEAREST): + self.downscale = downscale + self.upscale = upscale + + def __init__( + self, + scale_min: float = 0.25, + scale_max: float = 0.25, + interpolation: Optional[Union[int, Interpolation, Dict[str, int]]] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super(Downscale, self).__init__(always_apply, p) + if interpolation is None: + self.interpolation = self.Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST) + warnings.warn( + "Using default interpolation INTER_NEAREST, which is sub-optimal." + "Please specify interpolation mode for downscale and upscale explicitly." + "For additional information see this PR https://github.com/albumentations-team/albumentations/pull/584" + ) + elif isinstance(interpolation, int): + self.interpolation = self.Interpolation(downscale=interpolation, upscale=interpolation) + elif isinstance(interpolation, self.Interpolation): + self.interpolation = interpolation + elif isinstance(interpolation, dict): + self.interpolation = self.Interpolation(**interpolation) + else: + raise ValueError( + "Wrong interpolation data type. Supported types: `Optional[Union[int, Interpolation, Dict[str, int]]]`." + f" Got: {type(interpolation)}" + ) + + if scale_min > scale_max: + raise ValueError("Expected scale_min be less or equal scale_max, got {} {}".format(scale_min, scale_max)) + if scale_max >= 1: + raise ValueError("Expected scale_max to be less than 1, got {}".format(scale_max)) + self.scale_min = scale_min + self.scale_max = scale_max + + def apply(self, img: np.ndarray, scale: Optional[float] = None, **params) -> np.ndarray: + return F.downscale( + img, + scale=scale, + down_interpolation=self.interpolation.downscale, + up_interpolation=self.interpolation.upscale, + ) + + def get_params(self) -> Dict[str, Any]: + return {"scale": random.uniform(self.scale_min, self.scale_max)} + + def get_transform_init_args_names(self) -> Tuple[str, str]: + return "scale_min", "scale_max" + + def _to_dict(self) -> Dict[str, Any]: + result = super()._to_dict() + result["interpolation"] = {"upscale": self.interpolation.upscale, "downscale": self.interpolation.downscale} + return result + + +class Lambda(NoOp): + """A flexible transformation class for using user-defined transformation functions per targets. + Function signature must include **kwargs to accept optinal arguments like interpolation method, image size, etc: + + Args: + image (callable): Image transformation function. + mask (callable): Mask transformation function. + keypoint (callable): Keypoint transformation function. + bbox (callable): BBox transformation function. + always_apply (bool): Indicates whether this transformation should be always applied. + p (float): probability of applying the transform. Default: 1.0. + + Targets: + image, mask, bboxes, keypoints + + Image types: + Any + """ + + def __init__( + self, + image=None, + mask=None, + keypoint=None, + bbox=None, + name=None, + always_apply=False, + p=1.0, + ): + super(Lambda, self).__init__(always_apply, p) + + self.name = name + self.custom_apply_fns = {target_name: F.noop for target_name in ("image", "mask", "keypoint", "bbox")} + for target_name, custom_apply_fn in { + "image": image, + "mask": mask, + "keypoint": keypoint, + "bbox": bbox, + }.items(): + if custom_apply_fn is not None: + if isinstance(custom_apply_fn, LambdaType) and custom_apply_fn.__name__ == "": + warnings.warn( + "Using lambda is incompatible with multiprocessing. " + "Consider using regular functions or partial()." + ) + + self.custom_apply_fns[target_name] = custom_apply_fn + + def apply(self, img, **params): + fn = self.custom_apply_fns["image"] + return fn(img, **params) + + def apply_to_mask(self, mask, **params): + fn = self.custom_apply_fns["mask"] + return fn(mask, **params) + + def apply_to_bbox(self, bbox, **params): + fn = self.custom_apply_fns["bbox"] + return fn(bbox, **params) + + def apply_to_keypoint(self, keypoint, **params): + fn = self.custom_apply_fns["keypoint"] + return fn(keypoint, **params) + + @classmethod + def is_serializable(cls): + return False + + def _to_dict(self): + if self.name is None: + raise ValueError( + "To make a Lambda transform serializable you should provide the `name` argument, " + "e.g. `Lambda(name='my_transform', image=, ...)`." + ) + return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name} + + def __repr__(self): + state = {"name": self.name} + state.update(self.custom_apply_fns.items()) + state.update(self.get_base_init_args()) + return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state)) + + +class MultiplicativeNoise(ImageOnlyTransform): + """Multiply image to random number or array of numbers. + + Args: + multiplier (float or tuple of floats): If single float image will be multiplied to this number. + If tuple of float multiplier will be in range `[multiplier[0], multiplier[1])`. Default: (0.9, 1.1). + per_channel (bool): If `False`, same values for all channels will be used. + If `True` use sample values for each channels. Default False. + elementwise (bool): If `False` multiply multiply all pixels in an image with a random value sampled once. + If `True` Multiply image pixels with values that are pixelwise randomly sampled. Defaule: False. + + Targets: + image + + Image types: + Any + """ + + def __init__( + self, + multiplier=(0.9, 1.1), + per_channel=False, + elementwise=False, + always_apply=False, + p=0.5, + ): + super(MultiplicativeNoise, self).__init__(always_apply, p) + self.multiplier = to_tuple(multiplier, multiplier) + self.per_channel = per_channel + self.elementwise = elementwise + + def apply(self, img, multiplier=np.array([1]), **kwargs): + return F.multiply(img, multiplier) + + def get_params_dependent_on_targets(self, params): + if self.multiplier[0] == self.multiplier[1]: + return {"multiplier": np.array([self.multiplier[0]])} + + img = params["image"] + + h, w = img.shape[:2] + + if self.per_channel: + c = 1 if is_grayscale_image(img) else img.shape[-1] + else: + c = 1 + + if self.elementwise: + shape = [h, w, c] + else: + shape = [c] + + multiplier = random_utils.uniform(self.multiplier[0], self.multiplier[1], shape) + if is_grayscale_image(img) and img.ndim == 2: + multiplier = np.squeeze(multiplier) + + return {"multiplier": multiplier} + + @property + def targets_as_params(self): + return ["image"] + + def get_transform_init_args_names(self): + return "multiplier", "per_channel", "elementwise" + + +class FancyPCA(ImageOnlyTransform): + """Augment RGB image using FancyPCA from Krizhevsky's paper + "ImageNet Classification with Deep Convolutional Neural Networks" + + Args: + alpha (float): how much to perturb/scale the eigen vecs and vals. + scale is samples from gaussian distribution (mu=0, sigma=alpha) + + Targets: + image + + Image types: + 3-channel uint8 images only + + Credit: + http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf + https://deshanadesai.github.io/notes/Fancy-PCA-with-Scikit-Image + https://pixelatedbrian.github.io/2018-04-29-fancy_pca/ + """ + + def __init__(self, alpha=0.1, always_apply=False, p=0.5): + super(FancyPCA, self).__init__(always_apply=always_apply, p=p) + self.alpha = alpha + + def apply(self, img, alpha=0.1, **params): + img = F.fancy_pca(img, alpha) + return img + + def get_params(self): + return {"alpha": random.gauss(0, self.alpha)} + + def get_transform_init_args_names(self): + return ("alpha",) + + +class ColorJitter(ImageOnlyTransform): + """Randomly changes the brightness, contrast, and saturation of an image. Compared to ColorJitter from torchvision, + this transform gives a little bit different results because Pillow (used in torchvision) and OpenCV (used in + Albumentations) transform an image to HSV format by different formulas. Another difference - Pillow uses uint8 + overflow, but we use value saturation. + + Args: + brightness (float or tuple of float (min, max)): How much to jitter brightness. + brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness] + or the given [min, max]. Should be non negative numbers. + contrast (float or tuple of float (min, max)): How much to jitter contrast. + contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast] + or the given [min, max]. Should be non negative numbers. + saturation (float or tuple of float (min, max)): How much to jitter saturation. + saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation] + or the given [min, max]. Should be non negative numbers. + hue (float or tuple of float (min, max)): How much to jitter hue. + hue_factor is chosen uniformly from [-hue, hue] or the given [min, max]. + Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5. + """ + + def __init__( + self, + brightness=0.2, + contrast=0.2, + saturation=0.2, + hue=0.2, + always_apply=False, + p=0.5, + ): + super(ColorJitter, self).__init__(always_apply=always_apply, p=p) + + self.brightness = self.__check_values(brightness, "brightness") + self.contrast = self.__check_values(contrast, "contrast") + self.saturation = self.__check_values(saturation, "saturation") + self.hue = self.__check_values(hue, "hue", offset=0, bounds=[-0.5, 0.5], clip=False) + + self.transforms = [ + F.adjust_brightness_torchvision, + F.adjust_contrast_torchvision, + F.adjust_saturation_torchvision, + F.adjust_hue_torchvision, + ] + + @staticmethod + def __check_values(value, name, offset=1, bounds=(0, float("inf")), clip=True): + if isinstance(value, numbers.Number): + if value < 0: + raise ValueError("If {} is a single number, it must be non negative.".format(name)) + value = [offset - value, offset + value] + if clip: + value[0] = max(value[0], 0) + elif isinstance(value, (tuple, list)) and len(value) == 2: + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError("{} values should be between {}".format(name, bounds)) + else: + raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name)) + + return value + + def get_params(self): + brightness = random.uniform(self.brightness[0], self.brightness[1]) + contrast = random.uniform(self.contrast[0], self.contrast[1]) + saturation = random.uniform(self.saturation[0], self.saturation[1]) + hue = random.uniform(self.hue[0], self.hue[1]) + + order = [0, 1, 2, 3] + random.shuffle(order) + + return { + "brightness": brightness, + "contrast": contrast, + "saturation": saturation, + "hue": hue, + "order": order, + } + + def apply(self, img, brightness=1.0, contrast=1.0, saturation=1.0, hue=0, order=[0, 1, 2, 3], **params): + if not is_rgb_image(img) and not is_grayscale_image(img): + raise TypeError("ColorJitter transformation expects 1-channel or 3-channel images.") + params = [brightness, contrast, saturation, hue] + for i in order: + img = self.transforms[i](img, params[i]) + return img + + def get_transform_init_args_names(self): + return ("brightness", "contrast", "saturation", "hue") + + +class Sharpen(ImageOnlyTransform): + """Sharpen the input image and overlays the result with the original image. + + Args: + alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is + visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5). + lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5): + super(Sharpen, self).__init__(always_apply, p) + self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0)) + self.lightness = self.__check_values(to_tuple(lightness, 0.0), name="lightness") + + @staticmethod + def __check_values(value, name, bounds=(0, float("inf"))): + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError("{} values should be between {}".format(name, bounds)) + return value + + @staticmethod + def __generate_sharpening_matrix(alpha_sample, lightness_sample): + matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32) + matrix_effect = np.array( + [[-1, -1, -1], [-1, 8 + lightness_sample, -1], [-1, -1, -1]], + dtype=np.float32, + ) + + matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect + return matrix + + def get_params(self): + alpha = random.uniform(*self.alpha) + lightness = random.uniform(*self.lightness) + sharpening_matrix = self.__generate_sharpening_matrix(alpha_sample=alpha, lightness_sample=lightness) + return {"sharpening_matrix": sharpening_matrix} + + def apply(self, img, sharpening_matrix=None, **params): + return F.convolve(img, sharpening_matrix) + + def get_transform_init_args_names(self): + return ("alpha", "lightness") + + +class Emboss(ImageOnlyTransform): + """Emboss the input image and overlays the result with the original image. + + Args: + alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is + visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5). + strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5): + super(Emboss, self).__init__(always_apply, p) + self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0)) + self.strength = self.__check_values(to_tuple(strength, 0.0), name="strength") + + @staticmethod + def __check_values(value, name, bounds=(0, float("inf"))): + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError("{} values should be between {}".format(name, bounds)) + return value + + @staticmethod + def __generate_emboss_matrix(alpha_sample, strength_sample): + matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32) + matrix_effect = np.array( + [ + [-1 - strength_sample, 0 - strength_sample, 0], + [0 - strength_sample, 1, 0 + strength_sample], + [0, 0 + strength_sample, 1 + strength_sample], + ], + dtype=np.float32, + ) + matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect + return matrix + + def get_params(self): + alpha = random.uniform(*self.alpha) + strength = random.uniform(*self.strength) + emboss_matrix = self.__generate_emboss_matrix(alpha_sample=alpha, strength_sample=strength) + return {"emboss_matrix": emboss_matrix} + + def apply(self, img, emboss_matrix=None, **params): + return F.convolve(img, emboss_matrix) + + def get_transform_init_args_names(self): + return ("alpha", "strength") + + +class Superpixels(ImageOnlyTransform): + """Transform images partially/completely to their superpixel representation. + This implementation uses skimage's version of the SLIC algorithm. + + Args: + p_replace (float or tuple of float): Defines for any segment the probability that the pixels within that + segment are replaced by their average color (otherwise, the pixels are not changed). + Examples: + * A probability of ``0.0`` would mean, that the pixels in no + segment are replaced by their average color (image is not + changed at all). + * A probability of ``0.5`` would mean, that around half of all + segments are replaced by their average color. + * A probability of ``1.0`` would mean, that all segments are + replaced by their average color (resulting in a voronoi + image). + Behaviour based on chosen data types for this parameter: + * If a ``float``, then that ``flat`` will always be used. + * If ``tuple`` ``(a, b)``, then a random probability will be + sampled from the interval ``[a, b]`` per image. + n_segments (int, or tuple of int): Rough target number of how many superpixels to generate (the algorithm + may deviate from this number). Lower value will lead to coarser superpixels. + Higher values are computationally more intensive and will hence lead to a slowdown + * If a single ``int``, then that value will always be used as the + number of segments. + * If a ``tuple`` ``(a, b)``, then a value from the discrete + interval ``[a..b]`` will be sampled per image. + max_size (int or None): Maximum image size at which the augmentation is performed. + If the width or height of an image exceeds this value, it will be + downscaled before the augmentation so that the longest side matches `max_size`. + This is done to speed up the process. The final output image has the same size as the input image. + Note that in case `p_replace` is below ``1.0``, + the down-/upscaling will affect the not-replaced pixels too. + Use ``None`` to apply no down-/upscaling. + interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of: + cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4. + Default: cv2.INTER_LINEAR. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__( + self, + p_replace: Union[float, Sequence[float]] = 0.1, + n_segments: Union[int, Sequence[int]] = 100, + max_size: Optional[int] = 128, + interpolation: int = cv2.INTER_LINEAR, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + self.p_replace = to_tuple(p_replace, p_replace) + self.n_segments = to_tuple(n_segments, n_segments) + self.max_size = max_size + self.interpolation = interpolation + + if min(self.n_segments) < 1: + raise ValueError(f"n_segments must be >= 1. Got: {n_segments}") + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str]: + return ("p_replace", "n_segments", "max_size", "interpolation") + + def get_params(self) -> dict: + n_segments = random.randint(*self.n_segments) + p = random.uniform(*self.p_replace) + return {"replace_samples": random_utils.random(n_segments) < p, "n_segments": n_segments} + + def apply(self, img: np.ndarray, replace_samples: Sequence[bool] = (False,), n_segments: int = 1, **kwargs): + return F.superpixels(img, n_segments, replace_samples, self.max_size, self.interpolation) + + +class TemplateTransform(ImageOnlyTransform): + """ + Apply blending of input image with specified templates + Args: + templates (numpy array or list of numpy arrays): Images as template for transform. + img_weight ((float, float) or float): If single float will be used as weight for input image. + If tuple of float img_weight will be in range `[img_weight[0], img_weight[1])`. Default: 0.5. + template_weight ((float, float) or float): If single float will be used as weight for template. + If tuple of float template_weight will be in range `[template_weight[0], template_weight[1])`. + Default: 0.5. + template_transform: transformation object which could be applied to template, + must produce template the same size as input image. + name (string): (Optional) Name of transform, used only for deserialization. + p (float): probability of applying the transform. Default: 0.5. + Targets: + image + Image types: + uint8, float32 + """ + + def __init__( + self, + templates, + img_weight=0.5, + template_weight=0.5, + template_transform=None, + name=None, + always_apply=False, + p=0.5, + ): + super().__init__(always_apply, p) + + self.templates = templates if isinstance(templates, (list, tuple)) else [templates] + self.img_weight = to_tuple(img_weight, img_weight) + self.template_weight = to_tuple(template_weight, template_weight) + self.template_transform = template_transform + self.name = name + + def apply(self, img, template=None, img_weight=0.5, template_weight=0.5, **params): + return F.add_weighted(img, img_weight, template, template_weight) + + def get_params(self): + return { + "img_weight": random.uniform(self.img_weight[0], self.img_weight[1]), + "template_weight": random.uniform(self.template_weight[0], self.template_weight[1]), + } + + def get_params_dependent_on_targets(self, params): + img = params["image"] + template = random.choice(self.templates) + + if self.template_transform is not None: + template = self.template_transform(image=template)["image"] + + if get_num_channels(template) not in [1, get_num_channels(img)]: + raise ValueError( + "Template must be a single channel or " + "has the same number of channels as input image ({}), got {}".format( + get_num_channels(img), get_num_channels(template) + ) + ) + + if template.dtype != img.dtype: + raise ValueError("Image and template must be the same image type") + + if img.shape[:2] != template.shape[:2]: + raise ValueError( + "Image and template must be the same size, got {} and {}".format(img.shape[:2], template.shape[:2]) + ) + + if get_num_channels(template) == 1 and get_num_channels(img) > 1: + template = np.stack((template,) * get_num_channels(img), axis=-1) + + # in order to support grayscale image with dummy dim + template = template.reshape(img.shape) + + return {"template": template} + + @classmethod + def is_serializable(cls): + return False + + @property + def targets_as_params(self): + return ["image"] + + def _to_dict(self): + if self.name is None: + raise ValueError( + "To make a TemplateTransform serializable you should provide the `name` argument, " + "e.g. `TemplateTransform(name='my_transform', ...)`." + ) + return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name} + + +class RingingOvershoot(ImageOnlyTransform): + """Create ringing or overshoot artefacts by conlvolving image with 2D sinc filter. + + Args: + blur_limit (int, (int, int)): maximum kernel size for sinc filter. + Should be in range [3, inf). Default: (7, 15). + cutoff (float, (float, float)): range to choose the cutoff frequency in radians. + Should be in range (0, np.pi) + Default: (np.pi / 4, np.pi / 2). + p (float): probability of applying the transform. Default: 0.5. + + Reference: + dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + https://arxiv.org/abs/2107.10833 + + Targets: + image + """ + + def __init__( + self, + blur_limit: Union[int, Sequence[int]] = (7, 15), + cutoff: Union[float, Sequence[float]] = (np.pi / 4, np.pi / 2), + always_apply=False, + p=0.5, + ): + super(RingingOvershoot, self).__init__(always_apply, p) + self.blur_limit = to_tuple(blur_limit, 3) + self.cutoff = self.__check_values(to_tuple(cutoff, np.pi / 2), name="cutoff", bounds=(0, np.pi)) + + @staticmethod + def __check_values(value, name, bounds=(0, float("inf"))): + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError(f"{name} values should be between {bounds}") + return value + + def get_params(self): + ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2) + if ksize % 2 == 0: + raise ValueError(f"Kernel size must be odd. Got: {ksize}") + + cutoff = random.uniform(*self.cutoff) + + # From dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter + with np.errstate(divide="ignore", invalid="ignore"): + kernel = np.fromfunction( + lambda x, y: cutoff + * special.j1(cutoff * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2)) + / (2 * np.pi * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2)), + [ksize, ksize], + ) + kernel[(ksize - 1) // 2, (ksize - 1) // 2] = cutoff**2 / (4 * np.pi) + + # Normalize kernel + kernel = kernel.astype(np.float32) / np.sum(kernel) + + return {"kernel": kernel} + + def apply(self, img, kernel=None, **params): + return F.convolve(img, kernel) + + def get_transform_init_args_names(self): + return ("blur_limit", "cutoff") + + +class UnsharpMask(ImageOnlyTransform): + """ + Sharpen the input image using Unsharp Masking processing and overlays the result with the original image. + + Args: + blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image. + Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma + as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`. + If set single value `blur_limit` will be in range (0, blur_limit). + Default: (3, 7). + sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf). + If set single value `sigma_limit` will be in range (0, sigma_limit). + If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0. + alpha (float, (float, float)): range to choose the visibility of the sharpened image. + At 0, only the original image is visible, at 1.0 only its sharpened version is visible. + Default: (0.2, 0.5). + threshold (int): Value to limit sharpening only for areas with high pixel difference between original image + and it's smoothed version. Higher threshold means less sharpening on flat areas. + Must be in range [0, 255]. Default: 10. + p (float): probability of applying the transform. Default: 0.5. + + Reference: + arxiv.org/pdf/2107.10833.pdf + + Targets: + image + """ + + def __init__( + self, + blur_limit: Union[int, Sequence[int]] = (3, 7), + sigma_limit: Union[float, Sequence[float]] = 0.0, + alpha: Union[float, Sequence[float]] = (0.2, 0.5), + threshold: int = 10, + always_apply=False, + p=0.5, + ): + super(UnsharpMask, self).__init__(always_apply, p) + self.blur_limit = to_tuple(blur_limit, 3) + self.sigma_limit = self.__check_values(to_tuple(sigma_limit, 0.0), name="sigma_limit") + self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0)) + self.threshold = threshold + + if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0: + self.blur_limit = 3, max(3, self.blur_limit[1]) + raise ValueError("blur_limit and sigma_limit minimum value can not be both equal to 0.") + + if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or ( + self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1 + ): + raise ValueError("UnsharpMask supports only odd blur limits.") + + @staticmethod + def __check_values(value, name, bounds=(0, float("inf"))): + if not bounds[0] <= value[0] <= value[1] <= bounds[1]: + raise ValueError(f"{name} values should be between {bounds}") + return value + + def get_params(self): + return { + "ksize": random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2), + "sigma": random.uniform(*self.sigma_limit), + "alpha": random.uniform(*self.alpha), + } + + def apply(self, img, ksize=3, sigma=0, alpha=0.2, **params): + return F.unsharp_mask(img, ksize, sigma=sigma, alpha=alpha, threshold=self.threshold) + + def get_transform_init_args_names(self): + return ("blur_limit", "sigma_limit", "alpha", "threshold") + + +class PixelDropout(DualTransform): + """Set pixels to 0 with some probability. + + Args: + dropout_prob (float): pixel drop probability. Default: 0.01 + per_channel (bool): if set to `True` drop mask will be sampled fo each channel, + otherwise the same mask will be sampled for all channels. Default: False + drop_value (number or sequence of numbers or None): Value that will be set in dropped place. + If set to None value will be sampled randomly, default ranges will be used: + - uint8 - [0, 255] + - uint16 - [0, 65535] + - uint32 - [0, 4294967295] + - float, double - [0, 1] + Default: 0 + mask_drop_value (number or sequence of numbers or None): Value that will be set in dropped place in masks. + If set to None masks will be unchanged. Default: 0 + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + Image types: + any + """ + + def __init__( + self, + dropout_prob: float = 0.01, + per_channel: bool = False, + drop_value: Optional[Union[float, Sequence[float]]] = 0, + mask_drop_value: Optional[Union[float, Sequence[float]]] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply, p) + self.dropout_prob = dropout_prob + self.per_channel = per_channel + self.drop_value = drop_value + self.mask_drop_value = mask_drop_value + + if self.mask_drop_value is not None and self.per_channel: + raise ValueError("PixelDropout supports mask only with per_channel=False") + + def apply( + self, + img: np.ndarray, + drop_mask: np.ndarray = np.array(None), + drop_value: Union[float, Sequence[float]] = (), + **params + ) -> np.ndarray: + return F.pixel_dropout(img, drop_mask, drop_value) + + def apply_to_mask(self, img: np.ndarray, drop_mask: np.ndarray = np.array(None), **params) -> np.ndarray: + if self.mask_drop_value is None: + return img + + if img.ndim == 2: + drop_mask = np.squeeze(drop_mask) + + return F.pixel_dropout(img, drop_mask, self.mask_drop_value) + + def apply_to_bbox(self, bbox, **params): + return bbox + + def apply_to_keypoint(self, keypoint, **params): + return keypoint + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + img = params["image"] + shape = img.shape if self.per_channel else img.shape[:2] + + rnd = np.random.RandomState(random.randint(0, 1 << 31)) + # Use choice to create boolean matrix, if we will use binomial after that we will need type conversion + drop_mask = rnd.choice([True, False], shape, p=[self.dropout_prob, 1 - self.dropout_prob]) + + drop_value: Union[float, Sequence[float], np.ndarray] + if drop_mask.ndim != img.ndim: + drop_mask = np.expand_dims(drop_mask, -1) + if self.drop_value is None: + drop_shape = 1 if is_grayscale_image(img) else int(img.shape[-1]) + + if img.dtype in (np.uint8, np.uint16, np.uint32): + drop_value = rnd.randint(0, int(F.MAX_VALUES_BY_DTYPE[img.dtype]), drop_shape, img.dtype) + elif img.dtype in [np.float32, np.double]: + drop_value = rnd.uniform(0, 1, drop_shape).astype(img.dtype) + else: + raise ValueError(f"Unsupported dtype: {img.dtype}") + else: + drop_value = self.drop_value + + return {"drop_mask": drop_mask, "drop_value": drop_value} + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str]: + return ("dropout_prob", "per_channel", "drop_value", "mask_drop_value") + + +class Spatter(ImageOnlyTransform): + """ + Apply spatter transform. It simulates corruption which can occlude a lens in the form of rain or mud. + + Args: + mean (float, or tuple of floats): Mean value of normal distribution for generating liquid layer. + If single float it will be used as mean. + If tuple of float mean will be sampled from range `[mean[0], mean[1])`. Default: (0.65). + std (float, or tuple of floats): Standard deviation value of normal distribution for generating liquid layer. + If single float it will be used as std. + If tuple of float std will be sampled from range `[std[0], std[1])`. Default: (0.3). + gauss_sigma (float, or tuple of floats): Sigma value for gaussian filtering of liquid layer. + If single float it will be used as gauss_sigma. + If tuple of float gauss_sigma will be sampled from range `[sigma[0], sigma[1])`. Default: (2). + cutout_threshold (float, or tuple of floats): Threshold for filtering liqued layer + (determines number of drops). If single float it will used as cutout_threshold. + If tuple of float cutout_threshold will be sampled from range `[cutout_threshold[0], cutout_threshold[1])`. + Default: (0.68). + intensity (float, or tuple of floats): Intensity of corruption. + If single float it will be used as intensity. + If tuple of float intensity will be sampled from range `[intensity[0], intensity[1])`. Default: (0.6). + mode (string, or list of strings): Type of corruption. Currently, supported options are 'rain' and 'mud'. + If list is provided type of corruption will be sampled list. Default: ("rain"). + color (list of (r, g, b) or dict or None): Corruption elements color. + If list uses provided list as color for specified mode. + If dict uses provided color for specified mode. Color for each specified mode should be provided in dict. + If None uses default colors (rain: (238, 238, 175), mud: (20, 42, 63)). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + + Image types: + uint8, float32 + + Reference: + | https://arxiv.org/pdf/1903.12261.pdf + | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py + """ + + def __init__( + self, + mean: ScaleFloatType = 0.65, + std: ScaleFloatType = 0.3, + gauss_sigma: ScaleFloatType = 2, + cutout_threshold: ScaleFloatType = 0.68, + intensity: ScaleFloatType = 0.6, + mode: Union[str, Sequence[str]] = "rain", + color: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None, + always_apply: bool = False, + p: float = 0.5, + ): + super().__init__(always_apply=always_apply, p=p) + + self.mean = to_tuple(mean, mean) + self.std = to_tuple(std, std) + self.gauss_sigma = to_tuple(gauss_sigma, gauss_sigma) + self.intensity = to_tuple(intensity, intensity) + self.cutout_threshold = to_tuple(cutout_threshold, cutout_threshold) + self.color = ( + color + if color is not None + else { + "rain": [238, 238, 175], + "mud": [20, 42, 63], + } + ) + self.mode = mode if isinstance(mode, (list, tuple)) else [mode] + + if len(set(self.mode)) > 1 and not isinstance(self.color, dict): + raise ValueError(f"Unsupported color: {self.color}. Please specify color for each mode (use dict for it).") + + for i in self.mode: + if i not in ["rain", "mud"]: + raise ValueError(f"Unsupported color mode: {mode}. Transform supports only `rain` and `mud` mods.") + if isinstance(self.color, dict): + if i not in self.color: + raise ValueError(f"Wrong color definition: {self.color}. Color for mode: {i} not specified.") + if len(self.color[i]) != 3: + raise ValueError( + f"Unsupported color: {self.color[i]} for mode {i}. Color should be presented in RGB format." + ) + + if isinstance(self.color, (list, tuple)): + if len(self.color) != 3: + raise ValueError(f"Unsupported color: {self.color}. Color should be presented in RGB format.") + self.color = {self.mode[0]: self.color} + + def apply( + self, + img: np.ndarray, + non_mud: Optional[np.ndarray] = None, + mud: Optional[np.ndarray] = None, + drops: Optional[np.ndarray] = None, + mode: str = "", + **params + ) -> np.ndarray: + return F.spatter(img, non_mud, mud, drops, mode) + + @property + def targets_as_params(self) -> List[str]: + return ["image"] + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + h, w = params["image"].shape[:2] + + mean = random.uniform(self.mean[0], self.mean[1]) + std = random.uniform(self.std[0], self.std[1]) + cutout_threshold = random.uniform(self.cutout_threshold[0], self.cutout_threshold[1]) + sigma = random.uniform(self.gauss_sigma[0], self.gauss_sigma[1]) + mode = random.choice(self.mode) + intensity = random.uniform(self.intensity[0], self.intensity[1]) + color = np.array(self.color[mode]) / 255.0 + + liquid_layer = random_utils.normal(size=(h, w), loc=mean, scale=std) + liquid_layer = gaussian_filter(liquid_layer, sigma=sigma, mode="nearest") + liquid_layer[liquid_layer < cutout_threshold] = 0 + + if mode == "rain": + liquid_layer = (liquid_layer * 255).astype(np.uint8) + dist = 255 - cv2.Canny(liquid_layer, 50, 150) + dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5) + _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC) + dist = blur(dist, 3).astype(np.uint8) + dist = F.equalize(dist) + + ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]]) + dist = F.convolve(dist, ker) + dist = blur(dist, 3).astype(np.float32) + + m = liquid_layer * dist + m *= 1 / np.max(m, axis=(0, 1)) + + drops = m[:, :, None] * color * intensity + mud = None + non_mud = None + else: + m = np.where(liquid_layer > cutout_threshold, 1, 0) + m = gaussian_filter(m.astype(np.float32), sigma=sigma, mode="nearest") + m[m < 1.2 * cutout_threshold] = 0 + m = m[..., np.newaxis] + + mud = m * color + non_mud = 1 - m + drops = None + + return { + "non_mud": non_mud, + "mud": mud, + "drops": drops, + "mode": mode, + } + + def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]: + return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color" diff --git a/custom_albumentations/augmentations/utils.py b/custom_albumentations/augmentations/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..32789e206d4f8ce578b4268eb1194db2f929a671 --- /dev/null +++ b/custom_albumentations/augmentations/utils.py @@ -0,0 +1,211 @@ +from functools import wraps +from typing import Callable, Union + +import cv2 +import numpy as np +from typing_extensions import Concatenate, ParamSpec + +from custom_albumentations.core.keypoints_utils import angle_to_2pi_range +from custom_albumentations.core.transforms_interface import KeypointInternalType + +__all__ = [ + "read_bgr_image", + "read_rgb_image", + "MAX_VALUES_BY_DTYPE", + "NPDTYPE_TO_OPENCV_DTYPE", + "clipped", + "get_opencv_dtype_from_numpy", + "angle_2pi_range", + "clip", + "preserve_shape", + "preserve_channel_dim", + "ensure_contiguous", + "is_rgb_image", + "is_grayscale_image", + "is_multispectral_image", + "get_num_channels", + "non_rgb_warning", + "_maybe_process_in_chunks", +] + +P = ParamSpec("P") + +MAX_VALUES_BY_DTYPE = { + np.dtype("uint8"): 255, + np.dtype("uint16"): 65535, + np.dtype("uint32"): 4294967295, + np.dtype("float32"): 1.0, +} + +NPDTYPE_TO_OPENCV_DTYPE = { + np.uint8: cv2.CV_8U, # type: ignore[attr-defined] + np.uint16: cv2.CV_16U, # type: ignore[attr-defined] + np.int32: cv2.CV_32S, # type: ignore[attr-defined] + np.float32: cv2.CV_32F, # type: ignore[attr-defined] + np.float64: cv2.CV_64F, # type: ignore[attr-defined] + np.dtype("uint8"): cv2.CV_8U, # type: ignore[attr-defined] + np.dtype("uint16"): cv2.CV_16U, # type: ignore[attr-defined] + np.dtype("int32"): cv2.CV_32S, # type: ignore[attr-defined] + np.dtype("float32"): cv2.CV_32F, # type: ignore[attr-defined] + np.dtype("float64"): cv2.CV_64F, # type: ignore[attr-defined] +} + + +def read_bgr_image(path): + return cv2.imread(path, cv2.IMREAD_COLOR) + + +def read_rgb_image(path): + image = cv2.imread(path, cv2.IMREAD_COLOR) + return cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + +def clipped(func: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + dtype = img.dtype + maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0) + return clip(func(img, *args, **kwargs), dtype, maxval) + + return wrapped_function + + +def clip(img: np.ndarray, dtype: np.dtype, maxval: float) -> np.ndarray: + return np.clip(img, 0, maxval).astype(dtype) + + +def get_opencv_dtype_from_numpy(value: Union[np.ndarray, int, np.dtype, object]) -> int: + """ + Return a corresponding OpenCV dtype for a numpy's dtype + :param value: Input dtype of numpy array + :return: Corresponding dtype for OpenCV + """ + if isinstance(value, np.ndarray): + value = value.dtype + return NPDTYPE_TO_OPENCV_DTYPE[value] + + +def angle_2pi_range( + func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType] +) -> Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]: + @wraps(func) + def wrapped_function(keypoint: KeypointInternalType, *args: P.args, **kwargs: P.kwargs) -> KeypointInternalType: + (x, y, a, s) = func(keypoint, *args, **kwargs)[:4] + return (x, y, angle_to_2pi_range(a), s) + + return wrapped_function + + +def preserve_shape( + func: Callable[Concatenate[np.ndarray, P], np.ndarray] +) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + """Preserve shape of the image""" + + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + shape = img.shape + result = func(img, *args, **kwargs) + result = result.reshape(shape) + return result + + return wrapped_function + + +def preserve_channel_dim( + func: Callable[Concatenate[np.ndarray, P], np.ndarray] +) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + """Preserve dummy channel dim.""" + + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + shape = img.shape + result = func(img, *args, **kwargs) + if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2: + result = np.expand_dims(result, axis=-1) + return result + + return wrapped_function + + +def ensure_contiguous( + func: Callable[Concatenate[np.ndarray, P], np.ndarray] +) -> Callable[Concatenate[np.ndarray, P], np.ndarray]: + """Ensure that input img is contiguous.""" + + @wraps(func) + def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray: + img = np.require(img, requirements=["C_CONTIGUOUS"]) + result = func(img, *args, **kwargs) + return result + + return wrapped_function + + +def is_rgb_image(image: np.ndarray) -> bool: + return len(image.shape) == 3 and image.shape[-1] == 3 + + +def is_grayscale_image(image: np.ndarray) -> bool: + return (len(image.shape) == 2) or (len(image.shape) == 3 and image.shape[-1] == 1) + + +def is_multispectral_image(image: np.ndarray) -> bool: + return len(image.shape) == 3 and image.shape[-1] not in [1, 3] + + +def get_num_channels(image: np.ndarray) -> int: + return image.shape[2] if len(image.shape) == 3 else 1 + + +def non_rgb_warning(image: np.ndarray) -> None: + if not is_rgb_image(image): + message = "This transformation expects 3-channel images" + if is_grayscale_image(image): + message += "\nYou can convert your grayscale image to RGB using cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))" + if is_multispectral_image(image): # Any image with a number of channels other than 1 and 3 + message += "\nThis transformation cannot be applied to multi-spectral images" + + raise ValueError(message) + + +def _maybe_process_in_chunks( + process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs +) -> Callable[[np.ndarray], np.ndarray]: + """ + Wrap OpenCV function to enable processing images with more than 4 channels. + + Limitations: + This wrapper requires image to be the first argument and rest must be sent via named arguments. + + Args: + process_fn: Transform function (e.g cv2.resize). + kwargs: Additional parameters. + + Returns: + numpy.ndarray: Transformed image. + + """ + + @wraps(process_fn) + def __process_fn(img: np.ndarray) -> np.ndarray: + num_channels = get_num_channels(img) + if num_channels > 4: + chunks = [] + for index in range(0, num_channels, 4): + if num_channels - index == 2: + # Many OpenCV functions cannot work with 2-channel images + for i in range(2): + chunk = img[:, :, index + i : index + i + 1] + chunk = process_fn(chunk, **kwargs) + chunk = np.expand_dims(chunk, -1) + chunks.append(chunk) + else: + chunk = img[:, :, index : index + 4] + chunk = process_fn(chunk, **kwargs) + chunks.append(chunk) + img = np.dstack(chunks) + else: + img = process_fn(img, **kwargs) + return img + + return __process_fn diff --git a/custom_albumentations/core/__init__.py b/custom_albumentations/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_albumentations/core/bbox_utils.py b/custom_albumentations/core/bbox_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ae0913ace7c539f505b7355d911aae12f3a3ba6 --- /dev/null +++ b/custom_albumentations/core/bbox_utils.py @@ -0,0 +1,522 @@ +from __future__ import division + +from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, cast + +import numpy as np + +from .transforms_interface import BoxInternalType, BoxType +from .utils import DataProcessor, Params + +__all__ = [ + "normalize_bbox", + "denormalize_bbox", + "normalize_bboxes", + "denormalize_bboxes", + "calculate_bbox_area", + "filter_bboxes_by_visibility", + "convert_bbox_to_albumentations", + "convert_bbox_from_albumentations", + "convert_bboxes_to_albumentations", + "convert_bboxes_from_albumentations", + "check_bbox", + "check_bboxes", + "filter_bboxes", + "union_of_bboxes", + "BboxProcessor", + "BboxParams", +] + +TBox = TypeVar("TBox", BoxType, BoxInternalType) + + +class BboxParams(Params): + """ + Parameters of bounding boxes + + Args: + format (str): format of bounding boxes. Should be 'coco', 'pascal_voc', 'albumentations' or 'yolo'. + + The `coco` format + `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200]. + The `pascal_voc` format + `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212]. + The `albumentations` format + is like `pascal_voc`, but normalized, + in other words: `[x_min, y_min, x_max, y_max]`, e.g. [0.2, 0.3, 0.4, 0.5]. + The `yolo` format + `[x, y, width, height]`, e.g. [0.1, 0.2, 0.3, 0.4]; + `x`, `y` - normalized bbox center; `width`, `height` - normalized bbox width and height. + label_fields (list): list of fields that are joined with boxes, e.g labels. + Should be same type as boxes. + min_area (float): minimum area of a bounding box. All bounding boxes whose + visible area in pixels is less than this value will be removed. Default: 0.0. + min_visibility (float): minimum fraction of area for a bounding box + to remain this box in list. Default: 0.0. + min_width (float): Minimum width of a bounding box. All bounding boxes whose width is + less than this value will be removed. Default: 0.0. + min_height (float): Minimum height of a bounding box. All bounding boxes whose height is + less than this value will be removed. Default: 0.0. + check_each_transform (bool): if `True`, then bboxes will be checked after each dual transform. + Default: `True` + """ + + def __init__( + self, + format: str, + label_fields: Optional[Sequence[str]] = None, + min_area: float = 0.0, + min_visibility: float = 0.0, + min_width: float = 0.0, + min_height: float = 0.0, + check_each_transform: bool = True, + ): + super(BboxParams, self).__init__(format, label_fields) + self.min_area = min_area + self.min_visibility = min_visibility + self.min_width = min_width + self.min_height = min_height + self.check_each_transform = check_each_transform + + def _to_dict(self) -> Dict[str, Any]: + data = super(BboxParams, self)._to_dict() + data.update( + { + "min_area": self.min_area, + "min_visibility": self.min_visibility, + "min_width": self.min_width, + "min_height": self.min_height, + "check_each_transform": self.check_each_transform, + } + ) + return data + + @classmethod + def is_serializable(cls) -> bool: + return True + + @classmethod + def get_class_fullname(cls) -> str: + return "BboxParams" + + +class BboxProcessor(DataProcessor): + def __init__(self, params: BboxParams, additional_targets: Optional[Dict[str, str]] = None): + super().__init__(params, additional_targets) + + @property + def default_data_name(self) -> str: + return "bboxes" + + def ensure_data_valid(self, data: Dict[str, Any]) -> None: + for data_name in self.data_fields: + data_exists = data_name in data and len(data[data_name]) + if data_exists and len(data[data_name][0]) < 5: + if self.params.label_fields is None: + raise ValueError( + "Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox " + "because bboxes must have labels" + ) + if self.params.label_fields: + if not all(i in data.keys() for i in self.params.label_fields): + raise ValueError("Your 'label_fields' are not valid - them must have same names as params in dict") + + def filter(self, data: Sequence, rows: int, cols: int) -> List: + self.params: BboxParams + return filter_bboxes( + data, + rows, + cols, + min_area=self.params.min_area, + min_visibility=self.params.min_visibility, + min_width=self.params.min_width, + min_height=self.params.min_height, + ) + + def check(self, data: Sequence, rows: int, cols: int) -> None: + check_bboxes(data) + + def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> List[BoxType]: + return convert_bboxes_from_albumentations(data, self.params.format, rows, cols, check_validity=True) + + def convert_to_albumentations(self, data: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]: + return convert_bboxes_to_albumentations(data, self.params.format, rows, cols, check_validity=True) + + +def normalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox: + """Normalize coordinates of a bounding box. Divide x-coordinates by image width and y-coordinates + by image height. + + Args: + bbox: Denormalized bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image height. + cols: Image width. + + Returns: + Normalized bounding box `(x_min, y_min, x_max, y_max)`. + + Raises: + ValueError: If rows or cols is less or equal zero + + """ + + if rows <= 0: + raise ValueError("Argument rows must be positive integer") + if cols <= 0: + raise ValueError("Argument cols must be positive integer") + + tail: Tuple[Any, ...] + (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) + + x_min, x_max = x_min / cols, x_max / cols + y_min, y_max = y_min / rows, y_max / rows + + return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore + + +def denormalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox: + """Denormalize coordinates of a bounding box. Multiply x-coordinates by image width and y-coordinates + by image height. This is an inverse operation for :func:`~albumentations.augmentations.bbox.normalize_bbox`. + + Args: + bbox: Normalized bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image height. + cols: Image width. + + Returns: + Denormalized bounding box `(x_min, y_min, x_max, y_max)`. + + Raises: + ValueError: If rows or cols is less or equal zero + + """ + tail: Tuple[Any, ...] + (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) + + if rows <= 0: + raise ValueError("Argument rows must be positive integer") + if cols <= 0: + raise ValueError("Argument cols must be positive integer") + + x_min, x_max = x_min * cols, x_max * cols + y_min, y_max = y_min * rows, y_max * rows + + return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore + + +def normalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]: + """Normalize a list of bounding boxes. + + Args: + bboxes: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. + rows: Image height. + cols: Image width. + + Returns: + Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. + + """ + return [normalize_bbox(bbox, rows, cols) for bbox in bboxes] + + +def denormalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]: + """Denormalize a list of bounding boxes. + + Args: + bboxes: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. + rows: Image height. + cols: Image width. + + Returns: + List: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`. + + """ + return [denormalize_bbox(bbox, rows, cols) for bbox in bboxes] + + +def calculate_bbox_area(bbox: BoxType, rows: int, cols: int) -> float: + """Calculate the area of a bounding box in (fractional) pixels. + + Args: + bbox: A bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image height. + cols: Image width. + + Return: + Area in (fractional) pixels of the (denormalized) bounding box. + + """ + bbox = denormalize_bbox(bbox, rows, cols) + x_min, y_min, x_max, y_max = bbox[:4] + area = (x_max - x_min) * (y_max - y_min) + return area + + +def filter_bboxes_by_visibility( + original_shape: Sequence[int], + bboxes: Sequence[BoxType], + transformed_shape: Sequence[int], + transformed_bboxes: Sequence[BoxType], + threshold: float = 0.0, + min_area: float = 0.0, +) -> List[BoxType]: + """Filter bounding boxes and return only those boxes whose visibility after transformation is above + the threshold and minimal area of bounding box in pixels is more then min_area. + + Args: + original_shape: Original image shape `(height, width, ...)`. + bboxes: Original bounding boxes `[(x_min, y_min, x_max, y_max)]`. + transformed_shape: Transformed image shape `(height, width)`. + transformed_bboxes: Transformed bounding boxes `[(x_min, y_min, x_max, y_max)]`. + threshold: visibility threshold. Should be a value in the range [0.0, 1.0]. + min_area: Minimal area threshold. + + Returns: + Filtered bounding boxes `[(x_min, y_min, x_max, y_max)]`. + + """ + img_height, img_width = original_shape[:2] + transformed_img_height, transformed_img_width = transformed_shape[:2] + + visible_bboxes = [] + for bbox, transformed_bbox in zip(bboxes, transformed_bboxes): + if not all(0.0 <= value <= 1.0 for value in transformed_bbox[:4]): + continue + bbox_area = calculate_bbox_area(bbox, img_height, img_width) + transformed_bbox_area = calculate_bbox_area(transformed_bbox, transformed_img_height, transformed_img_width) + if transformed_bbox_area < min_area: + continue + visibility = transformed_bbox_area / bbox_area + if visibility >= threshold: + visible_bboxes.append(transformed_bbox) + return visible_bboxes + + +def convert_bbox_to_albumentations( + bbox: BoxType, source_format: str, rows: int, cols: int, check_validity: bool = False +) -> BoxType: + """Convert a bounding box from a format specified in `source_format` to the format used by albumentations: + normalized coordinates of top-left and bottom-right corners of the bounding box in a form of + `(x_min, y_min, x_max, y_max)` e.g. `(0.15, 0.27, 0.67, 0.5)`. + + Args: + bbox: A bounding box tuple. + source_format: format of the bounding box. Should be 'coco', 'pascal_voc', or 'yolo'. + check_validity: Check if all boxes are valid boxes. + rows: Image height. + cols: Image width. + + Returns: + tuple: A bounding box `(x_min, y_min, x_max, y_max)`. + + Note: + The `coco` format of a bounding box looks like `(x_min, y_min, width, height)`, e.g. (97, 12, 150, 200). + The `pascal_voc` format of a bounding box looks like `(x_min, y_min, x_max, y_max)`, e.g. (97, 12, 247, 212). + The `yolo` format of a bounding box looks like `(x, y, width, height)`, e.g. (0.3, 0.1, 0.05, 0.07); + where `x`, `y` coordinates of the center of the box, all values normalized to 1 by image height and width. + + Raises: + ValueError: if `target_format` is not equal to `coco` or `pascal_voc`, or `yolo`. + ValueError: If in YOLO format all labels not in range (0, 1). + + """ + if source_format not in {"coco", "pascal_voc", "yolo"}: + raise ValueError( + f"Unknown source_format {source_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'" + ) + + if source_format == "coco": + (x_min, y_min, width, height), tail = bbox[:4], bbox[4:] + x_max = x_min + width + y_max = y_min + height + elif source_format == "yolo": + # https://github.com/pjreddie/darknet/blob/f6d861736038da22c9eb0739dca84003c5a5e275/scripts/voc_label.py#L12 + _bbox = np.array(bbox[:4]) + if check_validity and np.any((_bbox <= 0) | (_bbox > 1)): + raise ValueError("In YOLO format all coordinates must be float and in range (0, 1]") + + (x, y, w, h), tail = bbox[:4], bbox[4:] + + w_half, h_half = w / 2, h / 2 + x_min = x - w_half + y_min = y - h_half + x_max = x_min + w + y_max = y_min + h + else: + (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:] + + bbox = (x_min, y_min, x_max, y_max) + tuple(tail) # type: ignore + + if source_format != "yolo": + bbox = normalize_bbox(bbox, rows, cols) + if check_validity: + check_bbox(bbox) + return bbox + + +def convert_bbox_from_albumentations( + bbox: BoxType, target_format: str, rows: int, cols: int, check_validity: bool = False +) -> BoxType: + """Convert a bounding box from the format used by albumentations to a format, specified in `target_format`. + + Args: + bbox: An albumentations bounding box `(x_min, y_min, x_max, y_max)`. + target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'. + rows: Image height. + cols: Image width. + check_validity: Check if all boxes are valid boxes. + + Returns: + tuple: A bounding box. + + Note: + The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200]. + The `pascal_voc` format of a bounding box looks like `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212]. + The `yolo` format of a bounding box looks like `[x, y, width, height]`, e.g. [0.3, 0.1, 0.05, 0.07]. + + Raises: + ValueError: if `target_format` is not equal to `coco`, `pascal_voc` or `yolo`. + + """ + if target_format not in {"coco", "pascal_voc", "yolo"}: + raise ValueError( + f"Unknown target_format {target_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'" + ) + if check_validity: + check_bbox(bbox) + + if target_format != "yolo": + bbox = denormalize_bbox(bbox, rows, cols) + if target_format == "coco": + (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:]) + width = x_max - x_min + height = y_max - y_min + bbox = cast(BoxType, (x_min, y_min, width, height) + tail) + elif target_format == "yolo": + (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:] + x = (x_min + x_max) / 2.0 + y = (y_min + y_max) / 2.0 + w = x_max - x_min + h = y_max - y_min + bbox = cast(BoxType, (x, y, w, h) + tail) + return bbox + + +def convert_bboxes_to_albumentations( + bboxes: Sequence[BoxType], source_format, rows, cols, check_validity=False +) -> List[BoxType]: + """Convert a list bounding boxes from a format specified in `source_format` to the format used by albumentations""" + return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes] + + +def convert_bboxes_from_albumentations( + bboxes: Sequence[BoxType], target_format: str, rows: int, cols: int, check_validity: bool = False +) -> List[BoxType]: + """Convert a list of bounding boxes from the format used by albumentations to a format, specified + in `target_format`. + + Args: + bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`. + target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'. + rows: Image height. + cols: Image width. + check_validity: Check if all boxes are valid boxes. + + Returns: + List of bounding boxes. + + """ + return [convert_bbox_from_albumentations(bbox, target_format, rows, cols, check_validity) for bbox in bboxes] + + +def check_bbox(bbox: BoxType) -> None: + """Check if bbox boundaries are in range 0, 1 and minimums are lesser then maximums""" + for name, value in zip(["x_min", "y_min", "x_max", "y_max"], bbox[:4]): + if not 0 <= value <= 1 and not np.isclose(value, 0) and not np.isclose(value, 1): + raise ValueError(f"Expected {name} for bbox {bbox} to be in the range [0.0, 1.0], got {value}.") + x_min, y_min, x_max, y_max = bbox[:4] + if x_max <= x_min: + raise ValueError(f"x_max is less than or equal to x_min for bbox {bbox}.") + if y_max <= y_min: + raise ValueError(f"y_max is less than or equal to y_min for bbox {bbox}.") + + +def check_bboxes(bboxes: Sequence[BoxType]) -> None: + """Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums""" + for bbox in bboxes: + check_bbox(bbox) + + +def filter_bboxes( + bboxes: Sequence[BoxType], + rows: int, + cols: int, + min_area: float = 0.0, + min_visibility: float = 0.0, + min_width: float = 0.0, + min_height: float = 0.0, +) -> List[BoxType]: + """Remove bounding boxes that either lie outside of the visible area by more then min_visibility + or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size. + + Args: + bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`. + rows: Image height. + cols: Image width. + min_area: Minimum area of a bounding box. All bounding boxes whose visible area in pixels. + is less than this value will be removed. Default: 0.0. + min_visibility: Minimum fraction of area for a bounding box to remain this box in list. Default: 0.0. + min_width: Minimum width of a bounding box. All bounding boxes whose width is + less than this value will be removed. Default: 0.0. + min_height: Minimum height of a bounding box. All bounding boxes whose height is + less than this value will be removed. Default: 0.0. + + Returns: + List of bounding boxes. + + """ + resulting_boxes: List[BoxType] = [] + for bbox in bboxes: + # Calculate areas of bounding box before and after clipping. + transformed_box_area = calculate_bbox_area(bbox, rows, cols) + bbox, tail = cast(BoxType, tuple(np.clip(bbox[:4], 0, 1.0))), tuple(bbox[4:]) + clipped_box_area = calculate_bbox_area(bbox, rows, cols) + + # Calculate width and height of the clipped bounding box. + x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4] + clipped_width, clipped_height = x_max - x_min, y_max - y_min + + if ( + clipped_box_area != 0 # to ensure transformed_box_area!=0 and to handle min_area=0 or min_visibility=0 + and clipped_box_area >= min_area + and clipped_box_area / transformed_box_area >= min_visibility + and clipped_width >= min_width + and clipped_height >= min_height + ): + resulting_boxes.append(cast(BoxType, bbox + tail)) + return resulting_boxes + + +def union_of_bboxes(height: int, width: int, bboxes: Sequence[BoxType], erosion_rate: float = 0.0) -> BoxType: + """Calculate union of bounding boxes. + + Args: + height (float): Height of image or space. + width (float): Width of image or space. + bboxes (List[tuple]): List like bounding boxes. Format is `[(x_min, y_min, x_max, y_max)]`. + erosion_rate (float): How much each bounding box can be shrinked, useful for erosive cropping. + Set this in range [0, 1]. 0 will not be erosive at all, 1.0 can make any bbox to lose its volume. + + Returns: + tuple: A bounding box `(x_min, y_min, x_max, y_max)`. + + """ + x1, y1 = width, height + x2, y2 = 0, 0 + for bbox in bboxes: + x_min, y_min, x_max, y_max = bbox[:4] + w, h = x_max - x_min, y_max - y_min + lim_x1, lim_y1 = x_min + erosion_rate * w, y_min + erosion_rate * h + lim_x2, lim_y2 = x_max - erosion_rate * w, y_max - erosion_rate * h + x1, y1 = np.min([x1, lim_x1]), np.min([y1, lim_y1]) + x2, y2 = np.max([x2, lim_x2]), np.max([y2, lim_y2]) + return x1, y1, x2, y2 diff --git a/custom_albumentations/core/composition.py b/custom_albumentations/core/composition.py new file mode 100644 index 0000000000000000000000000000000000000000..ed8e1d71184d0fa69284d3e5faefa7063632d667 --- /dev/null +++ b/custom_albumentations/core/composition.py @@ -0,0 +1,552 @@ +from __future__ import division + +import random +import typing +import warnings +from collections import defaultdict + +import numpy as np + +from .. import random_utils +from .bbox_utils import BboxParams, BboxProcessor +from .keypoints_utils import KeypointParams, KeypointsProcessor +from .serialization import ( + SERIALIZABLE_REGISTRY, + Serializable, + get_shortest_class_fullname, + instantiate_nonserializable, +) +from .transforms_interface import BasicTransform +from .utils import format_args, get_shape + +__all__ = [ + "BaseCompose", + "Compose", + "SomeOf", + "OneOf", + "OneOrOther", + "BboxParams", + "KeypointParams", + "ReplayCompose", + "Sequential", +] + + +REPR_INDENT_STEP = 2 +TransformType = typing.Union[BasicTransform, "BaseCompose"] +TransformsSeqType = typing.Sequence[TransformType] + + +def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) -> TransformsSeqType: + new_transforms: typing.List[TransformType] = [] + for transform in transforms: # type: ignore + if isinstance(transform, BaseCompose): + new_transforms.extend(get_always_apply(transform)) + elif transform.always_apply: + new_transforms.append(transform) + return new_transforms + + +class BaseCompose(Serializable): + def __init__(self, transforms: TransformsSeqType, p: float): + if isinstance(transforms, (BaseCompose, BasicTransform)): + warnings.warn( + "transforms is single transform, but a sequence is expected! Transform will be wrapped into list." + ) + transforms = [transforms] + + self.transforms = transforms + self.p = p + + self.replay_mode = False + self.applied_in_replay = False + + def __len__(self) -> int: + return len(self.transforms) + + def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: + raise NotImplementedError + + def __getitem__(self, item: int) -> TransformType: # type: ignore + return self.transforms[item] + + def __repr__(self) -> str: + return self.indented_repr() + + def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str: + args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")} + repr_string = self.__class__.__name__ + "([" + for t in self.transforms: + repr_string += "\n" + if hasattr(t, "indented_repr"): + t_repr = t.indented_repr(indent + REPR_INDENT_STEP) # type: ignore + else: + t_repr = repr(t) + repr_string += " " * indent + t_repr + "," + repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + "], {args})".format(args=format_args(args)) + return repr_string + + @classmethod + def get_class_fullname(cls) -> str: + return get_shortest_class_fullname(cls) + + @classmethod + def is_serializable(cls) -> bool: + return True + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + return { + "__class_fullname__": self.get_class_fullname(), + "p": self.p, + "transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212 + } + + def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: + return { + "__class_fullname__": self.get_class_fullname(), + "id": id(self), + "params": None, + "transforms": [t.get_dict_with_id() for t in self.transforms], + } + + def add_targets(self, additional_targets: typing.Optional[typing.Dict[str, str]]) -> None: + if additional_targets: + for t in self.transforms: + t.add_targets(additional_targets) + + def set_deterministic(self, flag: bool, save_key: str = "replay") -> None: + for t in self.transforms: + t.set_deterministic(flag, save_key) + + +class Compose(BaseCompose): + """Compose transforms and handle all transformations regarding bounding boxes + + Args: + transforms (list): list of transformations to compose. + bbox_params (BboxParams): Parameters for bounding boxes transforms + keypoint_params (KeypointParams): Parameters for keypoints transforms + additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'} + p (float): probability of applying all list of transforms. Default: 1.0. + is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you + would like to disable this check - pass False (do it only if you are sure in your data consistency). + """ + + def __init__( + self, + transforms: TransformsSeqType, + bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, + keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, + additional_targets: typing.Optional[typing.Dict[str, str]] = None, + p: float = 1.0, + is_check_shapes: bool = True, + ): + super(Compose, self).__init__(transforms, p) + + self.processors: typing.Dict[str, typing.Union[BboxProcessor, KeypointsProcessor]] = {} + if bbox_params: + if isinstance(bbox_params, dict): + b_params = BboxParams(**bbox_params) + elif isinstance(bbox_params, BboxParams): + b_params = bbox_params + else: + raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`") + self.processors["bboxes"] = BboxProcessor(b_params, additional_targets) + + if keypoint_params: + if isinstance(keypoint_params, dict): + k_params = KeypointParams(**keypoint_params) + elif isinstance(keypoint_params, KeypointParams): + k_params = keypoint_params + else: + raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`") + self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets) + + if additional_targets is None: + additional_targets = {} + + self.additional_targets = additional_targets + + for proc in self.processors.values(): + proc.ensure_transforms_valid(self.transforms) + + self.add_targets(additional_targets) + + self.is_check_args = True + self._disable_check_args_for_transforms(self.transforms) + + self.is_check_shapes = is_check_shapes + + @staticmethod + def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None: + for transform in transforms: + if isinstance(transform, BaseCompose): + Compose._disable_check_args_for_transforms(transform.transforms) + if isinstance(transform, Compose): + transform._disable_check_args() + + def _disable_check_args(self) -> None: + self.is_check_args = False + + def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: + if args: + raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") + if self.is_check_args: + self._check_args(**data) + assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type" + need_to_run = force_apply or random.random() < self.p + for p in self.processors.values(): + p.ensure_data_valid(data) + transforms = self.transforms if need_to_run else get_always_apply(self.transforms) + + check_each_transform = any( + getattr(item.params, "check_each_transform", False) for item in self.processors.values() + ) + + for p in self.processors.values(): + p.preprocess(data) + + for idx, t in enumerate(transforms): + data = t(**data) + + if check_each_transform: + data = self._check_data_post_transform(data) + data = Compose._make_targets_contiguous(data) # ensure output targets are contiguous + + for p in self.processors.values(): + p.postprocess(data) + + return data + + def _check_data_post_transform(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: + rows, cols = get_shape(data["image"]) + + for p in self.processors.values(): + if not getattr(p.params, "check_each_transform", False): + continue + + for data_name in p.data_fields: + data[data_name] = p.filter(data[data_name], rows, cols) + return data + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + dictionary = super(Compose, self)._to_dict() + bbox_processor = self.processors.get("bboxes") + keypoints_processor = self.processors.get("keypoints") + dictionary.update( + { + "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 + "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 + if keypoints_processor + else None, + "additional_targets": self.additional_targets, + "is_check_shapes": self.is_check_shapes, + } + ) + return dictionary + + def get_dict_with_id(self) -> typing.Dict[str, typing.Any]: + dictionary = super().get_dict_with_id() + bbox_processor = self.processors.get("bboxes") + keypoints_processor = self.processors.get("keypoints") + dictionary.update( + { + "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212 + "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212 + if keypoints_processor + else None, + "additional_targets": self.additional_targets, + "params": None, + "is_check_shapes": self.is_check_shapes, + } + ) + return dictionary + + def _check_args(self, **kwargs) -> None: + checked_single = ["image", "mask"] + checked_multi = ["masks"] + check_bbox_param = ["bboxes"] + # ["bboxes", "keypoints"] could be almost any type, no need to check them + shapes = [] + for data_name, data in kwargs.items(): + internal_data_name = self.additional_targets.get(data_name, data_name) + if internal_data_name in checked_single: + if not isinstance(data, np.ndarray): + raise TypeError("{} must be numpy array type".format(data_name)) + shapes.append(data.shape[:2]) + if internal_data_name in checked_multi: + if data is not None and len(data): + if not isinstance(data[0], np.ndarray): + raise TypeError("{} must be list of numpy arrays".format(data_name)) + shapes.append(data[0].shape[:2]) + if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None: + raise ValueError("bbox_params must be specified for bbox transformations") + + if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes): + raise ValueError( + "Height and Width of image, mask or masks should be equal. You can disable shapes check " + "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure " + "about your data consistency)." + ) + + @staticmethod + def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]: + result = {} + for key, value in data.items(): + if isinstance(value, np.ndarray): + value = np.ascontiguousarray(value) + result[key] = value + return result + + +class OneOf(BaseCompose): + """Select one of transforms to apply. Selected transform will be called with `force_apply=True`. + Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. + + Args: + transforms (list): list of transformations to compose. + p (float): probability of applying selected transform. Default: 0.5. + """ + + def __init__(self, transforms: TransformsSeqType, p: float = 0.5): + super(OneOf, self).__init__(transforms, p) + transforms_ps = [t.p for t in self.transforms] + s = sum(transforms_ps) + self.transforms_ps = [t / s for t in transforms_ps] + + def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: + if self.replay_mode: + for t in self.transforms: + data = t(**data) + return data + + if self.transforms_ps and (force_apply or random.random() < self.p): + idx: int = random_utils.choice(len(self.transforms), p=self.transforms_ps) + t = self.transforms[idx] + data = t(force_apply=True, **data) + return data + + +class SomeOf(BaseCompose): + """Select N transforms to apply. Selected transforms will be called with `force_apply=True`. + Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights. + + Args: + transforms (list): list of transformations to compose. + n (int): number of transforms to apply. + replace (bool): Whether the sampled transforms are with or without replacement. Default: True. + p (float): probability of applying selected transform. Default: 1. + """ + + def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1): + super(SomeOf, self).__init__(transforms, p) + self.n = n + self.replace = replace + transforms_ps = [t.p for t in self.transforms] + s = sum(transforms_ps) + self.transforms_ps = [t / s for t in transforms_ps] + + def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: + if self.replay_mode: + for t in self.transforms: + data = t(**data) + return data + + if self.transforms_ps and (force_apply or random.random() < self.p): + idx = random_utils.choice(len(self.transforms), size=self.n, replace=self.replace, p=self.transforms_ps) + for i in idx: # type: ignore + t = self.transforms[i] + data = t(force_apply=True, **data) + return data + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + dictionary = super(SomeOf, self)._to_dict() + dictionary.update({"n": self.n, "replace": self.replace}) + return dictionary + + +class OneOrOther(BaseCompose): + """Select one or another transform to apply. Selected transform will be called with `force_apply=True`.""" + + def __init__( + self, + first: typing.Optional[TransformType] = None, + second: typing.Optional[TransformType] = None, + transforms: typing.Optional[TransformsSeqType] = None, + p: float = 0.5, + ): + if transforms is None: + if first is None or second is None: + raise ValueError("You must set both first and second or set transforms argument.") + transforms = [first, second] + super(OneOrOther, self).__init__(transforms, p) + if len(self.transforms) != 2: + warnings.warn("Length of transforms is not equal to 2.") + + def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: + if self.replay_mode: + for t in self.transforms: + data = t(**data) + return data + + if random.random() < self.p: + return self.transforms[0](force_apply=True, **data) + + return self.transforms[-1](force_apply=True, **data) + + +class PerChannel(BaseCompose): + """Apply transformations per-channel + + Args: + transforms (list): list of transformations to compose. + channels (sequence): channels to apply the transform to. Pass None to apply to all. + Default: None (apply to all) + p (float): probability of applying the transform. Default: 0.5. + """ + + def __init__( + self, transforms: TransformsSeqType, channels: typing.Optional[typing.Sequence[int]] = None, p: float = 0.5 + ): + super(PerChannel, self).__init__(transforms, p) + self.channels = channels + + def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]: + if force_apply or random.random() < self.p: + image = data["image"] + + # Expand mono images to have a single channel + if len(image.shape) == 2: + image = np.expand_dims(image, -1) + + if self.channels is None: + self.channels = range(image.shape[2]) + + for c in self.channels: + for t in self.transforms: + image[:, :, c] = t(image=image[:, :, c])["image"] + + data["image"] = image + + return data + + +class ReplayCompose(Compose): + def __init__( + self, + transforms: TransformsSeqType, + bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None, + keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None, + additional_targets: typing.Optional[typing.Dict[str, str]] = None, + p: float = 1.0, + is_check_shapes: bool = True, + save_key: str = "replay", + ): + super(ReplayCompose, self).__init__( + transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes + ) + self.set_deterministic(True, save_key=save_key) + self.save_key = save_key + + def __call__(self, *args, force_apply: bool = False, **kwargs) -> typing.Dict[str, typing.Any]: + kwargs[self.save_key] = defaultdict(dict) + result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs) + serialized = self.get_dict_with_id() + self.fill_with_params(serialized, result[self.save_key]) + self.fill_applied(serialized) + result[self.save_key] = serialized + return result + + @staticmethod + def replay(saved_augmentations: typing.Dict[str, typing.Any], **kwargs) -> typing.Dict[str, typing.Any]: + augs = ReplayCompose._restore_for_replay(saved_augmentations) + return augs(force_apply=True, **kwargs) + + @staticmethod + def _restore_for_replay( + transform_dict: typing.Dict[str, typing.Any], lambda_transforms: typing.Optional[dict] = None + ) -> TransformType: + """ + Args: + lambda_transforms (dict): A dictionary that contains lambda transforms, that + is instances of the Lambda class. + This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys + in that dictionary should be named same as `name` arguments in respective lambda transforms from + a serialized pipeline. + """ + applied = transform_dict["applied"] + params = transform_dict["params"] + lmbd = instantiate_nonserializable(transform_dict, lambda_transforms) + if lmbd: + transform = lmbd + else: + name = transform_dict["__class_fullname__"] + args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]} + cls = SERIALIZABLE_REGISTRY[name] + if "transforms" in args: + args["transforms"] = [ + ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms) + for t in args["transforms"] + ] + transform = cls(**args) + + transform = typing.cast(BasicTransform, transform) + if isinstance(transform, BasicTransform): + transform.params = params + transform.replay_mode = True + transform.applied_in_replay = applied + return transform + + def fill_with_params(self, serialized: dict, all_params: dict) -> None: + params = all_params.get(serialized.get("id")) + serialized["params"] = params + del serialized["id"] + for transform in serialized.get("transforms", []): + self.fill_with_params(transform, all_params) + + def fill_applied(self, serialized: typing.Dict[str, typing.Any]) -> bool: + if "transforms" in serialized: + applied = [self.fill_applied(t) for t in serialized["transforms"]] + serialized["applied"] = any(applied) + else: + serialized["applied"] = serialized.get("params") is not None + return serialized["applied"] + + def _to_dict(self) -> typing.Dict[str, typing.Any]: + dictionary = super(ReplayCompose, self)._to_dict() + dictionary.update({"save_key": self.save_key}) + return dictionary + + +class Sequential(BaseCompose): + """Sequentially applies all transforms to targets. + + Note: + This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose` + the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to + create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly + chose sequence to input data (see the `Example` section for an example definition of such pipeline). + + Example: + >>> import custom_albumentations as albumentations as A + >>> transform = A.Compose([ + >>> A.OneOf([ + >>> A.Sequential([ + >>> A.HorizontalFlip(p=0.5), + >>> A.ShiftScaleRotate(p=0.5), + >>> ]), + >>> A.Sequential([ + >>> A.VerticalFlip(p=0.5), + >>> A.RandomBrightnessContrast(p=0.5), + >>> ]), + >>> ], p=1) + >>> ]) + """ + + def __init__(self, transforms: TransformsSeqType, p: float = 0.5): + super().__init__(transforms, p) + + def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]: + for t in self.transforms: + data = t(**data) + return data diff --git a/custom_albumentations/core/keypoints_utils.py b/custom_albumentations/core/keypoints_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..389853aa24249d283adfb5ba3e63420d2c5ec8f4 --- /dev/null +++ b/custom_albumentations/core/keypoints_utils.py @@ -0,0 +1,286 @@ +from __future__ import division + +import math +import typing +import warnings +from typing import Any, Dict, List, Optional, Sequence, Tuple + +from .utils import DataProcessor, Params + +__all__ = [ + "angle_to_2pi_range", + "check_keypoints", + "convert_keypoints_from_albumentations", + "convert_keypoints_to_albumentations", + "filter_keypoints", + "KeypointsProcessor", + "KeypointParams", +] + +keypoint_formats = {"xy", "yx", "xya", "xys", "xyas", "xysa"} + + +def angle_to_2pi_range(angle: float) -> float: + two_pi = 2 * math.pi + return angle % two_pi + + +class KeypointParams(Params): + """ + Parameters of keypoints + + Args: + format (str): format of keypoints. Should be 'xy', 'yx', 'xya', 'xys', 'xyas', 'xysa'. + + x - X coordinate, + + y - Y coordinate + + s - Keypoint scale + + a - Keypoint orientation in radians or degrees (depending on KeypointParams.angle_in_degrees) + label_fields (list): list of fields that are joined with keypoints, e.g labels. + Should be same type as keypoints. + remove_invisible (bool): to remove invisible points after transform or not + angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints + check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform. + Default: `True` + """ + + def __init__( + self, + format: str, # skipcq: PYL-W0622 + label_fields: Optional[Sequence[str]] = None, + remove_invisible: bool = True, + angle_in_degrees: bool = True, + check_each_transform: bool = True, + ): + super(KeypointParams, self).__init__(format, label_fields) + self.remove_invisible = remove_invisible + self.angle_in_degrees = angle_in_degrees + self.check_each_transform = check_each_transform + + def _to_dict(self) -> Dict[str, Any]: + data = super(KeypointParams, self)._to_dict() + data.update( + { + "remove_invisible": self.remove_invisible, + "angle_in_degrees": self.angle_in_degrees, + "check_each_transform": self.check_each_transform, + } + ) + return data + + @classmethod + def is_serializable(cls) -> bool: + return True + + @classmethod + def get_class_fullname(cls) -> str: + return "KeypointParams" + + +class KeypointsProcessor(DataProcessor): + def __init__(self, params: KeypointParams, additional_targets: Optional[Dict[str, str]] = None): + super().__init__(params, additional_targets) + + @property + def default_data_name(self) -> str: + return "keypoints" + + def ensure_data_valid(self, data: Dict[str, Any]) -> None: + if self.params.label_fields: + if not all(i in data.keys() for i in self.params.label_fields): + raise ValueError( + "Your 'label_fields' are not valid - them must have same names as params in " + "'keypoint_params' dict" + ) + + def ensure_transforms_valid(self, transforms: Sequence[object]) -> None: + # IAA-based augmentations supports only transformation of xy keypoints. + # If your keypoints formats is other than 'xy' we emit warning to let user + # be aware that angle and size will not be modified. + + try: + from custom_albumentations.imgaug.transforms import DualIAATransform + except ImportError: + # imgaug is not installed so we skip imgaug checks. + return + + if self.params.format is not None and self.params.format != "xy": + for transform in transforms: + if isinstance(transform, DualIAATransform): + warnings.warn( + "{} transformation supports only 'xy' keypoints " + "augmentation. You have '{}' keypoints format. Scale " + "and angle WILL NOT BE transformed.".format(transform.__class__.__name__, self.params.format) + ) + break + + def filter(self, data: Sequence[Sequence], rows: int, cols: int) -> Sequence[Sequence]: + self.params: KeypointParams + return filter_keypoints(data, rows, cols, remove_invisible=self.params.remove_invisible) + + def check(self, data: Sequence[Sequence], rows: int, cols: int) -> None: + check_keypoints(data, rows, cols) + + def convert_from_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]: + params = self.params + return convert_keypoints_from_albumentations( + data, + params.format, + rows, + cols, + check_validity=params.remove_invisible, + angle_in_degrees=params.angle_in_degrees, + ) + + def convert_to_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]: + params = self.params + return convert_keypoints_to_albumentations( + data, + params.format, + rows, + cols, + check_validity=params.remove_invisible, + angle_in_degrees=params.angle_in_degrees, + ) + + +def check_keypoint(kp: Sequence, rows: int, cols: int) -> None: + """Check if keypoint coordinates are less than image shapes""" + for name, value, size in zip(["x", "y"], kp[:2], [cols, rows]): + if not 0 <= value < size: + raise ValueError( + "Expected {name} for keypoint {kp} " + "to be in the range [0.0, {size}], got {value}.".format(kp=kp, name=name, value=value, size=size) + ) + + angle = kp[2] + if not (0 <= angle < 2 * math.pi): + raise ValueError("Keypoint angle must be in range [0, 2 * PI). Got: {angle}".format(angle=angle)) + + +def check_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int) -> None: + """Check if keypoints boundaries are less than image shapes""" + for kp in keypoints: + check_keypoint(kp, rows, cols) + + +def filter_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int, remove_invisible: bool) -> Sequence[Sequence]: + if not remove_invisible: + return keypoints + + resulting_keypoints = [] + for kp in keypoints: + x, y = kp[:2] + if x < 0 or x >= cols: + continue + if y < 0 or y >= rows: + continue + resulting_keypoints.append(kp) + return resulting_keypoints + + +def convert_keypoint_to_albumentations( + keypoint: Sequence, + source_format: str, + rows: int, + cols: int, + check_validity: bool = False, + angle_in_degrees: bool = True, +) -> Tuple: + if source_format not in keypoint_formats: + raise ValueError("Unknown target_format {}. Supported formats are: {}".format(source_format, keypoint_formats)) + + if source_format == "xy": + (x, y), tail = keypoint[:2], tuple(keypoint[2:]) + a, s = 0.0, 0.0 + elif source_format == "yx": + (y, x), tail = keypoint[:2], tuple(keypoint[2:]) + a, s = 0.0, 0.0 + elif source_format == "xya": + (x, y, a), tail = keypoint[:3], tuple(keypoint[3:]) + s = 0.0 + elif source_format == "xys": + (x, y, s), tail = keypoint[:3], tuple(keypoint[3:]) + a = 0.0 + elif source_format == "xyas": + (x, y, a, s), tail = keypoint[:4], tuple(keypoint[4:]) + elif source_format == "xysa": + (x, y, s, a), tail = keypoint[:4], tuple(keypoint[4:]) + else: + raise ValueError(f"Unsupported source format. Got {source_format}") + + if angle_in_degrees: + a = math.radians(a) + + keypoint = (x, y, angle_to_2pi_range(a), s) + tail + if check_validity: + check_keypoint(keypoint, rows, cols) + return keypoint + + +def convert_keypoint_from_albumentations( + keypoint: Sequence, + target_format: str, + rows: int, + cols: int, + check_validity: bool = False, + angle_in_degrees: bool = True, +) -> Tuple: + if target_format not in keypoint_formats: + raise ValueError("Unknown target_format {}. Supported formats are: {}".format(target_format, keypoint_formats)) + + (x, y, angle, scale), tail = keypoint[:4], tuple(keypoint[4:]) + angle = angle_to_2pi_range(angle) + if check_validity: + check_keypoint((x, y, angle, scale), rows, cols) + if angle_in_degrees: + angle = math.degrees(angle) + + kp: Tuple + if target_format == "xy": + kp = (x, y) + elif target_format == "yx": + kp = (y, x) + elif target_format == "xya": + kp = (x, y, angle) + elif target_format == "xys": + kp = (x, y, scale) + elif target_format == "xyas": + kp = (x, y, angle, scale) + elif target_format == "xysa": + kp = (x, y, scale, angle) + else: + raise ValueError(f"Invalid target format. Got: {target_format}") + + return kp + tail + + +def convert_keypoints_to_albumentations( + keypoints: Sequence[Sequence], + source_format: str, + rows: int, + cols: int, + check_validity: bool = False, + angle_in_degrees: bool = True, +) -> List[Tuple]: + return [ + convert_keypoint_to_albumentations(kp, source_format, rows, cols, check_validity, angle_in_degrees) + for kp in keypoints + ] + + +def convert_keypoints_from_albumentations( + keypoints: Sequence[Sequence], + target_format: str, + rows: int, + cols: int, + check_validity: bool = False, + angle_in_degrees: bool = True, +) -> List[Tuple]: + return [ + convert_keypoint_from_albumentations(kp, target_format, rows, cols, check_validity, angle_in_degrees) + for kp in keypoints + ] diff --git a/custom_albumentations/core/serialization.py b/custom_albumentations/core/serialization.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6962eb3ae5feabcf40e3486d0b9594edb8c9ed --- /dev/null +++ b/custom_albumentations/core/serialization.py @@ -0,0 +1,247 @@ +from __future__ import absolute_import + +import json +import typing +import warnings +from abc import ABC, ABCMeta, abstractmethod +from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union + +try: + import yaml + + yaml_available = True +except ImportError: + yaml_available = False + + +from custom_albumentations import __version__ + +__all__ = ["to_dict", "from_dict", "save", "load"] + + +SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {} +NON_SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {} + + +def shorten_class_name(class_fullname: str) -> str: + splitted = class_fullname.split(".") + if len(splitted) == 1: + return class_fullname + top_module, *_, class_name = splitted + if top_module == "albumentations": + return class_name + return class_fullname + + +def get_shortest_class_fullname(cls: Type) -> str: + class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls) + return shorten_class_name(class_fullname) + + +class SerializableMeta(ABCMeta): + """ + A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY` + so they can be found later while deserializing transformation pipeline using classes full names. + """ + + def __new__(mcs, name: str, bases: Tuple[type, ...], *args, **kwargs) -> "SerializableMeta": + cls_obj = super().__new__(mcs, name, bases, *args, **kwargs) + if name != "Serializable" and ABC not in bases: + if cls_obj.is_serializable(): + SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj + else: + NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj + return cls_obj + + @classmethod + def is_serializable(mcs) -> bool: + return False + + @classmethod + def get_class_fullname(mcs) -> str: + return get_shortest_class_fullname(mcs) + + @classmethod + def _to_dict(mcs) -> Dict[str, Any]: + return {} + + +class Serializable(metaclass=SerializableMeta): + @classmethod + @abstractmethod + def is_serializable(cls) -> bool: + raise NotImplementedError + + @classmethod + @abstractmethod + def get_class_fullname(cls) -> str: + raise NotImplementedError + + @abstractmethod + def _to_dict(self) -> Dict[str, Any]: + raise NotImplementedError + + def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]: + """ + Take a transform pipeline and convert it to a serializable representation that uses only standard + python data types: dictionaries, lists, strings, integers, and floats. + + Args: + self: A transform that should be serialized. If the transform doesn't implement the `to_dict` + method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised. + If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored + but no transform parameters will be serialized. + on_not_implemented_error (str): `raise` or `warn`. + """ + if on_not_implemented_error not in {"raise", "warn"}: + raise ValueError( + "Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format( + on_not_implemented_error + ) + ) + try: + transform_dict = self._to_dict() + except NotImplementedError as e: + if on_not_implemented_error == "raise": + raise e + + transform_dict = {} + warnings.warn( + "Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. " + "Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' " + "method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__) + ) + return {"__version__": __version__, "transform": transform_dict} + + +def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]: + """ + Take a transform pipeline and convert it to a serializable representation that uses only standard + python data types: dictionaries, lists, strings, integers, and floats. + + Args: + transform: A transform that should be serialized. If the transform doesn't implement the `to_dict` + method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised. + If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored + but no transform parameters will be serialized. + on_not_implemented_error (str): `raise` or `warn`. + """ + return transform.to_dict(on_not_implemented_error) + + +def instantiate_nonserializable( + transform: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None +) -> Optional[Serializable]: + if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY: + name = transform["__name__"] + if nonserializable is None: + raise ValueError( + "To deserialize a non-serializable transform with name {name} you need to pass a dict with" + "this transform as the `lambda_transforms` argument".format(name=name) + ) + result_transform = nonserializable.get(name) + if transform is None: + raise ValueError( + "Non-serializable transform with {name} was not found in `nonserializable`".format(name=name) + ) + return result_transform + return None + + +def from_dict( + transform_dict: Dict[str, Any], + nonserializable: Optional[Dict[str, Any]] = None, + lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated", +) -> Optional[Serializable]: + """ + Args: + transform_dict (dict): A dictionary with serialized transform pipeline. + nonserializable (dict): A dictionary that contains non-serializable transforms. + This dictionary is required when you are restoring a pipeline that contains non-serializable transforms. + Keys in that dictionary should be named same as `name` arguments in respective transforms from + a serialized pipeline. + lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead. + """ + if lambda_transforms != "deprecated": + warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning) + nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms) + + register_additional_transforms() + transform = transform_dict["transform"] + lmbd = instantiate_nonserializable(transform, nonserializable) + if lmbd: + return lmbd + name = transform["__class_fullname__"] + args = {k: v for k, v in transform.items() if k != "__class_fullname__"} + cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)] + if "transforms" in args: + args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]] + return cls(**args) + + +def check_data_format(data_format: str) -> None: + if data_format not in {"json", "yaml"}: + raise ValueError("Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format(data_format)) + + +def save( + transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise" +) -> None: + """ + Take a transform pipeline, serialize it and save a serialized version to a file + using either json or yaml format. + + Args: + transform (obj): Transform to serialize. + filepath (str): Filepath to write to. + data_format (str): Serialization format. Should be either `json` or 'yaml'. + on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement + the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be + ignored and no transform arguments will be saved. + """ + check_data_format(data_format) + transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error) + dump_fn = json.dump if data_format == "json" else yaml.safe_dump + with open(filepath, "w") as f: + dump_fn(transform_dict, f) # type: ignore + + +def load( + filepath: str, + data_format: str = "json", + nonserializable: Optional[Dict[str, Any]] = None, + lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated", +) -> object: + """ + Load a serialized pipeline from a json or yaml file and construct a transform pipeline. + + Args: + filepath (str): Filepath to read from. + data_format (str): Serialization format. Should be either `json` or 'yaml'. + nonserializable (dict): A dictionary that contains non-serializable transforms. + This dictionary is required when you are restoring a pipeline that contains non-serializable transforms. + Keys in that dictionary should be named same as `name` arguments in respective transforms from + a serialized pipeline. + lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead. + """ + if lambda_transforms != "deprecated": + warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning) + nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms) + + check_data_format(data_format) + load_fn = json.load if data_format == "json" else yaml.safe_load + with open(filepath) as f: + transform_dict = load_fn(f) # type: ignore + + return from_dict(transform_dict, nonserializable=nonserializable) + + +def register_additional_transforms() -> None: + """ + Register transforms that are not imported directly into the `albumentations` module. + """ + try: + # This import will result in ImportError if `torch` is not installed + import custom_albumentations.pytorch + except ImportError: + pass diff --git a/custom_albumentations/core/transforms_interface.py b/custom_albumentations/core/transforms_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..eec2cd304019c947cc529e1606b2282f047a6a81 --- /dev/null +++ b/custom_albumentations/core/transforms_interface.py @@ -0,0 +1,293 @@ +from __future__ import absolute_import + +import random +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast +from warnings import warn + +import cv2 +import numpy as np + +from .serialization import Serializable, get_shortest_class_fullname +from .utils import format_args + +__all__ = [ + "to_tuple", + "BasicTransform", + "DualTransform", + "ImageOnlyTransform", + "NoOp", + "BoxType", + "KeypointType", + "ImageColorType", + "ScaleFloatType", + "ScaleIntType", + "ImageColorType", +] + +NumType = Union[int, float, np.ndarray] +BoxInternalType = Tuple[float, float, float, float] +BoxType = Union[BoxInternalType, Tuple[float, float, float, float, Any]] +KeypointInternalType = Tuple[float, float, float, float] +KeypointType = Union[KeypointInternalType, Tuple[float, float, float, float, Any]] +ImageColorType = Union[float, Sequence[float]] + +ScaleFloatType = Union[float, Tuple[float, float]] +ScaleIntType = Union[int, Tuple[int, int]] + +FillValueType = Optional[Union[int, float, Sequence[int], Sequence[float]]] + + +def to_tuple(param, low=None, bias=None): + """Convert input argument to min-max tuple + Args: + param (scalar, tuple or list of 2+ elements): Input value. + If value is scalar, return value would be (offset - value, offset + value). + If value is tuple, return value would be value + offset (broadcasted). + low: Second element of tuple can be passed as optional argument + bias: An offset factor added to each element + """ + if low is not None and bias is not None: + raise ValueError("Arguments low and bias are mutually exclusive") + + if param is None: + return param + + if isinstance(param, (int, float)): + if low is None: + param = -param, +param + else: + param = (low, param) if low < param else (param, low) + elif isinstance(param, Sequence): + if len(param) != 2: + raise ValueError("to_tuple expects 1 or 2 values") + param = tuple(param) + else: + raise ValueError("Argument param must be either scalar (int, float) or tuple") + + if bias is not None: + return tuple(bias + x for x in param) + + return tuple(param) + + +class BasicTransform(Serializable): + call_backup = None + interpolation: Any + fill_value: Any + mask_fill_value: Any + + def __init__(self, always_apply: bool = False, p: float = 0.5): + self.p = p + self.always_apply = always_apply + self._additional_targets: Dict[str, str] = {} + + # replay mode params + self.deterministic = False + self.save_key = "replay" + self.params: Dict[Any, Any] = {} + self.replay_mode = False + self.applied_in_replay = False + + def __call__(self, *args, force_apply: bool = False, **kwargs) -> Dict[str, Any]: + if args: + raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)") + if self.replay_mode: + if self.applied_in_replay: + return self.apply_with_params(self.params, **kwargs) + + return kwargs + + if (random.random() < self.p) or self.always_apply or force_apply: + params = self.get_params() + + if self.targets_as_params: + assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format( + self.__class__.__name__, self.targets_as_params + ) + targets_as_params = {k: kwargs[k] for k in self.targets_as_params} + params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params) + params.update(params_dependent_on_targets) + if self.deterministic: + if self.targets_as_params: + warn( + self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data" + " because its' params depend on targets." + ) + kwargs[self.save_key][id(self)] = deepcopy(params) + return self.apply_with_params(params, **kwargs) + + return kwargs + + def apply_with_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: # skipcq: PYL-W0613 + if params is None: + return kwargs + params = self.update_params(params, **kwargs) + res = {} + for key, arg in kwargs.items(): + if arg is not None: + target_function = self._get_target_function(key) + target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])} + res[key] = target_function(arg, **dict(params, **target_dependencies)) + else: + res[key] = None + return res + + def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform": + assert save_key != "params", "params save_key is reserved" + self.deterministic = flag + self.save_key = save_key + return self + + def __repr__(self) -> str: + state = self.get_base_init_args() + state.update(self.get_transform_init_args()) + return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state)) + + def _get_target_function(self, key: str) -> Callable: + transform_key = key + if key in self._additional_targets: + transform_key = self._additional_targets.get(key, key) + + target_function = self.targets.get(transform_key, lambda x, **p: x) + return target_function + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + raise NotImplementedError + + def get_params(self) -> Dict: + return {} + + @property + def targets(self) -> Dict[str, Callable]: + # you must specify targets in subclass + # for example: ('image', 'mask') + # ('image', 'boxes') + raise NotImplementedError + + def update_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: + if hasattr(self, "interpolation"): + params["interpolation"] = self.interpolation + if hasattr(self, "fill_value"): + params["fill_value"] = self.fill_value + if hasattr(self, "mask_fill_value"): + params["mask_fill_value"] = self.mask_fill_value + params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]}) + return params + + @property + def target_dependence(self) -> Dict: + return {} + + def add_targets(self, additional_targets: Dict[str, str]): + """Add targets to transform them the same way as one of existing targets + ex: {'target_image': 'image'} + ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'} + by the way you must have at least one object with key 'image' + + Args: + additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'} + """ + self._additional_targets = additional_targets + + @property + def targets_as_params(self) -> List[str]: + return [] + + def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]: + raise NotImplementedError( + "Method get_params_dependent_on_targets is not implemented in class " + self.__class__.__name__ + ) + + @classmethod + def get_class_fullname(cls) -> str: + return get_shortest_class_fullname(cls) + + @classmethod + def is_serializable(cls): + return True + + def get_transform_init_args_names(self) -> Tuple[str, ...]: + raise NotImplementedError( + "Class {name} is not serializable because the `get_transform_init_args_names` method is not " + "implemented".format(name=self.get_class_fullname()) + ) + + def get_base_init_args(self) -> Dict[str, Any]: + return {"always_apply": self.always_apply, "p": self.p} + + def get_transform_init_args(self) -> Dict[str, Any]: + return {k: getattr(self, k) for k in self.get_transform_init_args_names()} + + def _to_dict(self) -> Dict[str, Any]: + state = {"__class_fullname__": self.get_class_fullname()} + state.update(self.get_base_init_args()) + state.update(self.get_transform_init_args()) + return state + + def get_dict_with_id(self) -> Dict[str, Any]: + d = self._to_dict() + d["id"] = id(self) + return d + + +class DualTransform(BasicTransform): + """Transform for segmentation task.""" + + @property + def targets(self) -> Dict[str, Callable]: + return { + "image": self.apply, + "mask": self.apply_to_mask, + "masks": self.apply_to_masks, + "bboxes": self.apply_to_bboxes, + "keypoints": self.apply_to_keypoints, + } + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__) + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + raise NotImplementedError("Method apply_to_keypoint is not implemented in class " + self.__class__.__name__) + + def apply_to_bboxes(self, bboxes: Sequence[BoxType], **params) -> List[BoxType]: + return [self.apply_to_bbox(tuple(bbox[:4]), **params) + tuple(bbox[4:]) for bbox in bboxes] # type: ignore + + def apply_to_keypoints(self, keypoints: Sequence[KeypointType], **params) -> List[KeypointType]: + return [ # type: ignore + self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) # type: ignore + for keypoint in keypoints + ] + + def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray: + return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()}) + + def apply_to_masks(self, masks: Sequence[np.ndarray], **params) -> List[np.ndarray]: + return [self.apply_to_mask(mask, **params) for mask in masks] + + +class ImageOnlyTransform(BasicTransform): + """Transform applied to image only.""" + + @property + def targets(self) -> Dict[str, Callable]: + return {"image": self.apply} + + +class NoOp(DualTransform): + """Does nothing""" + + def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType: + return keypoint + + def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType: + return bbox + + def apply(self, img: np.ndarray, **params) -> np.ndarray: + return img + + def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray: + return img + + def get_transform_init_args_names(self) -> Tuple: + return () diff --git a/custom_albumentations/core/utils.py b/custom_albumentations/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..971700f7f348469843a4a1769875387e07a4c11c --- /dev/null +++ b/custom_albumentations/core/utils.py @@ -0,0 +1,137 @@ +from __future__ import absolute_import + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence, Tuple + +import numpy as np + +from .serialization import Serializable + + +def get_shape(img: Any) -> Tuple[int, int]: + if isinstance(img, np.ndarray): + rows, cols = img.shape[:2] + return rows, cols + + try: + import torch + + if torch.is_tensor(img): + rows, cols = img.shape[-2:] + return rows, cols + except ImportError: + pass + + raise RuntimeError( + f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}" + ) + + +def format_args(args_dict: Dict): + formatted_args = [] + for k, v in args_dict.items(): + if isinstance(v, str): + v = f"'{v}'" + formatted_args.append(f"{k}={v}") + return ", ".join(formatted_args) + + +class Params(Serializable, ABC): + def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None): + self.format = format + self.label_fields = label_fields + + def _to_dict(self) -> Dict[str, Any]: + return {"format": self.format, "label_fields": self.label_fields} + + +class DataProcessor(ABC): + def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None): + self.params = params + self.data_fields = [self.default_data_name] + if additional_targets is not None: + for k, v in additional_targets.items(): + if v == self.default_data_name: + self.data_fields.append(k) + + @property + @abstractmethod + def default_data_name(self) -> str: + raise NotImplementedError + + def ensure_data_valid(self, data: Dict[str, Any]) -> None: + pass + + def ensure_transforms_valid(self, transforms: Sequence[object]) -> None: + pass + + def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]: + rows, cols = get_shape(data["image"]) + + for data_name in self.data_fields: + data[data_name] = self.filter(data[data_name], rows, cols) + data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from") + + data = self.remove_label_fields_from_data(data) + return data + + def preprocess(self, data: Dict[str, Any]) -> None: + data = self.add_label_fields_to_data(data) + + rows, cols = data["image"].shape[:2] + for data_name in self.data_fields: + data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to") + + def check_and_convert(self, data: Sequence, rows: int, cols: int, direction: str = "to") -> Sequence: + if self.params.format == "albumentations": + self.check(data, rows, cols) + return data + + if direction == "to": + return self.convert_to_albumentations(data, rows, cols) + elif direction == "from": + return self.convert_from_albumentations(data, rows, cols) + else: + raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`") + + @abstractmethod + def filter(self, data: Sequence, rows: int, cols: int) -> Sequence: + pass + + @abstractmethod + def check(self, data: Sequence, rows: int, cols: int) -> None: + pass + + @abstractmethod + def convert_to_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence: + pass + + @abstractmethod + def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence: + pass + + def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.params.label_fields is None: + return data + for data_name in self.data_fields: + for field in self.params.label_fields: + assert len(data[data_name]) == len(data[field]) + data_with_added_field = [] + for d, field_value in zip(data[data_name], data[field]): + data_with_added_field.append(list(d) + [field_value]) + data[data_name] = data_with_added_field + return data + + def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]: + if self.params.label_fields is None: + return data + for data_name in self.data_fields: + label_fields_len = len(self.params.label_fields) + for idx, field in enumerate(self.params.label_fields): + field_values = [] + for bbox in data[data_name]: + field_values.append(bbox[-label_fields_len + idx]) + data[field] = field_values + if label_fields_len: + data[data_name] = [d[:-label_fields_len] for d in data[data_name]] + return data diff --git a/custom_albumentations/imgaug/__init__.py b/custom_albumentations/imgaug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_albumentations/imgaug/stubs.py b/custom_albumentations/imgaug/stubs.py new file mode 100644 index 0000000000000000000000000000000000000000..1eac33300d58bf8eb4f1a98f50adb5a1f155fb2b --- /dev/null +++ b/custom_albumentations/imgaug/stubs.py @@ -0,0 +1,77 @@ +__all__ = [ + "IAAEmboss", + "IAASuperpixels", + "IAASharpen", + "IAAAdditiveGaussianNoise", + "IAACropAndPad", + "IAAFliplr", + "IAAFlipud", + "IAAAffine", + "IAAPiecewiseAffine", + "IAAPerspective", +] + + +class IAAStub: + def __init__(self, *args, **kwargs): + cls_name = self.__class__.__name__ + doc_link = "https://albumentations.ai/docs/api_reference/augmentations" + self.doc_link + raise RuntimeError( + f"You are trying to use a deprecated augmentation '{cls_name}' which depends on the imgaug library, " + f"but imgaug is not installed.\n\n" + "There are two options to fix this error:\n" + "1. [Recommended]. Switch to the Albumentations' implementation of the augmentation with the same API: " + f"{self.alternative} - {doc_link}\n" + "2. Install a version of Albumentations that contains imgaug by running " + "'pip install -U albumentations[imgaug]'." + ) + + +class IAACropAndPad(IAAStub): + alternative = "CropAndPad" + doc_link = "/crops/transforms/#albumentations.augmentations.crops.transforms.CropAndPad" + + +class IAAFliplr(IAAStub): + alternative = "HorizontalFlip" + doc_link = "/transforms/#albumentations.augmentations.transforms.HorizontalFlip" + + +class IAAFlipud(IAAStub): + alternative = "VerticalFlip" + doc_link = "/transforms/#albumentations.augmentations.transforms.VerticalFlip" + + +class IAAEmboss(IAAStub): + alternative = "Emboss" + doc_link = "/transforms/#albumentations.augmentations.transforms.Emboss" + + +class IAASuperpixels(IAAStub): + alternative = "Superpixels" + doc_link = "/transforms/#albumentations.augmentations.transforms.Superpixels" + + +class IAASharpen(IAAStub): + alternative = "Sharpen" + doc_link = "/transforms/#albumentations.augmentations.transforms.Sharpen" + + +class IAAAdditiveGaussianNoise(IAAStub): + alternative = "GaussNoise" + doc_link = "/transforms/#albumentations.augmentations.transforms.GaussNoise" + + +class IAAPiecewiseAffine(IAAStub): + alternative = "PiecewiseAffine" + doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.PiecewiseAffine" + + +class IAAAffine(IAAStub): + alternative = "Affine" + doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Affine" + + +class IAAPerspective(IAAStub): + alternative = "Perspective" + doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Perspective" diff --git a/custom_albumentations/imgaug/transforms.py b/custom_albumentations/imgaug/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..96f0c1393697dc0b186ac68f24958c19b2a1def7 --- /dev/null +++ b/custom_albumentations/imgaug/transforms.py @@ -0,0 +1,391 @@ +try: + import imgaug as ia +except ImportError as e: + raise ImportError( + "You are trying to import an augmentation that depends on the imgaug library, but imgaug is not installed. To " + "install a version of Albumentations that contains imgaug please run 'pip install -U albumentations[imgaug]'" + ) from e + +try: + from imgaug import augmenters as iaa +except ImportError: + import imgaug.imgaug.augmenters as iaa + +import warnings + +from custom_albumentations.core.bbox_utils import ( + convert_bboxes_from_albumentations, + convert_bboxes_to_albumentations, +) +from custom_albumentations.core.keypoints_utils import ( + convert_keypoints_from_albumentations, + convert_keypoints_to_albumentations, +) + +from ..augmentations import Perspective +from ..core.transforms_interface import ( + BasicTransform, + DualTransform, + ImageOnlyTransform, + to_tuple, +) + +__all__ = [ + "BasicIAATransform", + "DualIAATransform", + "ImageOnlyIAATransform", + "IAAEmboss", + "IAASuperpixels", + "IAASharpen", + "IAAAdditiveGaussianNoise", + "IAACropAndPad", + "IAAFliplr", + "IAAFlipud", + "IAAAffine", + "IAAPiecewiseAffine", + "IAAPerspective", +] + + +class BasicIAATransform(BasicTransform): + def __init__(self, always_apply=False, p=0.5): + super(BasicIAATransform, self).__init__(always_apply, p) + + @property + def processor(self): + return iaa.Noop() + + def update_params(self, params, **kwargs): + params = super(BasicIAATransform, self).update_params(params, **kwargs) + params["deterministic_processor"] = self.processor.to_deterministic() + return params + + def apply(self, img, deterministic_processor=None, **params): + return deterministic_processor.augment_image(img) + + +class DualIAATransform(DualTransform, BasicIAATransform): + def apply_to_bboxes(self, bboxes, deterministic_processor=None, rows=0, cols=0, **params): + if len(bboxes) > 0: + bboxes = convert_bboxes_from_albumentations(bboxes, "pascal_voc", rows=rows, cols=cols) + + bboxes_t = ia.BoundingBoxesOnImage([ia.BoundingBox(*bbox[:4]) for bbox in bboxes], (rows, cols)) + bboxes_t = deterministic_processor.augment_bounding_boxes([bboxes_t])[0].bounding_boxes + bboxes_t = [ + [bbox.x1, bbox.y1, bbox.x2, bbox.y2] + list(bbox_orig[4:]) + for (bbox, bbox_orig) in zip(bboxes_t, bboxes) + ] + + bboxes = convert_bboxes_to_albumentations(bboxes_t, "pascal_voc", rows=rows, cols=cols) + return bboxes + + """Applies transformation to keypoints. + Notes: + Since IAA supports only xy keypoints, scale and orientation will remain unchanged. + TODO: + Emit a warning message if child classes of DualIAATransform are instantiated + inside Compose with keypoints format other than 'xy'. + """ + + def apply_to_keypoints(self, keypoints, deterministic_processor=None, rows=0, cols=0, **params): + if len(keypoints) > 0: + keypoints = convert_keypoints_from_albumentations(keypoints, "xy", rows=rows, cols=cols) + keypoints_t = ia.KeypointsOnImage([ia.Keypoint(*kp[:2]) for kp in keypoints], (rows, cols)) + keypoints_t = deterministic_processor.augment_keypoints([keypoints_t])[0].keypoints + + bboxes_t = [[kp.x, kp.y] + list(kp_orig[2:]) for (kp, kp_orig) in zip(keypoints_t, keypoints)] + + keypoints = convert_keypoints_to_albumentations(bboxes_t, "xy", rows=rows, cols=cols) + return keypoints + + +class ImageOnlyIAATransform(ImageOnlyTransform, BasicIAATransform): + pass + + +class IAACropAndPad(DualIAATransform): + """This augmentation is deprecated. Please use CropAndPad instead.""" + + def __init__(self, px=None, percent=None, pad_mode="constant", pad_cval=0, keep_size=True, always_apply=False, p=1): + super(IAACropAndPad, self).__init__(always_apply, p) + self.px = px + self.percent = percent + self.pad_mode = pad_mode + self.pad_cval = pad_cval + self.keep_size = keep_size + warnings.warn("IAACropAndPad is deprecated. Please use CropAndPad instead", FutureWarning) + + @property + def processor(self): + return iaa.CropAndPad(self.px, self.percent, self.pad_mode, self.pad_cval, self.keep_size) + + def get_transform_init_args_names(self): + return ("px", "percent", "pad_mode", "pad_cval", "keep_size") + + +class IAAFliplr(DualIAATransform): + """This augmentation is deprecated. Please use HorizontalFlip instead.""" + + def __init__(self, always_apply=False, p=0.5): + super().__init__(always_apply, p) + warnings.warn("IAAFliplr is deprecated. Please use HorizontalFlip instead.", FutureWarning) + + @property + def processor(self): + return iaa.Fliplr(1) + + def get_transform_init_args_names(self): + return () + + +class IAAFlipud(DualIAATransform): + """This augmentation is deprecated. Please use VerticalFlip instead.""" + + def __init__(self, always_apply=False, p=0.5): + super().__init__(always_apply, p) + warnings.warn("IAAFlipud is deprecated. Please use VerticalFlip instead.", FutureWarning) + + @property + def processor(self): + return iaa.Flipud(1) + + def get_transform_init_args_names(self): + return () + + +class IAAEmboss(ImageOnlyIAATransform): + """Emboss the input image and overlays the result with the original image. + This augmentation is deprecated. Please use Emboss instead. + + Args: + alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is + visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5). + strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5): + super(IAAEmboss, self).__init__(always_apply, p) + self.alpha = to_tuple(alpha, 0.0) + self.strength = to_tuple(strength, 0.0) + warnings.warn("This augmentation is deprecated. Please use Emboss instead", FutureWarning) + + @property + def processor(self): + return iaa.Emboss(self.alpha, self.strength) + + def get_transform_init_args_names(self): + return ("alpha", "strength") + + +class IAASuperpixels(ImageOnlyIAATransform): + """Completely or partially transform the input image to its superpixel representation. Uses skimage's version + of the SLIC algorithm. May be slow. + + This augmentation is deprecated. Please use Superpixels instead. + + Args: + p_replace (float): defines the probability of any superpixel area being replaced by the superpixel, i.e. by + the average pixel color within its area. Default: 0.1. + n_segments (int): target number of superpixels to generate. Default: 100. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, p_replace=0.1, n_segments=100, always_apply=False, p=0.5): + super(IAASuperpixels, self).__init__(always_apply, p) + self.p_replace = p_replace + self.n_segments = n_segments + warnings.warn("IAASuperpixels is deprecated. Please use Superpixels instead.", FutureWarning) + + @property + def processor(self): + return iaa.Superpixels(p_replace=self.p_replace, n_segments=self.n_segments) + + def get_transform_init_args_names(self): + return ("p_replace", "n_segments") + + +class IAASharpen(ImageOnlyIAATransform): + """Sharpen the input image and overlays the result with the original image. + This augmentation is deprecated. Please use Sharpen instead + Args: + alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is + visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5). + lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5): + super(IAASharpen, self).__init__(always_apply, p) + self.alpha = to_tuple(alpha, 0) + self.lightness = to_tuple(lightness, 0) + warnings.warn("IAASharpen is deprecated. Please use Sharpen instead", FutureWarning) + + @property + def processor(self): + return iaa.Sharpen(self.alpha, self.lightness) + + def get_transform_init_args_names(self): + return ("alpha", "lightness") + + +class IAAAdditiveGaussianNoise(ImageOnlyIAATransform): + """Add gaussian noise to the input image. + + This augmentation is deprecated. Please use GaussNoise instead. + + Args: + loc (int): mean of the normal distribution that generates the noise. Default: 0. + scale ((float, float)): standard deviation of the normal distribution that generates the noise. + Default: (0.01 * 255, 0.05 * 255). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image + """ + + def __init__(self, loc=0, scale=(0.01 * 255, 0.05 * 255), per_channel=False, always_apply=False, p=0.5): + super(IAAAdditiveGaussianNoise, self).__init__(always_apply, p) + self.loc = loc + self.scale = to_tuple(scale, 0.0) + self.per_channel = per_channel + warnings.warn("IAAAdditiveGaussianNoise is deprecated. Please use GaussNoise instead", FutureWarning) + + @property + def processor(self): + return iaa.AdditiveGaussianNoise(self.loc, self.scale, self.per_channel) + + def get_transform_init_args_names(self): + return ("loc", "scale", "per_channel") + + +class IAAPiecewiseAffine(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + This augmentation is deprecated. Please use PiecewiseAffine instead. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): factor range that determines how far each point is moved. Default: (0.03, 0.05). + nb_rows (int): number of rows of points that the regular grid should have. Default: 4. + nb_cols (int): number of columns of points that the regular grid should have. Default: 4. + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, scale=(0.03, 0.05), nb_rows=4, nb_cols=4, order=1, cval=0, mode="constant", always_apply=False, p=0.5 + ): + super(IAAPiecewiseAffine, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 0.0) + self.nb_rows = nb_rows + self.nb_cols = nb_cols + self.order = order + self.cval = cval + self.mode = mode + warnings.warn("This IAAPiecewiseAffine is deprecated. Please use PiecewiseAffine instead", FutureWarning) + + @property + def processor(self): + return iaa.PiecewiseAffine(self.scale, self.nb_rows, self.nb_cols, self.order, self.cval, self.mode) + + def get_transform_init_args_names(self): + return ("scale", "nb_rows", "nb_cols", "order", "cval", "mode") + + +class IAAAffine(DualIAATransform): + """Place a regular grid of points on the input and randomly move the neighbourhood of these point around + via affine transformations. + + This augmentation is deprecated. Please use Affine instead. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__( + self, + scale=1.0, + translate_percent=None, + translate_px=None, + rotate=0.0, + shear=0.0, + order=1, + cval=0, + mode="reflect", + always_apply=False, + p=0.5, + ): + super(IAAAffine, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.translate_percent = to_tuple(translate_percent, 0) + self.translate_px = to_tuple(translate_px, 0) + self.rotate = to_tuple(rotate) + self.shear = to_tuple(shear) + self.order = order + self.cval = cval + self.mode = mode + warnings.warn("This IAAAffine is deprecated. Please use Affine instead", FutureWarning) + + @property + def processor(self): + return iaa.Affine( + self.scale, + self.translate_percent, + self.translate_px, + self.rotate, + self.shear, + self.order, + self.cval, + self.mode, + ) + + def get_transform_init_args_names(self): + return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") + + +class IAAPerspective(Perspective): + """Perform a random four point perspective transform of the input. + This augmentation is deprecated. Please use Perspective instead. + + Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} + + Args: + scale ((float, float): standard deviation of the normal distributions. These are used to sample + the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). + p (float): probability of applying the transform. Default: 0.5. + + Targets: + image, mask + """ + + def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5): + super(IAAPerspective, self).__init__(always_apply, p) + self.scale = to_tuple(scale, 1.0) + self.keep_size = keep_size + warnings.warn("This IAAPerspective is deprecated. Please use Perspective instead", FutureWarning) + + @property + def processor(self): + return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size) + + def get_transform_init_args_names(self): + return ("scale", "keep_size") diff --git a/custom_albumentations/pytorch/__init__.py b/custom_albumentations/pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..25cf792a3a8e44b1973a39e3ea783a3dcd5036cd --- /dev/null +++ b/custom_albumentations/pytorch/__init__.py @@ -0,0 +1,3 @@ +from __future__ import absolute_import + +from .transforms import * diff --git a/custom_albumentations/pytorch/functional.py b/custom_albumentations/pytorch/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..edf3ed9466c99b5b49586f9ca6b78260dcc00243 --- /dev/null +++ b/custom_albumentations/pytorch/functional.py @@ -0,0 +1,31 @@ +from __future__ import division + +import numpy as np +import torch +import torchvision.transforms.functional as F + + +def img_to_tensor(im, normalize=None): + tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32)) + if normalize is not None: + return F.normalize(tensor, **normalize) + return tensor + + +def mask_to_tensor(mask, num_classes, sigmoid): + if num_classes > 1: + if not sigmoid: + # softmax + long_mask = np.zeros((mask.shape[:2]), dtype=np.int64) + if len(mask.shape) == 3: + for c in range(mask.shape[2]): + long_mask[mask[..., c] > 0] = c + else: + long_mask[mask > 127] = 1 + long_mask[mask == 0] = 0 + mask = long_mask + else: + mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32) + else: + mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32) + return torch.from_numpy(mask) diff --git a/custom_albumentations/pytorch/transforms.py b/custom_albumentations/pytorch/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..0c0296acf60f390b1fb52858b722eb67dc765d60 --- /dev/null +++ b/custom_albumentations/pytorch/transforms.py @@ -0,0 +1,104 @@ +from __future__ import absolute_import + +import warnings + +import numpy as np +import torch +from torchvision.transforms import functional as F + +from ..core.transforms_interface import BasicTransform + +__all__ = ["ToTensorV2"] + + +def img_to_tensor(im, normalize=None): + tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32)) + if normalize is not None: + return F.normalize(tensor, **normalize) + return tensor + + +def mask_to_tensor(mask, num_classes, sigmoid): + if num_classes > 1: + if not sigmoid: + # softmax + long_mask = np.zeros((mask.shape[:2]), dtype=np.int64) + if len(mask.shape) == 3: + for c in range(mask.shape[2]): + long_mask[mask[..., c] > 0] = c + else: + long_mask[mask > 127] = 1 + long_mask[mask == 0] = 0 + mask = long_mask + else: + mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32) + else: + mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32) + return torch.from_numpy(mask) + + +class ToTensor(BasicTransform): + """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type. + This transform is now removed from custom_albumentations. If you need it downgrade the library to version 0.5.2. + + Args: + num_classes (int): only for segmentation + sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not. + normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize + + """ + + def __init__(self, num_classes=1, sigmoid=True, normalize=None): + raise RuntimeError( + "`ToTensor` is obsolete and it was removed from custom_albumentations. Please use `ToTensorV2` instead - " + "https://albumentations.ai/docs/api_reference/pytorch/transforms/" + "#albumentations.pytorch.transforms.ToTensorV2. " + "\n\nIf you need `ToTensor` downgrade Albumentations to version 0.5.2." + ) + + +class ToTensorV2(BasicTransform): + """Convert image and mask to `torch.Tensor`. The numpy `HWC` image is converted to pytorch `CHW` tensor. + If the image is in `HW` format (grayscale image), it will be converted to pytorch `HW` tensor. + This is a simplified and improved version of the old `ToTensor` + transform (`ToTensor` was deprecated, and now it is not present in Albumentations. You should use `ToTensorV2` + instead). + + Args: + transpose_mask (bool): If True and an input mask has three dimensions, this transform will transpose dimensions + so the shape `[height, width, num_channels]` becomes `[num_channels, height, width]`. The latter format is a + standard format for PyTorch Tensors. Default: False. + always_apply (bool): Indicates whether this transformation should be always applied. Default: True. + p (float): Probability of applying the transform. Default: 1.0. + """ + + def __init__(self, transpose_mask=False, always_apply=True, p=1.0): + super(ToTensorV2, self).__init__(always_apply=always_apply, p=p) + self.transpose_mask = transpose_mask + + @property + def targets(self): + return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks} + + def apply(self, img, **params): # skipcq: PYL-W0613 + if len(img.shape) not in [2, 3]: + raise ValueError("Albumentations only supports images in HW or HWC format") + + if len(img.shape) == 2: + img = np.expand_dims(img, 2) + + return torch.from_numpy(img.transpose(2, 0, 1)) + + def apply_to_mask(self, mask, **params): # skipcq: PYL-W0613 + if self.transpose_mask and mask.ndim == 3: + mask = mask.transpose(2, 0, 1) + return torch.from_numpy(mask) + + def apply_to_masks(self, masks, **params): + return [self.apply_to_mask(mask, **params) for mask in masks] + + def get_transform_init_args_names(self): + return ("transpose_mask",) + + def get_params_dependent_on_targets(self, params): + return {} diff --git a/custom_albumentations/random_utils.py b/custom_albumentations/random_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c7c8798aa33ef36ec4540cc402ac7419f4c3b0f --- /dev/null +++ b/custom_albumentations/random_utils.py @@ -0,0 +1,96 @@ +# Use `Any` as the return type to avoid mypy problems with Union data types, +# because numpy can return single number and ndarray + +import random as py_random +from typing import Any, Optional, Sequence, Type, Union + +import numpy as np + +from .core.transforms_interface import NumType + +IntNumType = Union[int, np.ndarray] +Size = Union[int, Sequence[int]] + + +def get_random_state() -> np.random.RandomState: + return np.random.RandomState(py_random.randint(0, (1 << 32) - 1)) + + +def uniform( + low: NumType = 0.0, + high: NumType = 1.0, + size: Optional[Size] = None, + random_state: Optional[np.random.RandomState] = None, +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.uniform(low, high, size) + + +def rand(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.rand(d0, d1, *more, **kwargs) # type: ignore + + +def randn(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.randn(d0, d1, *more, **kwargs) # type: ignore + + +def normal( + loc: NumType = 0.0, + scale: NumType = 1.0, + size: Optional[Size] = None, + random_state: Optional[np.random.RandomState] = None, +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.normal(loc, scale, size) + + +def poisson( + lam: NumType = 1.0, size: Optional[Size] = None, random_state: Optional[np.random.RandomState] = None +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.poisson(lam, size) + + +def permutation( + x: Union[int, Sequence[float], np.ndarray], random_state: Optional[np.random.RandomState] = None +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.permutation(x) + + +def randint( + low: IntNumType, + high: Optional[IntNumType] = None, + size: Optional[Size] = None, + dtype: Type = np.int32, + random_state: Optional[np.random.RandomState] = None, +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.randint(low, high, size, dtype) + + +def random(size: Optional[NumType] = None, random_state: Optional[np.random.RandomState] = None) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.random(size) # type: ignore + + +def choice( + a: NumType, + size: Optional[Size] = None, + replace: bool = True, + p: Optional[Union[Sequence[float], np.ndarray]] = None, + random_state: Optional[np.random.RandomState] = None, +) -> Any: + if random_state is None: + random_state = get_random_state() + return random_state.choice(a, size, replace, p) # type: ignore diff --git a/custom_controlnet_aux/__init__.py b/custom_controlnet_aux/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8 --- /dev/null +++ b/custom_controlnet_aux/__init__.py @@ -0,0 +1 @@ +#Dummy file ensuring this package will be recognized \ No newline at end of file diff --git a/custom_controlnet_aux/anime_face_segment/__init__.py b/custom_controlnet_aux/anime_face_segment/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b3d8ace514883275a4021d0e9a0bc2640e59ca9a --- /dev/null +++ b/custom_controlnet_aux/anime_face_segment/__init__.py @@ -0,0 +1,66 @@ +from .network import UNet +from .util import seg2img +import torch +import os +import cv2 +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME +from huggingface_hub import hf_hub_download +from PIL import Image +from einops import rearrange +from .anime_segmentation import AnimeSegmentation +import numpy as np + +class AnimeFaceSegmentor: + def __init__(self, model, seg_model): + self.model = model + self.seg_model = seg_model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"): + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators") + seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename) + + model = UNet() + ckpt = torch.load(model_path, map_location="cpu") + model.load_state_dict(ckpt) + model.eval() + + seg_model = AnimeSegmentation(seg_model_path) + seg_model.net.eval() + return cls(model, seg_model) + + def to(self, device): + self.model.to(device) + self.seg_model.net.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + with torch.no_grad(): + if remove_background: + print(input_image.shape) + mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized + image_feed = torch.from_numpy(input_image).float().to(self.device) + image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + image_feed = image_feed / 255 + seg = self.model(image_feed).squeeze(dim=0) + result = seg2img(seg.cpu().detach().numpy()) + + detected_map = HWC3(result) + detected_map = remove_pad(detected_map) + if remove_background: + mask = remove_pad(mask) + H, W, C = detected_map.shape + tmp = np.zeros([H, W, C + 1]) + tmp[:,:,:C] = detected_map + tmp[:,:,3:] = mask + detected_map = tmp + + if output_type == "pil": + detected_map = Image.fromarray(detected_map[..., :3]) + + return detected_map diff --git a/custom_controlnet_aux/anime_face_segment/anime_segmentation.py b/custom_controlnet_aux/anime_face_segment/anime_segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4ab9c0effe4a84b3a76fc1b2cffd373109a035 --- /dev/null +++ b/custom_controlnet_aux/anime_face_segment/anime_segmentation.py @@ -0,0 +1,58 @@ +#https://github.com/SkyTNT/anime-segmentation/tree/main +#Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt) +import torch.nn as nn +import torch +from .isnet import ISNetDIS +import numpy as np +import cv2 +from comfy.model_management import get_torch_device +DEVICE = get_torch_device() + +class AnimeSegmentation: + def __init__(self, ckpt_path): + super(AnimeSegmentation).__init__() + sd = torch.load(ckpt_path, map_location="cpu") + self.net = ISNetDIS() + #gt_encoder isn't used during inference + self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")}) + self.net = self.net.to(DEVICE) + self.net.eval() + + def get_mask(self, input_img, s=640): + input_img = (input_img / 255).astype(np.float32) + if s == 0: + img_input = np.transpose(input_img, (2, 0, 1)) + img_input = img_input[np.newaxis, :] + tmpImg = torch.from_numpy(img_input).float().to(DEVICE) + with torch.no_grad(): + pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47 + pred = pred.cpu().numpy()[0] + pred = np.transpose(pred, (1, 2, 0)) + #pred = pred[:, :, np.newaxis] + return pred + + h, w = h0, w0 = input_img.shape[:-1] + h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s) + ph, pw = s - h, s - w + img_input = np.zeros([s, s, 3], dtype=np.float32) + img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h)) + img_input = np.transpose(img_input, (2, 0, 1)) + img_input = img_input[np.newaxis, :] + tmpImg = torch.from_numpy(img_input).float().to(DEVICE) + with torch.no_grad(): + pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47 + pred = pred.cpu().numpy()[0] + pred = np.transpose(pred, (1, 2, 0)) + pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] + #pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis] + pred = cv2.resize(pred, (w0, h0)) + return pred + + def __call__(self, np_img, img_size): + mask = self.get_mask(np_img, int(img_size)) + np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8) + mask = (mask * 255).astype(np.uint8) + #np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8) + #mask = mask.repeat(3, axis=2) + return mask, np_img + diff --git a/custom_controlnet_aux/anime_face_segment/isnet.py b/custom_controlnet_aux/anime_face_segment/isnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8a0a504ecadf426358c8accee62d3f35129b0a11 --- /dev/null +++ b/custom_controlnet_aux/anime_face_segment/isnet.py @@ -0,0 +1,619 @@ +# Codes are borrowed from +# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import models + +bce_loss = nn.BCEWithLogitsLoss(reduction="mean") + + +def muti_loss_fusion(preds, target): + loss0 = 0.0 + loss = 0.0 + + for i in range(0, len(preds)): + if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]: + tmp_target = F.interpolate( + target, size=preds[i].size()[2:], mode="bilinear", align_corners=True + ) + loss = loss + bce_loss(preds[i], tmp_target) + else: + loss = loss + bce_loss(preds[i], target) + if i == 0: + loss0 = loss + return loss0, loss + + +fea_loss = nn.MSELoss(reduction="mean") +kl_loss = nn.KLDivLoss(reduction="mean") +l1_loss = nn.L1Loss(reduction="mean") +smooth_l1_loss = nn.SmoothL1Loss(reduction="mean") + + +def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"): + loss0 = 0.0 + loss = 0.0 + + for i in range(0, len(preds)): + if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]: + tmp_target = F.interpolate( + target, size=preds[i].size()[2:], mode="bilinear", align_corners=True + ) + loss = loss + bce_loss(preds[i], tmp_target) + else: + loss = loss + bce_loss(preds[i], target) + if i == 0: + loss0 = loss + + for i in range(0, len(dfs)): + df = dfs[i] + fs_i = fs[i] + if mode == "MSE": + loss = loss + fea_loss( + df, fs_i + ) ### add the mse loss of features as additional constraints + elif mode == "KL": + loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1)) + elif mode == "MAE": + loss = loss + l1_loss(df, fs_i) + elif mode == "SmoothL1": + loss = loss + smooth_l1_loss(df, fs_i) + + return loss0, loss + + +class REBNCONV(nn.Module): + def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1): + super(REBNCONV, self).__init__() + + self.conv_s1 = nn.Conv2d( + in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride + ) + self.bn_s1 = nn.BatchNorm2d(out_ch) + self.relu_s1 = nn.ReLU(inplace=True) + + def forward(self, x): + hx = x + xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) + + return xout + + +## upsample tensor 'src' to have the same spatial size with tensor 'tar' +def _upsample_like(src, tar): + src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False) + + return src + + +### RSU-7 ### +class RSU7(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): + super(RSU7, self).__init__() + + self.in_ch = in_ch + self.mid_ch = mid_ch + self.out_ch = out_ch + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2 + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + b, c, h, w = x.shape + + hx = x + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + hx = self.pool5(hx5) + + hx6 = self.rebnconv6(hx) + + hx7 = self.rebnconv7(hx6) + + hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1)) + hx6dup = _upsample_like(hx6d, hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-6 ### +class RSU6(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU6, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + hx = self.pool4(hx4) + + hx5 = self.rebnconv5(hx) + + hx6 = self.rebnconv6(hx5) + + hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-5 ### +class RSU5(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU5, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + hx = self.pool3(hx3) + + hx4 = self.rebnconv4(hx) + + hx5 = self.rebnconv5(hx4) + + hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4 ### +class RSU4(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1) + self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx = self.pool1(hx1) + + hx2 = self.rebnconv2(hx) + hx = self.pool2(hx2) + + hx3 = self.rebnconv3(hx) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1)) + + return hx1d + hxin + + +### RSU-4F ### +class RSU4F(nn.Module): + def __init__(self, in_ch=3, mid_ch=12, out_ch=3): + super(RSU4F, self).__init__() + + self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) + + self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1) + self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2) + self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4) + + self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8) + + self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4) + self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2) + self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1) + + def forward(self, x): + hx = x + + hxin = self.rebnconvin(hx) + + hx1 = self.rebnconv1(hxin) + hx2 = self.rebnconv2(hx1) + hx3 = self.rebnconv3(hx2) + + hx4 = self.rebnconv4(hx3) + + hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1)) + hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1)) + hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1)) + + return hx1d + hxin + + +class myrebnconv(nn.Module): + def __init__( + self, + in_ch=3, + out_ch=1, + kernel_size=3, + stride=1, + padding=1, + dilation=1, + groups=1, + ): + super(myrebnconv, self).__init__() + + self.conv = nn.Conv2d( + in_ch, + out_ch, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + ) + self.bn = nn.BatchNorm2d(out_ch) + self.rl = nn.ReLU(inplace=True) + + def forward(self, x): + return self.rl(self.bn(self.conv(x))) + + +class ISNetGTEncoder(nn.Module): + def __init__(self, in_ch=1, out_ch=1): + super(ISNetGTEncoder, self).__init__() + + self.conv_in = myrebnconv( + in_ch, 16, 3, stride=2, padding=1 + ) # nn.Conv2d(in_ch,64,3,stride=2,padding=1) + + self.stage1 = RSU7(16, 16, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 16, 64) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(64, 32, 128) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(128, 32, 256) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(256, 64, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 64, 512) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + @staticmethod + def compute_loss(args): + preds, targets = args + return muti_loss_fusion(preds, targets) + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + # hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + + # side output + d1 = self.side1(hx1) + d1 = _upsample_like(d1, x) + + d2 = self.side2(hx2) + d2 = _upsample_like(d2, x) + + d3 = self.side3(hx3) + d3 = _upsample_like(d3, x) + + d4 = self.side4(hx4) + d4 = _upsample_like(d4, x) + + d5 = self.side5(hx5) + d5 = _upsample_like(d5, x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6] + return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6] + + +class ISNetDIS(nn.Module): + def __init__(self, in_ch=3, out_ch=1): + super(ISNetDIS, self).__init__() + + self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1) + self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage1 = RSU7(64, 32, 64) + self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage2 = RSU6(64, 32, 128) + self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage3 = RSU5(128, 64, 256) + self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage4 = RSU4(256, 128, 512) + self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage5 = RSU4F(512, 256, 512) + self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True) + + self.stage6 = RSU4F(512, 256, 512) + + # decoder + self.stage5d = RSU4F(1024, 256, 512) + self.stage4d = RSU4(1024, 128, 256) + self.stage3d = RSU5(512, 64, 128) + self.stage2d = RSU6(256, 32, 64) + self.stage1d = RSU7(128, 16, 64) + + self.side1 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side2 = nn.Conv2d(64, out_ch, 3, padding=1) + self.side3 = nn.Conv2d(128, out_ch, 3, padding=1) + self.side4 = nn.Conv2d(256, out_ch, 3, padding=1) + self.side5 = nn.Conv2d(512, out_ch, 3, padding=1) + self.side6 = nn.Conv2d(512, out_ch, 3, padding=1) + + # self.outconv = nn.Conv2d(6*out_ch,out_ch,1) + + @staticmethod + def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"): + return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode) + + @staticmethod + def compute_loss(args): + if len(args) == 3: + ds, dfs, labels = args + return muti_loss_fusion(ds, labels) + else: + ds, dfs, labels, fs = args + return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE") + + def forward(self, x): + hx = x + + hxin = self.conv_in(hx) + hx = self.pool_in(hxin) + + # stage 1 + hx1 = self.stage1(hxin) + hx = self.pool12(hx1) + + # stage 2 + hx2 = self.stage2(hx) + hx = self.pool23(hx2) + + # stage 3 + hx3 = self.stage3(hx) + hx = self.pool34(hx3) + + # stage 4 + hx4 = self.stage4(hx) + hx = self.pool45(hx4) + + # stage 5 + hx5 = self.stage5(hx) + hx = self.pool56(hx5) + + # stage 6 + hx6 = self.stage6(hx) + hx6up = _upsample_like(hx6, hx5) + + # -------------------- decoder -------------------- + hx5d = self.stage5d(torch.cat((hx6up, hx5), 1)) + hx5dup = _upsample_like(hx5d, hx4) + + hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1)) + hx4dup = _upsample_like(hx4d, hx3) + + hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1)) + hx3dup = _upsample_like(hx3d, hx2) + + hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1)) + hx2dup = _upsample_like(hx2d, hx1) + + hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1)) + + # side output + d1 = self.side1(hx1d) + d1 = _upsample_like(d1, x) + + d2 = self.side2(hx2d) + d2 = _upsample_like(d2, x) + + d3 = self.side3(hx3d) + d3 = _upsample_like(d3, x) + + d4 = self.side4(hx4d) + d4 = _upsample_like(d4, x) + + d5 = self.side5(hx5d) + d5 = _upsample_like(d5, x) + + d6 = self.side6(hx6) + d6 = _upsample_like(d6, x) + + # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1)) + + # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] + return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6] \ No newline at end of file diff --git a/custom_controlnet_aux/anime_face_segment/network.py b/custom_controlnet_aux/anime_face_segment/network.py new file mode 100644 index 0000000000000000000000000000000000000000..1c5929546c82181fe47f604048d9bfe5afb16031 --- /dev/null +++ b/custom_controlnet_aux/anime_face_segment/network.py @@ -0,0 +1,100 @@ +#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from custom_controlnet_aux.util import custom_torch_download + +class UNet(nn.Module): + def __init__(self): + super(UNet, self).__init__() + self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes + + mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False) + mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True) + mob_blocks = mobilenet_v2.features + + # Encoder + self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16 + mob_blocks[0], + mob_blocks[1] + ) + self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24 + mob_blocks[2], + mob_blocks[3], + ) + self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32 + mob_blocks[4], + mob_blocks[5], + mob_blocks[6], + ) + self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96 + mob_blocks[7], + mob_blocks[8], + mob_blocks[9], + mob_blocks[10], + mob_blocks[11], + mob_blocks[12], + mob_blocks[13], + ) + self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160 + mob_blocks[14], + mob_blocks[15], + mob_blocks[16], + ) + + # Decoder + self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(160, 96, kernel_size=3, padding=1), + nn.InstanceNorm2d(96), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(96*2, 32, kernel_size=3, padding=1), + nn.InstanceNorm2d(32), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(32*2, 24, kernel_size=3, padding=1), + nn.InstanceNorm2d(24), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(24*2, 16, kernel_size=3, padding=1), + nn.InstanceNorm2d(16), + nn.LeakyReLU(0.1), + nn.Dropout(p=0.2) + ) + + self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7 + nn.UpsamplingNearest2d(scale_factor=2), + nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1), + nn.Softmax2d() + ) + + def forward(self, x): + e0 = self.en_block0(x) + e1 = self.en_block1(e0) + e2 = self.en_block2(e1) + e3 = self.en_block3(e2) + e4 = self.en_block4(e3) + + d4 = self.de_block4(e4) + c4 = torch.cat((d4,e3),1) + d3 = self.de_block3(c4) + c3 = torch.cat((d3,e2),1) + d2 = self.de_block2(c3) + c2 =torch.cat((d2,e1),1) + d1 = self.de_block1(c2) + c1 = torch.cat((d1,e0),1) + y = self.de_block0(c1) + + return y \ No newline at end of file diff --git a/custom_controlnet_aux/anime_face_segment/util.py b/custom_controlnet_aux/anime_face_segment/util.py new file mode 100644 index 0000000000000000000000000000000000000000..f6f3d22543675f9b098b8bc57d244c9e437d0636 --- /dev/null +++ b/custom_controlnet_aux/anime_face_segment/util.py @@ -0,0 +1,40 @@ +#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py +#The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py +import cv2 as cv +import glob +import numpy as np +import os + +""" +COLOR_BACKGROUND = (0,255,255) +COLOR_HAIR = (255,0,0) +COLOR_EYE = (0,0,255) +COLOR_MOUTH = (255,255,255) +COLOR_FACE = (0,255,0) +COLOR_SKIN = (255,255,0) +COLOR_CLOTHES = (255,0,255) +""" +COLOR_BACKGROUND = (255,255,0) +COLOR_HAIR = (0,0,255) +COLOR_EYE = (255,0,0) +COLOR_MOUTH = (255,255,255) +COLOR_FACE = (0,255,0) +COLOR_SKIN = (0,255,255) +COLOR_CLOTHES = (255,0,255) +PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES] + +def img2seg(path): + src = cv.imread(path) + src = src.reshape(-1, 3) + seg_list = [] + for color in PALETTE: + seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0)) + dst = np.stack(seg_list,axis=1).reshape(512,512,7) + + return dst.astype(np.float32) + +def seg2img(src): + src = np.moveaxis(src,0,2) + dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src] + + return np.array(dst).astype(np.uint8) \ No newline at end of file diff --git a/custom_controlnet_aux/binary/__init__.py b/custom_controlnet_aux/binary/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..34fd63fa2f29d310a5c1e9785e7997b6fdbd8227 --- /dev/null +++ b/custom_controlnet_aux/binary/__init__.py @@ -0,0 +1,38 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import HWC3, resize_image_with_pad + +class BinaryDetector: + def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY) + if bin_threshold == 0 or bin_threshold == 255: + # Otsu's threshold + otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) + print("Otsu threshold:", otsu_threshold) + else: + _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV) + + detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB) + detected_map = HWC3(remove_pad(255 - detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/canny/__init__.py b/custom_controlnet_aux/canny/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f91c6dda3716c950fc55d6c0b6c53c81a32b7df --- /dev/null +++ b/custom_controlnet_aux/canny/__init__.py @@ -0,0 +1,17 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3 + +class CannyDetector: + def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + detected_map = cv2.Canny(detected_map, low_threshold, high_threshold) + detected_map = HWC3(remove_pad(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/color/__init__.py b/custom_controlnet_aux/color/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2f010b6901128f49f926aa0c407a07c7d471e3 --- /dev/null +++ b/custom_controlnet_aux/color/__init__.py @@ -0,0 +1,37 @@ +import cv2 +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate + +def cv2_resize_shortest_edge(image, size): + h, w = image.shape[:2] + if h < w: + new_h = size + new_w = int(round(w / h * size)) + else: + new_w = size + new_h = int(round(h / w * size)) + resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA) + return resized_image + +def apply_color(img, res=512): + img = cv2_resize_shortest_edge(img, res) + h, w = img.shape[:2] + + input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC) + input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST) + return input_img_color + +#Color T2I like multiples-of-64, upscale methods are fixed. +class ColorDetector: + def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image = HWC3(input_image) + detected_map = HWC3(apply_color(input_image, detect_resolution)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/densepose/__init__.py b/custom_controlnet_aux/densepose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6f9448c94a027156f4cc9db965f25beaf09f1fb --- /dev/null +++ b/custom_controlnet_aux/densepose/__init__.py @@ -0,0 +1,66 @@ +import torchvision # Fix issue Unknown builtin op: torchvision::nms +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, DENSEPOSE_MODEL_NAME +from .densepose import DensePoseMaskedColormapResultsVisualizer, _extract_i_from_iuvarr, densepose_chart_predictor_output_to_result_with_confidences + +N_PART_LABELS = 24 + +class DenseposeDetector: + def __init__(self, model): + self.dense_pose_estimation = model + self.device = "cpu" + self.result_visualizer = DensePoseMaskedColormapResultsVisualizer( + alpha=1, + data_extractor=_extract_i_from_iuvarr, + segm_extractor=_extract_i_from_iuvarr, + val_scale = 255.0 / N_PART_LABELS + ) + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=DENSEPOSE_MODEL_NAME, filename="densepose_r50_fpn_dl.torchscript"): + torchscript_model_path = custom_hf_download(pretrained_model_or_path, filename) + densepose = torch.jit.load(torchscript_model_path, map_location="cpu") + return cls(densepose) + + def to(self, device): + self.dense_pose_estimation.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", cmap="viridis", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + H, W = input_image.shape[:2] + + hint_image_canvas = np.zeros([H, W], dtype=np.uint8) + hint_image_canvas = np.tile(hint_image_canvas[:, :, np.newaxis], [1, 1, 3]) + + input_image = rearrange(torch.from_numpy(input_image).to(self.device), 'h w c -> c h w') + + pred_boxes, corase_segm, fine_segm, u, v = self.dense_pose_estimation(input_image) + + extractor = densepose_chart_predictor_output_to_result_with_confidences + densepose_results = [extractor(pred_boxes[i:i+1], corase_segm[i:i+1], fine_segm[i:i+1], u[i:i+1], v[i:i+1]) for i in range(len(pred_boxes))] + + if cmap=="viridis": + self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_VIRIDIS + hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results) + hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB) + hint_image[:, :, 0][hint_image[:, :, 0] == 0] = 68 + hint_image[:, :, 1][hint_image[:, :, 1] == 0] = 1 + hint_image[:, :, 2][hint_image[:, :, 2] == 0] = 84 + else: + self.result_visualizer.mask_visualizer.cmap = cv2.COLORMAP_PARULA + hint_image = self.result_visualizer.visualize(hint_image_canvas, densepose_results) + hint_image = cv2.cvtColor(hint_image, cv2.COLOR_BGR2RGB) + + detected_map = remove_pad(HWC3(hint_image)) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + return detected_map diff --git a/custom_controlnet_aux/densepose/densepose.py b/custom_controlnet_aux/densepose/densepose.py new file mode 100644 index 0000000000000000000000000000000000000000..5e43b05fcd76efd5774485ca35e715e64acefdbe --- /dev/null +++ b/custom_controlnet_aux/densepose/densepose.py @@ -0,0 +1,347 @@ +from typing import Tuple +import math +import numpy as np +from enum import IntEnum +from typing import List, Tuple, Union +import torch +from torch.nn import functional as F +import logging +import cv2 + +Image = np.ndarray +Boxes = torch.Tensor +ImageSizeType = Tuple[int, int] +_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] +IntTupleBox = Tuple[int, int, int, int] + +class BoxMode(IntEnum): + """ + Enum of different ways to represent a box. + """ + + XYXY_ABS = 0 + """ + (x0, y0, x1, y1) in absolute floating points coordinates. + The coordinates in range [0, width or height]. + """ + XYWH_ABS = 1 + """ + (x0, y0, w, h) in absolute floating points coordinates. + """ + XYXY_REL = 2 + """ + Not yet supported! + (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image. + """ + XYWH_REL = 3 + """ + Not yet supported! + (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image. + """ + XYWHA_ABS = 4 + """ + (xc, yc, w, h, a) in absolute floating points coordinates. + (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw. + """ + + @staticmethod + def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType: + """ + Args: + box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5 + from_mode, to_mode (BoxMode) + + Returns: + The converted box of the same type. + """ + if from_mode == to_mode: + return box + + original_type = type(box) + is_numpy = isinstance(box, np.ndarray) + single_box = isinstance(box, (list, tuple)) + if single_box: + assert len(box) == 4 or len(box) == 5, ( + "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor," + " where k == 4 or 5" + ) + arr = torch.tensor(box)[None, :] + else: + # avoid modifying the input box + if is_numpy: + arr = torch.from_numpy(np.asarray(box)).clone() + else: + arr = box.clone() + + assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + + if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: + assert ( + arr.shape[-1] == 5 + ), "The last dimension of input shape must be 5 for XYWHA format" + original_dtype = arr.dtype + arr = arr.double() + + w = arr[:, 2] + h = arr[:, 3] + a = arr[:, 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + new_w = c * w + s * h + new_h = c * h + s * w + + # convert center to top-left corner + arr[:, 0] -= new_w / 2.0 + arr[:, 1] -= new_h / 2.0 + # bottom-right corner + arr[:, 2] = arr[:, 0] + new_w + arr[:, 3] = arr[:, 1] + new_h + + arr = arr[:, :4].to(dtype=original_dtype) + elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS: + original_dtype = arr.dtype + arr = arr.double() + arr[:, 0] += arr[:, 2] / 2.0 + arr[:, 1] += arr[:, 3] / 2.0 + angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype) + arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) + else: + if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS: + arr[:, 2] += arr[:, 0] + arr[:, 3] += arr[:, 1] + elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS: + arr[:, 2] -= arr[:, 0] + arr[:, 3] -= arr[:, 1] + else: + raise NotImplementedError( + "Conversion from BoxMode {} to {} is not supported yet".format( + from_mode, to_mode + ) + ) + + if single_box: + return original_type(arr.flatten().tolist()) + if is_numpy: + return arr.numpy() + else: + return arr + +class MatrixVisualizer: + """ + Base visualizer for matrix data + """ + + def __init__( + self, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + val_scale=1.0, + alpha=0.7, + interp_method_matrix=cv2.INTER_LINEAR, + interp_method_mask=cv2.INTER_NEAREST, + ): + self.inplace = inplace + self.cmap = cmap + self.val_scale = val_scale + self.alpha = alpha + self.interp_method_matrix = interp_method_matrix + self.interp_method_mask = interp_method_mask + + def visualize(self, image_bgr, mask, matrix, bbox_xywh): + self._check_image(image_bgr) + self._check_mask_matrix(mask, matrix) + if self.inplace: + image_target_bgr = image_bgr + else: + image_target_bgr = image_bgr * 0 + x, y, w, h = [int(v) for v in bbox_xywh] + if w <= 0 or h <= 0: + return image_bgr + mask, matrix = self._resize(mask, matrix, w, h) + mask_bg = np.tile((mask == 0)[:, :, np.newaxis], [1, 1, 3]) + matrix_scaled = matrix.astype(np.float32) * self.val_scale + _EPSILON = 1e-6 + if np.any(matrix_scaled > 255 + _EPSILON): + logger = logging.getLogger(__name__) + logger.warning( + f"Matrix has values > {255 + _EPSILON} after " f"scaling, clipping to [0..255]" + ) + matrix_scaled_8u = matrix_scaled.clip(0, 255).astype(np.uint8) + matrix_vis = cv2.applyColorMap(matrix_scaled_8u, self.cmap) + matrix_vis[mask_bg] = image_target_bgr[y : y + h, x : x + w, :][mask_bg] + image_target_bgr[y : y + h, x : x + w, :] = ( + image_target_bgr[y : y + h, x : x + w, :] * (1.0 - self.alpha) + matrix_vis * self.alpha + ) + return image_target_bgr.astype(np.uint8) + + def _resize(self, mask, matrix, w, h): + if (w != mask.shape[1]) or (h != mask.shape[0]): + mask = cv2.resize(mask, (w, h), self.interp_method_mask) + if (w != matrix.shape[1]) or (h != matrix.shape[0]): + matrix = cv2.resize(matrix, (w, h), self.interp_method_matrix) + return mask, matrix + + def _check_image(self, image_rgb): + assert len(image_rgb.shape) == 3 + assert image_rgb.shape[2] == 3 + assert image_rgb.dtype == np.uint8 + + def _check_mask_matrix(self, mask, matrix): + assert len(matrix.shape) == 2 + assert len(mask.shape) == 2 + assert mask.dtype == np.uint8 + +class DensePoseResultsVisualizer: + def visualize( + self, + image_bgr: Image, + results, + ) -> Image: + context = self.create_visualization_context(image_bgr) + for i, result in enumerate(results): + boxes_xywh, labels, uv = result + iuv_array = torch.cat( + (labels[None].type(torch.float32), uv * 255.0) + ).type(torch.uint8) + self.visualize_iuv_arr(context, iuv_array.cpu().numpy(), boxes_xywh) + image_bgr = self.context_to_image_bgr(context) + return image_bgr + + def create_visualization_context(self, image_bgr: Image): + return image_bgr + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None: + pass + + def context_to_image_bgr(self, context): + return context + + def get_image_bgr_from_context(self, context): + return context + +class DensePoseMaskedColormapResultsVisualizer(DensePoseResultsVisualizer): + def __init__( + self, + data_extractor, + segm_extractor, + inplace=True, + cmap=cv2.COLORMAP_PARULA, + alpha=0.7, + val_scale=1.0, + **kwargs, + ): + self.mask_visualizer = MatrixVisualizer( + inplace=inplace, cmap=cmap, val_scale=val_scale, alpha=alpha + ) + self.data_extractor = data_extractor + self.segm_extractor = segm_extractor + + def context_to_image_bgr(self, context): + return context + + def visualize_iuv_arr(self, context, iuv_arr: np.ndarray, bbox_xywh) -> None: + image_bgr = self.get_image_bgr_from_context(context) + matrix = self.data_extractor(iuv_arr) + segm = self.segm_extractor(iuv_arr) + mask = np.zeros(matrix.shape, dtype=np.uint8) + mask[segm > 0] = 1 + image_bgr = self.mask_visualizer.visualize(image_bgr, mask, matrix, bbox_xywh) + + +def _extract_i_from_iuvarr(iuv_arr): + return iuv_arr[0, :, :] + + +def _extract_u_from_iuvarr(iuv_arr): + return iuv_arr[1, :, :] + + +def _extract_v_from_iuvarr(iuv_arr): + return iuv_arr[2, :, :] + +def make_int_box(box: torch.Tensor) -> IntTupleBox: + int_box = [0, 0, 0, 0] + int_box[0], int_box[1], int_box[2], int_box[3] = tuple(box.long().tolist()) + return int_box[0], int_box[1], int_box[2], int_box[3] + +def densepose_chart_predictor_output_to_result_with_confidences( + boxes: Boxes, + coarse_segm, + fine_segm, + u, v + +): + boxes_xyxy_abs = boxes.clone() + boxes_xywh_abs = BoxMode.convert(boxes_xyxy_abs, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + box_xywh = make_int_box(boxes_xywh_abs[0]) + + labels = resample_fine_and_coarse_segm_tensors_to_bbox(fine_segm, coarse_segm, box_xywh).squeeze(0) + uv = resample_uv_tensors_to_bbox(u, v, labels, box_xywh) + confidences = [] + return box_xywh, labels, uv + +def resample_fine_and_coarse_segm_tensors_to_bbox( + fine_segm: torch.Tensor, coarse_segm: torch.Tensor, box_xywh_abs: IntTupleBox +): + """ + Resample fine and coarse segmentation tensors to the given + bounding box and derive labels for each pixel of the bounding box + + Args: + fine_segm: float tensor of shape [1, C, Hout, Wout] + coarse_segm: float tensor of shape [1, K, Hout, Wout] + box_xywh_abs (tuple of 4 int): bounding box given by its upper-left + corner coordinates, width (W) and height (H) + Return: + Labels for each pixel of the bounding box, a long tensor of size [1, H, W] + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + # coarse segmentation + coarse_segm_bbox = F.interpolate( + coarse_segm, + (h, w), + mode="bilinear", + align_corners=False, + ).argmax(dim=1) + # combined coarse and fine segmentation + labels = ( + F.interpolate(fine_segm, (h, w), mode="bilinear", align_corners=False).argmax(dim=1) + * (coarse_segm_bbox > 0).long() + ) + return labels + +def resample_uv_tensors_to_bbox( + u: torch.Tensor, + v: torch.Tensor, + labels: torch.Tensor, + box_xywh_abs: IntTupleBox, +) -> torch.Tensor: + """ + Resamples U and V coordinate estimates for the given bounding box + + Args: + u (tensor [1, C, H, W] of float): U coordinates + v (tensor [1, C, H, W] of float): V coordinates + labels (tensor [H, W] of long): labels obtained by resampling segmentation + outputs for the given bounding box + box_xywh_abs (tuple of 4 int): bounding box that corresponds to predictor outputs + Return: + Resampled U and V coordinates - a tensor [2, H, W] of float + """ + x, y, w, h = box_xywh_abs + w = max(int(w), 1) + h = max(int(h), 1) + u_bbox = F.interpolate(u, (h, w), mode="bilinear", align_corners=False) + v_bbox = F.interpolate(v, (h, w), mode="bilinear", align_corners=False) + uv = torch.zeros([2, h, w], dtype=torch.float32, device=u.device) + for part_id in range(1, u_bbox.size(1)): + uv[0][labels == part_id] = u_bbox[0, part_id][labels == part_id] + uv[1][labels == part_id] = v_bbox[0, part_id][labels == part_id] + return uv + diff --git a/custom_controlnet_aux/depth_anything/__init__.py b/custom_controlnet_aux/depth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fc45e6f26e5282656dd86b51b29837770fbeaa03 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/__init__.py @@ -0,0 +1,71 @@ +import numpy as np +import torch +from einops import repeat +from PIL import Image +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DEPTH_ANYTHING_MODEL_NAME +from custom_controlnet_aux.depth_anything.depth_anything.dpt import DPT_DINOv2 +from custom_controlnet_aux.depth_anything.depth_anything.util.transform import Resize, NormalizeImage, PrepareForNet +from torchvision.transforms import Compose +import cv2 +import torch.nn.functional as F + +transform = Compose([ + Resize( + width=518, + height=518, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), +]) + +#https://huggingface.co/LiheYoung/depth_anything_vitl14/raw/main/config.json +DPT_CONFIGS = { + "depth_anything_vitl14.pth": {"encoder": "vitl", "features": 256, "out_channels": [256, 512, 1024, 1024], "use_bn": False, "use_clstoken": False}, + "depth_anything_vitb14.pth": {"encoder": "vitb", "features": 128, "out_channels": [96, 192, 384, 768], "use_bn": False, "use_clstoken": False}, + "depth_anything_vits14.pth": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384], "use_bn": False, "use_clstoken": False} +} + +class DepthAnythingDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=DEPTH_ANYTHING_MODEL_NAME, filename="depth_anything_vitl14.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="checkpoints", repo_type="space") + model = DPT_DINOv2(**DPT_CONFIGS[filename], localhub=True) + model.load_state_dict(torch.load(model_path, map_location="cpu")) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + t, remove_pad = resize_image_with_pad(np.zeros_like(input_image), detect_resolution, upscale_method) + t = remove_pad(t) + + h, w = t.shape[:2] + h, w = int(h), int(w) + image = transform({'image': input_image / 255.})['image'] + image = torch.from_numpy(image).unsqueeze(0).to(self.device) + + with torch.no_grad(): + depth = self.model(image) + depth = F.interpolate(depth[None], (h, w), mode='bilinear', align_corners=False)[0, 0] + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + + detected_map = repeat(depth, "h w -> h w 3").cpu().numpy().astype(np.uint8) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/depth_anything/blocks.py b/custom_controlnet_aux/depth_anything/depth_anything/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..38dbcfeffc0c38ef51bcb20dfd347e50b2a60616 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/depth_anything/blocks.py @@ -0,0 +1,153 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output diff --git a/custom_controlnet_aux/depth_anything/depth_anything/dpt.py b/custom_controlnet_aux/depth_anything/depth_anything/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..b21502675f1f13c92f04f2280e4d30287b6d4060 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/depth_anything/dpt.py @@ -0,0 +1,171 @@ +import torch +import torch.nn as nn + +from .blocks import FeatureFusionBlock, _make_scratch +import torch.nn.functional as F +from custom_controlnet_aux.util import TORCHHUB_PATH + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__(self, nclass, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): + super(DPTHead, self).__init__() + + self.nclass = nclass + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + if nclass > 1: + self.scratch.output_conv = nn.Sequential( + nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_1, nclass, kernel_size=1, stride=1, padding=0), + ) + else: + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DPT_DINOv2(nn.Module): + def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024], use_bn=False, use_clstoken=False, localhub=True): + super(DPT_DINOv2, self).__init__() + + assert encoder in ['vits', 'vitb', 'vitl'] + + # in case the Internet connection is not stable, please load the DINOv2 locally + if localhub: + self.pretrained = torch.hub.load(TORCHHUB_PATH / 'facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) + else: + self.pretrained = torch.hub.load('facebookresearch/dinov2', 'dinov2_{:}14'.format(encoder), ) + + dim = self.pretrained.blocks[0].attn.qkv.in_features + + self.depth_head = DPTHead(1, dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + + patch_h, patch_w = h // 14, w // 14 + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + depth = F.relu(depth) + + return depth.squeeze(1) + + +if __name__ == '__main__': + depth_anything = DPT_DINOv2() + depth_anything.load_state_dict(torch.load('checkpoints/depth_anything_dinov2_vitl14.pth')) + \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/depth_anything/util/transform.py b/custom_controlnet_aux/depth_anything/depth_anything/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..7beab14a9a1e18a8e46e7666fe5bdec223074155 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/depth_anything/util/transform.py @@ -0,0 +1,248 @@ +import random +from PIL import Image, ImageOps, ImageFilter +import torch +from torchvision import transforms +import torch.nn.functional as F + +import numpy as np +import cv2 +import math + + +def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): + """Rezise the sample to ensure the given size. Keeps aspect ratio. + + Args: + sample (dict): sample + size (tuple): image size + + Returns: + tuple: new size + """ + shape = list(sample["disparity"].shape) + + if shape[0] >= size[0] and shape[1] >= size[1]: + return sample + + scale = [0, 0] + scale[0] = size[0] / shape[0] + scale[1] = size[1] / shape[1] + + scale = max(scale) + + shape[0] = math.ceil(scale * shape[0]) + shape[1] = math.ceil(scale * shape[1]) + + # resize + sample["image"] = cv2.resize( + sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method + ) + + sample["disparity"] = cv2.resize( + sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST + ) + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + tuple(shape[::-1]), + interpolation=cv2.INTER_NEAREST, + ) + sample["mask"] = sample["mask"].astype(bool) + + return tuple(shape) + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size( + sample["image"].shape[1], sample["image"].shape[0] + ) + + # resize sample + sample["image"] = cv2.resize( + sample["image"], + (width, height), + interpolation=self.__image_interpolation_method, + ) + + if self.__resize_target: + if "disparity" in sample: + sample["disparity"] = cv2.resize( + sample["disparity"], + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + + if "depth" in sample: + sample["depth"] = cv2.resize( + sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST + ) + + if "semseg_mask" in sample: + # sample["semseg_mask"] = cv2.resize( + # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST + # ) + sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0] + + if "mask" in sample: + sample["mask"] = cv2.resize( + sample["mask"].astype(np.float32), + (width, height), + interpolation=cv2.INTER_NEAREST, + ) + # sample["mask"] = sample["mask"].astype(bool) + + # print(sample['image'].shape, sample['depth'].shape) + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "semseg_mask" in sample: + sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32) + sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"]) + + return sample diff --git a/custom_controlnet_aux/depth_anything/torchhub/README.md b/custom_controlnet_aux/depth_anything/torchhub/README.md new file mode 100644 index 0000000000000000000000000000000000000000..eade757c3b0a25c350ba6bf3b5d2e6f048ad1de6 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/README.md @@ -0,0 +1,3 @@ +# Local PyTorch Hub + +This directory is for loading the DINOv2 encoder locally in case of no Internet connection. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000000000000000000000000000000000..3232ed665566ec047ce55a929db1581dbda266a1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CODE_OF_CONDUCT.md @@ -0,0 +1,80 @@ +# Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to make participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behavior that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behavior by participants include: + +* The use of sexualized language or imagery and unwelcome sexual attention or +advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic +address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a +professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behavior and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behavior. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviors that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies within all project spaces, and it also applies when +an individual is representing the project or its community in public spaces. +Examples of representing a project or community include using an official +project e-mail address, posting via an official social media account, or acting +as an appointed representative at an online or offline event. Representation of +a project may be further defined and clarified by project maintainers. + +This Code of Conduct also applies outside the project spaces when there is a +reasonable belief that an individual's behavior may have a negative impact on +the project or its community. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported by contacting the project team at . All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md new file mode 100644 index 0000000000000000000000000000000000000000..afc89823fc90b920f0758f50e4d808df6a884a34 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/CONTRIBUTING.md @@ -0,0 +1,31 @@ +# Contributing to DINOv2 +We want to make contributing to this project as easy and transparent as +possible. + +## Pull Requests +We actively welcome your pull requests. + +1. Fork the repo and create your branch from `main`. +2. If you've added code that should be tested, add tests. +3. If you've changed APIs, update the documentation. +4. Ensure the test suite passes. +5. Make sure your code lints. +6. If you haven't already, complete the Contributor License Agreement ("CLA"). + +## Contributor License Agreement ("CLA") +In order to accept your pull request, we need you to submit a CLA. You only need +to do this once to work on any of Meta's open source projects. + +Complete your CLA here: + +## Issues +We use GitHub issues to track public bugs. Please ensure your description is +clear and has sufficient instructions to be able to reproduce the issue. + +Meta has a [bounty program](https://www.facebook.com/whitehat/) for the safe +disclosure of security bugs. In those cases, please go through the process +outlined on that page and do not file a public issue. + +## License +By contributing to DINOv2, you agree that your contributions will be licensed +under the LICENSE file in the root directory of this source tree. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/LICENSE b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a115f899f8d09ef3b1def4a16c7bae1a0bd50fbe --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/LICENSE @@ -0,0 +1,400 @@ + +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/MODEL_CARD.md b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/MODEL_CARD.md new file mode 100644 index 0000000000000000000000000000000000000000..5cd35748eb3c5d8f607f83ff068367a0102117c5 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/MODEL_CARD.md @@ -0,0 +1,201 @@ +# Model Card for DINOv2-S/B/L/g + +These are Vision Transformer models trained following the method described in the paper: +"DINOv2: Learning Robust Visual Features without Supervision" + +We provide 4 models: 1 ViT-g trained from scratch, and 3 ViT-S/B/L models distilled from the ViT-g. + +## Model Details +The model takes an image as input and returns a class token and patch tokens. + +The embedding dimension is: +- 384 for ViT-S. +- 768 for ViT-B. +- 1024 for ViT-L. +- 1536 for ViT-g. + +The models follow a Transformer architecture, with a patch size of 14. + +For a 224x224 image, this results in 1 class token + 256 patch tokens. + +The models can accept larger images provided the image shapes are multiples of the patch size (14). +If this condition is not verified, the model will crop to the closest smaller multiple of the patch size. + +### Model Description + +- **Developed by:** Meta AI +- **Model type:** Vision Transformer +- **License:** CC-BY-NC + +- **Repository:** https://github.com/facebookresearch/dinov2 +- **Paper:** https://arxiv.org/abs/2304.07193 +- **Demo:** https://dinov2.metademolab.com/ + +## Uses + +The models are vision backbones providing multi-purpose features for downstream tasks. + +### Direct Use + +The models can be used without fine-tuning, with downstream classifiers as simple as linear layers, to obtain competitive results: +- on depth estimation, semantic segmentation, using linear layers. +- on image classification, using k-NN classifiers on the class token. +- on image classification, with logistic regression classifiers applied on the class token. +- on image classification, with a linear layer applied on the class token and the average of the patch tokens. +- on image retrieval using nearest neighbors. + +### Downstream Use + +It is technically possible to perform fine-tuning on the models, for small gains (we measured +2% on ImageNet-1k classification). +We recommend keeping this as a very last step and only when necessary, as the features already provide good performance out-of-the-box. + +## Bias, Risks, and Limitations + +Despite improvements thanks to the training method not using annotations, we still observe significant biases in our models toward rich households from Western countries. + +### Recommendations + +We expect fine-tuning will increase the biases in the features produced by the model as they will be tuned to the fine-tuning labels. + +## How to Get Started with the Model + +Use the code below to get started with the model. + +```python +import torch +dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') +dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') +dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') +dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') +``` + +## Training Details + +### Training Data + +- **Training data:** LVD-142M (see paper) +- **Training regime:** fp16 using PyTorch-FSDP mixed-precision. + +### Training Procedure + +- **Training objective:** + - DINO self-distillation loss with multi-crop + - iBOT masked-image modeling loss + - KoLeo regularization on [CLS] tokens +- **Architectures:** + - ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN + - ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN + - ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN + - ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN +- **Distillation:** + - Distillation follows the standard DINOv2 pretraining procedure, except the teacher is a pretrained ViT-g, frozen. + +## Evaluation + +We refer users to the associated paper for the evaluation protocols. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelImageNet-1kNYU-Depth v2SUN-RGBDADE20kiNaturalist 2018Oxford-H
taskclassif. (acc)classif. (acc)classif. V2 (acc)depth (RMSE)depth (RMSE)segm. (mAP)classif. (acc)retrieval (mAP)
k-NNlinearlinearlinear
4 layers
NYU-D transfermultiscalelinearnearest neighbor
ViT-S/1479.0%81.1%70.8%0.4170.43147.269.5%43.2
ViT-B/1482.1%84.5%74.9%0.3620.40051.376.3%49.5
ViT-L/1483.5%86.3%77.6%0.3330.39653.179.8%54.0
ViT-g/1483.5%86.5%78.4%0.2980.36253.081.6%52.3
+ +## Environmental Impact + +- **Hardware Type:** Nvidia A100 +- **Hours used:** 22,000 for ViT-g, 4,500 for ViT-S distillation, 5,300 for ViT-B distillation, 8,000 for ViT-L distillation +- **Cloud Provider:** Private infra +- **Compute Region:** USA +- **Carbon Emitted:** 7t CO2eq + +#### Hardware + +Nvidia A100 GPUs + +#### Software + +PyTorch 2.0, +xFormers 0.0.18 + +**BibTeX** + +``` +@misc{oquab2023dinov2, + title={DINOv2: Learning Robust Visual Features without Supervision}, + author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr}, + journal={arXiv:2304.07193}, + year={2023} +} +``` diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/README.md b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/README.md new file mode 100644 index 0000000000000000000000000000000000000000..1ae528844e32695153153c043fed3a46675c9086 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/README.md @@ -0,0 +1,277 @@ +# DINOv2: Learning Robust Visual Features without Supervision + +**[Meta AI Research, FAIR](https://ai.facebook.com/research/)** + +Maxime Oquab, +Timothée Darcet, +Théo Moutakanni, +Huy V. Vo, +Marc Szafraniec, +Vasil Khalidov, +Patrick Labatut, +Armand Joulin, +Piotr Bojanowski + +[[`Paper`](https://arxiv.org/abs/2304.07193)] [[`Blog`](https://ai.facebook.com/blog/dino-v2-computer-vision-self-supervised-learning/)] [[`Demo`](https://dinov2.metademolab.com)] [[`BibTeX`](#citing-dinov2)] + +PyTorch implementation and pretrained models for DINOv2. For details, see the paper: **[DINOv2: Learning Robust Visual Features without Supervision](https://arxiv.org/abs/2304.07193)**. + +DINOv2 models produce high-performance visual features that can be directly employed with classifiers as simple as linear layers on a variety of computer vision tasks; these visual features are robust and perform well across domains without any requirement for fine-tuning. The models were pretrained on a dataset of 142 M images without using any labels or annotations. + +https://github.com/facebookresearch/dinov2/assets/60359573/f168823e-7922-415a-b429-578badf5c356 + +
+ Visualization of the three first principal components of the patch features of all frames, mapped to RGB values. +
+ +## Pretrained models + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
model# of
params
ImageNet
k-NN
ImageNet
linear
download
ViT-S/14 distilled21 M79.0%81.1%backbone only
ViT-B/14 distilled86 M82.1%84.5%backbone only
ViT-L/14 distilled300 M83.5%86.3%backbone only
ViT-g/141,100 M83.5%86.5%backbone only
+ +### Pretrained models via PyTorch Hub + +Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install PyTorch (the only required dependency for loading the model). Installing PyTorch with CUDA support is strongly recommended. + +A corresponding [model card](MODEL_CARD.md) is included in the repository. + +```python +import torch + +dinov2_vits14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14') +dinov2_vitb14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14') +dinov2_vitl14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14') +dinov2_vitg14 = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitg14') +``` + +## Installation + +The training and evaluation code requires PyTorch 2.0 and [xFormers](https://github.com/facebookresearch/xformers) 0.0.18 as well as a number of other 3rd party packages. Note that the code has only been tested with the specified versions and also expects a Linux environment. To setup all the required dependencies for training and evaluation, please follow the instructions below: + +*[conda](https://docs.conda.io/projects/conda/en/latest/user-guide/getting-started.html)* **(Recommended)** - Clone the repository and then create and activate a `dinov2` conda environment using the provided environment definition: + +```shell +conda env create -f conda.yaml +conda activate dinov2 +``` + +*[pip](https://pip.pypa.io/en/stable/getting-started/)* - Clone the repository and then use the provided `requirements.txt` to install the dependencies: + +```shell +pip install -r requirements.txt +``` + +## Data preparation + +### ImageNet-1k + +The root directory of the dataset should hold the following contents: + +- `/test/ILSVRC2012_test_00000001.JPEG` +- `/test/[..]` +- `/test/ILSVRC2012_test_00100000.JPEG` +- `/train/n01440764/n01440764_10026.JPEG` +- `/train/[...]` +- `/train/n15075141/n15075141_9993.JPEG` +- `/val/n01440764/ILSVRC2012_val_00000293.JPEG` +- `/val/[...]` +- `/val/n15075141/ILSVRC2012_val_00049174.JPEG` +- `/labels.txt` + +The provided dataset implementation expects a few additional metadata files to be present under the extra directory: + +- `/class-ids-TRAIN.npy` +- `/class-ids-VAL.npy` +- `/class-names-TRAIN.npy` +- `/class-names-VAL.npy` +- `/entries-TEST.npy` +- `/entries-TRAIN.npy` +- `/entries-VAL.npy` + +These metadata files can be generated (once) with the following lines of Python code: + +```python +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data.datasets import ImageNet + +for split in ImageNet.Split: + dataset = ImageNet(split=split, root="", extra="") + dataset.dump_extra() +``` + +Note that the root and extra directories do not have to be distinct directories. + +### ImageNet-22k + +Please adapt the [dataset class](dinov2/data/datasets/image_net_22k.py) to match your local setup. + +
+ +:warning: To execute the commands provided in the next sections for training and evaluation, the `dinov2` package should be included in the Python module search path, i.e. simply prefix the command to run with `PYTHONPATH=.`. + +## Training + +### Fast setup: training DINOv2 ViT-L/16 on ImageNet-1k + +Run DINOv2 training on 4 A100-80GB nodes (32 GPUs) in a SLURM cluster environment with submitit: + +```shell +python dinov2/run/train/train.py \ + --nodes 4 \ + --config-file dinov2/configs/train/vitl16_short.yaml \ + --output-dir \ + train.dataset_path=ImageNet:split=TRAIN:root=:extra= +``` + +Training time is approximately 1 day and the resulting checkpoint should reach 81.6% on k-NN eval and 82.9% on linear eval. + +The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation. + +### Long setup: training DINOv2 ViT-L/14 on ImageNet-22k + +Run DINOv2 training on 12 A100-80GB nodes (96 GPUs) in a SLURM cluster environment with submitit: + +```shell +python dinov2/run/train/train.py \ + --nodes 12 \ + --config-file dinov2/configs/train/vitl14.yaml \ + --output-dir \ + train.dataset_path=ImageNet22k:root=:extra= +``` + +Training time is approximately 3.3 days and the resulting checkpoint should reach 82.0% on k-NN eval and 84.5% on linear eval. + +The training code saves the weights of the teacher in the `eval` folder every 12500 iterations for evaluation. + + +## Evaluation + +The training code regularly saves the teacher weights. In order to evaluate the model, run the following evaluation on a single node: + +### k-NN classification on ImageNet-1k + +```shell +python dinov2/run/eval/knn.py \ + --config-file /config.yaml \ + --pretrained-weights /eval/training_24999/teacher_checkpoint.pth \ + --output-dir /eval/training_24999/knn \ + --train-dataset ImageNet:split=TRAIN:root=:extra= \ + --val-dataset ImageNet:split=VAL:root=:extra= +``` + +### Logistic regression classification on ImageNet-1k + +```shell +python dinov2/run/eval/log_regression.py \ + --config-file /config.yaml \ + --pretrained-weights /eval/training_24999/teacher_checkpoint.pth \ + --output-dir /eval/training_24999/logreg \ + --train-dataset ImageNet:split=TRAIN:root=:extra= \ + --val-dataset ImageNet:split=VAL:root=:extra= +``` + +### Linear classification with data augmentation on ImageNet-1k + +```shell +python dinov2/run/eval/linear.py \ + --config-file /config.yaml \ + --pretrained-weights /eval/training_24999/teacher_checkpoint.pth \ + --output-dir /eval/training_24999/linear \ + --train-dataset ImageNet:split=TRAIN:root=:extra= \ + --val-dataset ImageNet:split=VAL:root=:extra= +``` + +We release the weights from evaluating the different models: + + + + + + + + + + + + + + + + + + + + + + + + + + + +
modelImageNet
top-1
linear evaluation
ViT-S/14 distilled81.1%linear head weights
ViT-B/14 distilled84.5%linear head weights
ViT-L/14 distilled86.3%linear head weights
ViT-g/1486.5%linear head weights
+ +The performance of the provided pretrained model weights can be evaluated as follows on ImageNet-1k: + +```shell +python dinov2/run/eval/linear.py \ + --config-file dinov2/configs/eval/vitg14_pretrain.yaml \ + --pretrained-weights https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_pretrain.pth \ + --train-dataset ImageNet:split=TRAIN:root=:extra= \ + --val-dataset ImageNet:split=VAL:root=:extra= +``` + +## License + +DINOv2 code and model weights are released under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for additional details. + +## Contributing + +See [contributing](CONTRIBUTING.md) and the [code of conduct](CODE_OF_CONDUCT.md). + +## Citing DINOv2 + +If you find this repository useful, please consider giving a star :star: and citation :t-rex:: + +``` +@misc{oquab2023dinov2, + title={DINOv2: Learning Robust Visual Features without Supervision}, + author={Oquab, Maxime and Darcet, Timothée and Moutakanni, Theo and Vo, Huy V. and Szafraniec, Marc and Khalidov, Vasil and Fernandez, Pierre and Haziza, Daniel and Massa, Francisco and El-Nouby, Alaaeldin and Howes, Russell and Huang, Po-Yao and Xu, Hu and Sharma, Vasu and Li, Shang-Wen and Galuba, Wojciech and Rabbat, Mike and Assran, Mido and Ballas, Nicolas and Synnaeve, Gabriel and Misra, Ishan and Jegou, Herve and Mairal, Julien and Labatut, Patrick and Joulin, Armand and Bojanowski, Piotr}, + journal={arXiv:2304.07193}, + year={2023} +} +``` diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/conda.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/conda.yaml new file mode 100644 index 0000000000000000000000000000000000000000..35dfc30adc275da51b58ff2340dd1d53d2cb9250 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/conda.yaml @@ -0,0 +1,22 @@ +name: dinov2 +channels: + - defaults + - pytorch + - nvidia + - xformers + - conda-forge +dependencies: + - python=3.9 + - pytorch::pytorch=2.0.0 + - pytorch::pytorch-cuda=11.7.0 + - pytorch::torchvision=0.15.0 + - omegaconf + - torchmetrics=0.10.3 + - fvcore + - iopath + - xformers::xformers=0.0.18 + - pip + - pip: + - git+https://github.com/facebookincubator/submitit + - --extra-index-url https://pypi.nvidia.com + - cuml-cu11 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b4afb514783786adf76744f9b97f3e1db1d6081 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..033c35660450afec6612adb342c7c30e1ccd15ee --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import pathlib + +from omegaconf import OmegaConf + + +def load_config(config_name: str): + config_filename = config_name + ".yaml" + return OmegaConf.load(pathlib.Path(__file__).parent.resolve() / config_filename) + + +dinov2_default_config = load_config("ssl_default_config") + + +def load_and_merge_config(config_name: str): + default_config = OmegaConf.create(dinov2_default_config) + loaded_config = load_config(config_name) + return OmegaConf.merge(default_config, loaded_config) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..117d0f027ca26cd8ce6c010bb78d5a8fac42c70e --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitb14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_base + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a96dd5b117b4d59ee210b65037821f1b3e3f16e3 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitg14_pretrain.yaml @@ -0,0 +1,7 @@ +student: + arch: vit_giant2 + patch_size: 14 + ffn_layer: swiglufused +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitl14_pretrain.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitl14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..7a984548bd034f762d455419d7193917fa462dd8 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vitl14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_large + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vits14_pretrain.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vits14_pretrain.yaml new file mode 100644 index 0000000000000000000000000000000000000000..afbdb4ba14f1c97130a25b579360f4d817cda495 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/eval/vits14_pretrain.yaml @@ -0,0 +1,6 @@ +student: + arch: vit_small + patch_size: 14 +crops: + global_crops_size: 518 # this is to set up the position embeddings properly + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/ssl_default_config.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/ssl_default_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a4ef04545ce9d6cc52b5179236008adc8a9bbda2 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/ssl_default_config.yaml @@ -0,0 +1,115 @@ +MODEL: + WEIGHTS: '' +compute_precision: + grad_scaler: true + teacher: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + student: + backbone: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp16 + buffer_dtype: fp32 + dino_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 + ibot_head: + sharding_strategy: SHARD_GRAD_OP + mixed_precision: + param_dtype: fp16 + reduce_dtype: fp32 + buffer_dtype: fp32 +dino: + loss_weight: 1.0 + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 + koleo_loss_weight: 0.1 +ibot: + loss_weight: 1.0 + mask_sample_probability: 0.5 + mask_ratio_min_max: + - 0.1 + - 0.5 + separate_head: false + head_n_prototypes: 65536 + head_bottleneck_dim: 256 + head_nlayers: 3 + head_hidden_dim: 2048 +train: + batch_size_per_gpu: 64 + dataset_path: ImageNet:split=TRAIN + output_dir: . + saveckp_freq: 20 + seed: 0 + num_workers: 10 + OFFICIAL_EPOCH_LENGTH: 1250 + cache_dataset: true + centering: "centering" # or "sinkhorn_knopp" +student: + arch: vit_large + patch_size: 16 + drop_path_rate: 0.3 + layerscale: 1.0e-05 + drop_path_uniform: true + pretrained_weights: '' + ffn_layer: "mlp" + block_chunks: 0 + qkv_bias: true + proj_bias: true + ffn_bias: true +teacher: + momentum_teacher: 0.992 + final_momentum_teacher: 1 + warmup_teacher_temp: 0.04 + teacher_temp: 0.07 + warmup_teacher_temp_epochs: 30 +optim: + epochs: 100 + weight_decay: 0.04 + weight_decay_end: 0.4 + base_lr: 0.004 # learning rate for a batch size of 1024 + lr: 0. # will be set after applying scaling rule + warmup_epochs: 10 + min_lr: 1.0e-06 + clip_grad: 3.0 + freeze_last_layer_epochs: 1 + scaling_rule: sqrt_wrt_1024 + patch_embed_lr_mult: 0.2 + layerwise_decay: 0.9 + adamw_beta1: 0.9 + adamw_beta2: 0.999 +crops: + global_crops_scale: + - 0.32 + - 1.0 + local_crops_number: 8 + local_crops_scale: + - 0.05 + - 0.32 + global_crops_size: 224 + local_crops_size: 96 +evaluation: + eval_period_iterations: 12500 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitg14.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitg14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d05cf0d59e07ac6e4a2b0f9bdcb6131d7c508962 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitg14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 12 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_giant2 + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl14.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl14.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9b491dcc6a522c71328fc2933dd0501123c8f6b --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl14.yaml @@ -0,0 +1,26 @@ +dino: + head_n_prototypes: 131072 + head_bottleneck_dim: 384 +ibot: + separate_head: true + head_n_prototypes: 131072 +train: + batch_size_per_gpu: 32 + dataset_path: ImageNet22k + centering: sinkhorn_knopp +student: + arch: vit_large + patch_size: 14 + drop_path_rate: 0.4 + ffn_layer: swiglufused + block_chunks: 4 +teacher: + momentum_teacher: 0.994 +optim: + epochs: 500 + weight_decay_end: 0.2 + base_lr: 2.0e-04 # learning rate for a batch size of 1024 + warmup_epochs: 80 + layerwise_decay: 1.0 +crops: + local_crops_size: 98 \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl16_short.yaml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl16_short.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e7e72864c92175a1354142ac1d64da8070d1e5e --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/configs/train/vitl16_short.yaml @@ -0,0 +1,6 @@ +# this corresponds to the default config +train: + dataset_path: ImageNet:split=TRAIN + batch_size_per_gpu: 64 +student: + block_chunks: 4 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..357db5c542c5810391ba2bd45a60c13c01c3737a --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .adapters import DatasetWithEnumeratedTargets +from .loaders import make_data_loader, make_dataset, SamplerType +from .collate import collate_data_and_cast +from .masking import MaskingGenerator +from .augmentations import DataAugmentationDINO diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/adapters.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/adapters.py new file mode 100644 index 0000000000000000000000000000000000000000..7dcbc68e046f03460d5867f1388d5380d9c6f603 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/adapters.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torch.utils.data import Dataset + + +class DatasetWithEnumeratedTargets(Dataset): + def __init__(self, dataset): + self._dataset = dataset + + def get_image_data(self, index: int) -> bytes: + return self._dataset.get_image_data(index) + + def get_target(self, index: int) -> Tuple[Any, int]: + target = self._dataset.get_target(index) + return (index, target) + + def __getitem__(self, index: int) -> Tuple[Any, Tuple[Any, int]]: + image, target = self._dataset[index] + target = index if target is None else target + return image, (index, target) + + def __len__(self) -> int: + return len(self._dataset) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/augmentations.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..7ca28cb59a4de2566a6c9ef9c301cbbb4e54b5ee --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/augmentations.py @@ -0,0 +1,119 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from torchvision import transforms + +from .transforms import ( + GaussianBlur, + make_normalize_transform, +) + + +logger = logging.getLogger("dinov2") + + +class DataAugmentationDINO(object): + def __init__( + self, + global_crops_scale, + local_crops_scale, + local_crops_number, + global_crops_size=224, + local_crops_size=96, + ): + self.global_crops_scale = global_crops_scale + self.local_crops_scale = local_crops_scale + self.local_crops_number = local_crops_number + self.global_crops_size = global_crops_size + self.local_crops_size = local_crops_size + + logger.info("###################################") + logger.info("Using data augmentation parameters:") + logger.info(f"global_crops_scale: {global_crops_scale}") + logger.info(f"local_crops_scale: {local_crops_scale}") + logger.info(f"local_crops_number: {local_crops_number}") + logger.info(f"global_crops_size: {global_crops_size}") + logger.info(f"local_crops_size: {local_crops_size}") + logger.info("###################################") + + # random resized crop and flip + self.geometric_augmentation_global = transforms.Compose( + [ + transforms.RandomResizedCrop( + global_crops_size, scale=global_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + self.geometric_augmentation_local = transforms.Compose( + [ + transforms.RandomResizedCrop( + local_crops_size, scale=local_crops_scale, interpolation=transforms.InterpolationMode.BICUBIC + ), + transforms.RandomHorizontalFlip(p=0.5), + ] + ) + + # color distorsions / blurring + color_jittering = transforms.Compose( + [ + transforms.RandomApply( + [transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.1)], + p=0.8, + ), + transforms.RandomGrayscale(p=0.2), + ] + ) + + global_transfo1_extra = GaussianBlur(p=1.0) + + global_transfo2_extra = transforms.Compose( + [ + GaussianBlur(p=0.1), + transforms.RandomSolarize(threshold=128, p=0.2), + ] + ) + + local_transfo_extra = GaussianBlur(p=0.5) + + # normalization + self.normalize = transforms.Compose( + [ + transforms.ToTensor(), + make_normalize_transform(), + ] + ) + + self.global_transfo1 = transforms.Compose([color_jittering, global_transfo1_extra, self.normalize]) + self.global_transfo2 = transforms.Compose([color_jittering, global_transfo2_extra, self.normalize]) + self.local_transfo = transforms.Compose([color_jittering, local_transfo_extra, self.normalize]) + + def __call__(self, image): + output = {} + + # global crops: + im1_base = self.geometric_augmentation_global(image) + global_crop_1 = self.global_transfo1(im1_base) + + im2_base = self.geometric_augmentation_global(image) + global_crop_2 = self.global_transfo2(im2_base) + + output["global_crops"] = [global_crop_1, global_crop_2] + + # global crops for teacher: + output["global_crops_teacher"] = [global_crop_1, global_crop_2] + + # local crops: + local_crops = [ + self.local_transfo(self.geometric_augmentation_local(image)) for _ in range(self.local_crops_number) + ] + output["local_crops"] = local_crops + output["offsets"] = () + + return output diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/collate.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/collate.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0d98906808ed326dff4486d95b3ec04f8a5e75 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/collate.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import random + + +def collate_data_and_cast(samples_list, mask_ratio_tuple, mask_probability, dtype, n_tokens=None, mask_generator=None): + # dtype = torch.half # TODO: Remove + + n_global_crops = len(samples_list[0][0]["global_crops"]) + n_local_crops = len(samples_list[0][0]["local_crops"]) + + collated_global_crops = torch.stack([s[0]["global_crops"][i] for i in range(n_global_crops) for s in samples_list]) + + collated_local_crops = torch.stack([s[0]["local_crops"][i] for i in range(n_local_crops) for s in samples_list]) + + B = len(collated_global_crops) + N = n_tokens + n_samples_masked = int(B * mask_probability) + probs = torch.linspace(*mask_ratio_tuple, n_samples_masked + 1) + upperbound = 0 + masks_list = [] + for i in range(0, n_samples_masked): + prob_min = probs[i] + prob_max = probs[i + 1] + masks_list.append(torch.BoolTensor(mask_generator(int(N * random.uniform(prob_min, prob_max))))) + upperbound += int(N * prob_max) + for i in range(n_samples_masked, B): + masks_list.append(torch.BoolTensor(mask_generator(0))) + + random.shuffle(masks_list) + + collated_masks = torch.stack(masks_list).flatten(1) + mask_indices_list = collated_masks.flatten().nonzero().flatten() + + masks_weight = (1 / collated_masks.sum(-1).clamp(min=1.0)).unsqueeze(-1).expand_as(collated_masks)[collated_masks] + + return { + "collated_global_crops": collated_global_crops.to(dtype), + "collated_local_crops": collated_local_crops.to(dtype), + "collated_masks": collated_masks, + "mask_indices_list": mask_indices_list, + "masks_weight": masks_weight, + "upperbound": upperbound, + "n_masked_patches": torch.full((1,), fill_value=mask_indices_list.shape[0], dtype=torch.long), + } diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7b537aee8fe31d7e0fa06713d2cfe9233ff0ef60 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .image_net import ImageNet +from .image_net_22k import ImageNet22k diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/decoders.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/decoders.py new file mode 100644 index 0000000000000000000000000000000000000000..548720b3b9959b4949f71fb2dd5cf6af3d184066 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/decoders.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from io import BytesIO +from typing import Any + +from PIL import Image + + +class Decoder: + def decode(self) -> Any: + raise NotImplementedError + + +class ImageDataDecoder(Decoder): + def __init__(self, image_data: bytes) -> None: + self._image_data = image_data + + def decode(self) -> Image: + f = BytesIO(self._image_data) + return Image.open(f).convert(mode="RGB") + + +class TargetDecoder(Decoder): + def __init__(self, target: Any): + self._target = target + + def decode(self) -> Any: + return self._target diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/extended.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/extended.py new file mode 100644 index 0000000000000000000000000000000000000000..4da831e6ad275025ed55eaa490f780ecf6083f2c --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/extended.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, Tuple + +from torchvision.datasets import VisionDataset + +from .decoders import TargetDecoder, ImageDataDecoder + + +class ExtendedVisionDataset(VisionDataset): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) # type: ignore + + def get_image_data(self, index: int) -> bytes: + raise NotImplementedError + + def get_target(self, index: int) -> Any: + raise NotImplementedError + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + try: + image_data = self.get_image_data(index) + image = ImageDataDecoder(image_data).decode() + except Exception as e: + raise RuntimeError(f"can not read image for sample {index}") from e + target = self.get_target(index) + target = TargetDecoder(target).decode() + + if self.transforms is not None: + image, target = self.transforms(image, target) + + return image, target + + def __len__(self) -> int: + raise NotImplementedError diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net.py new file mode 100644 index 0000000000000000000000000000000000000000..1e1c384cc96ceb6afeb3e555d9b3e2a2c008c5d4 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net.py @@ -0,0 +1,291 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import csv +from enum import Enum +import logging +import os +from typing import Callable, List, Optional, Tuple, Union + +import numpy as np + +from .extended import ExtendedVisionDataset + + +logger = logging.getLogger("dinov2") +_Target = int + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + TEST = "test" # NOTE: torchvision does not support the test split + + @property + def length(self) -> int: + split_lengths = { + _Split.TRAIN: 1_281_167, + _Split.VAL: 50_000, + _Split.TEST: 100_000, + } + return split_lengths[self] + + def get_dirname(self, class_id: Optional[str] = None) -> str: + return self.value if class_id is None else os.path.join(self.value, class_id) + + def get_image_relpath(self, actual_index: int, class_id: Optional[str] = None) -> str: + dirname = self.get_dirname(class_id) + if self == _Split.TRAIN: + basename = f"{class_id}_{actual_index}" + else: # self in (_Split.VAL, _Split.TEST): + basename = f"ILSVRC2012_{self.value}_{actual_index:08d}" + return os.path.join(dirname, basename + ".JPEG") + + def parse_image_relpath(self, image_relpath: str) -> Tuple[str, int]: + assert self != _Split.TEST + dirname, filename = os.path.split(image_relpath) + class_id = os.path.split(dirname)[-1] + basename, _ = os.path.splitext(filename) + actual_index = int(basename.split("_")[-1]) + return class_id, actual_index + + +class ImageNet(ExtendedVisionDataset): + Target = Union[_Target] + Split = Union[_Split] + + def __init__( + self, + *, + split: "ImageNet.Split", + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + self._split = split + + self._entries = None + self._class_ids = None + self._class_names = None + + @property + def split(self) -> "ImageNet.Split": + return self._split + + def _get_extra_full_path(self, extra_path: str) -> str: + return os.path.join(self._extra_root, extra_path) + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_full_path = self._get_extra_full_path(extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_full_path = self._get_extra_full_path(extra_path) + os.makedirs(self._extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _entries_path(self) -> str: + return f"entries-{self._split.value.upper()}.npy" + + @property + def _class_ids_path(self) -> str: + return f"class-ids-{self._split.value.upper()}.npy" + + @property + def _class_names_path(self) -> str: + return f"class-names-{self._split.value.upper()}.npy" + + def _get_entries(self) -> np.ndarray: + if self._entries is None: + self._entries = self._load_extra(self._entries_path) + assert self._entries is not None + return self._entries + + def _get_class_ids(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class IDs are not available in TEST split" + if self._class_ids is None: + self._class_ids = self._load_extra(self._class_ids_path) + assert self._class_ids is not None + return self._class_ids + + def _get_class_names(self) -> np.ndarray: + if self._split == _Split.TEST: + assert False, "Class names are not available in TEST split" + if self._class_names is None: + self._class_names = self._load_extra(self._class_names_path) + assert self._class_names is not None + return self._class_names + + def find_class_id(self, class_index: int) -> str: + class_ids = self._get_class_ids() + return str(class_ids[class_index]) + + def find_class_name(self, class_index: int) -> str: + class_names = self._get_class_names() + return str(class_names[class_index]) + + def get_image_data(self, index: int) -> bytes: + entries = self._get_entries() + actual_index = entries[index]["actual_index"] + + class_id = self.get_class_id(index) + + image_relpath = self.split.get_image_relpath(actual_index, class_id) + image_full_path = os.path.join(self.root, image_relpath) + with open(image_full_path, mode="rb") as f: + image_data = f.read() + return image_data + + def get_target(self, index: int) -> Optional[Target]: + entries = self._get_entries() + class_index = entries[index]["class_index"] + return None if self.split == _Split.TEST else int(class_index) + + def get_targets(self) -> Optional[np.ndarray]: + entries = self._get_entries() + return None if self.split == _Split.TEST else entries["class_index"] + + def get_class_id(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_id = entries[index]["class_id"] + return None if self.split == _Split.TEST else str(class_id) + + def get_class_name(self, index: int) -> Optional[str]: + entries = self._get_entries() + class_name = entries[index]["class_name"] + return None if self.split == _Split.TEST else str(class_name) + + def __len__(self) -> int: + entries = self._get_entries() + assert len(entries) == self.split.length + return len(entries) + + def _load_labels(self, labels_path: str) -> List[Tuple[str, str]]: + labels_full_path = os.path.join(self.root, labels_path) + labels = [] + + try: + with open(labels_full_path, "r") as f: + reader = csv.reader(f) + for row in reader: + class_id, class_name = row + labels.append((class_id, class_name)) + except OSError as e: + raise RuntimeError(f'can not read labels file "{labels_full_path}"') from e + + return labels + + def _dump_entries(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + dataset = None + sample_count = split.length + max_class_id_length, max_class_name_length = 0, 0 + else: + labels_path = "labels.txt" + logger.info(f'loading labels from "{labels_path}"') + labels = self._load_labels(labels_path) + + # NOTE: Using torchvision ImageFolder for consistency + from torchvision.datasets import ImageFolder + + dataset_root = os.path.join(self.root, split.get_dirname()) + dataset = ImageFolder(dataset_root) + sample_count = len(dataset) + max_class_id_length, max_class_name_length = -1, -1 + for sample in dataset.samples: + _, class_index = sample + class_id, class_name = labels[class_index] + max_class_id_length = max(len(class_id), max_class_id_length) + max_class_name_length = max(len(class_name), max_class_name_length) + + dtype = np.dtype( + [ + ("actual_index", " old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + actual_index = index + 1 + class_index = np.uint32(-1) + class_id, class_name = "", "" + entries_array[index] = (actual_index, class_index, class_id, class_name) + else: + class_names = {class_id: class_name for class_id, class_name in labels} + + assert dataset + old_percent = -1 + for index in range(sample_count): + percent = 100 * (index + 1) // sample_count + if percent > old_percent: + logger.info(f"creating entries: {percent}%") + old_percent = percent + + image_full_path, class_index = dataset.samples[index] + image_relpath = os.path.relpath(image_full_path, self.root) + class_id, actual_index = split.parse_image_relpath(image_relpath) + class_name = class_names[class_id] + entries_array[index] = (actual_index, class_index, class_id, class_name) + + logger.info(f'saving entries to "{self._entries_path}"') + self._save_extra(entries_array, self._entries_path) + + def _dump_class_ids_and_names(self) -> None: + split = self.split + if split == ImageNet.Split.TEST: + return + + entries_array = self._load_extra(self._entries_path) + + max_class_id_length, max_class_name_length, max_class_index = -1, -1, -1 + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + max_class_name_length = max(len(str(class_name)), max_class_name_length) + + class_count = max_class_index + 1 + class_ids_array = np.empty(class_count, dtype=f"U{max_class_id_length}") + class_names_array = np.empty(class_count, dtype=f"U{max_class_name_length}") + for entry in entries_array: + class_index, class_id, class_name = ( + entry["class_index"], + entry["class_id"], + entry["class_name"], + ) + class_ids_array[class_index] = class_id + class_names_array[class_index] = class_name + + logger.info(f'saving class IDs to "{self._class_ids_path}"') + self._save_extra(class_ids_array, self._class_ids_path) + + logger.info(f'saving class names to "{self._class_names_path}"') + self._save_extra(class_names_array, self._class_names_path) + + def dump_extra(self) -> None: + self._dump_entries() + self._dump_class_ids_and_names() diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net_22k.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net_22k.py new file mode 100644 index 0000000000000000000000000000000000000000..2c0bfd335a68b67e02c241f39b1ae06f9fbe1dd0 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/datasets/image_net_22k.py @@ -0,0 +1,303 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass +from enum import Enum +from functools import lru_cache +from gzip import GzipFile +from io import BytesIO +from mmap import ACCESS_READ, mmap +import os +from typing import Any, Callable, List, Optional, Set, Tuple +import warnings + +import numpy as np + +from .extended import ExtendedVisionDataset + + +_Labels = int + +_DEFAULT_MMAP_CACHE_SIZE = 16 # Warning: This can exhaust file descriptors + + +@dataclass +class _ClassEntry: + block_offset: int + maybe_filename: Optional[str] = None + + +@dataclass +class _Entry: + class_index: int # noqa: E701 + start_offset: int + end_offset: int + filename: str + + +class _Split(Enum): + TRAIN = "train" + VAL = "val" + + @property + def length(self) -> int: + return { + _Split.TRAIN: 11_797_647, + _Split.VAL: 561_050, + }[self] + + def entries_path(self): + return f"imagenet21kp_{self.value}.txt" + + +def _get_tarball_path(class_id: str) -> str: + return f"{class_id}.tar" + + +def _make_mmap_tarball(tarballs_root: str, mmap_cache_size: int): + @lru_cache(maxsize=mmap_cache_size) + def _mmap_tarball(class_id: str) -> mmap: + tarball_path = _get_tarball_path(class_id) + tarball_full_path = os.path.join(tarballs_root, tarball_path) + with open(tarball_full_path) as f: + return mmap(fileno=f.fileno(), length=0, access=ACCESS_READ) + + return _mmap_tarball + + +class ImageNet22k(ExtendedVisionDataset): + _GZIPPED_INDICES: Set[int] = { + 841_545, + 1_304_131, + 2_437_921, + 2_672_079, + 2_795_676, + 2_969_786, + 6_902_965, + 6_903_550, + 6_903_628, + 7_432_557, + 7_432_589, + 7_813_809, + 8_329_633, + 10_296_990, + 10_417_652, + 10_492_265, + 10_598_078, + 10_782_398, + 10_902_612, + 11_203_736, + 11_342_890, + 11_397_596, + 11_589_762, + 11_705_103, + 12_936_875, + 13_289_782, + } + Labels = _Labels + + def __init__( + self, + *, + root: str, + extra: str, + transforms: Optional[Callable] = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, + mmap_cache_size: int = _DEFAULT_MMAP_CACHE_SIZE, + ) -> None: + super().__init__(root, transforms, transform, target_transform) + self._extra_root = extra + + entries_path = self._get_entries_path(root) + self._entries = self._load_extra(entries_path) + + class_ids_path = self._get_class_ids_path(root) + self._class_ids = self._load_extra(class_ids_path) + + self._gzipped_indices = ImageNet22k._GZIPPED_INDICES + self._mmap_tarball = _make_mmap_tarball(self._tarballs_root, mmap_cache_size) + + def _get_entries_path(self, root: Optional[str] = None) -> str: + return "entries.npy" + + def _get_class_ids_path(self, root: Optional[str] = None) -> str: + return "class-ids.npy" + + def _find_class_ids(self, path: str) -> List[str]: + class_ids = [] + + with os.scandir(path) as entries: + for entry in entries: + root, ext = os.path.splitext(entry.name) + if ext != ".tar": + continue + class_ids.append(root) + + return sorted(class_ids) + + def _load_entries_class_ids(self, root: Optional[str] = None) -> Tuple[List[_Entry], List[str]]: + root = self.get_root(root) + entries: List[_Entry] = [] + class_ids = self._find_class_ids(root) + + for class_index, class_id in enumerate(class_ids): + path = os.path.join(root, "blocks", f"{class_id}.log") + class_entries = [] + + try: + with open(path) as f: + for line in f: + line = line.rstrip() + block, filename = line.split(":") + block_offset = int(block[6:]) + filename = filename[1:] + + maybe_filename = None + if filename != "** Block of NULs **": + maybe_filename = filename + _, ext = os.path.splitext(filename) + # assert ext == ".JPEG" + + class_entry = _ClassEntry(block_offset, maybe_filename) + class_entries.append(class_entry) + except OSError as e: + raise RuntimeError(f'can not read blocks file "{path}"') from e + + assert class_entries[-1].maybe_filename is None + + for class_entry1, class_entry2 in zip(class_entries, class_entries[1:]): + assert class_entry1.block_offset <= class_entry2.block_offset + start_offset = 512 * class_entry1.block_offset + end_offset = 512 * class_entry2.block_offset + assert class_entry1.maybe_filename is not None + filename = class_entry1.maybe_filename + entry = _Entry(class_index, start_offset, end_offset, filename) + # Skip invalid image files (PIL throws UnidentifiedImageError) + if filename == "n06470073_47249.JPEG": + continue + entries.append(entry) + + return entries, class_ids + + def _load_extra(self, extra_path: str) -> np.ndarray: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + return np.load(extra_full_path, mmap_mode="r") + + def _save_extra(self, extra_array: np.ndarray, extra_path: str) -> None: + extra_root = self._extra_root + extra_full_path = os.path.join(extra_root, extra_path) + os.makedirs(extra_root, exist_ok=True) + np.save(extra_full_path, extra_array) + + @property + def _tarballs_root(self) -> str: + return self.root + + def find_class_id(self, class_index: int) -> str: + return str(self._class_ids[class_index]) + + def get_image_data(self, index: int) -> bytes: + entry = self._entries[index] + class_id = entry["class_id"] + class_mmap = self._mmap_tarball(class_id) + + start_offset, end_offset = entry["start_offset"], entry["end_offset"] + try: + mapped_data = class_mmap[start_offset:end_offset] + data = mapped_data[512:] # Skip entry header block + + if len(data) >= 2 and tuple(data[:2]) == (0x1F, 0x8B): + assert index in self._gzipped_indices, f"unexpected gzip header for sample {index}" + with GzipFile(fileobj=BytesIO(data)) as g: + data = g.read() + except Exception as e: + raise RuntimeError(f"can not retrieve image data for sample {index} " f'from "{class_id}" tarball') from e + + return data + + def get_target(self, index: int) -> Any: + return int(self._entries[index]["class_index"]) + + def get_targets(self) -> np.ndarray: + return self._entries["class_index"] + + def get_class_id(self, index: int) -> str: + return str(self._entries[index]["class_id"]) + + def get_class_ids(self) -> np.ndarray: + return self._entries["class_id"] + + def __getitem__(self, index: int) -> Tuple[Any, Any]: + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + return super().__getitem__(index) + + def __len__(self) -> int: + return len(self._entries) + + def _dump_entries(self, *args, **kwargs) -> None: + entries, class_ids = self._load_entries_class_ids(*args, **kwargs) + + max_class_id_length, max_filename_length, max_class_index = -1, -1, -1 + for entry in entries: + class_id = class_ids[entry.class_index] + max_class_index = max(entry.class_index, max_class_index) + max_class_id_length = max(len(class_id), max_class_id_length) + max_filename_length = max(len(entry.filename), max_filename_length) + + dtype = np.dtype( + [ + ("class_index", " None: + entries_path = self._get_entries_path(*args, **kwargs) + entries_array = self._load_extra(entries_path) + + max_class_id_length, max_class_index = -1, -1 + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + max_class_index = max(int(class_index), max_class_index) + max_class_id_length = max(len(str(class_id)), max_class_id_length) + + class_ids_array = np.empty(max_class_index + 1, dtype=f"U{max_class_id_length}") + for entry in entries_array: + class_index, class_id = entry["class_index"], entry["class_id"] + class_ids_array[class_index] = class_id + class_ids_path = self._get_class_ids_path(*args, **kwargs) + self._save_extra(class_ids_array, class_ids_path) + + def _dump_extra(self, *args, **kwargs) -> None: + self._dump_entries(*args, *kwargs) + self._dump_class_ids(*args, *kwargs) + + def dump_extra(self, root: Optional[str] = None) -> None: + return self._dump_extra(root) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/loaders.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb6f25a0a3c3251b803f48d0a515aa0b9591226 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/loaders.py @@ -0,0 +1,223 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from enum import Enum +from typing import Any, Callable, List, Optional, TypeVar + +import torch +from torch.utils.data import Sampler + +from .datasets import ImageNet, ImageNet22k +from .samplers import EpochSampler, InfiniteSampler, ShardedInfiniteSampler + + +logger = logging.getLogger("dinov2") + + +class SamplerType(Enum): + DISTRIBUTED = 0 + EPOCH = 1 + INFINITE = 2 + SHARDED_INFINITE = 3 + SHARDED_INFINITE_NEW = 4 + + +def _make_bool_str(b: bool) -> str: + return "yes" if b else "no" + + +def _make_sample_transform(image_transform: Optional[Callable] = None, target_transform: Optional[Callable] = None): + def transform(sample): + image, target = sample + if image_transform is not None: + image = image_transform(image) + if target_transform is not None: + target = target_transform(target) + return image, target + + return transform + + +def _parse_dataset_str(dataset_str: str): + tokens = dataset_str.split(":") + + name = tokens[0] + kwargs = {} + + for token in tokens[1:]: + key, value = token.split("=") + assert key in ("root", "extra", "split") + kwargs[key] = value + + if name == "ImageNet": + class_ = ImageNet + if "split" in kwargs: + kwargs["split"] = ImageNet.Split[kwargs["split"]] + elif name == "ImageNet22k": + class_ = ImageNet22k + else: + raise ValueError(f'Unsupported dataset "{name}"') + + return class_, kwargs + + +def make_dataset( + *, + dataset_str: str, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, +): + """ + Creates a dataset with the specified parameters. + + Args: + dataset_str: A dataset string description (e.g. ImageNet:split=TRAIN). + transform: A transform to apply to images. + target_transform: A transform to apply to targets. + + Returns: + The created dataset. + """ + logger.info(f'using dataset: "{dataset_str}"') + + class_, kwargs = _parse_dataset_str(dataset_str) + dataset = class_(transform=transform, target_transform=target_transform, **kwargs) + + logger.info(f"# of dataset samples: {len(dataset):,d}") + + # Aggregated datasets do not expose (yet) these attributes, so add them. + if not hasattr(dataset, "transform"): + setattr(dataset, "transform", transform) + if not hasattr(dataset, "target_transform"): + setattr(dataset, "target_transform", target_transform) + + return dataset + + +def _make_sampler( + *, + dataset, + type: Optional[SamplerType] = None, + shuffle: bool = False, + seed: int = 0, + size: int = -1, + advance: int = 0, +) -> Optional[Sampler]: + sample_count = len(dataset) + + if type == SamplerType.INFINITE: + logger.info("sampler: infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + return InfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + ) + elif type in (SamplerType.SHARDED_INFINITE, SamplerType.SHARDED_INFINITE_NEW): + logger.info("sampler: sharded infinite") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + # TODO: Remove support for old shuffling + use_new_shuffle_tensor_slice = type == SamplerType.SHARDED_INFINITE_NEW + return ShardedInfiniteSampler( + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + advance=advance, + use_new_shuffle_tensor_slice=use_new_shuffle_tensor_slice, + ) + elif type == SamplerType.EPOCH: + logger.info("sampler: epoch") + if advance > 0: + raise NotImplementedError("sampler advance > 0 is not supported") + size = size if size > 0 else sample_count + logger.info(f"# of samples / epoch: {size:,d}") + return EpochSampler( + size=size, + sample_count=sample_count, + shuffle=shuffle, + seed=seed, + ) + elif type == SamplerType.DISTRIBUTED: + logger.info("sampler: distributed") + if size > 0: + raise ValueError("sampler size > 0 is invalid") + if advance > 0: + raise ValueError("sampler advance > 0 is invalid") + return torch.utils.data.DistributedSampler( + dataset=dataset, + shuffle=shuffle, + seed=seed, + drop_last=False, + ) + + logger.info("sampler: none") + return None + + +T = TypeVar("T") + + +def make_data_loader( + *, + dataset, + batch_size: int, + num_workers: int, + shuffle: bool = True, + seed: int = 0, + sampler_type: Optional[SamplerType] = SamplerType.INFINITE, + sampler_size: int = -1, + sampler_advance: int = 0, + drop_last: bool = True, + persistent_workers: bool = False, + collate_fn: Optional[Callable[[List[T]], Any]] = None, +): + """ + Creates a data loader with the specified parameters. + + Args: + dataset: A dataset (third party, LaViDa or WebDataset). + batch_size: The size of batches to generate. + num_workers: The number of workers to use. + shuffle: Whether to shuffle samples. + seed: The random seed to use. + sampler_type: Which sampler to use: EPOCH, INFINITE, SHARDED_INFINITE, SHARDED_INFINITE_NEW, DISTRIBUTED or None. + sampler_size: The number of images per epoch (when applicable) or -1 for the entire dataset. + sampler_advance: How many samples to skip (when applicable). + drop_last: Whether the last non-full batch of data should be dropped. + persistent_workers: maintain the workers Dataset instances alive after a dataset has been consumed once. + collate_fn: Function that performs batch collation + """ + + sampler = _make_sampler( + dataset=dataset, + type=sampler_type, + shuffle=shuffle, + seed=seed, + size=sampler_size, + advance=sampler_advance, + ) + + logger.info("using PyTorch data loader") + data_loader = torch.utils.data.DataLoader( + dataset, + sampler=sampler, + batch_size=batch_size, + num_workers=num_workers, + pin_memory=True, + drop_last=drop_last, + persistent_workers=persistent_workers, + collate_fn=collate_fn, + ) + + try: + logger.info(f"# of batches: {len(data_loader):,d}") + except TypeError: # data loader has no length + logger.info("infinite data loader") + return data_loader diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/masking.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3c72648c3e440dcdb284366b98d2df12ad1272 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/masking.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import random +import math +import numpy as np + + +class MaskingGenerator: + def __init__( + self, + input_size, + num_masking_patches=None, + min_num_patches=4, + max_num_patches=None, + min_aspect=0.3, + max_aspect=None, + ): + if not isinstance(input_size, tuple): + input_size = (input_size,) * 2 + self.height, self.width = input_size + + self.num_patches = self.height * self.width + self.num_masking_patches = num_masking_patches + + self.min_num_patches = min_num_patches + self.max_num_patches = num_masking_patches if max_num_patches is None else max_num_patches + + max_aspect = max_aspect or 1 / min_aspect + self.log_aspect_ratio = (math.log(min_aspect), math.log(max_aspect)) + + def __repr__(self): + repr_str = "Generator(%d, %d -> [%d ~ %d], max = %d, %.3f ~ %.3f)" % ( + self.height, + self.width, + self.min_num_patches, + self.max_num_patches, + self.num_masking_patches, + self.log_aspect_ratio[0], + self.log_aspect_ratio[1], + ) + return repr_str + + def get_shape(self): + return self.height, self.width + + def _mask(self, mask, max_mask_patches): + delta = 0 + for _ in range(10): + target_area = random.uniform(self.min_num_patches, max_mask_patches) + aspect_ratio = math.exp(random.uniform(*self.log_aspect_ratio)) + h = int(round(math.sqrt(target_area * aspect_ratio))) + w = int(round(math.sqrt(target_area / aspect_ratio))) + if w < self.width and h < self.height: + top = random.randint(0, self.height - h) + left = random.randint(0, self.width - w) + + num_masked = mask[top : top + h, left : left + w].sum() + # Overlap + if 0 < h * w - num_masked <= max_mask_patches: + for i in range(top, top + h): + for j in range(left, left + w): + if mask[i, j] == 0: + mask[i, j] = 1 + delta += 1 + + if delta > 0: + break + return delta + + def __call__(self, num_masking_patches=0): + mask = np.zeros(shape=self.get_shape(), dtype=bool) + mask_count = 0 + while mask_count < num_masking_patches: + max_mask_patches = num_masking_patches - mask_count + max_mask_patches = min(max_mask_patches, self.max_num_patches) + + delta = self._mask(mask, max_mask_patches) + if delta == 0: + break + else: + mask_count += delta + + return mask diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/samplers.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..5656a1cbe2d8fd25296678619b5eba36fffc3184 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/samplers.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import itertools +from typing import Any, Optional +import warnings + +import numpy as np +import torch +from torch.utils.data.sampler import Sampler + +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed + + +class EpochSampler(Sampler): + def __init__( + self, + *, + size: int, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + ): + self._size = size + self._sample_count = sample_count + self._shuffle = shuffle + self._seed = seed + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._epoch = 0 + + def __iter__(self): + count = (self._size + self._sample_count - 1) // self._sample_count + tiled_indices = np.tile(np.arange(self._sample_count), count) + if self._shuffle: + seed = self._seed * self._epoch if self._seed != 0 else self._epoch + rng = np.random.default_rng(seed) + iterable = rng.choice(tiled_indices, self._size, replace=False) + else: + iterable = tiled_indices[: self._size] + + yield from itertools.islice(iterable, self._start, None, self._step) + + def __len__(self): + return (self._size - self._start + self._step - 1) // self._step + + def set_epoch(self, epoch): + self._epoch = epoch + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +class InfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + + def __iter__(self): + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator().manual_seed(self._seed) + + while True: + iterable = _generate_randperm_indices(size=self._sample_count, generator=generator) + yield from itertools.islice(iterable, self._start, None, self._step) + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = torch.randint(0, i + 1, size=(1,), generator=generator).item() if i > 0 else 0 + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + self._start = distributed.get_global_rank() if start is None else start + self._step = distributed.get_global_size() if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice if use_new_shuffle_tensor_slice else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/transforms.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..f1bc4cbd1a459a9f44314806cf9ccedea112ab14 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/data/transforms.py @@ -0,0 +1,92 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Sequence + +import torch +from torchvision import transforms + + +class GaussianBlur(transforms.RandomApply): + """ + Apply Gaussian Blur to the PIL image. + """ + + def __init__(self, *, p: float = 0.5, radius_min: float = 0.1, radius_max: float = 2.0): + # NOTE: torchvision is applying 1 - probability to return the original image + keep_p = 1 - p + transform = transforms.GaussianBlur(kernel_size=9, sigma=(radius_min, radius_max)) + super().__init__(transforms=[transform], p=keep_p) + + +class MaybeToTensor(transforms.ToTensor): + """ + Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor, or keep as is if already a tensor. + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image, numpy.ndarray or torch.tensor): Image to be converted to tensor. + Returns: + Tensor: Converted image. + """ + if isinstance(pic, torch.Tensor): + return pic + return super().__call__(pic) + + +# Use timm's names +IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) + + +def make_normalize_transform( + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Normalize: + return transforms.Normalize(mean=mean, std=std) + + +# This roughly matches torchvision's preset for classification training: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L6-L44 +def make_classification_train_transform( + *, + crop_size: int = 224, + interpolation=transforms.InterpolationMode.BICUBIC, + hflip_prob: float = 0.5, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +): + transforms_list = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if hflip_prob > 0.0: + transforms_list.append(transforms.RandomHorizontalFlip(hflip_prob)) + transforms_list.extend( + [ + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + ) + return transforms.Compose(transforms_list) + + +# This matches (roughly) torchvision's preset for classification evaluation: +# https://github.com/pytorch/vision/blob/main/references/classification/presets.py#L47-L69 +def make_classification_eval_transform( + *, + resize_size: int = 256, + interpolation=transforms.InterpolationMode.BICUBIC, + crop_size: int = 224, + mean: Sequence[float] = IMAGENET_DEFAULT_MEAN, + std: Sequence[float] = IMAGENET_DEFAULT_STD, +) -> transforms.Compose: + transforms_list = [ + transforms.Resize(resize_size, interpolation=interpolation), + transforms.CenterCrop(crop_size), + MaybeToTensor(), + make_normalize_transform(mean=mean, std=std), + ] + return transforms.Compose(transforms_list) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/distributed/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccd663f33d5a21ad1f9d25db7bd378ec52aeac2 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/distributed/__init__.py @@ -0,0 +1,271 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import random +import re +import socket +from typing import Dict, List + +import torch +import torch.distributed as dist + +_LOCAL_RANK = -1 +_LOCAL_WORLD_SIZE = -1 + + +def is_enabled() -> bool: + """ + Returns: + True if distributed training is enabled + """ + return dist.is_available() and dist.is_initialized() + + +def get_global_size() -> int: + """ + Returns: + The number of processes in the process group + """ + return dist.get_world_size() if is_enabled() else 1 + + +def get_global_rank() -> int: + """ + Returns: + The rank of the current process within the global process group. + """ + return dist.get_rank() if is_enabled() else 0 + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_enabled(): + return 0 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_RANK + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_enabled(): + return 1 + assert 0 <= _LOCAL_RANK < _LOCAL_WORLD_SIZE + return _LOCAL_WORLD_SIZE + + +def is_main_process() -> bool: + """ + Returns: + True if the current process is the main one. + """ + return get_global_rank() == 0 + + +def _restrict_print_to_main_process() -> None: + """ + This function disables printing when not in the main process + """ + import builtins as __builtin__ + + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop("force", False) + if is_main_process() or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def _get_master_port(seed: int = 0) -> int: + MIN_MASTER_PORT, MAX_MASTER_PORT = (20_000, 60_000) + + master_port_str = os.environ.get("MASTER_PORT") + if master_port_str is None: + rng = random.Random(seed) + return rng.randint(MIN_MASTER_PORT, MAX_MASTER_PORT) + + return int(master_port_str) + + +def _get_available_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + # A "" host address means INADDR_ANY i.e. binding to all interfaces. + # Note this is not compatible with IPv6. + s.bind(("", 0)) + port = s.getsockname()[1] + return port + + +_TORCH_DISTRIBUTED_ENV_VARS = ( + "MASTER_ADDR", + "MASTER_PORT", + "RANK", + "WORLD_SIZE", + "LOCAL_RANK", + "LOCAL_WORLD_SIZE", +) + + +def _collect_env_vars() -> Dict[str, str]: + return {env_var: os.environ[env_var] for env_var in _TORCH_DISTRIBUTED_ENV_VARS if env_var in os.environ} + + +def _is_slurm_job_process() -> bool: + return "SLURM_JOB_ID" in os.environ + + +def _parse_slurm_node_list(s: str) -> List[str]: + nodes = [] + # Extract "hostname", "hostname[1-2,3,4-5]," substrings + p = re.compile(r"(([^\[]+)(?:\[([^\]]+)\])?),?") + for m in p.finditer(s): + prefix, suffixes = s[m.start(2) : m.end(2)], s[m.start(3) : m.end(3)] + for suffix in suffixes.split(","): + span = suffix.split("-") + if len(span) == 1: + nodes.append(prefix + suffix) + else: + width = len(span[0]) + start, end = int(span[0]), int(span[1]) + 1 + nodes.extend([prefix + f"{i:0{width}}" for i in range(start, end)]) + return nodes + + +def _check_env_variable(key: str, new_value: str): + # Only check for difference with preset environment variables + if key in os.environ and os.environ[key] != new_value: + raise RuntimeError(f"Cannot export environment variables as {key} is already set") + + +class _TorchDistributedEnvironment: + def __init__(self): + self.master_addr = "127.0.0.1" + self.master_port = 0 + self.rank = -1 + self.world_size = -1 + self.local_rank = -1 + self.local_world_size = -1 + + if _is_slurm_job_process(): + return self._set_from_slurm_env() + + env_vars = _collect_env_vars() + if not env_vars: + # Environment is not set + pass + elif len(env_vars) == len(_TORCH_DISTRIBUTED_ENV_VARS): + # Environment is fully set + return self._set_from_preset_env() + else: + # Environment is partially set + collected_env_vars = ", ".join(env_vars.keys()) + raise RuntimeError(f"Partially set environment: {collected_env_vars}") + + if torch.cuda.device_count() > 0: + return self._set_from_local() + + raise RuntimeError("Can't initialize PyTorch distributed environment") + + # Slurm job created with sbatch, submitit, etc... + def _set_from_slurm_env(self): + # logger.info("Initialization from Slurm environment") + job_id = int(os.environ["SLURM_JOB_ID"]) + node_count = int(os.environ["SLURM_JOB_NUM_NODES"]) + nodes = _parse_slurm_node_list(os.environ["SLURM_JOB_NODELIST"]) + assert len(nodes) == node_count + + self.master_addr = nodes[0] + self.master_port = _get_master_port(seed=job_id) + self.rank = int(os.environ["SLURM_PROCID"]) + self.world_size = int(os.environ["SLURM_NTASKS"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["SLURM_LOCALID"]) + self.local_world_size = self.world_size // node_count + assert self.local_rank < self.local_world_size + + # Single node job with preset environment (i.e. torchrun) + def _set_from_preset_env(self): + # logger.info("Initialization from preset environment") + self.master_addr = os.environ["MASTER_ADDR"] + self.master_port = os.environ["MASTER_PORT"] + self.rank = int(os.environ["RANK"]) + self.world_size = int(os.environ["WORLD_SIZE"]) + assert self.rank < self.world_size + self.local_rank = int(os.environ["LOCAL_RANK"]) + self.local_world_size = int(os.environ["LOCAL_WORLD_SIZE"]) + assert self.local_rank < self.local_world_size + + # Single node and GPU job (i.e. local script run) + def _set_from_local(self): + # logger.info("Initialization from local") + self.master_addr = "127.0.0.1" + self.master_port = _get_available_port() + self.rank = 0 + self.world_size = 1 + self.local_rank = 0 + self.local_world_size = 1 + + def export(self, *, overwrite: bool) -> "_TorchDistributedEnvironment": + # See the "Environment variable initialization" section from + # https://pytorch.org/docs/stable/distributed.html for the complete list of + # environment variables required for the env:// initialization method. + env_vars = { + "MASTER_ADDR": self.master_addr, + "MASTER_PORT": str(self.master_port), + "RANK": str(self.rank), + "WORLD_SIZE": str(self.world_size), + "LOCAL_RANK": str(self.local_rank), + "LOCAL_WORLD_SIZE": str(self.local_world_size), + } + if not overwrite: + for k, v in env_vars.items(): + _check_env_variable(k, v) + + os.environ.update(env_vars) + return self + + +def enable(*, set_cuda_current_device: bool = True, overwrite: bool = False, allow_nccl_timeout: bool = False): + """Enable distributed mode + + Args: + set_cuda_current_device: If True, call torch.cuda.set_device() to set the + current PyTorch CUDA device to the one matching the local rank. + overwrite: If True, overwrites already set variables. Else fails. + """ + + global _LOCAL_RANK, _LOCAL_WORLD_SIZE + if _LOCAL_RANK >= 0 or _LOCAL_WORLD_SIZE >= 0: + raise RuntimeError("Distributed mode has already been enabled") + torch_env = _TorchDistributedEnvironment() + torch_env.export(overwrite=overwrite) + + if set_cuda_current_device: + torch.cuda.set_device(torch_env.local_rank) + + if allow_nccl_timeout: + # This allows to use torch distributed timeout in a NCCL backend + key, value = "NCCL_ASYNC_ERROR_HANDLING", "1" + if not overwrite: + _check_env_variable(key, value) + os.environ[key] = value + + dist.init_process_group(backend="nccl") + dist.barrier() + + # Finalize setup + _LOCAL_RANK = torch_env.local_rank + _LOCAL_WORLD_SIZE = torch_env.local_world_size + _restrict_print_to_main_process() diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/knn.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..b98c630e183702cfbe8f7e5420bbef538a9d881b --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/knn.py @@ -0,0 +1,405 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import torch +from torch.nn.functional import one_hot, softmax + +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import SamplerType, make_data_loader, make_dataset +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data.transforms import make_classification_eval_transform +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.metrics import AccuracyAveraging, build_topk_accuracy_metric +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import get_args_parser as get_setup_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import setup_and_build_model +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.utils import ModelWithNormalize, evaluate, extract_features + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--nb_knn", + nargs="+", + type=int, + help="Number of NN to use. 20 is usually working the best.", + ) + parser.add_argument( + "--temperature", + type=float, + help="Temperature used in the voting coefficient", + ) + parser.add_argument( + "--gather-on-cpu", + action="store_true", + help="Whether to gather the train features on cpu, slower" + "but useful to avoid OOM for large datasets (e.g. ImageNet22k).", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch size.", + ) + parser.add_argument( + "--n-per-class-list", + nargs="+", + type=int, + help="Number to take per class", + ) + parser.add_argument( + "--n-tries", + type=int, + help="Number of tries", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=[10, 20, 100, 200], + temperature=0.07, + batch_size=256, + n_per_class_list=[-1], + n_tries=1, + ) + return parser + + +class KnnModule(torch.nn.Module): + """ + Gets knn of test features from all processes on a chunk of the train features + + Each rank gets a chunk of the train features as well as a chunk of the test features. + In `compute_neighbors`, for each rank one after the other, its chunk of test features + is sent to all devices, partial knns are computed with each chunk of train features + then collated back on the original device. + """ + + def __init__(self, train_features, train_labels, nb_knn, T, device, num_classes=1000): + super().__init__() + + self.global_rank = distributed.get_global_rank() + self.global_size = distributed.get_global_size() + + self.device = device + self.train_features_rank_T = train_features.chunk(self.global_size)[self.global_rank].T.to(self.device) + self.candidates = train_labels.chunk(self.global_size)[self.global_rank].view(1, -1).to(self.device) + + self.nb_knn = nb_knn + self.max_k = max(self.nb_knn) + self.T = T + self.num_classes = num_classes + + def _get_knn_sims_and_labels(self, similarity, train_labels): + topk_sims, indices = similarity.topk(self.max_k, largest=True, sorted=True) + neighbors_labels = torch.gather(train_labels, 1, indices) + return topk_sims, neighbors_labels + + def _similarity_for_rank(self, features_rank, source_rank): + # Send the features from `source_rank` to all ranks + broadcast_shape = torch.tensor(features_rank.shape).to(self.device) + torch.distributed.broadcast(broadcast_shape, source_rank) + + broadcasted = features_rank + if self.global_rank != source_rank: + broadcasted = torch.zeros(*broadcast_shape, dtype=features_rank.dtype, device=self.device) + torch.distributed.broadcast(broadcasted, source_rank) + + # Compute the neighbors for `source_rank` among `train_features_rank_T` + similarity_rank = torch.mm(broadcasted, self.train_features_rank_T) + candidate_labels = self.candidates.expand(len(similarity_rank), -1) + return self._get_knn_sims_and_labels(similarity_rank, candidate_labels) + + def _gather_all_knn_for_rank(self, topk_sims, neighbors_labels, target_rank): + # Gather all neighbors for `target_rank` + topk_sims_rank = retrieved_rank = None + if self.global_rank == target_rank: + topk_sims_rank = [torch.zeros_like(topk_sims) for _ in range(self.global_size)] + retrieved_rank = [torch.zeros_like(neighbors_labels) for _ in range(self.global_size)] + + torch.distributed.gather(topk_sims, topk_sims_rank, dst=target_rank) + torch.distributed.gather(neighbors_labels, retrieved_rank, dst=target_rank) + + if self.global_rank == target_rank: + # Perform a second top-k on the k * global_size retrieved neighbors + topk_sims_rank = torch.cat(topk_sims_rank, dim=1) + retrieved_rank = torch.cat(retrieved_rank, dim=1) + results = self._get_knn_sims_and_labels(topk_sims_rank, retrieved_rank) + return results + return None + + def compute_neighbors(self, features_rank): + for rank in range(self.global_size): + topk_sims, neighbors_labels = self._similarity_for_rank(features_rank, rank) + results = self._gather_all_knn_for_rank(topk_sims, neighbors_labels, rank) + if results is not None: + topk_sims_rank, neighbors_labels_rank = results + return topk_sims_rank, neighbors_labels_rank + + def forward(self, features_rank): + """ + Compute the results on all values of `self.nb_knn` neighbors from the full `self.max_k` + """ + assert all(k <= self.max_k for k in self.nb_knn) + + topk_sims, neighbors_labels = self.compute_neighbors(features_rank) + batch_size = neighbors_labels.shape[0] + topk_sims_transform = softmax(topk_sims / self.T, 1) + matmul = torch.mul( + one_hot(neighbors_labels, num_classes=self.num_classes), + topk_sims_transform.view(batch_size, -1, 1), + ) + probas_for_k = {k: torch.sum(matmul[:, :k, :], 1) for k in self.nb_knn} + return probas_for_k + + +class DictKeysModule(torch.nn.Module): + def __init__(self, keys): + super().__init__() + self.keys = keys + + def forward(self, features_dict, targets): + for k in self.keys: + features_dict = features_dict[k] + return {"preds": features_dict, "target": targets} + + +def create_module_dict(*, module, n_per_class_list, n_tries, nb_knn, train_features, train_labels): + modules = {} + mapping = create_class_indices_mapping(train_labels) + for npc in n_per_class_list: + if npc < 0: # Only one try needed when using the full data + full_module = module( + train_features=train_features, + train_labels=train_labels, + nb_knn=nb_knn, + ) + modules["full"] = ModuleDictWithForward({"1": full_module}) + continue + all_tries = {} + for t in range(n_tries): + final_indices = filter_train(mapping, npc, seed=t) + k_list = list(set(nb_knn + [npc])) + k_list = sorted([el for el in k_list if el <= npc]) + all_tries[str(t)] = module( + train_features=train_features[final_indices], + train_labels=train_labels[final_indices], + nb_knn=k_list, + ) + modules[f"{npc} per class"] = ModuleDictWithForward(all_tries) + + return ModuleDictWithForward(modules) + + +def filter_train(mapping, n_per_class, seed): + torch.manual_seed(seed) + final_indices = [] + for k in mapping.keys(): + index = torch.randperm(len(mapping[k]))[:n_per_class] + final_indices.append(mapping[k][index]) + return torch.cat(final_indices).squeeze() + + +def create_class_indices_mapping(labels): + unique_labels, inverse = torch.unique(labels, return_inverse=True) + mapping = {unique_labels[i]: (inverse == i).nonzero() for i in range(len(unique_labels))} + return mapping + + +class ModuleDictWithForward(torch.nn.ModuleDict): + def forward(self, *args, **kwargs): + return {k: module(*args, **kwargs) for k, module in self._modules.items()} + + +def eval_knn( + model, + train_dataset, + val_dataset, + accuracy_averaging, + nb_knn, + temperature, + batch_size, + num_workers, + gather_on_cpu, + n_per_class_list=[-1], + n_tries=1, +): + model = ModelWithNormalize(model) + + logger.info("Extracting features for train set...") + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=gather_on_cpu + ) + logger.info(f"Train features created, shape {train_features.shape}.") + + val_dataloader = make_data_loader( + dataset=val_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=True, + ) + num_classes = train_labels.max() + 1 + metric_collection = build_topk_accuracy_metric(accuracy_averaging, num_classes=num_classes) + + device = torch.cuda.current_device() + partial_module = partial(KnnModule, T=temperature, device=device, num_classes=num_classes) + knn_module_dict = create_module_dict( + module=partial_module, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + nb_knn=nb_knn, + train_features=train_features, + train_labels=train_labels, + ) + postprocessors, metrics = {}, {} + for n_per_class, knn_module in knn_module_dict.items(): + for t, knn_try in knn_module.items(): + postprocessors = { + **postprocessors, + **{(n_per_class, t, k): DictKeysModule([n_per_class, t, k]) for k in knn_try.nb_knn}, + } + metrics = {**metrics, **{(n_per_class, t, k): metric_collection.clone() for k in knn_try.nb_knn}} + model_with_knn = torch.nn.Sequential(model, knn_module_dict) + + # ============ evaluation ... ============ + logger.info("Start the k-NN classification.") + _, results_dict = evaluate(model_with_knn, val_dataloader, postprocessors, metrics, device) + + # Averaging the results over the n tries for each value of n_per_class + for n_per_class, knn_module in knn_module_dict.items(): + first_try = list(knn_module.keys())[0] + k_list = knn_module[first_try].nb_knn + for k in k_list: + keys = results_dict[(n_per_class, first_try, k)].keys() # keys are e.g. `top-1` and `top-5` + results_dict[(n_per_class, k)] = { + key: torch.mean(torch.stack([results_dict[(n_per_class, t, k)][key] for t in knn_module.keys()])) + for key in keys + } + for t in knn_module.keys(): + del results_dict[(n_per_class, t, k)] + + return results_dict + + +def eval_knn_with_model( + model, + output_dir, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + nb_knn=(10, 20, 100, 200), + temperature=0.07, + autocast_dtype=torch.float, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=False, + batch_size=256, + num_workers=5, + n_per_class_list=[-1], + n_tries=1, +): + transform = transform or make_classification_eval_transform() + + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=transform, + ) + val_dataset = make_dataset( + dataset_str=val_dataset_str, + transform=transform, + ) + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_knn = eval_knn( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + accuracy_averaging=accuracy_averaging, + nb_knn=nb_knn, + temperature=temperature, + batch_size=batch_size, + num_workers=num_workers, + gather_on_cpu=gather_on_cpu, + n_per_class_list=n_per_class_list, + n_tries=n_tries, + ) + + results_dict = {} + if distributed.is_main_process(): + for knn_ in results_dict_knn.keys(): + top1 = results_dict_knn[knn_]["top-1"].item() * 100.0 + top5 = results_dict_knn[knn_]["top-5"].item() * 100.0 + results_dict[f"{knn_} Top 1"] = top1 + results_dict[f"{knn_} Top 5"] = top5 + logger.info(f"{knn_} classifier result: Top1: {top1:.2f} Top5: {top5:.2f}") + + metrics_file_path = os.path.join(output_dir, "results_eval_knn.json") + with open(metrics_file_path, "a") as f: + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + + if distributed.is_enabled(): + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_knn_with_model( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + nb_knn=args.nb_knn, + temperature=args.temperature, + autocast_dtype=autocast_dtype, + accuracy_averaging=AccuracyAveraging.MEAN_ACCURACY, + transform=None, + gather_on_cpu=args.gather_on_cpu, + batch_size=args.batch_size, + num_workers=5, + n_per_class_list=args.n_per_class_list, + n_tries=args.n_tries, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 k-NN evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/linear.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..088c538bd9c2ae2544990b5a3c94ce8bf063b446 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/linear.py @@ -0,0 +1,626 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from functools import partial +import json +import logging +import os +import sys +from typing import List, Optional + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import SamplerType, make_data_loader, make_dataset +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data.transforms import make_classification_eval_transform, make_classification_train_transform +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.metrics import MetricType, build_metric +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import get_args_parser as get_setup_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import setup_and_build_model +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.utils import ModelWithIntermediateLayers, evaluate +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--test-datasets", + dest="test_dataset_strs", + type=str, + nargs="+", + help="Test datasets, none to reuse the validation dataset", + ) + parser.add_argument( + "--epochs", + type=int, + help="Number of training epochs", + ) + parser.add_argument( + "--batch-size", + type=int, + help="Batch Size (per GPU)", + ) + parser.add_argument( + "--num-workers", + type=int, + help="Number de Workers", + ) + parser.add_argument( + "--epoch-length", + type=int, + help="Length of an epoch in number of iterations", + ) + parser.add_argument( + "--save-checkpoint-frequency", + type=int, + help="Number of epochs between two named checkpoint saves.", + ) + parser.add_argument( + "--eval-period-iterations", + type=int, + help="Number of iterations between two evaluations.", + ) + parser.add_argument( + "--learning-rates", + nargs="+", + type=float, + help="Learning rates to grid search.", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not resume from existing checkpoints", + ) + parser.add_argument( + "--val-metric-type", + type=MetricType, + choices=list(MetricType), + help="Validation metric", + ) + parser.add_argument( + "--test-metric-types", + type=MetricType, + choices=list(MetricType), + nargs="+", + help="Evaluation metric", + ) + parser.add_argument( + "--classifier-fpath", + type=str, + help="Path to a file containing pretrained linear classifiers", + ) + parser.add_argument( + "--val-class-mapping-fpath", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.add_argument( + "--test-class-mapping-fpaths", + nargs="+", + type=str, + help="Path to a file containing a mapping to adjust classifier outputs", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + test_dataset_strs=None, + epochs=10, + batch_size=128, + num_workers=8, + epoch_length=1250, + save_checkpoint_frequency=20, + eval_period_iterations=1250, + learning_rates=[1e-5, 2e-5, 5e-5, 1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2, 2e-2, 5e-2, 0.1], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + ) + return parser + + +def has_ddp_wrapper(m: nn.Module) -> bool: + return isinstance(m, DistributedDataParallel) + + +def remove_ddp_wrapper(m: nn.Module) -> nn.Module: + return m.module if has_ddp_wrapper(m) else m + + +def _pad_and_collate(batch): + maxlen = max(len(targets) for image, targets in batch) + padded_batch = [ + (image, np.pad(targets, (0, maxlen - len(targets)), constant_values=-1)) for image, targets in batch + ] + return torch.utils.data.default_collate(padded_batch) + + +def create_linear_input(x_tokens_list, use_n_blocks, use_avgpool): + intermediate_output = x_tokens_list[-use_n_blocks:] + output = torch.cat([class_token for _, class_token in intermediate_output], dim=-1) + if use_avgpool: + output = torch.cat( + ( + output, + torch.mean(intermediate_output[-1][0], dim=1), # patch tokens + ), + dim=-1, + ) + output = output.reshape(output.shape[0], -1) + return output.float() + + +class LinearClassifier(nn.Module): + """Linear layer to train on top of frozen features""" + + def __init__(self, out_dim, use_n_blocks, use_avgpool, num_classes=1000): + super().__init__() + self.out_dim = out_dim + self.use_n_blocks = use_n_blocks + self.use_avgpool = use_avgpool + self.num_classes = num_classes + self.linear = nn.Linear(out_dim, num_classes) + self.linear.weight.data.normal_(mean=0.0, std=0.01) + self.linear.bias.data.zero_() + + def forward(self, x_tokens_list): + output = create_linear_input(x_tokens_list, self.use_n_blocks, self.use_avgpool) + return self.linear(output) + + +class AllClassifiers(nn.Module): + def __init__(self, classifiers_dict): + super().__init__() + self.classifiers_dict = nn.ModuleDict() + self.classifiers_dict.update(classifiers_dict) + + def forward(self, inputs): + return {k: v.forward(inputs) for k, v in self.classifiers_dict.items()} + + def __len__(self): + return len(self.classifiers_dict) + + +class LinearPostprocessor(nn.Module): + def __init__(self, linear_classifier, class_mapping=None): + super().__init__() + self.linear_classifier = linear_classifier + self.register_buffer("class_mapping", None if class_mapping is None else torch.LongTensor(class_mapping)) + + def forward(self, samples, targets): + preds = self.linear_classifier(samples) + return { + "preds": preds[:, self.class_mapping] if self.class_mapping is not None else preds, + "target": targets, + } + + +def scale_lr(learning_rates, batch_size): + return learning_rates * (batch_size * distributed.get_global_size()) / 256.0 + + +def setup_linear_classifiers(sample_output, n_last_blocks_list, learning_rates, batch_size, num_classes=1000): + linear_classifiers_dict = nn.ModuleDict() + optim_param_groups = [] + for n in n_last_blocks_list: + for avgpool in [False, True]: + for _lr in learning_rates: + lr = scale_lr(_lr, batch_size) + out_dim = create_linear_input(sample_output, use_n_blocks=n, use_avgpool=avgpool).shape[1] + linear_classifier = LinearClassifier( + out_dim, use_n_blocks=n, use_avgpool=avgpool, num_classes=num_classes + ) + linear_classifier = linear_classifier.cuda() + linear_classifiers_dict[ + f"classifier_{n}_blocks_avgpool_{avgpool}_lr_{lr:.5f}".replace(".", "_") + ] = linear_classifier + optim_param_groups.append({"params": linear_classifier.parameters(), "lr": lr}) + + linear_classifiers = AllClassifiers(linear_classifiers_dict) + if distributed.is_enabled(): + linear_classifiers = nn.parallel.DistributedDataParallel(linear_classifiers) + + return linear_classifiers, optim_param_groups + + +@torch.no_grad() +def evaluate_linear_classifiers( + feature_model, + linear_classifiers, + data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=None, + best_classifier_on_val=None, +): + logger.info("running validation !") + + num_classes = len(class_mapping) if class_mapping is not None else training_num_classes + metric = build_metric(metric_type, num_classes=num_classes) + postprocessors = {k: LinearPostprocessor(v, class_mapping) for k, v in linear_classifiers.classifiers_dict.items()} + metrics = {k: metric.clone() for k in linear_classifiers.classifiers_dict} + + _, results_dict_temp = evaluate( + feature_model, + data_loader, + postprocessors, + metrics, + torch.cuda.current_device(), + ) + + logger.info("") + results_dict = {} + max_accuracy = 0 + best_classifier = "" + for i, (classifier_string, metric) in enumerate(results_dict_temp.items()): + logger.info(f"{prefixstring} -- Classifier: {classifier_string} * {metric}") + if ( + best_classifier_on_val is None and metric["top-1"].item() > max_accuracy + ) or classifier_string == best_classifier_on_val: + max_accuracy = metric["top-1"].item() + best_classifier = classifier_string + + results_dict["best_classifier"] = {"name": best_classifier, "accuracy": max_accuracy} + + logger.info(f"best classifier: {results_dict['best_classifier']}") + + if distributed.is_main_process(): + with open(metrics_file_path, "a") as f: + f.write(f"iter: {iteration}\n") + for k, v in results_dict.items(): + f.write(json.dumps({k: v}) + "\n") + f.write("\n") + + return results_dict + + +def eval_linear( + *, + feature_model, + linear_classifiers, + train_data_loader, + val_data_loader, + metrics_file_path, + optimizer, + scheduler, + output_dir, + max_iter, + checkpoint_period, # In number of iter, creates a new file every period + running_checkpoint_period, # Period to update main checkpoint file + eval_period, + metric_type, + training_num_classes, + resume=True, + classifier_fpath=None, + val_class_mapping=None, +): + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + + periodic_checkpointer = PeriodicCheckpointer(checkpointer, checkpoint_period, max_iter=max_iter) + iteration = start_iter + logger.info("Starting training from iteration {}".format(start_iter)) + metric_logger = MetricLogger(delimiter=" ") + header = "Training" + + for data, labels in metric_logger.log_every( + train_data_loader, + 10, + header, + max_iter, + start_iter, + ): + data = data.cuda(non_blocking=True) + labels = labels.cuda(non_blocking=True) + + features = feature_model(data) + outputs = linear_classifiers(features) + + losses = {f"loss_{k}": nn.CrossEntropyLoss()(v, labels) for k, v in outputs.items()} + loss = sum(losses.values()) + + # compute the gradients + optimizer.zero_grad() + loss.backward() + + # step + optimizer.step() + scheduler.step() + + # log + if iteration % 10 == 0: + torch.cuda.synchronize() + metric_logger.update(loss=loss.item()) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + print("lr", optimizer.param_groups[0]["lr"]) + + if iteration - start_iter > 5: + if iteration % running_checkpoint_period == 0: + torch.cuda.synchronize() + if distributed.is_main_process(): + logger.info("Checkpointing running_checkpoint") + periodic_checkpointer.save("running_checkpoint_linear_eval", iteration=iteration) + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + if eval_period > 0 and (iteration + 1) % eval_period == 0 and iteration != max_iter - 1: + _ = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + prefixstring=f"ITER: {iteration}", + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + torch.cuda.synchronize() + + iteration = iteration + 1 + + val_results_dict = evaluate_linear_classifiers( + feature_model=feature_model, + linear_classifiers=remove_ddp_wrapper(linear_classifiers), + data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + metric_type=metric_type, + training_num_classes=training_num_classes, + iteration=iteration, + class_mapping=val_class_mapping, + ) + return val_results_dict, feature_model, linear_classifiers, iteration + + +def make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type): + test_dataset = make_dataset( + dataset_str=test_dataset_str, + transform=make_classification_eval_transform(), + ) + test_data_loader = make_data_loader( + dataset=test_dataset, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + persistent_workers=False, + collate_fn=_pad_and_collate if metric_type == MetricType.IMAGENET_REAL_ACCURACY else None, + ) + return test_data_loader + + +def test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + best_classifier_on_val, + prefixstring="", + test_class_mappings=[None], +): + results_dict = {} + for test_dataset_str, class_mapping, metric_type in zip(test_dataset_strs, test_class_mappings, test_metric_types): + logger.info(f"Testing on {test_dataset_str}") + test_data_loader = make_eval_data_loader(test_dataset_str, batch_size, num_workers, metric_type) + dataset_results_dict = evaluate_linear_classifiers( + feature_model, + remove_ddp_wrapper(linear_classifiers), + test_data_loader, + metric_type, + metrics_file_path, + training_num_classes, + iteration, + prefixstring="", + class_mapping=class_mapping, + best_classifier_on_val=best_classifier_on_val, + ) + results_dict[f"{test_dataset_str}_accuracy"] = 100.0 * dataset_results_dict["best_classifier"]["accuracy"] + return results_dict + + +def run_eval_linear( + model, + output_dir, + train_dataset_str, + val_dataset_str, + batch_size, + epochs, + epoch_length, + num_workers, + save_checkpoint_frequency, + eval_period_iterations, + learning_rates, + autocast_dtype, + test_dataset_strs=None, + resume=True, + classifier_fpath=None, + val_class_mapping_fpath=None, + test_class_mapping_fpaths=[None], + val_metric_type=MetricType.MEAN_ACCURACY, + test_metric_types=None, +): + seed = 0 + + if test_dataset_strs is None: + test_dataset_strs = [val_dataset_str] + if test_metric_types is None: + test_metric_types = [val_metric_type] * len(test_dataset_strs) + else: + assert len(test_metric_types) == len(test_dataset_strs) + assert len(test_dataset_strs) == len(test_class_mapping_fpaths) + + train_transform = make_classification_train_transform() + train_dataset = make_dataset( + dataset_str=train_dataset_str, + transform=train_transform, + ) + training_num_classes = len(torch.unique(torch.Tensor(train_dataset.get_targets().astype(int)))) + sampler_type = SamplerType.SHARDED_INFINITE + # sampler_type = SamplerType.INFINITE + + n_last_blocks_list = [1, 4] + n_last_blocks = max(n_last_blocks_list) + autocast_ctx = partial(torch.cuda.amp.autocast, enabled=True, dtype=autocast_dtype) + feature_model = ModelWithIntermediateLayers(model, n_last_blocks, autocast_ctx) + sample_output = feature_model(train_dataset[0][0].unsqueeze(0).cuda()) + + linear_classifiers, optim_param_groups = setup_linear_classifiers( + sample_output, + n_last_blocks_list, + learning_rates, + batch_size, + training_num_classes, + ) + + optimizer = torch.optim.SGD(optim_param_groups, momentum=0.9, weight_decay=0) + max_iter = epochs * epoch_length + scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, max_iter, eta_min=0) + checkpointer = Checkpointer(linear_classifiers, output_dir, optimizer=optimizer, scheduler=scheduler) + start_iter = checkpointer.resume_or_load(classifier_fpath or "", resume=resume).get("iteration", -1) + 1 + train_data_loader = make_data_loader( + dataset=train_dataset, + batch_size=batch_size, + num_workers=num_workers, + shuffle=True, + seed=seed, + sampler_type=sampler_type, + sampler_advance=start_iter, + drop_last=True, + persistent_workers=True, + ) + val_data_loader = make_eval_data_loader(val_dataset_str, batch_size, num_workers, val_metric_type) + + checkpoint_period = save_checkpoint_frequency * epoch_length + + if val_class_mapping_fpath is not None: + logger.info(f"Using class mapping from {val_class_mapping_fpath}") + val_class_mapping = np.load(val_class_mapping_fpath) + else: + val_class_mapping = None + + test_class_mappings = [] + for class_mapping_fpath in test_class_mapping_fpaths: + if class_mapping_fpath is not None and class_mapping_fpath != "None": + logger.info(f"Using class mapping from {class_mapping_fpath}") + class_mapping = np.load(class_mapping_fpath) + else: + class_mapping = None + test_class_mappings.append(class_mapping) + + metrics_file_path = os.path.join(output_dir, "results_eval_linear.json") + val_results_dict, feature_model, linear_classifiers, iteration = eval_linear( + feature_model=feature_model, + linear_classifiers=linear_classifiers, + train_data_loader=train_data_loader, + val_data_loader=val_data_loader, + metrics_file_path=metrics_file_path, + optimizer=optimizer, + scheduler=scheduler, + output_dir=output_dir, + max_iter=max_iter, + checkpoint_period=checkpoint_period, + running_checkpoint_period=epoch_length, + eval_period=eval_period_iterations, + metric_type=val_metric_type, + training_num_classes=training_num_classes, + resume=resume, + val_class_mapping=val_class_mapping, + classifier_fpath=classifier_fpath, + ) + results_dict = {} + if len(test_dataset_strs) > 1 or test_dataset_strs[0] != val_dataset_str: + results_dict = test_on_datasets( + feature_model, + linear_classifiers, + test_dataset_strs, + batch_size, + 0, # num_workers, + test_metric_types, + metrics_file_path, + training_num_classes, + iteration, + val_results_dict["best_classifier"]["name"], + prefixstring="", + test_class_mappings=test_class_mappings, + ) + results_dict["best_classifier"] = val_results_dict["best_classifier"]["name"] + results_dict[f"{val_dataset_str}_accuracy"] = 100.0 * val_results_dict["best_classifier"]["accuracy"] + logger.info("Test Results Dict " + str(results_dict)) + + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + run_eval_linear( + model=model, + output_dir=args.output_dir, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + test_dataset_strs=args.test_dataset_strs, + batch_size=args.batch_size, + epochs=args.epochs, + epoch_length=args.epoch_length, + num_workers=args.num_workers, + save_checkpoint_frequency=args.save_checkpoint_frequency, + eval_period_iterations=args.eval_period_iterations, + learning_rates=args.learning_rates, + autocast_dtype=autocast_dtype, + resume=not args.no_resume, + classifier_fpath=args.classifier_fpath, + val_metric_type=args.val_metric_type, + test_metric_types=args.test_metric_types, + val_class_mapping_fpath=args.val_class_mapping_fpath, + test_class_mapping_fpaths=args.test_class_mapping_fpaths, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 linear evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/log_regression.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..8bcc402abee36b7bf6ded0b11ba3d3e99da5539e --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/log_regression.py @@ -0,0 +1,445 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import gc +import logging +import sys +import time +from typing import List, Optional + +from cuml.linear_model import LogisticRegression +import torch +import torch.backends.cudnn as cudnn +import torch.distributed +from torch import nn +from torch.utils.data import TensorDataset +from torchmetrics import MetricTracker + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import make_dataset +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data.transforms import make_classification_eval_transform +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed import get_global_rank, get_global_size +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.metrics import MetricType, build_metric +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import get_args_parser as get_setup_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.setup import setup_and_build_model +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.utils import evaluate, extract_features +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.dtype import as_torch_dtype + + +logger = logging.getLogger("dinov2") + +DEFAULT_MAX_ITER = 1_000 +C_POWER_RANGE = torch.linspace(-6, 5, 45) +_CPU_DEVICE = torch.device("cpu") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parents = parents or [] + setup_args_parser = get_setup_args_parser(parents=parents, add_help=False) + parents = [setup_args_parser] + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--train-dataset", + dest="train_dataset_str", + type=str, + help="Training dataset", + ) + parser.add_argument( + "--val-dataset", + dest="val_dataset_str", + type=str, + help="Validation dataset", + ) + parser.add_argument( + "--finetune-dataset-str", + dest="finetune_dataset_str", + type=str, + help="Fine-tuning dataset", + ) + parser.add_argument( + "--finetune-on-val", + action="store_true", + help="If there is no finetune dataset, whether to choose the " + "hyperparameters on the val set instead of 10%% of the train dataset", + ) + parser.add_argument( + "--metric-type", + type=MetricType, + choices=list(MetricType), + help="Metric type", + ) + parser.add_argument( + "--train-features-device", + type=str, + help="Device to gather train features (cpu, cuda, cuda:0, etc.), default: %(default)s", + ) + parser.add_argument( + "--train-dtype", + type=str, + help="Data type to convert the train features to (default: %(default)s)", + ) + parser.add_argument( + "--max-train-iters", + type=int, + help="Maximum number of train iterations (default: %(default)s)", + ) + parser.set_defaults( + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + metric_type=MetricType.MEAN_ACCURACY, + train_features_device="cpu", + train_dtype="float64", + max_train_iters=DEFAULT_MAX_ITER, + finetune_on_val=False, + ) + return parser + + +class LogRegModule(nn.Module): + def __init__( + self, + C, + max_iter=DEFAULT_MAX_ITER, + dtype=torch.float64, + device=_CPU_DEVICE, + ): + super().__init__() + self.dtype = dtype + self.device = device + self.estimator = LogisticRegression( + penalty="l2", + C=C, + max_iter=max_iter, + output_type="numpy", + tol=1e-12, + linesearch_max_iter=50, + ) + + def forward(self, samples, targets): + samples_device = samples.device + samples = samples.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + samples = samples.numpy() + probas = self.estimator.predict_proba(samples) + return {"preds": torch.from_numpy(probas).to(samples_device), "target": targets} + + def fit(self, train_features, train_labels): + train_features = train_features.to(dtype=self.dtype, device=self.device) + train_labels = train_labels.to(dtype=self.dtype, device=self.device) + if self.device == _CPU_DEVICE: + # both cuML and sklearn only work with numpy arrays on CPU + train_features = train_features.numpy() + train_labels = train_labels.numpy() + self.estimator.fit(train_features, train_labels) + + +def evaluate_model(*, logreg_model, logreg_metric, test_data_loader, device): + postprocessors = {"metrics": logreg_model} + metrics = {"metrics": logreg_metric} + return evaluate(nn.Identity(), test_data_loader, postprocessors, metrics, device) + + +def train_for_C(*, C, max_iter, train_features, train_labels, dtype=torch.float64, device=_CPU_DEVICE): + logreg_model = LogRegModule(C, max_iter=max_iter, dtype=dtype, device=device) + logreg_model.fit(train_features, train_labels) + return logreg_model + + +def train_and_evaluate( + *, + C, + max_iter, + train_features, + train_labels, + logreg_metric, + test_data_loader, + train_dtype=torch.float64, + train_features_device, + eval_device, +): + logreg_model = train_for_C( + C=C, + max_iter=max_iter, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + return evaluate_model( + logreg_model=logreg_model, + logreg_metric=logreg_metric, + test_data_loader=test_data_loader, + device=eval_device, + ) + + +def sweep_C_values( + *, + train_features, + train_labels, + test_data_loader, + metric_type, + num_classes, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + if metric_type == MetricType.PER_CLASS_ACCURACY: + # If we want to output per-class accuracy, we select the hyperparameters with mean per class + metric_type = MetricType.MEAN_PER_CLASS_ACCURACY + logreg_metric = build_metric(metric_type, num_classes=num_classes) + metric_tracker = MetricTracker(logreg_metric, maximize=True) + ALL_C = 10**C_POWER_RANGE + logreg_models = {} + + train_features = train_features.to(dtype=train_dtype, device=train_features_device) + train_labels = train_labels.to(device=train_features_device) + + for i in range(get_global_rank(), len(ALL_C), get_global_size()): + C = ALL_C[i].item() + logger.info( + f"Training for C = {C:.5f}, dtype={train_dtype}, " + f"features: {train_features.shape}, {train_features.dtype}, " + f"labels: {train_labels.shape}, {train_labels.dtype}" + ) + logreg_models[C] = train_for_C( + C=C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + dtype=train_dtype, + device=train_features_device, + ) + + gather_list = [None for _ in range(get_global_size())] + torch.distributed.all_gather_object(gather_list, logreg_models) + + logreg_models_gathered = {} + for logreg_dict in gather_list: + logreg_models_gathered.update(logreg_dict) + + for i in range(len(ALL_C)): + metric_tracker.increment() + C = ALL_C[i].item() + evals = evaluate_model( + logreg_model=logreg_models_gathered[C], + logreg_metric=metric_tracker, + test_data_loader=test_data_loader, + device=torch.cuda.current_device(), + ) + logger.info(f"Trained for C = {C:.5f}, accuracies = {evals}") + + best_stats, which_epoch = metric_tracker.best_metric(return_step=True) + best_stats_100 = {k: 100.0 * v for k, v in best_stats.items()} + if which_epoch["top-1"] == i: + best_C = C + logger.info(f"Sweep best {best_stats_100}, best C = {best_C:.6f}") + + return best_stats, best_C + + +def eval_log_regression( + *, + model, + train_dataset, + val_dataset, + finetune_dataset, + metric_type, + batch_size, + num_workers, + finetune_on_val=False, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + """ + Implements the "standard" process for log regression evaluation: + The value of C is chosen by training on train_dataset and evaluating on + finetune_dataset. Then, the final model is trained on a concatenation of + train_dataset and finetune_dataset, and is evaluated on val_dataset. + If there is no finetune_dataset, the value of C is the one that yields + the best results on a random 10% subset of the train dataset + """ + + start = time.time() + + train_features, train_labels = extract_features( + model, train_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_features, val_labels = extract_features( + model, val_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + val_data_loader = torch.utils.data.DataLoader( + TensorDataset(val_features, val_labels), + batch_size=batch_size, + drop_last=False, + num_workers=0, + persistent_workers=False, + ) + + if finetune_dataset is None and finetune_on_val: + logger.info("Choosing hyperparameters on the val dataset") + finetune_features, finetune_labels = val_features, val_labels + elif finetune_dataset is None and not finetune_on_val: + logger.info("Choosing hyperparameters on 10% of the train dataset") + torch.manual_seed(0) + indices = torch.randperm(len(train_features), device=train_features.device) + finetune_index = indices[: len(train_features) // 10] + train_index = indices[len(train_features) // 10 :] + finetune_features, finetune_labels = train_features[finetune_index], train_labels[finetune_index] + train_features, train_labels = train_features[train_index], train_labels[train_index] + else: + logger.info("Choosing hyperparameters on the finetune dataset") + finetune_features, finetune_labels = extract_features( + model, finetune_dataset, batch_size, num_workers, gather_on_cpu=(train_features_device == _CPU_DEVICE) + ) + # release the model - free GPU memory + del model + gc.collect() + torch.cuda.empty_cache() + finetune_data_loader = torch.utils.data.DataLoader( + TensorDataset(finetune_features, finetune_labels), + batch_size=batch_size, + drop_last=False, + ) + + if len(train_labels.shape) > 1: + num_classes = train_labels.shape[1] + else: + num_classes = train_labels.max() + 1 + + logger.info("Using cuML for logistic regression") + + best_stats, best_C = sweep_C_values( + train_features=train_features, + train_labels=train_labels, + test_data_loader=finetune_data_loader, + metric_type=metric_type, + num_classes=num_classes, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + if not finetune_on_val: + logger.info("Best parameter found, concatenating features") + train_features = torch.cat((train_features, finetune_features)) + train_labels = torch.cat((train_labels, finetune_labels)) + + logger.info("Training final model") + logreg_metric = build_metric(metric_type, num_classes=num_classes) + evals = train_and_evaluate( + C=best_C, + max_iter=max_train_iters, + train_features=train_features, + train_labels=train_labels, + logreg_metric=logreg_metric.clone(), + test_data_loader=val_data_loader, + eval_device=torch.cuda.current_device(), + train_dtype=train_dtype, + train_features_device=train_features_device, + ) + + best_stats = evals[1]["metrics"] + + best_stats["best_C"] = best_C + + logger.info(f"Log regression evaluation done in {int(time.time() - start)}s") + return best_stats + + +def eval_log_regression_with_model( + model, + train_dataset_str="ImageNet:split=TRAIN", + val_dataset_str="ImageNet:split=VAL", + finetune_dataset_str=None, + autocast_dtype=torch.float, + finetune_on_val=False, + metric_type=MetricType.MEAN_ACCURACY, + train_dtype=torch.float64, + train_features_device=_CPU_DEVICE, + max_train_iters=DEFAULT_MAX_ITER, +): + cudnn.benchmark = True + + transform = make_classification_eval_transform(resize_size=224) + target_transform = None + + train_dataset = make_dataset(dataset_str=train_dataset_str, transform=transform, target_transform=target_transform) + val_dataset = make_dataset(dataset_str=val_dataset_str, transform=transform, target_transform=target_transform) + if finetune_dataset_str is not None: + finetune_dataset = make_dataset( + dataset_str=finetune_dataset_str, transform=transform, target_transform=target_transform + ) + else: + finetune_dataset = None + + with torch.cuda.amp.autocast(dtype=autocast_dtype): + results_dict_logreg = eval_log_regression( + model=model, + train_dataset=train_dataset, + val_dataset=val_dataset, + finetune_dataset=finetune_dataset, + metric_type=metric_type, + batch_size=256, + num_workers=0, # 5, + finetune_on_val=finetune_on_val, + train_dtype=train_dtype, + train_features_device=train_features_device, + max_train_iters=max_train_iters, + ) + + results_dict = { + "top-1": results_dict_logreg["top-1"].cpu().numpy() * 100.0, + "top-5": results_dict_logreg.get("top-5", torch.tensor(0.0)).cpu().numpy() * 100.0, + "best_C": results_dict_logreg["best_C"], + } + logger.info( + "\n".join( + [ + "Training of the supervised logistic regression on frozen features completed.\n" + "Top-1 test accuracy: {acc:.1f}".format(acc=results_dict["top-1"]), + "Top-5 test accuracy: {acc:.1f}".format(acc=results_dict["top-5"]), + "obtained for C = {c:.6f}".format(c=results_dict["best_C"]), + ] + ) + ) + + torch.distributed.barrier() + return results_dict + + +def main(args): + model, autocast_dtype = setup_and_build_model(args) + eval_log_regression_with_model( + model=model, + train_dataset_str=args.train_dataset_str, + val_dataset_str=args.val_dataset_str, + finetune_dataset_str=args.finetune_dataset_str, + autocast_dtype=autocast_dtype, + finetune_on_val=args.finetune_on_val, + metric_type=args.metric_type, + train_dtype=as_torch_dtype(args.train_dtype), + train_features_device=torch.device(args.train_features_device), + max_train_iters=args.max_train_iters, + ) + return 0 + + +if __name__ == "__main__": + description = "DINOv2 logistic regression evaluation" + args_parser = get_args_parser(description=description) + args = args_parser.parse_args() + sys.exit(main(args)) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/metrics.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..80bf88da224e749dd6b3dd4b2bd90ec99eaec34e --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/metrics.py @@ -0,0 +1,114 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +import logging +from typing import Any, Dict, Optional + +import torch +from torch import Tensor +from torchmetrics import Metric, MetricCollection +from torchmetrics.classification import MulticlassAccuracy +from torchmetrics.utilities.data import dim_zero_cat, select_topk + + +logger = logging.getLogger("dinov2") + + +class MetricType(Enum): + MEAN_ACCURACY = "mean_accuracy" + MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" + PER_CLASS_ACCURACY = "per_class_accuracy" + IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" + + @property + def accuracy_averaging(self): + return getattr(AccuracyAveraging, self.name, None) + + def __str__(self): + return self.value + + +class AccuracyAveraging(Enum): + MEAN_ACCURACY = "micro" + MEAN_PER_CLASS_ACCURACY = "macro" + PER_CLASS_ACCURACY = "none" + + def __str__(self): + return self.value + + +def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): + if metric_type.accuracy_averaging is not None: + return build_topk_accuracy_metric( + average_type=metric_type.accuracy_averaging, + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: + return build_topk_imagenet_real_accuracy_metric( + num_classes=num_classes, + ks=(1, 5) if ks is None else ks, + ) + + raise ValueError(f"Unknown metric type {metric_type}") + + +def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = { + f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks + } + return MetricCollection(metrics) + + +def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): + metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} + return MetricCollection(metrics) + + +class ImageNetReaLAccuracy(Metric): + is_differentiable: bool = False + higher_is_better: Optional[bool] = None + full_state_update: bool = False + + def __init__( + self, + num_classes: int, + top_k: int = 1, + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.num_classes = num_classes + self.top_k = top_k + self.add_state("tp", [], dist_reduce_fx="cat") + + def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore + # preds [B, D] + # target [B, A] + # preds_oh [B, D] with 0 and 1 + # select top K highest probabilities, use one hot representation + preds_oh = select_topk(preds, self.top_k) + # target_oh [B, D + 1] with 0 and 1 + target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) + target = target.long() + # for undefined targets (-1) use a fake value `num_classes` + target[target == -1] = self.num_classes + # fill targets, use one hot representation + target_oh.scatter_(1, target, 1) + # target_oh [B, D] (remove the fake target at index `num_classes`) + target_oh = target_oh[:, :-1] + # tp [B] with 0 and 1 + tp = (preds_oh * target_oh == 1).sum(dim=1) + # at least one match between prediction and target + tp.clip_(max=1) + # ignore instances where no targets are defined + mask = target_oh.sum(dim=1) > 0 + tp = tp[mask] + self.tp.append(tp) # type: ignore + + def compute(self) -> Tensor: + tp = dim_zero_cat(self.tp) # type: ignore + return tp.float().mean() diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/setup.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..eff2ad733b5595bef8d9c74abda10dd450208519 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/setup.py @@ -0,0 +1,76 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +from typing import Any, List, Optional, Tuple + +import torch +import torch.backends.cudnn as cudnn + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.models import build_model_from_cfg +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.config import setup +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.utils as dinov2_utils + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +): + parser = argparse.ArgumentParser( + description=description, + parents=parents or [], + add_help=add_help, + ) + parser.add_argument( + "--config-file", + type=str, + help="Model configuration file", + ) + parser.add_argument( + "--pretrained-weights", + type=str, + help="Pretrained model weights", + ) + parser.add_argument( + "--output-dir", + default="", + type=str, + help="Output directory to write results and logs", + ) + parser.add_argument( + "--opts", + help="Extra configuration options", + default=[], + nargs="+", + ) + return parser + + +def get_autocast_dtype(config): + teacher_dtype_str = config.compute_precision.teacher.backbone.mixed_precision.param_dtype + if teacher_dtype_str == "fp16": + return torch.half + elif teacher_dtype_str == "bf16": + return torch.bfloat16 + else: + return torch.float + + +def build_model_for_eval(config, pretrained_weights): + model, _ = build_model_from_cfg(config, only_teacher=True) + dinov2_utils.load_pretrained_weights(model, pretrained_weights, "teacher") + model.eval() + model.cuda() + return model + + +def setup_and_build_model(args) -> Tuple[Any, torch.dtype]: + cudnn.benchmark = True + config = setup(args) + model = build_model_for_eval(config, args.pretrained_weights) + autocast_dtype = get_autocast_dtype(config) + return model, autocast_dtype diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/utils.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..86f070d79069c592af2e81fae23c0eb5b5efa4b6 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/eval/utils.py @@ -0,0 +1,147 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +from typing import Dict, Optional + +import torch +from torch import nn +from torchmetrics import MetricCollection + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import DatasetWithEnumeratedTargets, SamplerType, make_data_loader +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import MetricLogger + + +logger = logging.getLogger("dinov2") + + +class ModelWithNormalize(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, samples): + return nn.functional.normalize(self.model(samples), dim=1, p=2) + + +class ModelWithIntermediateLayers(nn.Module): + def __init__(self, feature_model, n_last_blocks, autocast_ctx): + super().__init__() + self.feature_model = feature_model + self.feature_model.eval() + self.n_last_blocks = n_last_blocks + self.autocast_ctx = autocast_ctx + + def forward(self, images): + with torch.inference_mode(): + with self.autocast_ctx(): + features = self.feature_model.get_intermediate_layers( + images, self.n_last_blocks, return_class_token=True + ) + return features + + +@torch.inference_mode() +def evaluate( + model: nn.Module, + data_loader, + postprocessors: Dict[str, nn.Module], + metrics: Dict[str, MetricCollection], + device: torch.device, + criterion: Optional[nn.Module] = None, +): + model.eval() + if criterion is not None: + criterion.eval() + + for metric in metrics.values(): + metric = metric.to(device) + + metric_logger = MetricLogger(delimiter=" ") + header = "Test:" + + for samples, targets, *_ in metric_logger.log_every(data_loader, 10, header): + outputs = model(samples.to(device)) + targets = targets.to(device) + + if criterion is not None: + loss = criterion(outputs, targets) + metric_logger.update(loss=loss.item()) + + for k, metric in metrics.items(): + metric_inputs = postprocessors[k](outputs, targets) + metric.update(**metric_inputs) + + metric_logger.synchronize_between_processes() + logger.info(f"Averaged stats: {metric_logger}") + + stats = {k: metric.compute() for k, metric in metrics.items()} + metric_logger_stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()} + return metric_logger_stats, stats + + +def all_gather_and_flatten(tensor_rank): + tensor_all_ranks = torch.empty( + distributed.get_global_size(), + *tensor_rank.shape, + dtype=tensor_rank.dtype, + device=tensor_rank.device, + ) + tensor_list = list(tensor_all_ranks.unbind(0)) + torch.distributed.all_gather(tensor_list, tensor_rank.contiguous()) + return tensor_all_ranks.flatten(end_dim=1) + + +def extract_features(model, dataset, batch_size, num_workers, gather_on_cpu=False): + dataset_with_enumerated_targets = DatasetWithEnumeratedTargets(dataset) + sample_count = len(dataset_with_enumerated_targets) + data_loader = make_data_loader( + dataset=dataset_with_enumerated_targets, + batch_size=batch_size, + num_workers=num_workers, + sampler_type=SamplerType.DISTRIBUTED, + drop_last=False, + shuffle=False, + ) + return extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu) + + +@torch.inference_mode() +def extract_features_with_dataloader(model, data_loader, sample_count, gather_on_cpu=False): + gather_device = torch.device("cpu") if gather_on_cpu else torch.device("cuda") + metric_logger = MetricLogger(delimiter=" ") + features, all_labels = None, None + for samples, (index, labels_rank) in metric_logger.log_every(data_loader, 10): + samples = samples.cuda(non_blocking=True) + labels_rank = labels_rank.cuda(non_blocking=True) + index = index.cuda(non_blocking=True) + features_rank = model(samples).float() + + # init storage feature matrix + if features is None: + features = torch.zeros(sample_count, features_rank.shape[-1], device=gather_device) + labels_shape = list(labels_rank.shape) + labels_shape[0] = sample_count + all_labels = torch.full(labels_shape, fill_value=-1, device=gather_device) + logger.info(f"Storing features into tensor of shape {features.shape}") + + # share indexes, features and labels between processes + index_all = all_gather_and_flatten(index).to(gather_device) + features_all_ranks = all_gather_and_flatten(features_rank).to(gather_device) + labels_all_ranks = all_gather_and_flatten(labels_rank).to(gather_device) + + # update storage feature matrix + if len(index_all) > 0: + features.index_copy_(0, index_all, features_all_ranks) + all_labels.index_copy_(0, index_all, labels_all_ranks) + + logger.info(f"Features shape: {tuple(features.shape)}") + logger.info(f"Labels shape: {tuple(all_labels.shape)}") + + assert torch.all(all_labels > -1) + + return features, all_labels diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/fsdp/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..65f2ae1e9b84628411affb531044e285c912f01a --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/fsdp/__init__.py @@ -0,0 +1,158 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Any + +import torch +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from functools import partial +from fvcore.common.checkpoint import Checkpointer +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP +from torch.distributed.fsdp import ShardingStrategy +from torch.distributed.fsdp import MixedPrecision +from torch.distributed.fsdp import StateDictType +from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler +from torch.distributed.fsdp.wrap import ModuleWrapPolicy +from torch.distributed.fsdp._runtime_utils import _reshard + + +def get_fsdp_wrapper(model_cfg, modules_to_wrap=set()): + sharding_strategy_dict = { + "NO_SHARD": ShardingStrategy.NO_SHARD, + "SHARD_GRAD_OP": ShardingStrategy.SHARD_GRAD_OP, + "FULL_SHARD": ShardingStrategy.FULL_SHARD, + } + + dtype_dict = { + "fp32": torch.float32, + "fp16": torch.float16, + "bf16": torch.bfloat16, + } + + mixed_precision_config = MixedPrecision( + param_dtype=dtype_dict[model_cfg.mixed_precision.param_dtype], + reduce_dtype=dtype_dict[model_cfg.mixed_precision.reduce_dtype], + buffer_dtype=dtype_dict[model_cfg.mixed_precision.buffer_dtype], + ) + + sharding_strategy_config = sharding_strategy_dict[model_cfg.sharding_strategy] + + local_rank = distributed.get_local_rank() + + fsdp_wrapper = partial( + FSDP, + sharding_strategy=sharding_strategy_config, + mixed_precision=mixed_precision_config, + device_id=local_rank, + sync_module_states=True, + use_orig_params=True, + auto_wrap_policy=ModuleWrapPolicy(modules_to_wrap), + ) + return fsdp_wrapper + + +def is_fsdp(x): + return isinstance(x, FSDP) + + +def is_sharded_fsdp(x): + return is_fsdp(x) and x.sharding_strategy is not ShardingStrategy.NO_SHARD + + +def free_if_fsdp(x): + if is_sharded_fsdp(x): + handles = x._handles + true_list = [True for h in handles] + _reshard(x, handles, true_list) + + +def get_fsdp_modules(x): + return FSDP.fsdp_modules(x) + + +def reshard_fsdp_model(x): + for m in get_fsdp_modules(x): + free_if_fsdp(m) + + +def rankstr(): + return f"rank_{distributed.get_global_rank()}" + + +class FSDPCheckpointer(Checkpointer): + def save(self, name: str, **kwargs: Any) -> None: + """ + Dump model and checkpointables to a file. + + Args: + name (str): name of the file. + kwargs (dict): extra arbitrary data to save. + """ + if not self.save_dir or not self.save_to_disk: + return + + data = {} + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + data["model"] = self.model.state_dict() + + # data["model"] = self.model.state_dict() + for key, obj in self.checkpointables.items(): + data[key] = obj.state_dict() + data.update(kwargs) + + basename = f"{name}.{rankstr()}.pth" + save_file = os.path.join(self.save_dir, basename) + assert os.path.basename(save_file) == basename, basename + self.logger.info("Saving checkpoint to {}".format(save_file)) + with self.path_manager.open(save_file, "wb") as f: + torch.save(data, f) + self.tag_last_checkpoint(basename) + + def load(self, *args, **kwargs): + with FSDP.state_dict_type(self.model, StateDictType.LOCAL_STATE_DICT): + return super().load(*args, **kwargs) + + def has_checkpoint(self) -> bool: + """ + Returns: + bool: whether a checkpoint exists in the target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + return self.path_manager.exists(save_file) + + def get_checkpoint_file(self) -> str: + """ + Returns: + str: The latest checkpoint file in target directory. + """ + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + try: + with self.path_manager.open(save_file, "r") as f: + last_saved = f.read().strip() + except IOError: + # if file doesn't exist, maybe because it has just been + # deleted by a separate process + return "" + # pyre-fixme[6]: For 2nd param expected `Union[PathLike[str], str]` but got + # `Union[bytes, str]`. + return os.path.join(self.save_dir, last_saved) + + def tag_last_checkpoint(self, last_filename_basename: str) -> None: + """ + Tag the last checkpoint. + + Args: + last_filename_basename (str): the basename of the last filename. + """ + if distributed.is_enabled(): + torch.distributed.barrier() + save_file = os.path.join(self.save_dir, f"last_checkpoint.{rankstr()}") + with self.path_manager.open(save_file, "w") as f: + f.write(last_filename_basename) # pyre-ignore + + +ShardedGradScaler = ShardedGradScaler diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..31f196aacac5be8a7c537a3dfa8f97084671b466 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/attention.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..2c3bfb4e17352de6b6f6a42b15847b5232f675e0 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/attention.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/block.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..8de1d57fcb27abccdcefc393151f65b19f1a31d7 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/block.py @@ -0,0 +1,245 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/dino_head.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7212db92a4fd8d4c7230e284e551a0234e9d8623 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/dino_head.py @@ -0,0 +1,59 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/drop_path.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/layer_scale.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/mlp.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9305295f3a139fa6e5385beb61c28f797a86e85c --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/__init__.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import functools +import logging +import os +import sys +from typing import Optional + +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from .helpers import MetricLogger, SmoothedValue + + +# So that calling _configure_logger multiple times won't add many handlers +@functools.lru_cache() +def _configure_logger( + name: Optional[str] = None, + *, + level: int = logging.DEBUG, + output: Optional[str] = None, +): + """ + Configure a logger. + + Adapted from Detectron2. + + Args: + name: The name of the logger to configure. + level: The logging level to use. + output: A file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + + Returns: + The configured logger. + """ + + logger = logging.getLogger(name) + logger.setLevel(level) + logger.propagate = False + + # Loosely match Google glog format: + # [IWEF]yyyymmdd hh:mm:ss.uuuuuu threadid file:line] msg + # but use a shorter timestamp and include the logger name: + # [IWEF]yyyymmdd hh:mm:ss logger threadid file:line] msg + fmt_prefix = "%(levelname).1s%(asctime)s %(process)s %(name)s %(filename)s:%(lineno)s] " + fmt_message = "%(message)s" + fmt = fmt_prefix + fmt_message + datefmt = "%Y%m%d %H:%M:%S" + formatter = logging.Formatter(fmt=fmt, datefmt=datefmt) + + # stdout logging for main worker only + if distributed.is_main_process(): + handler = logging.StreamHandler(stream=sys.stdout) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # file logging for all workers + if output: + if os.path.splitext(output)[-1] in (".txt", ".log"): + filename = output + else: + filename = os.path.join(output, "logs", "log.txt") + + if not distributed.is_main_process(): + global_rank = distributed.get_global_rank() + filename = filename + ".rank{}".format(global_rank) + + os.makedirs(os.path.dirname(filename), exist_ok=True) + + handler = logging.StreamHandler(open(filename, "a")) + handler.setLevel(logging.DEBUG) + handler.setFormatter(formatter) + logger.addHandler(handler) + + return logger + + +def setup_logging( + output: Optional[str] = None, + *, + name: Optional[str] = None, + level: int = logging.DEBUG, + capture_warnings: bool = True, +) -> None: + """ + Setup logging. + + Args: + output: A file name or a directory to save log files. If None, log + files will not be saved. If output ends with ".txt" or ".log", it + is assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name: The name of the logger to configure, by default the root logger. + level: The logging level to use. + capture_warnings: Whether warnings should be captured as logs. + """ + logging.captureWarnings(capture_warnings) + _configure_logger(name, level=level, output=output) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/helpers.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..7ab2191266ed24efdefd98df1f0be6b97993de86 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/logging/helpers.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict, deque +import datetime +import json +import logging +import time + +import torch + +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed + + +logger = logging.getLogger("dinov2") + + +class MetricLogger(object): + def __init__(self, delimiter="\t", output_file=None): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + self.output_file = output_file + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append("{}: {}".format(name, str(meter))) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def dump_in_output_file(self, iteration, iter_time, data_time): + if self.output_file is None or not distributed.is_main_process(): + return + dict_to_dump = dict( + iteration=iteration, + iter_time=iter_time, + data_time=data_time, + ) + dict_to_dump.update({k: v.median for k, v in self.meters.items()}) + with open(self.output_file, "a") as f: + f.write(json.dumps(dict_to_dump) + "\n") + pass + + def log_every(self, iterable, print_freq, header=None, n_iterations=None, start_iteration=0): + i = start_iteration + if not header: + header = "" + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt="{avg:.6f}") + data_time = SmoothedValue(fmt="{avg:.6f}") + + if n_iterations is None: + n_iterations = len(iterable) + + space_fmt = ":" + str(len(str(n_iterations))) + "d" + + log_list = [ + header, + "[{0" + space_fmt + "}/{1}]", + "eta: {eta}", + "{meters}", + "time: {time}", + "data: {data}", + ] + if torch.cuda.is_available(): + log_list += ["max mem: {memory:.0f}"] + + log_msg = self.delimiter.join(log_list) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == n_iterations - 1: + self.dump_in_output_file(iteration=i, iter_time=iter_time.avg, data_time=data_time.avg) + eta_seconds = iter_time.global_avg * (n_iterations - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB, + ) + ) + else: + logger.info( + log_msg.format( + i, + n_iterations, + eta=eta_string, + meters=str(self), + time=str(iter_time), + data=str(data_time), + ) + ) + i += 1 + end = time.time() + if i >= n_iterations: + break + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info("{} Total time: {} ({:.6f} s / it)".format(header, total_time_str, total_time / n_iterations)) + + +class SmoothedValue: + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, num=1): + self.deque.append(value) + self.count += num + self.total += value * num + + def synchronize_between_processes(self): + """ + Distributed synchronization of the metric + Warning: does not synchronize the deque! + """ + if not distributed.is_enabled(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda") + torch.distributed.barrier() + torch.distributed.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value, + ) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..477b71b28259bf97b806df3f3d2f392dded866d6 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .dino_clstoken_loss import DINOLoss +from .ibot_patch_loss import iBOTPatchLoss +from .koleo_loss import KoLeoLoss diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/dino_clstoken_loss.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/dino_clstoken_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2f33897efb1084e6c1c14ae00bc93ab332c61074 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/dino_clstoken_loss.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + + +class DINOLoss(nn.Module): + def __init__( + self, + out_dim, + student_temp=0.1, + center_momentum=0.9, + ): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_output = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_output, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + return F.softmax((teacher_output - self.center) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_iterations=3): + teacher_output = teacher_output.float() + world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + B = Q.shape[1] * world_size # number of samples to assign + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_output_list, teacher_out_softmaxed_centered_list): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + """ + # TODO: Use cross_entropy_distribution here + total_loss = 0 + for s in student_output_list: + lsm = F.log_softmax(s / self.student_temp, dim=-1) + for t in teacher_out_softmaxed_centered_list: + loss = torch.sum(t * lsm, dim=-1) + total_loss -= loss.mean() + return total_loss + + @torch.no_grad() + def update_center(self, teacher_output): + self.reduce_center_update(teacher_output) + + @torch.no_grad() + def reduce_center_update(self, teacher_output): + self.updated = False + self.len_teacher_output = len(teacher_output) + self.async_batch_center = torch.sum(teacher_output, dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_output * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..16bc5cf634d661f1fa337304273f60dcd43c79c3 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/ibot_patch_loss.py @@ -0,0 +1,152 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from torch import nn + +import logging + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import cross_entropy + + def lossfunc(t, s, temp): + s = s.float() + t = t.float() + if s.ndim == 2: + return -cross_entropy(s.unsqueeze(0), t.unsqueeze(0), temp, bw_inplace=True).squeeze(0) + elif s.ndim == 3: + return -cross_entropy(s, t, temp, bw_inplace=True) + +except ImportError: + + def lossfunc(t, s, temp): + return torch.sum(t * F.log_softmax(s / temp, dim=-1), dim=-1) + + +class iBOTPatchLoss(nn.Module): + def __init__(self, patch_out_dim, student_temp=0.1, center_momentum=0.9): + super().__init__() + self.student_temp = student_temp + self.center_momentum = center_momentum + self.register_buffer("center", torch.zeros(1, 1, patch_out_dim)) + self.updated = True + self.reduce_handle = None + self.len_teacher_patch_tokens = None + self.async_batch_center = None + + @torch.no_grad() + def softmax_center_teacher(self, teacher_patch_tokens, teacher_temp): + self.apply_center_update() + # teacher centering and sharpening + # + # WARNING: + # as self.center is a float32, everything gets casted to float32 afterwards + # + # teacher_patch_tokens = teacher_patch_tokens.float() + # return F.softmax((teacher_patch_tokens.sub_(self.center.to(teacher_patch_tokens.dtype))).mul_(1 / teacher_temp), dim=-1) + + return F.softmax((teacher_patch_tokens - self.center) / teacher_temp, dim=-1) + + # this is experimental, keep everything in float16 and let's see what happens: + # return F.softmax((teacher_patch_tokens.sub_(self.center)) / teacher_temp, dim=-1) + + @torch.no_grad() + def sinkhorn_knopp_teacher(self, teacher_output, teacher_temp, n_masked_patches_tensor, n_iterations=3): + teacher_output = teacher_output.float() + # world_size = dist.get_world_size() if dist.is_initialized() else 1 + Q = torch.exp(teacher_output / teacher_temp).t() # Q is K-by-B for consistency with notations from our paper + # B = Q.shape[1] * world_size # number of samples to assign + B = n_masked_patches_tensor + dist.all_reduce(B) + K = Q.shape[0] # how many prototypes + + # make the matrix sums to 1 + sum_Q = torch.sum(Q) + if dist.is_initialized(): + dist.all_reduce(sum_Q) + Q /= sum_Q + + for it in range(n_iterations): + # normalize each row: total weight per prototype must be 1/K + sum_of_rows = torch.sum(Q, dim=1, keepdim=True) + if dist.is_initialized(): + dist.all_reduce(sum_of_rows) + Q /= sum_of_rows + Q /= K + + # normalize each column: total weight per sample must be 1/B + Q /= torch.sum(Q, dim=0, keepdim=True) + Q /= B + + Q *= B # the columns must sum to 1 so that Q is an assignment + return Q.t() + + def forward(self, student_patch_tokens, teacher_patch_tokens, student_masks_flat): + """ + Cross-entropy between softmax outputs of the teacher and student networks. + student_patch_tokens: (B, N, D) tensor + teacher_patch_tokens: (B, N, D) tensor + student_masks_flat: (B, N) tensor + """ + t = teacher_patch_tokens + s = student_patch_tokens + loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = torch.sum(loss * student_masks_flat.float(), dim=-1) / student_masks_flat.sum(dim=-1).clamp(min=1.0) + return -loss.mean() + + def forward_masked( + self, + student_patch_tokens_masked, + teacher_patch_tokens_masked, + student_masks_flat, + n_masked_patches=None, + masks_weight=None, + ): + t = teacher_patch_tokens_masked + s = student_patch_tokens_masked + # loss = torch.sum(t * F.log_softmax(s / self.student_temp, dim=-1), dim=-1) + loss = lossfunc(t, s, self.student_temp) + if masks_weight is None: + masks_weight = ( + (1 / student_masks_flat.sum(-1).clamp(min=1.0)) + .unsqueeze(-1) + .expand_as(student_masks_flat)[student_masks_flat] + ) + if n_masked_patches is not None: + loss = loss[:n_masked_patches] + loss = loss * masks_weight + return -loss.sum() / student_masks_flat.shape[0] + + @torch.no_grad() + def update_center(self, teacher_patch_tokens): + self.reduce_center_update(teacher_patch_tokens) + + @torch.no_grad() + def reduce_center_update(self, teacher_patch_tokens): + self.updated = False + self.len_teacher_patch_tokens = len(teacher_patch_tokens) + self.async_batch_center = torch.sum(teacher_patch_tokens.mean(1), dim=0, keepdim=True) + if dist.is_initialized(): + self.reduce_handle = dist.all_reduce(self.async_batch_center, async_op=True) + + @torch.no_grad() + def apply_center_update(self): + if self.updated is False: + world_size = dist.get_world_size() if dist.is_initialized() else 1 + + if self.reduce_handle is not None: + self.reduce_handle.wait() + _t = self.async_batch_center / (self.len_teacher_patch_tokens * world_size) + + self.center = self.center * self.center_momentum + _t * (1 - self.center_momentum) + + self.updated = True diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/koleo_loss.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/koleo_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..e776d0426bb029cf48f25b0c94077720bc8421c4 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/loss/koleo_loss.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# import torch.distributed as dist + + +logger = logging.getLogger("dinov2") + + +class KoLeoLoss(nn.Module): + """Kozachenko-Leonenko entropic loss regularizer from Sablayrolles et al. - 2018 - Spreading vectors for similarity search""" + + def __init__(self): + super().__init__() + self.pdist = nn.PairwiseDistance(2, eps=1e-8) + + def pairwise_NNs_inner(self, x): + """ + Pairwise nearest neighbors for L2-normalized vectors. + Uses Torch rather than Faiss to remain on GPU. + """ + # parwise dot products (= inverse distance) + dots = torch.mm(x, x.t()) + n = x.shape[0] + dots.view(-1)[:: (n + 1)].fill_(-1) # Trick to fill diagonal with -1 + # max inner prod -> min distance + _, I = torch.max(dots, dim=1) # noqa: E741 + return I + + def forward(self, student_output, eps=1e-8): + """ + Args: + student_output (BxD): backbone output of student + """ + with torch.cuda.amp.autocast(enabled=False): + student_output = F.normalize(student_output, eps=eps, p=2, dim=-1) + I = self.pairwise_NNs_inner(student_output) # noqa: E741 + distances = self.pdist(student_output, student_output[I]) # BxD, BxD -> B + loss = -torch.log(distances + eps).mean() + return loss diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5e5a1f3832464f898752e57e865760e9864613cb --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/__init__.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..a21d70f4e059bab934b813e8544bfc7d49ea8655 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/models/vision_transformer.py @@ -0,0 +1,358 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/knn.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..5fb30af05494fef9cebdd3bbf696ac04be95a589 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/knn.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.knn import get_args_parser as get_knn_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import setup_logging +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.knn import main as knn_main + + self._setup_args() + knn_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 k-NN evaluation" + knn_args_parser = get_knn_args_parser(add_help=False) + parents = [knn_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:knn") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/linear.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/linear.py new file mode 100644 index 0000000000000000000000000000000000000000..3aec9c0143c288d9828fde3ae69671496c4cc328 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/linear.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.linear import get_args_parser as get_linear_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import setup_logging +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.linear import main as linear_main + + self._setup_args() + linear_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 linear evaluation" + linear_args_parser = get_linear_args_parser(add_help=False) + parents = [linear_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:linear") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/log_regression.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/log_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..57515d3e9baf2c26736201b0488b7ea325df2ed6 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/eval/log_regression.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.log_regression import get_args_parser as get_log_regression_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import setup_logging +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Evaluator: + def __init__(self, args): + self.args = args + + def __call__(self): + from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.eval.log_regression import main as log_regression_main + + self._setup_args() + log_regression_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 logistic evaluation" + log_regression_args_parser = get_log_regression_args_parser(add_help=False) + parents = [log_regression_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Evaluator, args, name="dinov2:logreg") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/submit.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/submit.py new file mode 100644 index 0000000000000000000000000000000000000000..8a84cbba32b29c44604d57cfd55b782121ded258 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/submit.py @@ -0,0 +1,123 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import os +from pathlib import Path +from typing import List, Optional + +import submitit + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.cluster import ( + get_slurm_executor_parameters, + get_slurm_partition, + get_user_checkpoint_path, +) + + +logger = logging.getLogger("dinov2") + + +def get_args_parser( + description: Optional[str] = None, + parents: Optional[List[argparse.ArgumentParser]] = None, + add_help: bool = True, +) -> argparse.ArgumentParser: + parents = parents or [] + slurm_partition = get_slurm_partition() + parser = argparse.ArgumentParser( + description=description, + parents=parents, + add_help=add_help, + ) + parser.add_argument( + "--ngpus", + "--gpus", + "--gpus-per-node", + default=8, + type=int, + help="Number of GPUs to request on each node", + ) + parser.add_argument( + "--nodes", + "--nnodes", + default=2, + type=int, + help="Number of nodes to request", + ) + parser.add_argument( + "--timeout", + default=2800, + type=int, + help="Duration of the job", + ) + parser.add_argument( + "--partition", + default=slurm_partition, + type=str, + help="Partition where to submit", + ) + parser.add_argument( + "--use-volta32", + action="store_true", + help="Request V100-32GB GPUs", + ) + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + parser.add_argument( + "--exclude", + default="", + type=str, + help="Nodes to exclude", + ) + return parser + + +def get_shared_folder() -> Path: + user_checkpoint_path = get_user_checkpoint_path() + if user_checkpoint_path is None: + raise RuntimeError("Path to user checkpoint cannot be determined") + path = user_checkpoint_path / "experiments" + path.mkdir(exist_ok=True) + return path + + +def submit_jobs(task_class, args, name: str): + if not args.output_dir: + args.output_dir = str(get_shared_folder() / "%j") + + Path(args.output_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor(folder=args.output_dir, slurm_max_num_timeout=30) + + kwargs = {} + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + if args.comment: + kwargs["slurm_comment"] = args.comment + if args.exclude: + kwargs["slurm_exclude"] = args.exclude + + executor_params = get_slurm_executor_parameters( + nodes=args.nodes, + num_gpus_per_node=args.ngpus, + timeout_min=args.timeout, # max is 60 * 72 + slurm_signal_delay_s=120, + slurm_partition=args.partition, + **kwargs, + ) + executor.update_parameters(name=name, **executor_params) + + task = task_class(args) + job = executor.submit(task) + + logger.info(f"Submitted job_id: {job.job_id}") + str_output_dir = os.path.abspath(args.output_dir).replace("%j", str(job.job_id)) + logger.info(f"Logs and checkpoints will be saved at: {str_output_dir}") diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/train/train.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..837872e0783f4499dd7ca8d491cd1b00debb2072 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/run/train/train.py @@ -0,0 +1,60 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import setup_logging +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.train import get_args_parser as get_train_args_parser +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.run.submit import get_args_parser, submit_jobs + + +logger = logging.getLogger("dinov2") + + +class Trainer(object): + def __init__(self, args): + self.args = args + + def __call__(self): + from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.train import main as train_main + + self._setup_args() + train_main(self.args) + + def checkpoint(self): + import submitit + + logger.info(f"Requeuing {self.args}") + empty = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty) + + def _setup_args(self): + import submitit + + job_env = submitit.JobEnvironment() + self.args.output_dir = self.args.output_dir.replace("%j", str(job_env.job_id)) + logger.info(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + logger.info(f"Args: {self.args}") + + +def main(): + description = "Submitit launcher for DINOv2 training" + train_args_parser = get_train_args_parser(add_help=False) + parents = [train_args_parser] + args_parser = get_args_parser(description=description, parents=parents) + args = args_parser.parse_args() + + setup_logging() + + assert os.path.exists(args.config_file), "Configuration file does not exist!" + submit_jobs(Trainer, args, name="dinov2:train") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b66d17aa547ed5560e75a03f5c1587da2d4fd7 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/__init__.py @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .train import get_args_parser, main +from .ssl_meta_arch import SSLMetaArch diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/ssl_meta_arch.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/ssl_meta_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..d239496963b6ebca8f6524734d11badca5687b7f --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/ssl_meta_arch.py @@ -0,0 +1,403 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from functools import partial +import logging + +import torch +from torch import nn + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.loss import DINOLoss, iBOTPatchLoss, KoLeoLoss +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.models import build_model_from_cfg +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.layers import DINOHead +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.utils import has_batchnorms +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.param_groups import get_params_groups_with_decay, fuse_params_groups +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.fsdp import get_fsdp_wrapper, ShardedGradScaler, get_fsdp_modules, reshard_fsdp_model + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.models.vision_transformer import BlockChunk + +try: + from xformers.ops import fmha + + XFORMERS_AVAILABLE = True +except ImportError: + XFORMERS_AVAILABLE = False +assert XFORMERS_AVAILABLE, "xFormers is required for DINOv2 training" + + +logger = logging.getLogger("dinov2") + + +class SSLMetaArch(nn.Module): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.fp16_scaler = ShardedGradScaler() if cfg.compute_precision.grad_scaler else None + + student_model_dict = dict() + teacher_model_dict = dict() + + student_backbone, teacher_backbone, embed_dim = build_model_from_cfg(cfg) + student_model_dict["backbone"] = student_backbone + teacher_model_dict["backbone"] = teacher_backbone + logger.info(f"OPTIONS -- architecture : embed_dim: {embed_dim}") + + if cfg.student.pretrained_weights: + chkpt = torch.load(cfg.student.pretrained_weights) + logger.info(f"OPTIONS -- pretrained weights: loading from {cfg.student.pretrained_weights}") + student_backbone.load_state_dict(chkpt["model"], strict=False) + + self.embed_dim = embed_dim + self.dino_out_dim = cfg.dino.head_n_prototypes + + self.do_dino = cfg.dino.loss_weight > 0 + self.do_koleo = cfg.dino.koleo_loss_weight > 0 + self.do_ibot = cfg.ibot.loss_weight > 0 + self.ibot_separate_head = cfg.ibot.separate_head + + logger.info("OPTIONS -- DINO") + if self.do_dino: + logger.info(f"OPTIONS -- DINO -- loss_weight: {cfg.dino.loss_weight}") + logger.info(f"OPTIONS -- DINO -- head_n_prototypes: {cfg.dino.head_n_prototypes}") + logger.info(f"OPTIONS -- DINO -- head_bottleneck_dim: {cfg.dino.head_bottleneck_dim}") + logger.info(f"OPTIONS -- DINO -- head_hidden_dim: {cfg.dino.head_hidden_dim}") + self.dino_loss_weight = cfg.dino.loss_weight + dino_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.dino.head_n_prototypes, + hidden_dim=cfg.dino.head_hidden_dim, + bottleneck_dim=cfg.dino.head_bottleneck_dim, + nlayers=cfg.dino.head_nlayers, + ) + self.dino_loss = DINOLoss(self.dino_out_dim) + if self.do_koleo: + logger.info("OPTIONS -- DINO -- applying KOLEO regularization") + self.koleo_loss = KoLeoLoss() + + else: + logger.info("OPTIONS -- DINO -- not using DINO") + + if self.do_dino or self.do_ibot: + student_model_dict["dino_head"] = dino_head() + teacher_model_dict["dino_head"] = dino_head() + + logger.info("OPTIONS -- IBOT") + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_ratio_tuple: {cfg.ibot.mask_ratio_min_max}") + logger.info(f"OPTIONS -- IBOT masking -- ibot_mask_sample_probability: {cfg.ibot.mask_sample_probability}") + if self.do_ibot: + self.ibot_loss_weight = cfg.ibot.loss_weight + assert max(cfg.ibot.mask_ratio_min_max) > 0, "please provide a positive mask ratio tuple for ibot" + assert cfg.ibot.mask_sample_probability > 0, "please provide a positive mask probability for ibot" + self.ibot_out_dim = cfg.ibot.head_n_prototypes if self.ibot_separate_head else cfg.dino.head_n_prototypes + self.ibot_patch_loss = iBOTPatchLoss(self.ibot_out_dim) + if self.ibot_separate_head: + logger.info(f"OPTIONS -- IBOT -- loss_weight: {cfg.ibot.loss_weight}") + logger.info(f"OPTIONS -- IBOT -- head_n_prototypes: {cfg.ibot.head_n_prototypes}") + logger.info(f"OPTIONS -- IBOT -- head_bottleneck_dim: {cfg.ibot.head_bottleneck_dim}") + logger.info(f"OPTIONS -- IBOT -- head_hidden_dim: {cfg.ibot.head_hidden_dim}") + ibot_head = partial( + DINOHead, + in_dim=embed_dim, + out_dim=cfg.ibot.head_n_prototypes, + hidden_dim=cfg.ibot.head_hidden_dim, + bottleneck_dim=cfg.ibot.head_bottleneck_dim, + nlayers=cfg.ibot.head_nlayers, + ) + student_model_dict["ibot_head"] = ibot_head() + teacher_model_dict["ibot_head"] = ibot_head() + else: + logger.info("OPTIONS -- IBOT -- head shared with DINO") + + self.need_to_synchronize_fsdp_streams = True + + self.student = nn.ModuleDict(student_model_dict) + self.teacher = nn.ModuleDict(teacher_model_dict) + + # there is no backpropagation through the teacher, so no need for gradients + for p in self.teacher.parameters(): + p.requires_grad = False + logger.info(f"Student and Teacher are built: they are both {cfg.student.arch} network.") + + def forward(self, inputs): + raise NotImplementedError + + def backprop_loss(self, loss): + if self.fp16_scaler is not None: + self.fp16_scaler.scale(loss).backward() + else: + loss.backward() + + def forward_backward(self, images, teacher_temp): + n_global_crops = 2 + assert n_global_crops == 2 + n_local_crops = self.cfg.crops.local_crops_number + + global_crops = images["collated_global_crops"].cuda(non_blocking=True) + local_crops = images["collated_local_crops"].cuda(non_blocking=True) + + masks = images["collated_masks"].cuda(non_blocking=True) + mask_indices_list = images["mask_indices_list"].cuda(non_blocking=True) + n_masked_patches_tensor = images["n_masked_patches"].cuda(non_blocking=True) + n_masked_patches = mask_indices_list.shape[0] + upperbound = images["upperbound"] + masks_weight = images["masks_weight"].cuda(non_blocking=True) + + n_local_crops_loss_terms = max(n_local_crops * n_global_crops, 1) + n_global_crops_loss_terms = (n_global_crops - 1) * n_global_crops + + do_dino = self.do_dino + do_ibot = self.do_ibot + + # loss scales + ibot_loss_scale = 1.0 / n_global_crops + + # teacher output + @torch.no_grad() + def get_teacher_output(): + x, n_global_crops_teacher = global_crops, n_global_crops + teacher_backbone_output_dict = self.teacher.backbone(x, is_training=True) + teacher_cls_tokens = teacher_backbone_output_dict["x_norm_clstoken"] + teacher_cls_tokens = teacher_cls_tokens.chunk(n_global_crops_teacher) + # watch out: these are chunked and cat'd in reverse so A is matched to B in the global crops dino loss + teacher_cls_tokens = torch.cat((teacher_cls_tokens[1], teacher_cls_tokens[0])) + ibot_teacher_patch_tokens = teacher_backbone_output_dict["x_norm_patchtokens"] + _dim = ibot_teacher_patch_tokens.shape[-1] + n_cls_tokens = teacher_cls_tokens.shape[0] + + if do_ibot and not self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound + n_cls_tokens, _dim) + buffer_tensor_teacher[:n_cls_tokens].copy_(teacher_cls_tokens) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[n_cls_tokens : n_cls_tokens + n_masked_patches], + ) + tokens_after_head = self.teacher.dino_head(buffer_tensor_teacher) + teacher_cls_tokens_after_head = tokens_after_head[:n_cls_tokens] + masked_teacher_patch_tokens_after_head = tokens_after_head[ + n_cls_tokens : n_cls_tokens + n_masked_patches + ] + elif do_ibot and self.ibot_separate_head: + buffer_tensor_teacher = ibot_teacher_patch_tokens.new_zeros(upperbound, _dim) + torch.index_select( + ibot_teacher_patch_tokens.flatten(0, 1), + dim=0, + index=mask_indices_list, + out=buffer_tensor_teacher[:n_masked_patches], + ) + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_patch_tokens_after_head = self.teacher.ibot_head(buffer_tensor_teacher)[ + :n_masked_patches + ] + else: + teacher_cls_tokens_after_head = self.teacher.dino_head(teacher_cls_tokens) + masked_teacher_ibot_softmaxed_centered = None + + if self.cfg.train.centering == "centering": + teacher_dino_softmaxed_centered_list = self.dino_loss.softmax_center_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + self.dino_loss.update_center(teacher_cls_tokens_after_head) + if do_ibot: + masked_teacher_patch_tokens_after_head = masked_teacher_patch_tokens_after_head.unsqueeze(0) + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.softmax_center_teacher( + masked_teacher_patch_tokens_after_head[:, :n_masked_patches], teacher_temp=teacher_temp + ) + masked_teacher_ibot_softmaxed_centered = masked_teacher_ibot_softmaxed_centered.squeeze(0) + self.ibot_patch_loss.update_center(masked_teacher_patch_tokens_after_head[:n_masked_patches]) + + elif self.cfg.train.centering == "sinkhorn_knopp": + teacher_dino_softmaxed_centered_list = self.dino_loss.sinkhorn_knopp_teacher( + teacher_cls_tokens_after_head, teacher_temp=teacher_temp + ).view(n_global_crops_teacher, -1, *teacher_cls_tokens_after_head.shape[1:]) + + if do_ibot: + masked_teacher_ibot_softmaxed_centered = self.ibot_patch_loss.sinkhorn_knopp_teacher( + masked_teacher_patch_tokens_after_head, + teacher_temp=teacher_temp, + n_masked_patches_tensor=n_masked_patches_tensor, + ) + + else: + raise NotImplementedError + + return teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered + + teacher_dino_softmaxed_centered_list, masked_teacher_ibot_softmaxed_centered = get_teacher_output() + reshard_fsdp_model(self.teacher) + + loss_dict = {} + + loss_accumulator = 0 # for backprop + student_global_backbone_output_dict, student_local_backbone_output_dict = self.student.backbone( + [global_crops, local_crops], masks=[masks, None], is_training=True + ) + + inputs_for_student_head_list = [] + + # 1a: local crops cls tokens + student_local_cls_tokens = student_local_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_local_cls_tokens.unsqueeze(0)) + + # 1b: global crops cls tokens + student_global_cls_tokens = student_global_backbone_output_dict["x_norm_clstoken"] + inputs_for_student_head_list.append(student_global_cls_tokens.unsqueeze(0)) + + # 1c: global crops patch tokens + if do_ibot: + _dim = student_global_backbone_output_dict["x_norm_clstoken"].shape[-1] + ibot_student_patch_tokens = student_global_backbone_output_dict["x_norm_patchtokens"] + buffer_tensor_patch_tokens = ibot_student_patch_tokens.new_zeros(upperbound, _dim) + buffer_tensor_patch_tokens[:n_masked_patches].copy_( + torch.index_select(ibot_student_patch_tokens.flatten(0, 1), dim=0, index=mask_indices_list) + ) + if not self.ibot_separate_head: + inputs_for_student_head_list.append(buffer_tensor_patch_tokens.unsqueeze(0)) + else: + student_global_masked_patch_tokens_after_head = self.student.ibot_head(buffer_tensor_patch_tokens)[ + :n_masked_patches + ] + + # 2: run + _attn_bias, cat_inputs = fmha.BlockDiagonalMask.from_tensor_list(inputs_for_student_head_list) + outputs_list = _attn_bias.split(self.student.dino_head(cat_inputs)) + + # 3a: local crops cls tokens + student_local_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3b: global crops cls tokens + student_global_cls_tokens_after_head = outputs_list.pop(0).squeeze(0) + + # 3c: global crops patch tokens + if do_ibot and not self.ibot_separate_head: + student_global_masked_patch_tokens_after_head = outputs_list.pop(0).squeeze(0)[:n_masked_patches] + + if n_local_crops > 0: + dino_local_crops_loss = self.dino_loss( + student_output_list=student_local_cls_tokens_after_head.chunk(n_local_crops), + teacher_out_softmaxed_centered_list=teacher_dino_softmaxed_centered_list, + ) / (n_global_crops_loss_terms + n_local_crops_loss_terms) + + # store for display + loss_dict["dino_local_crops_loss"] = dino_local_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_local_crops_loss + + # process global crops + loss_scales = 2 # this is here since we process global crops together + + if do_dino: + # compute loss + dino_global_crops_loss = ( + self.dino_loss( + student_output_list=[student_global_cls_tokens_after_head], + teacher_out_softmaxed_centered_list=[ + teacher_dino_softmaxed_centered_list.flatten(0, 1) + ], # these were chunked and stacked in reverse so A is matched to B + ) + * loss_scales + / (n_global_crops_loss_terms + n_local_crops_loss_terms) + ) + + loss_dict["dino_global_crops_loss"] = dino_global_crops_loss + + # accumulate loss + loss_accumulator += self.dino_loss_weight * dino_global_crops_loss + + student_cls_tokens = student_global_cls_tokens + + if self.do_koleo: + koleo_loss = self.cfg.dino.koleo_loss_weight * sum( + self.koleo_loss(p) for p in student_cls_tokens.chunk(2) + ) # we don't apply koleo loss between cls tokens of a same image + loss_accumulator += koleo_loss + loss_dict["koleo_loss"] = ( + koleo_loss / loss_scales + ) # this is to display the same losses as before but we can remove eventually + + if do_ibot: + # compute loss + ibot_patch_loss = ( + self.ibot_patch_loss.forward_masked( + student_global_masked_patch_tokens_after_head, + masked_teacher_ibot_softmaxed_centered, + student_masks_flat=masks, + n_masked_patches=n_masked_patches, + masks_weight=masks_weight, + ) + * loss_scales + * ibot_loss_scale + ) + + # store for display + loss_dict["ibot_loss"] = ibot_patch_loss / 2 + + # accumulate loss + loss_accumulator += self.ibot_loss_weight * ibot_patch_loss + + self.backprop_loss(loss_accumulator) + + self.fsdp_synchronize_streams() + + return loss_dict + + def fsdp_synchronize_streams(self): + if self.need_to_synchronize_fsdp_streams: + torch.cuda.synchronize() + self.student.dino_head._streams = ( + self.teacher.dino_head._streams + ) = self.student.backbone._streams = self.teacher.backbone._streams + self.need_to_synchronize_fsdp_streams = False + + def update_teacher(self, m): + student_param_list = [] + teacher_param_list = [] + with torch.no_grad(): + for k in self.student.keys(): + for ms, mt in zip(get_fsdp_modules(self.student[k]), get_fsdp_modules(self.teacher[k])): + student_param_list += ms.params + teacher_param_list += mt.params + torch._foreach_mul_(teacher_param_list, m) + torch._foreach_add_(teacher_param_list, student_param_list, alpha=1 - m) + + def train(self): + super().train() + self.teacher.eval() + + def get_maybe_fused_params_for_submodel(self, m): + params_groups = get_params_groups_with_decay( + model=m, + lr_decay_rate=self.cfg.optim.layerwise_decay, + patch_embed_lr_mult=self.cfg.optim.patch_embed_lr_mult, + ) + fused_params_groups = fuse_params_groups(params_groups) + logger.info("fusing param groups") + + for g in fused_params_groups: + g["foreach"] = True + return fused_params_groups + + def get_params_groups(self): + all_params_groups = [] + for m in self.student.values(): + all_params_groups += self.get_maybe_fused_params_for_submodel(m) + return all_params_groups + + def prepare_for_distributed_training(self): + logger.info("DISTRIBUTED FSDP -- preparing model for distributed training") + if has_batchnorms(self.student): + raise NotImplementedError + # below will synchronize all student subnetworks across gpus: + for k, v in self.student.items(): + self.teacher[k].load_state_dict(self.student[k].state_dict()) + student_model_cfg = self.cfg.compute_precision.student[k] + self.student[k] = get_fsdp_wrapper(student_model_cfg, modules_to_wrap={BlockChunk})(self.student[k]) + teacher_model_cfg = self.cfg.compute_precision.teacher[k] + self.teacher[k] = get_fsdp_wrapper(teacher_model_cfg, modules_to_wrap={BlockChunk})(self.teacher[k]) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/train.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/train.py new file mode 100644 index 0000000000000000000000000000000000000000..472fbdb3deda5aa9c2484cc714244397b50f59db --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/train/train.py @@ -0,0 +1,319 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import logging +import math +import os +from functools import partial + +from fvcore.common.checkpoint import PeriodicCheckpointer +import torch + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import SamplerType, make_data_loader, make_dataset +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.data import collate_data_and_cast, DataAugmentationDINO, MaskingGenerator +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.fsdp import FSDPCheckpointer +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import MetricLogger +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.config import setup +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils.utils import CosineScheduler + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.train.ssl_meta_arch import SSLMetaArch + + +torch.backends.cuda.matmul.allow_tf32 = True # PyTorch 1.12 sets this to False by default +logger = logging.getLogger("dinov2") + + +def get_args_parser(add_help: bool = True): + parser = argparse.ArgumentParser("DINOv2 training", add_help=add_help) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--no-resume", + action="store_true", + help="Whether to not attempt to resume from the checkpoint directory. ", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--eval", type=str, default="", help="Eval type to perform") + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + parser.add_argument( + "--output-dir", + "--output_dir", + default="", + type=str, + help="Output directory to save logs and checkpoints", + ) + + return parser + + +def build_optimizer(cfg, params_groups): + return torch.optim.AdamW(params_groups, betas=(cfg.optim.adamw_beta1, cfg.optim.adamw_beta2)) + + +def build_schedulers(cfg): + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + lr = dict( + base_value=cfg.optim["lr"], + final_value=cfg.optim["min_lr"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.optim["warmup_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=0, + ) + wd = dict( + base_value=cfg.optim["weight_decay"], + final_value=cfg.optim["weight_decay_end"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + momentum = dict( + base_value=cfg.teacher["momentum_teacher"], + final_value=cfg.teacher["final_momentum_teacher"], + total_iters=cfg.optim["epochs"] * OFFICIAL_EPOCH_LENGTH, + ) + teacher_temp = dict( + base_value=cfg.teacher["teacher_temp"], + final_value=cfg.teacher["teacher_temp"], + total_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + warmup_iters=cfg.teacher["warmup_teacher_temp_epochs"] * OFFICIAL_EPOCH_LENGTH, + start_warmup_value=cfg.teacher["warmup_teacher_temp"], + ) + + lr_schedule = CosineScheduler(**lr) + wd_schedule = CosineScheduler(**wd) + momentum_schedule = CosineScheduler(**momentum) + teacher_temp_schedule = CosineScheduler(**teacher_temp) + last_layer_lr_schedule = CosineScheduler(**lr) + + last_layer_lr_schedule.schedule[ + : cfg.optim["freeze_last_layer_epochs"] * OFFICIAL_EPOCH_LENGTH + ] = 0 # mimicking the original schedules + + logger.info("Schedulers ready.") + + return ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) + + +def apply_optim_scheduler(optimizer, lr, wd, last_layer_lr): + for param_group in optimizer.param_groups: + is_last_layer = param_group["is_last_layer"] + lr_multiplier = param_group["lr_multiplier"] + wd_multiplier = param_group["wd_multiplier"] + param_group["weight_decay"] = wd * wd_multiplier + param_group["lr"] = (last_layer_lr if is_last_layer else lr) * lr_multiplier + + +def do_test(cfg, model, iteration): + new_state_dict = model.teacher.state_dict() + + if distributed.is_main_process(): + iterstring = str(iteration) + eval_dir = os.path.join(cfg.train.output_dir, "eval", iterstring) + os.makedirs(eval_dir, exist_ok=True) + # save teacher checkpoint + teacher_ckp_path = os.path.join(eval_dir, "teacher_checkpoint.pth") + torch.save({"teacher": new_state_dict}, teacher_ckp_path) + + +def do_train(cfg, model, resume=False): + model.train() + inputs_dtype = torch.half + fp16_scaler = model.fp16_scaler # for mixed precision training + + # setup optimizer + + optimizer = build_optimizer(cfg, model.get_params_groups()) + ( + lr_schedule, + wd_schedule, + momentum_schedule, + teacher_temp_schedule, + last_layer_lr_schedule, + ) = build_schedulers(cfg) + + # checkpointer + checkpointer = FSDPCheckpointer(model, cfg.train.output_dir, optimizer=optimizer, save_to_disk=True) + + start_iter = checkpointer.resume_or_load(cfg.MODEL.WEIGHTS, resume=resume).get("iteration", -1) + 1 + + OFFICIAL_EPOCH_LENGTH = cfg.train.OFFICIAL_EPOCH_LENGTH + max_iter = cfg.optim.epochs * OFFICIAL_EPOCH_LENGTH + + periodic_checkpointer = PeriodicCheckpointer( + checkpointer, + period=3 * OFFICIAL_EPOCH_LENGTH, + max_iter=max_iter, + max_to_keep=3, + ) + + # setup data preprocessing + + img_size = cfg.crops.global_crops_size + patch_size = cfg.student.patch_size + n_tokens = (img_size // patch_size) ** 2 + mask_generator = MaskingGenerator( + input_size=(img_size // patch_size, img_size // patch_size), + max_num_patches=0.5 * img_size // patch_size * img_size // patch_size, + ) + + data_transform = DataAugmentationDINO( + cfg.crops.global_crops_scale, + cfg.crops.local_crops_scale, + cfg.crops.local_crops_number, + global_crops_size=cfg.crops.global_crops_size, + local_crops_size=cfg.crops.local_crops_size, + ) + + collate_fn = partial( + collate_data_and_cast, + mask_ratio_tuple=cfg.ibot.mask_ratio_min_max, + mask_probability=cfg.ibot.mask_sample_probability, + n_tokens=n_tokens, + mask_generator=mask_generator, + dtype=inputs_dtype, + ) + + # setup data loader + + dataset = make_dataset( + dataset_str=cfg.train.dataset_path, + transform=data_transform, + target_transform=lambda _: (), + ) + # sampler_type = SamplerType.INFINITE + sampler_type = SamplerType.SHARDED_INFINITE + data_loader = make_data_loader( + dataset=dataset, + batch_size=cfg.train.batch_size_per_gpu, + num_workers=cfg.train.num_workers, + shuffle=True, + seed=start_iter, # TODO: Fix this -- cfg.train.seed + sampler_type=sampler_type, + sampler_advance=0, # TODO(qas): fix this -- start_iter * cfg.train.batch_size_per_gpu, + drop_last=True, + collate_fn=collate_fn, + ) + + # training loop + + iteration = start_iter + + logger.info("Starting training from iteration {}".format(start_iter)) + metrics_file = os.path.join(cfg.train.output_dir, "training_metrics.json") + metric_logger = MetricLogger(delimiter=" ", output_file=metrics_file) + header = "Training" + + for data in metric_logger.log_every( + data_loader, + 10, + header, + max_iter, + start_iter, + ): + current_batch_size = data["collated_global_crops"].shape[0] / 2 + if iteration > max_iter: + return + + # apply schedules + + lr = lr_schedule[iteration] + wd = wd_schedule[iteration] + mom = momentum_schedule[iteration] + teacher_temp = teacher_temp_schedule[iteration] + last_layer_lr = last_layer_lr_schedule[iteration] + apply_optim_scheduler(optimizer, lr, wd, last_layer_lr) + + # compute losses + + optimizer.zero_grad(set_to_none=True) + loss_dict = model.forward_backward(data, teacher_temp=teacher_temp) + + # clip gradients + + if fp16_scaler is not None: + if cfg.optim.clip_grad: + fp16_scaler.unscale_(optimizer) + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + fp16_scaler.step(optimizer) + fp16_scaler.update() + else: + if cfg.optim.clip_grad: + for v in model.student.values(): + v.clip_grad_norm_(cfg.optim.clip_grad) + optimizer.step() + + # perform teacher EMA update + + model.update_teacher(mom) + + # logging + + if distributed.get_global_size() > 1: + for v in loss_dict.values(): + torch.distributed.all_reduce(v) + loss_dict_reduced = {k: v.item() / distributed.get_global_size() for k, v in loss_dict.items()} + + if math.isnan(sum(loss_dict_reduced.values())): + logger.info("NaN detected") + raise AssertionError + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + metric_logger.update(lr=lr) + metric_logger.update(wd=wd) + metric_logger.update(mom=mom) + metric_logger.update(last_layer_lr=last_layer_lr) + metric_logger.update(current_batch_size=current_batch_size) + metric_logger.update(total_loss=losses_reduced, **loss_dict_reduced) + + # checkpointing and testing + + if cfg.evaluation.eval_period_iterations > 0 and (iteration + 1) % cfg.evaluation.eval_period_iterations == 0: + do_test(cfg, model, f"training_{iteration}") + torch.cuda.synchronize() + periodic_checkpointer.step(iteration) + + iteration = iteration + 1 + metric_logger.synchronize_between_processes() + return {k: meter.global_avg for k, meter in metric_logger.meters.items()} + + +def main(args): + cfg = setup(args) + + model = SSLMetaArch(cfg).to(torch.device("cuda")) + model.prepare_for_distributed_training() + + logger.info("Model:\n{}".format(model)) + if args.eval_only: + iteration = ( + FSDPCheckpointer(model, save_dir=cfg.train.output_dir) + .resume_or_load(cfg.MODEL.WEIGHTS, resume=not args.no_resume) + .get("iteration", -1) + + 1 + ) + return do_test(cfg, model, f"manual_{iteration}") + + do_train(cfg, model, resume=not args.no_resume) + + +if __name__ == "__main__": + args = get_args_parser(add_help=True).parse_args() + main(args) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/__init__.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0952fcc3f57e34b3747962e9ebd6fc57aeea63fa --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/cluster.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..8d98c05d68aa6e9dc165df3db06bd70d999b3fda --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/cluster.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/config.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..57c433c627fbe5e357893b80981665029358d54d --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/config.py @@ -0,0 +1,73 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.distributed as distributed +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.logging import setup_logging +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.utils import utils +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/dtype.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..cef122b25ff3533e004799a1d977f63eb213fee0 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/dtype.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..d707e70cc11591858d4166410d6ed80621cd49ff --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/param_groups.py @@ -0,0 +1,94 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ".pos_embed" in name or ".patch_embed" in name or ".mask_token" in name or ".cls_token" in name: + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name or "patch_embed" in name or "mask_token" in name or "cls_token" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/utils.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53e63eb427f6d5396c8dc153ab07e825c72b68b4 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/dinov2/utils/utils.py @@ -0,0 +1,96 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/hubconf.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..882af7a7e3697ab1b191bf66f88264c91695b70c --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/hubconf.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + import custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/pyproject.toml b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..da67abd8ceabe6d427a96e5d9d4f04b25aebcd32 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/pyproject.toml @@ -0,0 +1,29 @@ +[tool.black] +line-length = 120 + +[tool.pylint.master] +persistent = false +score = false + +[tool.pylint.messages_control] +disable = "all" +enable = [ + "miscellaneous", + "similarities", +] + +[tool.pylint.similarities] +ignore-comments = true +ignore-docstrings = true +ignore-imports = true +min-similarity-lines = 8 + +[tool.pylint.reports] +reports = false + +[tool.pylint.miscellaneous] +notes = [ + "FIXME", + "XXX", + "TODO", +] diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements-dev.txt b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements-dev.txt new file mode 100644 index 0000000000000000000000000000000000000000..5cad34c34cde3a182b616d68b168588827eb9b7c --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements-dev.txt @@ -0,0 +1,3 @@ +black==22.6.0 +flake8==5.0.4 +pylint==2.15.0 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements.txt b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..04c159c443b89330ff3c84257c41b011f9791257 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/requirements.txt @@ -0,0 +1,11 @@ +--extra-index-url https://download.pytorch.org/whl/cu117 +torch==2.0.0 +torchvision==0.15.0 +omegaconf +torchmetrics==0.10.3 +fvcore +iopath +xformers==0.0.18 +submitit +--extra-index-url https://pypi.nvidia.com +cuml-cu11 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/scripts/lint.sh b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/scripts/lint.sh new file mode 100644 index 0000000000000000000000000000000000000000..b91acaf762c4be3a0c9d2a162210bfebfaacba08 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/scripts/lint.sh @@ -0,0 +1,28 @@ +#!/bin/sh + +if [ -n "$1" ]; then + echo "linting \"$1\"" +fi + +echo "running black" +if [ -n "$1" ]; then + black "$1" +else + black dinov2 +fi + +echo "running flake8" +if [ -n "$1" ]; then + flake8 "$1" +else + flake8 +fi + +echo "running pylint" +if [ -n "$1" ]; then + pylint "$1" +else + pylint dinov2 +fi + +exit 0 diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.cfg b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.cfg new file mode 100644 index 0000000000000000000000000000000000000000..3cac0c045434cde205eebe91fd5a2c35a1226b4b --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.cfg @@ -0,0 +1,7 @@ +[flake8] +max-line-length = 120 +ignore = E203,E501,W503 +per-file-ignores = + __init__.py:F401 +exclude = + venv diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..001987cfeef6c5fe3469ea09cd4698352fa90939 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/setup.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from pathlib import Path +import re +from typing import List, Tuple + +from setuptools import setup, find_packages + + +NAME = "dinov2" +DESCRIPTION = "PyTorch code and models for the DINOv2 self-supervised learning method." + +URL = "https://github.com/facebookresearch/dinov2" +AUTHOR = "FAIR" +REQUIRES_PYTHON = ">=3.9.0" +HERE = Path(__file__).parent + + +try: + with open(HERE / "README.md", encoding="utf-8") as f: + long_description = "\n" + f.read() +except FileNotFoundError: + long_description = DESCRIPTION + + +def get_requirements(path: str = HERE / "requirements.txt") -> Tuple[List[str], List[str]]: + requirements = [] + extra_indices = [] + with open(path) as f: + for line in f.readlines(): + line = line.rstrip("\r\n") + if line.startswith("--extra-index-url "): + extra_indices.append(line[18:]) + continue + requirements.append(line) + return requirements, extra_indices + + +def get_package_version() -> str: + with open(HERE / "dinov2/__init__.py") as f: + result = re.search(r"^__version__ = ['\"]([^'\"]*)['\"]", f.read(), re.M) + if result: + return result.group(1) + raise RuntimeError("Can't get package version") + + +requirements, extra_indices = get_requirements() +version = get_package_version() +dev_requirements, _ = get_requirements(HERE / "requirements-dev.txt") + + +setup( + name=NAME, + version=version, + description=DESCRIPTION, + long_description=long_description, + long_description_content_type="text/markdown", + author=AUTHOR, + python_requires=REQUIRES_PYTHON, + url=URL, + packages=find_packages(), + package_data={ + "": ["*.yaml"], + }, + install_requires=requirements, + dependency_links=extra_indices, + extras_require={ + "dev": dev_requirements, + }, + install_package_data=True, + license="CC-BY-NC", + license_files=("LICENSE",), + classifiers=[ + # Trove classifiers: https://github.com/pypa/trove-classifiers/blob/main/src/trove_classifiers/__init__.py + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: Other/Proprietary License", + "Programming Language :: Python :: 3.9", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + ], +) diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/utils.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/vision_transformer.py b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1db64e54db87f0723680b959c16f215efe5c03 --- /dev/null +++ b/custom_controlnet_aux/depth_anything/torchhub/facebookresearch_dinov2_main/vision_transformer.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from custom_controlnet_aux.depth_anything.torchhub.facebookresearch_dinov2_main.dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/custom_controlnet_aux/depth_anything_v2/__init__.py b/custom_controlnet_aux/depth_anything_v2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..924e7bc7f6ae44ecaa1a365a4983bc72bbce71a4 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/__init__.py @@ -0,0 +1,56 @@ +import numpy as np +import torch +from einops import repeat +from PIL import Image +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DEPTH_ANYTHING_V2_MODEL_NAME_DICT +from custom_controlnet_aux.depth_anything_v2.dpt import DepthAnythingV2 +import cv2 +import torch.nn.functional as F + + +# https://github.com/DepthAnything/Depth-Anything-V2/blob/main/app.py +model_configs = { + 'depth_anything_v2_vits.pth': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, + 'depth_anything_v2_vitb.pth': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, + 'depth_anything_v2_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'depth_anything_v2_vitg.pth': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}, + 'depth_anything_v2_metric_vkitti_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, + 'depth_anything_v2_metric_hypersim_vitl.pth': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, +} + +class DepthAnythingV2Detector: + def __init__(self, model, filename): + self.model = model + self.device = "cpu" + self.filename = filename + @classmethod + def from_pretrained(cls, pretrained_model_or_path=None, filename="depth_anything_v2_vits.pth"): + if pretrained_model_or_path is None: + pretrained_model_or_path = DEPTH_ANYTHING_V2_MODEL_NAME_DICT[filename] + model_path = custom_hf_download(pretrained_model_or_path, filename) + model = DepthAnythingV2(**model_configs[filename]) + model.load_state_dict(torch.load(model_path, map_location="cpu")) + model = model.eval() + return cls(model, filename) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", max_depth=20.0, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + + depth = self.model.infer_image(cv2.cvtColor(input_image, cv2.COLOR_RGB2BGR), input_size=518, max_depth=max_depth) + depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 + depth = depth.astype(np.uint8) + if 'metric' in self.filename: + depth = 255 - depth + + detected_map = repeat(depth, "h w -> h w 3") + detected_map, remove_pad = resize_image_with_pad(detected_map, detect_resolution, upscale_method) + detected_map = remove_pad(detected_map) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2.py b/custom_controlnet_aux/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..78e99975d317fbaac81d0bab5401ce91ccfff070 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from custom_controlnet_aux.depth_anything_v2.dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/custom_controlnet_aux/depth_anything_v2/dpt.py b/custom_controlnet_aux/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..630ae6936895db1d0d8806b30a0815fbecf83443 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/dpt.py @@ -0,0 +1,220 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from custom_controlnet_aux.depth_anything_v2.dinov2 import DINOv2 +from custom_controlnet_aux.depth_anything_v2.util.blocks import FeatureFusionBlock, _make_scratch +from custom_controlnet_aux.depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + self.intermediate_layer_idx = { + 'vits': [2, 5, 8, 11], + 'vitb': [2, 5, 8, 11], + 'vitl': [4, 11, 17, 23], + 'vitg': [9, 19, 29, 39] + } + + self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + + self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x, max_depth): + patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + + features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + + depth = self.depth_head(features, patch_h, patch_w) * max_depth + + return depth.squeeze(1) + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518, max_depth=20.0): + image, (h, w) = self.image2tensor(raw_image, input_size) + + depth = self.forward(image, max_depth) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + + h, w = raw_image.shape[:2] + + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) \ No newline at end of file diff --git a/custom_controlnet_aux/depth_anything_v2/util/blocks.py b/custom_controlnet_aux/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/custom_controlnet_aux/depth_anything_v2/util/transform.py b/custom_controlnet_aux/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73 --- /dev/null +++ b/custom_controlnet_aux/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/__init__.py b/custom_controlnet_aux/diffusion_edge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d6010f8d4b16327bb78a3baf2dee27581cf1ef --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/__init__.py @@ -0,0 +1,40 @@ +from custom_controlnet_aux.diffusion_edge.model import DiffusionEdge, prepare_args +import numpy as np +import torch +from einops import rearrange +from PIL import Image +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DIFFUSION_EDGE_MODEL_NAME + +class DiffusionEdgeDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="diffusion_edge_indoor.pt"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + model = DiffusionEdge(prepare_args(model_path)) + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, patch_batch_size=8, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + with torch.no_grad(): + input_image = rearrange(torch.from_numpy(input_image), "h w c -> 1 c h w") + input_image = input_image.float() / 255. + line = self.model(input_image, patch_batch_size) + line = rearrange(line, "1 c h w -> h w c") + + detected_map = line.cpu().numpy().__mul__(255.).astype(np.uint8) + detected_map = remove_pad(HWC3(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/default.yaml b/custom_controlnet_aux/diffusion_edge/default.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0f13fbd63921bee108c5a015a2aae2a44122fb69 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/default.yaml @@ -0,0 +1,74 @@ +model: + model_type: const_sde + model_name: cond_unet + image_size: [320, 320] + input_keys: ['image', 'cond'] + ckpt_path: + ignore_keys: [ ] + only_model: False + timesteps: 1000 + train_sample: -1 + sampling_timesteps: 1 + loss_type: l2 + objective: pred_noise + start_dist: normal + perceptual_weight: 0 + scale_factor: 0.3 + scale_by_std: True + default_scale: True + scale_by_softsign: False + eps: !!float 1e-4 + weighting_loss: False + first_stage: + embed_dim: 3 + lossconfig: + disc_start: 50001 + kl_weight: 0.000001 + disc_weight: 0.5 + disc_in_channels: 1 + ddconfig: + double_z: True + z_channels: 3 + resolution: [ 320, 320 ] + in_channels: 1 + out_ch: 1 + ch: 128 + ch_mult: [ 1,2,4 ] # num_down = len(ch_mult)-1 + num_res_blocks: 2 + attn_resolutions: [ ] + dropout: 0.0 + ckpt_path: + unet: + dim: 128 + cond_net: swin + without_pretrain: False + channels: 3 + out_mul: 1 + dim_mults: [ 1, 2, 4, 4, ] # num_down = len(dim_mults) + cond_in_dim: 3 + cond_dim: 128 + cond_dim_mults: [ 2, 4 ] # num_down = len(cond_dim_mults) + # window_sizes1: [ [4, 4], [2, 2], [1, 1], [1, 1] ] + # window_sizes2: [ [4, 4], [2, 2], [1, 1], [1, 1] ] + window_sizes1: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] + window_sizes2: [ [ 8, 8 ], [ 4, 4 ], [ 2, 2 ], [ 1, 1 ] ] + fourier_scale: 16 + cond_pe: False + num_pos_feats: 128 + cond_feature_size: [ 80, 80 ] + +data: + name: edge + img_folder: '/data/yeyunfan/edge_detection_datasets/datasets/BSDS_test' + augment_horizontal_flip: True + batch_size: 8 + num_workers: 4 + +sampler: + sample_type: "slide" + stride: [240, 240] + batch_size: 1 + sample_num: 300 + use_ema: True + save_folder: + ckpt_path: \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc97bb8da2613d82f5756f5bc9476e4fda1fd95 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/__init__.py @@ -0,0 +1 @@ +# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.denoising_diffusion_pytorch import GaussianDiffusion, Unet, Trainer diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py new file mode 100644 index 0000000000000000000000000000000000000000..cf215fc5d8bae5101299cc360ee036b0b6647393 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/data.py @@ -0,0 +1,598 @@ +import torch +import torchvision.transforms as T +import torch.utils.data as data +import torch.nn as nn +from pathlib import Path +from functools import partial +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import exists, convert_image_to_fn, normalize_to_neg_one_to_one +from PIL import Image, ImageDraw +import torch.nn.functional as F +import math +import torchvision.transforms.functional as F2 +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from typing import Any, Callable, Optional, Tuple +import os +import pickle +import numpy as np +import copy +import custom_albumentations as albumentations +from torchvision.transforms.functional import InterpolationMode + +def get_imgs_list(imgs_dir): + imgs_list = os.listdir(imgs_dir) + imgs_list.sort() + return [os.path.join(imgs_dir, f) for f in imgs_list if f.endswith('.jpg') or f.endswith('.JPG')or f.endswith('.png') or f.endswith('.pgm') or f.endswith('.ppm')] + + +def fit_img_postfix(img_path): + if not os.path.exists(img_path) and img_path.endswith(".jpg"): + img_path = img_path[:-4] + ".png" + if not os.path.exists(img_path) and img_path.endswith(".png"): + img_path = img_path[:-4] + ".jpg" + return img_path + + +class AdaptEdgeDataset(data.Dataset): + def __init__( + self, + data_root, + # mask_folder, + image_size, + exts = ['png', 'jpg'], + augment_horizontal_flip = False, + convert_image_to = None, + normalize_to_neg_one_to_one=True, + split='train', + # inter_type='bicubic', + # down=4, + threshold=0.3, use_uncertainty=False + ): + super().__init__() + # self.img_folder = Path(img_folder) + # self.edge_folder = Path(os.path.join(data_root, f'gt_imgs')) + # self.img_folder = Path(os.path.join(data_root, f'imgs')) + # self.edge_folder = Path(os.path.join(data_root, "edge", "aug")) + # self.img_folder = Path(os.path.join(data_root, "image", "aug")) + self.data_root = data_root + self.image_size = image_size + + # self.edge_paths = [p for ext in exts for p in self.edge_folder.rglob(f'*.{ext}')] + # self.img_paths = [(self.img_folder / item.parent.name / f'{item.stem}.jpg') for item in self.edge_paths] + # self.img_paths = [(self.img_folder / f'{item.stem}.jpg') for item in self.edge_paths] + + self.threshold = threshold * 256 + self.use_uncertainty = use_uncertainty + self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one + + maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity() + + # self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one + # self.random_crop = RandomCrop(size=image_size) + # self.transform = Compose([ + # # Lambda(maybe_convert_fn), + # # Resize(image_size, interpolation=3, interpolation2=0), + # Resize(image_size, interpolation=InterpolationMode.BILINEAR, interpolation2=InterpolationMode.NEAREST), + # RandomHorizontalFlip() if augment_horizontal_flip else Identity(), + # # RandomCrop(image_size), + # ToTensor() + # ]) + self.data_list = self.build_list() + + self.transform = transforms.Compose([ + # Resize(self.image_size, interpolation=InterpolationMode.BILINEAR, interpolation2=InterpolationMode.NEAREST), + transforms.ToTensor()]) + + def __len__(self): + return len(self.data_list) + + + def read_img(self, image_path): + with open(image_path, 'rb') as f: + img = Image.open(f) + img = img.convert('RGB') + + raw_width, raw_height = img.size + # width = int(raw_width / 32) * 32 + # height = int(raw_height / 32) * 32 + # img = img.resize((width, height), Image.Resampling.BILINEAR) + # # print("img.size:", img.size) + # img = self.transform(img) + + return img, (raw_width, raw_height) + + def read_lb(self, lb_path): + lb_data = Image.open(lb_path) + + width, height = lb_data.size + width = int(width / 32) * 32 + height = int(height / 32) * 32 + lb_data = lb_data.resize((width, height), Image.Resampling.BILINEAR) + # print("lb_data.size:", lb_data.size) + lb = np.array(lb_data, dtype=np.float32) + if lb.ndim == 3: + lb = np.squeeze(lb[:, :, 0]) + assert lb.ndim == 2 + threshold = self.threshold + lb = lb[np.newaxis, :, :] + + lb[lb == 0] = 0 + + # ---------- important ---------- + if self.use_uncertainty: + lb[np.logical_and(lb > 0, lb < threshold)] = 2 + else: + lb[np.logical_and(lb > 0, lb < threshold)] /= 255. + + lb[lb >= threshold] = 1 + return lb + + def build_list(self): + data_root = os.path.abspath(self.data_root) + images_path = os.path.join(data_root, 'image', "raw") + labels_path = os.path.join(data_root, 'edge', "raw") + + samples = [] + for directory_name in os.listdir(images_path): + image_directories = os.path.join(images_path, directory_name) + for file_name_ext in os.listdir(image_directories): + file_name = os.path.basename(file_name_ext) + image_path = fit_img_postfix(os.path.join(images_path, directory_name, file_name)) + lb_path = fit_img_postfix(os.path.join(labels_path, directory_name, file_name)) + samples.append((image_path, lb_path)) + return samples + + def __getitem__(self, index): + img_path, edge_path = self.data_list[index] + # edge_path = self.edge_paths[index] + # img_path = self.img_paths[index] + img_name = os.path.basename(img_path) + + img, raw_size = self.read_img(img_path) + edge = self.read_lb(edge_path) + + # print("-------hhhhhhhhhhhhh--------:", img.shape, edge.shape) + # edge = Image.open(edge_path).convert('L') + # # default to score-sde preprocessing + # mask = Image.open(img_path).convert('RGB') + # edge, img = self.transform(edge, mask) + if self.normalize_to_neg_one_to_one: # transform to [-1, 1] + edge = normalize_to_neg_one_to_one(edge) + img = normalize_to_neg_one_to_one(img) + return {'image': edge, 'cond': img, 'raw_size': raw_size, 'img_name': img_name} + +class EdgeDataset(data.Dataset): + def __init__( + self, + data_root, + # mask_folder, + image_size, + exts = ['png', 'jpg'], + augment_horizontal_flip = True, + convert_image_to = None, + normalize_to_neg_one_to_one=True, + split='train', + # inter_type='bicubic', + # down=4, + threshold=0.3, use_uncertainty=False, cfg={} + ): + super().__init__() + # self.img_folder = Path(img_folder) + # self.edge_folder = Path(os.path.join(data_root, f'gt_imgs')) + # self.img_folder = Path(os.path.join(data_root, f'imgs')) + # self.edge_folder = Path(os.path.join(data_root, "edge", "aug")) + # self.img_folder = Path(os.path.join(data_root, "image", "aug")) + self.data_root = data_root + self.image_size = image_size + + # self.edge_paths = [p for ext in exts for p in self.edge_folder.rglob(f'*.{ext}')] + # self.img_paths = [(self.img_folder / item.parent.name / f'{item.stem}.jpg') for item in self.edge_paths] + # self.img_paths = [(self.img_folder / f'{item.stem}.jpg') for item in self.edge_paths] + + self.threshold = threshold * 255 + self.use_uncertainty = use_uncertainty + self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one + + maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity() + + self.data_list = self.build_list() + + # self.transform = Compose([ + # Resize(image_size), + # RandomHorizontalFlip() if augment_horizontal_flip else Identity(), + # ToTensor() + # ]) + crop_type = cfg.get('crop_type') if 'crop_type' in cfg else 'rand_crop' + if crop_type == 'rand_crop': + self.transform = Compose([ + RandomCrop(image_size), + RandomHorizontalFlip() if augment_horizontal_flip else Identity(), + ToTensor() + ]) + elif crop_type == 'rand_resize_crop': + self.transform = Compose([ + RandomResizeCrop(image_size), + RandomHorizontalFlip() if augment_horizontal_flip else Identity(), + ToTensor() + ]) + print("crop_type:", crop_type) + + def __len__(self): + return len(self.data_list) + + + def read_img(self, image_path): + with open(image_path, 'rb') as f: + img = Image.open(f) + img = img.convert('RGB') + + raw_width, raw_height = img.size + # width = int(raw_width / 32) * 32 + # height = int(raw_height / 32) * 32 + # img = img.resize((width, height), Image.Resampling.BILINEAR) + # # print("img.size:", img.size) + # img = self.transform(img) + + return img, (raw_width, raw_height) + + def read_lb(self, lb_path): + lb_data = Image.open(lb_path).convert('L') + lb = np.array(lb_data).astype(np.float32) + # width, height = lb_data.size + # width = int(width / 32) * 32 + # height = int(height / 32) * 32 + # lb_data = lb_data.resize((width, height), Image.Resampling.BILINEAR) + # print("lb_data.size:", lb_data.size) + # lb = np.array(lb_data, dtype=np.float32) + # if lb.ndim == 3: + # lb = np.squeeze(lb[:, :, 0]) + # assert lb.ndim == 2 + threshold = self.threshold + # lb = lb[np.newaxis, :, :] + # lb[lb == 0] = 0 + + # ---------- important ---------- + # if self.use_uncertainty: + # lb[np.logical_and(lb > 0, lb < threshold)] = 2 + # else: + # lb[np.logical_and(lb > 0, lb < threshold)] /= 255. + + lb[lb >= threshold] = 255 + lb = Image.fromarray(lb.astype(np.uint8)) + return lb + + def build_list(self): + data_root = os.path.abspath(self.data_root) + images_path = os.path.join(data_root, 'image') + labels_path = os.path.join(data_root, 'edge') + + samples = [] + for directory_name in os.listdir(images_path): + image_directories = os.path.join(images_path, directory_name) + for file_name_ext in os.listdir(image_directories): + file_name = os.path.basename(file_name_ext) + image_path = fit_img_postfix(os.path.join(images_path, directory_name, file_name)) + lb_path = fit_img_postfix(os.path.join(labels_path, directory_name, file_name)) + samples.append((image_path, lb_path)) + return samples + + def __getitem__(self, index): + img_path, edge_path = self.data_list[index] + # edge_path = self.edge_paths[index] + # img_path = self.img_paths[index] + img_name = os.path.basename(img_path) + + img, raw_size = self.read_img(img_path) + edge = self.read_lb(edge_path) + img, edge = self.transform(img, edge) + + # print("-------hhhhhhhhhhhhh--------:", img.shape, edge.shape) + # edge = Image.open(edge_path).convert('L') + # # default to score-sde preprocessing + # mask = Image.open(img_path).convert('RGB') + # edge, img = self.transform(edge, mask) + if self.normalize_to_neg_one_to_one: # transform to [-1, 1] + edge = normalize_to_neg_one_to_one(edge) + img = normalize_to_neg_one_to_one(img) + return {'image': edge, 'cond': img, 'raw_size': raw_size, 'img_name': img_name} + +class EdgeDatasetTest(data.Dataset): + def __init__( + self, + data_root, + # mask_folder, + image_size, + exts = ['png', 'jpg'], + convert_image_to = None, + normalize_to_neg_one_to_one=True, + ): + super().__init__() + + self.data_root = data_root + self.image_size = image_size + self.normalize_to_neg_one_to_one = normalize_to_neg_one_to_one + + maybe_convert_fn = partial(convert_image_to_fn, convert_image_to) if exists(convert_image_to) else Identity() + + self.data_list = self.build_list() + + self.transform = Compose([ + ToTensor() + ]) + + def __len__(self): + return len(self.data_list) + + + def read_img(self, image_path): + with open(image_path, 'rb') as f: + img = Image.open(f) + img = img.convert('RGB') + + raw_width, raw_height = img.size + + + return img, (raw_width, raw_height) + + def read_lb(self, lb_path): + lb_data = Image.open(lb_path).convert('L') + lb = np.array(lb_data).astype(np.float32) + + threshold = self.threshold + + + lb[lb >= threshold] = 255 + lb = Image.fromarray(lb.astype(np.uint8)) + return lb + + def build_list(self): + data_root = os.path.abspath(self.data_root) + # images_path = os.path.join(data_root) + images_path = data_root + samples = get_imgs_list(images_path) + return samples + + def __getitem__(self, index): + img_path = self.data_list[index] + # edge_path = self.edge_paths[index] + # img_path = self.img_paths[index] + img_name = os.path.basename(img_path) + + img, raw_size = self.read_img(img_path) + + img = self.transform(img) + if self.normalize_to_neg_one_to_one: # transform to [-1, 1] + img = normalize_to_neg_one_to_one(img) + return {'cond': img, 'raw_size': raw_size, 'img_name': img_name} + + +class Identity(nn.Identity): + r"""A placeholder identity operator that is argument-insensitive. + + Args: + args: any argument (unused) + kwargs: any keyword argument (unused) + + Shape: + - Input: :math:`(*)`, where :math:`*` means any number of dimensions. + - Output: :math:`(*)`, same shape as the input. + + Examples:: + + >>> m = nn.Identity(54, unused_argument1=0.1, unused_argument2=False) + >>> input = torch.randn(128, 20) + >>> output = m(input) + >>> print(output.size()) + torch.Size([128, 20]) + + """ + def __init__(self, *args, **kwargs): + super(Identity, self).__init__(*args, **kwargs) + + def forward(self, input, target): + return input, target + +class Resize(T.Resize): + def __init__(self, size, interpolation2=None, **kwargs): + super().__init__(size, **kwargs) + if interpolation2 is None: + self.interpolation2 = self.interpolation + else: + self.interpolation2 = interpolation2 + + def forward(self, img, target=None): + if target is None: + img = F2.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + return img + else: + img = F2.resize(img, self.size, self.interpolation, self.max_size, self.antialias) + target = F2.resize(target, self.size, self.interpolation2, self.max_size, self.antialias) + return img, target + +class RandomHorizontalFlip(T.RandomHorizontalFlip): + def __init__(self, p=0.5): + super().__init__(p) + + def forward(self, img, target=None): + if target is None: + if torch.rand(1) < self.p: + img = F2.hflip(img) + return img + else: + if torch.rand(1) < self.p: + img = F2.hflip(img) + target = F2.hflip(target) + return img, target + +class CenterCrop(T.CenterCrop): + def __init__(self, size): + super().__init__(size) + + def forward(self, img, target=None): + if target is None: + img = F2.center_crop(img, self.size) + return img + else: + img = F2.center_crop(img, self.size) + target = F2.center_crop(target, self.size) + return img, target + +class RandomCrop(T.RandomCrop): + def __init__(self, size, **kwargs): + super().__init__(size, **kwargs) + + def single_forward(self, img, i, j, h, w): + if self.padding is not None: + img = F2.pad(img, self.padding, self.fill, self.padding_mode) + width, height = F2.get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F2.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F2.pad(img, padding, self.fill, self.padding_mode) + + return F2.crop(img, i, j, h, w) + + def forward(self, img, target=None): + i, j, h, w = self.get_params(img, self.size) + if target is None: + img = self.single_forward(img, i, j, h, w) + return img + else: + img = self.single_forward(img, i, j, h, w) + target = self.single_forward(target, i, j, h, w) + return img, target + +class RandomResizeCrop(T.RandomResizedCrop): + def __init__(self, size, scale=(0.25, 1.0), **kwargs): + super().__init__(size, scale, **kwargs) + + # def single_forward(self, img, i, j, h, w): + # if self.padding is not None: + # img = F2.pad(img, self.padding, self.fill, self.padding_mode) + # width, height = F2.get_image_size(img) + # # pad the width if needed + # if self.pad_if_needed and width < self.size[1]: + # padding = [self.size[1] - width, 0] + # img = F2.pad(img, padding, self.fill, self.padding_mode) + # # pad the height if needed + # if self.pad_if_needed and height < self.size[0]: + # padding = [0, self.size[0] - height] + # img = F2.pad(img, padding, self.fill, self.padding_mode) + # + # return F2.crop(img, i, j, h, w) + + def single_forward(self, img, i, j, h, w, interpolation=InterpolationMode.BILINEAR): + """ + Args: + img (PIL Image or Tensor): Image to be cropped and resized. + + Returns: + PIL Image or Tensor: Randomly cropped and resized image. + """ + # i, j, h, w = self.get_params(img, self.scale, self.ratio) + return F2.resized_crop(img, i, j, h, w, self.size, interpolation) + + def forward(self, img, target=None): + i, j, h, w = self.get_params(img, self.scale, self.ratio) + if target is None: + img = self.single_forward(img, i, j, h, w) + return img + else: + img = self.single_forward(img, i, j, h, w) + target = self.single_forward(target, i, j, h, w, interpolation=InterpolationMode.NEAREST) + return img, target + +class ToTensor(T.ToTensor): + def __init__(self): + super().__init__() + + def __call__(self, img, target=None): + if target is None: + img = F2.to_tensor(img) + return img + else: + img = F2.to_tensor(img) + target = F2.to_tensor(target) + return img, target + +class Lambda(T.Lambda): + """Apply a user-defined lambda as a transform. This transform does not support torchscript. + + Args: + lambd (function): Lambda/function to be used for transform. + """ + + def __init__(self, lambd): + super().__init__(lambd) + + def __call__(self, img, target=None): + if target is None: + return self.lambd(img) + else: + return self.lambd(img), self.lambd(target) + +class Compose(T.Compose): + def __init__(self, transforms): + super().__init__(transforms) + + def __call__(self, img, target=None): + if target is None: + for t in self.transforms: + img = t(img) + return img + else: + for t in self.transforms: + img, target = t(img, target) + return img, target + + +if __name__ == '__main__': + dataset = CIFAR10( + img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/cifar-10-python', + augment_horizontal_flip=False + ) + # dataset = CityscapesDataset( + # # img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/CelebAHQ/celeba_hq_256', + # data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/Cityscapes/', + # # data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/ADEChallengeData2016/', + # image_size=[512, 1024], + # exts = ['png'], + # augment_horizontal_flip = False, + # convert_image_to = None, + # normalize_to_neg_one_to_one=True, + # ) + # dataset = SRDataset( + # img_folder='/media/huang/ZX3 512G/data/DIV2K/DIV2K_train_HR', + # image_size=[512, 512], + # ) + # dataset = InpaintDataset( + # img_folder='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/CelebAHQ/celeba_hq_256', + # image_size=[256, 256], + # augment_horizontal_flip = True + # ) + dataset = EdgeDataset( + data_root='/media/huang/2da18d46-7cba-4259-9abd-0df819bb104c/data/BSDS', + image_size=[320, 320], + ) + for i in range(len(dataset)): + d = dataset[i] + mask = d['cond'] + print(mask.max()) + dl = data.DataLoader(dataset, batch_size=2, shuffle=False, pin_memory=True, num_workers=0) + + + dataset_builder = tfds.builder('cifar10') + split = 'train' + dataset_options = tf.data.Options() + dataset_options.experimental_optimization.map_parallelization = True + dataset_options.experimental_threading.private_threadpool_size = 48 + dataset_options.experimental_threading.max_intra_op_parallelism = 1 + read_config = tfds.ReadConfig(options=dataset_options) + dataset_builder.download_and_prepare() + ds = dataset_builder.as_dataset( + split=split, shuffle_files=True, read_config=read_config) + pause = 0 \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py new file mode 100644 index 0000000000000000000000000000000000000000..9b733a2b00c4a3b750ccb35b0cbeabdc34ec0d0f --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ddm_const_sde.py @@ -0,0 +1,992 @@ +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd +import math +import torch.nn.functional as F +# import torchvision.transforms.functional as F2 +from .utils import default, identity, normalize_to_neg_one_to_one, unnormalize_to_zero_to_one +from tqdm.auto import tqdm +from einops import rearrange, reduce +from functools import partial +from collections import namedtuple +from random import random, randint, sample, choice +from .encoder_decoder import DiagonalGaussianDistribution +import random +from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import * + +# gaussian diffusion trainer class +ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start']) + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def linear_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64) + +def cosine_beta_schedule(timesteps, s = 0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype = torch.float64) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + +class DDPM(nn.Module): + def __init__( + self, + model, + *, + image_size, + timesteps = 1000, + sampling_timesteps = None, + loss_type = 'l2', + objective = 'pred_noise', + beta_schedule = 'cosine', + p2_loss_weight_gamma = 0., # p2 loss weight, from https://arxiv.org/abs/2204.00227 - 0 is equivalent to weight of 1 across time - 1. is recommended + p2_loss_weight_k = 1, + original_elbo_weight=0., + ddim_sampling_eta = 1., + clip_x_start=True, + train_sample=-1, + input_keys=['image'], + start_dist='normal', + sample_type='ddim', + perceptual_weight=1., + use_l1=False, + **kwargs + ): + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + only_model = kwargs.pop("only_model", False) + cfg = kwargs.pop("cfg", None) + super().__init__(**kwargs) + # assert not (type(self) == DDPM and model.channels != model.out_dim) + # assert not model.random_or_learned_sinusoidal_cond + + self.model = model + self.channels = self.model.channels + self.self_condition = self.model.self_condition + self.input_keys = input_keys + self.cfg = cfg + self.eps = cfg.get('eps', 1e-4) if cfg is not None else 1e-4 + self.weighting_loss = cfg.get("weighting_loss", False) if cfg is not None else False + if self.weighting_loss: + print('#### WEIGHTING LOSS ####') + + self.clip_x_start = clip_x_start + self.image_size = image_size + self.train_sample = train_sample + self.objective = objective + self.start_dist = start_dist + assert start_dist in ['normal', 'uniform'] + + assert objective in {'pred_noise', 'pred_x0', 'pred_v', 'pred_delta', 'pred_KC'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])' + + if beta_schedule == 'linear': + betas = linear_beta_schedule(timesteps) + elif beta_schedule == 'cosine': + betas = cosine_beta_schedule(timesteps, s=1e-4) + else: + raise ValueError(f'unknown beta schedule {beta_schedule}') + # betas[0] = 2e-3 * betas[0] + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, dim=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.time_range = list(range(self.num_timesteps + 1)) + self.loss_type = loss_type + self.original_elbo_weight = original_elbo_weight + + # sampling related parameters + + self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training + + # assert self.sampling_timesteps <= timesteps + self.is_ddim_sampling = self.sampling_timesteps < timesteps + self.ddim_sampling_eta = ddim_sampling_eta + + # helper function to register buffer from float64 to float32 + + register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20))) + register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # calculate p2 reweighting + + register_buffer('p2_loss_weight', (p2_loss_weight_k + alphas_cumprod / (1 - alphas_cumprod)) ** -p2_loss_weight_gamma) + assert not torch.isnan(self.p2_loss_weight).all() + if self.objective == "pred_noise": + lvlb_weights = self.betas ** 2 / ( + 2 * (self.posterior_variance+1e-5) * alphas * (1 - self.alphas_cumprod)) + elif self.objective == "pred_x0": + lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod) + elif self.objective == "pred_delta": + lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod) + elif self.objective == "pred_KC": + lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod) + elif self.objective == "pred_v": + lvlb_weights = 0.5 * torch.sqrt(alphas_cumprod) / (2. * 1 - alphas_cumprod) + self.register_buffer('lvlb_weights', lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + self.use_l1 = use_l1 + + self.perceptual_weight = perceptual_weight + if self.perceptual_weight > 0: + self.perceptual_loss = LPIPS().eval() + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model) + + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False, use_ema=False): + sd = torch.load(path, map_location="cpu") + if 'ema' in list(sd.keys()) and use_ema: + sd = sd['ema'] + new_sd = {} + for k in sd.keys(): + if k.startswith("ema_model."): + new_k = k[10:] # remove ema_model. + new_sd[new_k] = sd[k] + sd = new_sd + else: + if "model" in list(sd.keys()): + sd = sd["model"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + + @torch.no_grad() + def p_sample(self, x, mask, t: int, x_self_cond = None, clip_denoised = True): + b, *_, device = *x.shape, x.device + batched_times = torch.full((x.shape[0],), t, device = x.device, dtype = torch.long) + model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, mask=mask, t = batched_times, x_self_cond = x_self_cond, clip_denoised = clip_denoised) + noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0 + pred_img = model_mean + (0.5 * model_log_variance).exp() * noise + return pred_img, x_start + + @torch.no_grad() + def p_sample_loop(self, shape, mask, up_scale=1, unnormalize=True): + batch, device = shape[0], self.betas.device + + img = torch.randn(shape, device=device) + img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True) + + x_start = None + + for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps): + self_cond = x_start if self.self_condition else None + img, x_start = self.p_sample(img, mask, t, self_cond) + if unnormalize: + img = unnormalize_to_zero_to_one(img) + return img + + @torch.no_grad() + def ddim_sample(self, shape, mask, up_scale=1, unnormalize=True): + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective + + times = torch.linspace(-1, total_timesteps - 1, steps=sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps + times = list(reversed(times.int().tolist())) + time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)] + + img = torch.randn(shape, device = device) + img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True) + + x_start = None + + for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', total=len(time_pairs)): + time_cond = torch.full((batch,), time, device=device, dtype=torch.long) + self_cond = x_start if self.self_condition else None + pred_noise, x_start, *_ = self.model_predictions(img, time_cond, mask, self_cond) + + if time_next < 0: + img = x_start + continue + + alpha = self.alphas_cumprod[time] + alpha_next = self.alphas_cumprod[time_next] + + sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt() + c = (1 - alpha_next - sigma ** 2).sqrt() + + noise = torch.randn_like(img) + + img = x_start * alpha_next.sqrt() + \ + c * pred_noise + \ + sigma * noise + if unnormalize: + img = unnormalize_to_zero_to_one(img) + return img + + + @torch.no_grad() + def interpolate(self, x1, x2, mask, t = None, lam = 0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device = device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t = t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc = 'interpolation sample time step', total = t): + img = self.p_sample(img, mask, torch.full((b,), i, device=device, dtype=torch.long)) + return img + + def get_input(self, batch, return_first_stage_outputs=False, return_original_cond=False): + assert 'image' in self.input_keys; + if len(self.input_keys) > len(batch.keys()): + x, *_ = batch.values() + else: + x = batch.values() + return x + + def training_step(self, batch): + z, *_ = self.get_input(batch) + cond = batch['cond'] if 'cond' in batch else None + loss, loss_dict = self(z, cond) + return loss, loss_dict + + def forward(self, x, *args, **kwargs): + # continuous time, t in [0, 1] + # t = [] + # for _ in range(x.shape[0]): + # if self.train_sample <= 0: + # t.append(torch.tensor(sample(self.time_range, 2), device=x.device).long()) + # else: + # sl = choice(self.time_range) + # sl_range = list(range(sl - self.train_sample, sl + self.train_sample)) + # sl_range = list(set(sl_range) & set(self.time_range)) + # sl_range.pop(sl_range.index(sl)) + # sl2 = choice(sl_range) + # t.append(torch.tensor([sl, sl2], device=x.device).long()) + # t = torch.stack(t, dim=0) + # t = torch.randint(0, self.num_timesteps+1, (x.shape[0],), device=x.device).long() + eps = self.eps # smallest time step + # t = torch.rand((x.shape[0],), device=x.device) * (self.num_timesteps / eps) + # t = t.round() * eps + # t[t < eps] = eps + t = torch.rand(x.shape[0], device=x.device) * (1. - eps) + eps + return self.p_losses(x, t, *args, **kwargs) + + def q_sample2(self, x_start, t, noise=None): + b, c, h, w = x_start.shape + noise = default(noise, lambda: torch.randn_like(x_start)) + _, nt = t.shape + param_x = self.sqrt_alphas_cumprod.repeat(b, 1).gather(-1, t) # (b, nt) + x = x_start.expand(nt, b, c, h, w).transpose(1, 0) * param_x.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) + param_noise = self.sqrt_one_minus_alphas_cumprod.repeat(b, 1).gather(-1, t) + n = noise.expand(nt, b, c, h, w).transpose(1, 0) * param_noise.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) + return x + n # (b, nt, c, h, w) + + def q_sample3(self, x_start, t, C): + b, c, h, w = x_start.shape + _, nt = t.shape + # K_ = K.unsqueeze(1).repeat(1, nt, 1, 1, 1) + C_ = C.unsqueeze(1).repeat(1, nt, 1, 1, 1) + x_noisy = x_start.expand(nt, b, c, h, w).transpose(1, 0) + \ + + C_ * t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps + return x_noisy # (b, nt, c, h, w) + + # def q_sample(self, x_start, t, C): + # x_noisy = x_start + C * t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) / self.num_timesteps + # return x_noisy + def q_sample(self, x_start, noise, t, C): + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + x_noisy = x_start + C * time + torch.sqrt(time) * noise + return x_noisy + + def q_sample2(self, x_start, noise, t, C): + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + x_noisy = x_start + C / 2 * time ** 2 + torch.sqrt(time) * noise + return x_noisy + + def pred_x0_from_xt(self, xt, noise, C, t): + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + x0 = xt - C * time - torch.sqrt(time) * noise + return x0 + + def pred_x0_from_xt2(self, xt, noise, C, t): + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + x0 = xt - C / 2 * time ** 2 - torch.sqrt(time) * noise + return x0 + + def pred_xtms_from_xt(self, xt, noise, C, t, s): + # noise = noise / noise.std(dim=[1, 2, 3]).reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + mean = xt + C * (time-s) - C * time - s / torch.sqrt(time) * noise + epsilon = torch.randn_like(mean, device=xt.device) + sigma = torch.sqrt(s * (time-s) / time) + xtms = mean + sigma * epsilon + return xtms + + def pred_xtms_from_xt2(self, xt, noise, C, t, s): + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + mean = xt + C / 2 * (time-s) ** 2 - C / 2 * time ** 2 - s / torch.sqrt(time) * noise + epsilon = torch.randn_like(mean, device=xt.device) + sigma = torch.sqrt(s * (time-s) / time) + xtms = mean + sigma * epsilon + return xtms + + def WCE_loss(self, prediction, labelf, beta=1.1): + label = labelf.long() + mask = labelf.clone() + + num_positive = torch.sum(label == 1).float() + num_negative = torch.sum(label == 0).float() + + mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative) + mask[label == 0] = beta * num_positive / (num_positive + num_negative) + mask[label == 2] = 0 + cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='sum') + + return cost + + def Dice_Loss(self, pred, label): + # pred = torch.sigmoid(pred) + smooth = 1 + pred_flat = pred.view(-1) + label_flat = label.view(-1) + + intersecion = pred_flat * label_flat + unionsection = pred_flat.pow(2).sum() + label_flat.pow(2).sum() + smooth + loss = unionsection / (2 * intersecion.sum() + smooth) + loss = loss.sum() + return loss + + def p_losses(self, x_start, t, *args, **kwargs): + if self.start_dist == 'normal': + noise = torch.randn_like(x_start) + elif self.start_dist == 'uniform': + noise = 2 * torch.rand_like(x_start) - 1. + else: + raise NotImplementedError(f'{self.start_dist} is not supported !') + # K = -1. * torch.ones_like(x_start) + # C = noise - x_start # t = 1000 / 1000 + C = -1 * x_start # U(t) = Ct, U(1) = -x0 + # C = -2 * x_start # U(t) = 1/2 * C * t**2, U(1) = 1/2 * C = -x0 + x_noisy = self.q_sample(x_start=x_start, noise=noise, t=t, C=C) # (b, 2, c, h, w) + C_pred, noise_pred = self.model(x_noisy, t, **kwargs) + # C_pred = C_pred / torch.sqrt(t) + # noise_pred = noise_pred / torch.sqrt(1 - t) + x_rec = self.pred_x0_from_xt(x_noisy, noise_pred, C_pred, t) # x_rec:(B, 1, H, W) + loss_dict = {} + prefix = 'train' + + # elif self.objective == 'pred_KC': + # target1 = C + # target2 = noise + # target3 = x_start + + target1 = C + target2 = noise + target3 = x_start + + loss_simple = 0. + loss_vlb = 0. + # use l1 + l2 + if self.weighting_loss: + simple_weight1 = 2*torch.exp(1-t) + simple_weight2 = torch.exp(torch.sqrt(t)) + if self.cfg.model_name == 'ncsnpp9': + simple_weight1 = (t + 1) / t.sqrt() + simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt() + else: + simple_weight1 = 1 + simple_weight2 = 1 + + loss_simple += simple_weight1 * self.get_loss(C_pred, target1, mean=False).mean([1, 2, 3]) + \ + simple_weight2 * self.get_loss(noise_pred, target2, mean=False).mean([1, 2, 3]) + if self.use_l1: + loss_simple += simple_weight1 * (C_pred - target1).abs().mean([1, 2, 3]) + \ + simple_weight2 * (noise_pred - target2).abs().mean([1, 2, 3]) + loss_simple = loss_simple / 2 + # rec_weight = (1 - t.reshape(C.shape[0], 1)) ** 2 + rec_weight = 1 - t.reshape(C.shape[0], 1) # (B, 1) + loss_simple = loss_simple.mean() + loss_dict.update({f'{prefix}/loss_simple': loss_simple}) + + # loss_vlb += torch.abs(x_rec - target3).mean([1, 2, 3]) * rec_weight: (B, 1) + loss_vlb += self.Dice_Loss(x_rec, target3) + loss_vlb = loss_vlb.mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + loss_vlb + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + @torch.no_grad() + def sample(self, batch_size=16, up_scale=1, cond=None, denoise=True): + image_size, channels = self.image_size, self.channels + if cond is not None: + batch_size = cond.shape[0] + return self.sample_fn((batch_size, channels, image_size[0], image_size[1]), + up_scale=up_scale, unnormalize=True, cond=cond, denoise=denoise) + + @torch.no_grad() + def sample_fn(self, shape, up_scale=1, unnormalize=True, cond=None, denoise=False): + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], \ + self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective + + # times = torch.linspace(-1, total_timesteps, steps=self.sampling_timesteps + 1).int() + # times = list(reversed(times.int().tolist())) + # time_pairs = list(zip(times[:-1], times[1:])) + # time_steps = torch.tensor([0.25, 0.15, 0.1, 0.1, 0.1, 0.09, 0.075, 0.06, 0.045, 0.03]) + step = 1. / self.sampling_timesteps + # time_steps = torch.tensor([0.1]).repeat(10) + time_steps = torch.tensor([step]).repeat(self.sampling_timesteps) + if denoise: + eps = self.eps + time_steps = torch.cat((time_steps[:-1], torch.tensor([step - eps]), torch.tensor([eps])), dim=0) + + if self.start_dist == 'normal': + img = torch.randn(shape, device=device) + elif self.start_dist == 'uniform': + img = 2 * torch.rand(shape, device=device) - 1. + else: + raise NotImplementedError(f'{self.start_dist} is not supported !') + img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True) + # K = -1 * torch.ones_like(img) + cur_time = torch.ones((batch,), device=device) + for i, time_step in enumerate(time_steps): + s = torch.full((batch,), time_step, device=device) + if i == time_steps.shape[0] - 1: + s = cur_time + if cond is not None: + pred = self.model(img, cur_time, cond) + else: + pred = self.model(img, cur_time) + # C, noise = pred.chunk(2, dim=1) + C, noise = pred[:2] + # correct C + x0 = self.pred_x0_from_xt(img, noise, C, cur_time) + if self.clip_x_start: + x0.clamp_(-1., 1.) + # C.clamp_(-2., 2.) + C = -1 * x0 + img = self.pred_xtms_from_xt(img, noise, C, cur_time, s) + # img = self.pred_xtms_from_xt2(img, noise, C, cur_time, s) + cur_time = cur_time - s + img.clamp_(-1., 1.) + if unnormalize: + img = unnormalize_to_zero_to_one(img) + return img + + + +class LatentDiffusion(DDPM): + def __init__(self, + auto_encoder, + scale_factor=1.0, + scale_by_std=True, + scale_by_softsign=False, + input_keys=['image'], + sample_type='ddim', + num_timesteps_cond=1, + train_sample=-1, + default_scale=False, + *args, + **kwargs + ): + self.scale_by_std = scale_by_std + self.scale_by_softsign = scale_by_softsign + self.default_scale = default_scale + self.num_timesteps_cond = num_timesteps_cond + self.train_sample = train_sample + self.perceptual_weight = 0 + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + only_model = kwargs.pop("only_model", False) + super().__init__(*args, **kwargs) + assert self.num_timesteps_cond <= self.num_timesteps + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer('scale_factor', torch.tensor(scale_factor)) + if self.scale_by_softsign: + self.scale_by_std = False + print('### USING SOFTSIGN RESCALING') + assert (self.scale_by_std and self.scale_by_softsign) is False; + + self.init_first_stage(auto_encoder) + # self.instantiate_cond_stage(cond_stage_config) + self.input_keys = input_keys + self.clip_denoised = False + assert sample_type in ['p_loop', 'ddim', 'dpm', 'transformer'] ### 'dpm' is not availible now, suggestion 'ddim' + self.sample_type = sample_type + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys, only_model) + + def init_first_stage(self, first_stage_model): + self.first_stage_model = first_stage_model.eval() + # self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + ''' + def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): + sd = torch.load(path, map_location="cpu") + if "state_dict" in list(sd.keys()): + sd = sd["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( + sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + if len(unexpected) > 0: + print(f"Unexpected Keys: {unexpected}") + ''' + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError(f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented") + # return self.scale_factor * z.detach() + self.scale_bias + return z.detach() + + @torch.no_grad() + def on_train_batch_start(self, batch): + # only for the first batch + if self.scale_by_std and (not self.scale_by_softsign): + if not self.default_scale: + assert self.scale_factor == 1., 'rather not use custom rescaling and std-rescaling simultaneously' + # set rescale weight to 1./std of encodings + print("### USING STD-RESCALING ###") + x, *_ = batch.values() + encoder_posterior = self.first_stage_model.encode(x) + z = self.get_first_stage_encoding(encoder_posterior) + del self.scale_factor + self.register_buffer('scale_factor', 1. / z.flatten().std()) + print(f"setting self.scale_factor to {self.scale_factor}") + # print("### USING STD-RESCALING ###") + else: + print(f'### USING DEFAULT SCALE {self.scale_factor}') + else: + print(f'### USING SOFTSIGN SCALE !') + + @torch.no_grad() + def get_input(self, batch, return_first_stage_outputs=False, return_original_cond=False): + assert 'image' in self.input_keys; + # if len(self.input_keys) > len(batch.keys()): + # x, cond, *_ = batch.values() + # else: + # x, cond = batch.values() + x = batch['image'] + cond = batch['cond'] if 'cond' in batch else None + z = self.first_stage_model.encode(x) + # print('zzzz', z.shape) + z = self.get_first_stage_encoding(z) + out = [z, cond, x] + if return_first_stage_outputs: + xrec = self.first_stage_model.decode(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(cond) + return out + + def training_step(self, batch): + z, c, *_ = self.get_input(batch) + # print(_[0].shape) + if self.scale_by_softsign: + z = F.softsign(z) + elif self.scale_by_std: + z = self.scale_factor * z + # print('grad', self.scale_bias.grad) + loss, loss_dict = self(z, c, edge=_[0]) + return loss, loss_dict + + def q_sample3(self, x_start, t, K, C): + b, c, h, w = x_start.shape + _, nt = t.shape + K_ = K.unsqueeze(1).repeat(1, nt, 1, 1, 1) + C_ = C.unsqueeze(1).repeat(1, nt, 1, 1, 1) + x_noisy = x_start.expand(nt, b, c, h, w).transpose(1, 0) + K_ / 2 * (t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps) ** 2 \ + + C_ * t.reshape(b, nt, 1, 1, 1).repeat(1, 1, c, h, w) / self.num_timesteps + return x_noisy # (b, nt, c, h, w) + + def pred_xtms_from_xt(self, xt, noise, C, t, s): + # noise = noise / noise.std(dim=[1, 2, 3]).reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + time = t.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + s = s.reshape(C.shape[0], *((1,) * (len(C.shape) - 1))) + mean = xt - C * s - s / torch.sqrt(time) * noise + epsilon = torch.randn_like(mean, device=xt.device) + sigma = torch.sqrt(s * (time-s) / time) + xtms = mean + sigma * epsilon + return xtms + + def WCE_loss(self, prediction, labelf, beta=1.1): + label = labelf.long() + mask = labelf.clone() + + num_positive = torch.sum(label == 1).float() + num_negative = torch.sum(label == 0).float() + + mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative) + mask[label == 0] = beta * num_positive / (num_positive + num_negative) + mask[label == 2] = 0 + cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='sum') + + return cost + + def Dice_Loss(self, pred, label): + # pred = torch.sigmoid(pred) + B = pred.shape[0] + smooth = 1 + pred_flat = pred.view(B, -1) + label_flat = label.view(B, -1) + + intersecion = pred_flat * label_flat + unionsection = pred_flat.pow(2).sum(dim=-1) + label_flat.pow(2).sum(dim=-1) + smooth + loss = unionsection / (2 * intersecion.sum(dim=-1) + smooth) + loss = loss.reshape(B, 1) + return loss + + def p_losses(self, x_start, t, *args, **kwargs): + if self.start_dist == 'normal': + noise = torch.randn_like(x_start) + elif self.start_dist == 'uniform': + noise = 2 * torch.rand_like(x_start) - 1. + else: + raise NotImplementedError(f'{self.start_dist} is not supported !') + # K = -1. * torch.ones_like(x_start) + # C = noise - x_start # t = 1000 / 1000 + C = -1 * x_start # U(t) = Ct, U(1) = -x0 + # C = -2 * x_start # U(t) = 1/2 * C * t**2, U(1) = 1/2 * C = -x0 + x_noisy = self.q_sample(x_start=x_start, noise=noise, t=t, C=C) # (b, 2, c, h, w) + if self.cfg.model_name == 'cond_unet8': + C_pred, noise_pred, (e1, e2) = self.model(x_noisy, t, *args, **kwargs) + if self.cfg.model_name == 'cond_unet13': + C_pred, noise_pred, aux_C = self.model(x_noisy, t, *args, **kwargs) + else: + C_pred, noise_pred = self.model(x_noisy, t, *args, **kwargs) + # C_pred = C_pred / torch.sqrt(t) + # noise_pred = noise_pred / torch.sqrt(1 - t) + x_rec = self.pred_x0_from_xt(x_noisy, noise_pred, C_pred, t) # x_rec:(B, C, H, W) + loss_dict = {} + prefix = 'train' + + # elif self.objective == 'pred_KC': + # target1 = C + # target2 = noise + # target3 = x_start + + target1 = C + target2 = noise + target3 = x_start + + loss_simple = 0. + loss_vlb = 0. + + simple_weight1 = (t + 1) / t.sqrt() + simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt() + + # if self.weighting_loss: + # simple_weight1 = 2 * torch.exp(1 - t) + # simple_weight2 = torch.exp(torch.sqrt(t)) + # if self.cfg.model_name == 'ncsnpp9': + # simple_weight1 = (t + 1) / t.sqrt() + # simple_weight2 = (2 - t).sqrt() / (1 - t + self.eps).sqrt() + # else: + # simple_weight1 = 1 + # simple_weight2 = 1 + + loss_simple += simple_weight1 * self.get_loss(C_pred, target1, mean=False).mean([1, 2, 3]) + \ + simple_weight2 * self.get_loss(noise_pred, target2, mean=False).mean([1, 2, 3]) + + # loss_simple += self.Dice_Loss(C_pred, target1) * simple_weight1 + + if self.use_l1: + loss_simple += simple_weight1 * (C_pred - target1).abs().mean([1, 2, 3]) + \ + simple_weight2 * (noise_pred - target2).abs().mean([1, 2, 3]) + loss_simple = loss_simple / 2 + + if self.cfg.model_name == 'cond_unet8': + loss_simple += 0.05*(self.Dice_Loss(e1, (kwargs['edge'] + 1)/2) + self.Dice_Loss(e2, (kwargs['edge'] + 1)/2)) + elif self.cfg.model_name == 'cond_unet13': + loss_simple += 0.5 * (simple_weight1 * self.get_loss(aux_C, target1, mean=False).mean([1, 2, 3]) + \ + simple_weight1 * (aux_C - target1).abs().mean([1, 2, 3])) + + rec_weight = (1 - t.reshape(C.shape[0], 1)) ** 2 + # rec_weight = 1 - t.reshape(C.shape[0], 1) # (B, 1) + loss_simple = loss_simple.mean() + loss_dict.update({f'{prefix}/loss_simple': loss_simple}) + + loss_vlb += torch.abs(x_rec - target3).mean([1, 2, 3]) * rec_weight # : (B, 1) + # loss_vlb += self.Dice_Loss(x_rec, target3) * rec_weight + + # loss_vlb = loss_vlb + loss_vlb = loss_vlb.mean() + + if self.cfg.get('use_disloss', False): + with torch.no_grad(): + edge_rec = self.first_stage_model.decode(x_rec / self.scale_factor) + edge_rec = unnormalize_to_zero_to_one(edge_rec) + edge_rec = torch.clamp(edge_rec, min=0., max=1.) # B, 1, 320, 320 + loss_tmp = self.cross_entropy_loss_RCF(edge_rec, (kwargs['edge'] + 1)/2) * rec_weight # B, 1 + loss_ce = SpecifyGradient.apply(x_rec, loss_tmp.mean()) + # print(loss_ce.shape) + # print(loss_vlb.shape) + loss_vlb += loss_ce.mean() + loss_dict.update({f'{prefix}/loss_vlb': loss_vlb}) + + loss = loss_simple + loss_vlb + loss_dict.update({f'{prefix}/loss': loss}) + + return loss, loss_dict + + def get_loss(self, pred, target, mean=True): + if self.loss_type == 'l1': + loss = (target - pred).abs() + if mean: + loss = loss.mean() + elif self.loss_type == 'l2': + if mean: + loss = torch.nn.functional.mse_loss(target, pred) + else: + loss = torch.nn.functional.mse_loss(target, pred, reduction='none') + else: + raise NotImplementedError("unknown loss type '{loss_type}'") + + return loss + + def cross_entropy_loss_RCF(self, prediction, labelf, beta=1.1): + # label = labelf.long() + label = labelf + mask = labelf.clone() + + num_positive = torch.sum(label == 1).float() + num_negative = torch.sum(label == 0).float() + + mask_temp = (label > 0) & (label <= 0.3) + mask[mask_temp] = 0. + + mask[label == 1] = 1.0 * num_negative / (num_positive + num_negative) + mask[label == 0] = beta * num_positive / (num_positive + num_negative) + + # mask[label == 2] = 0 + cost = F.binary_cross_entropy(prediction, labelf, weight=mask, reduction='none') + return cost.mean([1, 2, 3]) + + @torch.no_grad() + def sample(self, batch_size=16, up_scale=1, cond=None, mask=None, denoise=True): + # image_size, channels = self.image_size, self.channels + channels = self.channels + image_size = cond.shape[-2:] + if cond is not None: + batch_size = cond.shape[0] + down_ratio = self.first_stage_model.down_ratio + if self.cfg.model_name == 'cond_unet8' or self.cfg.model_name == 'cond_unet13': + z, aux_out = self.sample_fn((batch_size, channels, image_size[0] // down_ratio, image_size[1] // down_ratio), + up_scale=up_scale, unnormalize=False, cond=cond, denoise=denoise) + else: + z = self.sample_fn((batch_size, channels, image_size[0]//down_ratio, image_size[1]//down_ratio), + up_scale=up_scale, unnormalize=False, cond=cond, denoise=denoise) + aux_out = None + + if self.scale_by_std: + z = 1. / self.scale_factor * z.detach() + if self.cfg.model_name == 'cond_unet13': + aux_out = 1. / self.scale_factor * aux_out.detach() + elif self.scale_by_softsign: + z = z / (1 - z.abs()) + z = z.detach() + #print(z.shape) + x_rec = self.first_stage_model.decode(z) + x_rec = unnormalize_to_zero_to_one(x_rec) + x_rec = torch.clamp(x_rec, min=0., max=1.) + if self.cfg.model_name == 'cond_unet13': + aux_out = self.first_stage_model.decode(aux_out) + aux_out = unnormalize_to_zero_to_one(aux_out) + aux_out = torch.clamp(aux_out, min=0., max=1.) + if mask is not None: + x_rec = mask * unnormalize_to_zero_to_one(cond) + (1 - mask) * x_rec + if aux_out is not None: + return x_rec, aux_out + return x_rec + + @torch.no_grad() + def sample_fn(self, shape, up_scale=1, unnormalize=True, cond=None, denoise=False): + batch, device, total_timesteps, sampling_timesteps, eta, objective = shape[0], \ + self.betas.device, self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective + + # times = torch.linspace(-1, total_timesteps, steps=self.sampling_timesteps + 1).int() + # times = list(reversed(times.int().tolist())) + # time_pairs = list(zip(times[:-1], times[1:])) + # time_steps = torch.tensor([0.25, 0.15, 0.1, 0.1, 0.1, 0.09, 0.075, 0.06, 0.045, 0.03]) + step = 1. / self.sampling_timesteps + # time_steps = torch.tensor([0.1]).repeat(10) + time_steps = torch.tensor([step]).repeat(self.sampling_timesteps) + if denoise: + eps = self.eps + time_steps = torch.cat((time_steps[:-1], torch.tensor([step - eps]), torch.tensor([eps])), dim=0) + + if self.start_dist == 'normal': + img = torch.randn(shape, device=device) + elif self.start_dist == 'uniform': + img = 2 * torch.rand(shape, device=device) - 1. + else: + raise NotImplementedError(f'{self.start_dist} is not supported !') + img = F.interpolate(img, scale_factor=up_scale, mode='bilinear', align_corners=True) + img_aux = F.interpolate(img.clone(), scale_factor=up_scale, mode='bilinear', align_corners=True) + # img_aux = img.clone() + # K = -1 * torch.ones_like(img) + cur_time = torch.ones((batch,), device=device) + for i, time_step in enumerate(time_steps): + s = torch.full((batch,), time_step, device=device) + if i == time_steps.shape[0] - 1: + s = cur_time + if cond is not None: + pred = self.model(img, cur_time, cond) + else: + pred = self.model(img, cur_time) + # C, noise = pred.chunk(2, dim=1) + C, noise = pred[:2] + if self.cfg.model_name == 'cond_unet8' or self.cfg.model_name == 'cond_unet13': + aux_out = pred[-1] + else: + aux_out = None + # if self.scale_by_softsign: + # # correct the C for softsign + # x0 = self.pred_x0_from_xt(img, noise, C, cur_time) + # x0 = torch.clamp(x0, min=-0.987654321, max=0.987654321) + # C = -x0 + # correct C + x0 = self.pred_x0_from_xt(img, noise, C, cur_time) + C = -1 * x0 + img = self.pred_xtms_from_xt(img, noise, C, cur_time, s) + # if self.cfg.model_name == 'cond_unet13' and i == len(time_steps) - 2: + # img_aux = img + # if self.cfg.model_name == 'cond_unet13' and i in [len(time_steps)-2, len(time_steps)-1]: + # x0_aux = self.pred_x0_from_xt(img_aux, noise, aux_out, cur_time) + # C_aux = -1 * x0_aux + # img_aux = self.pred_xtms_from_xt(img_aux, noise, C_aux, cur_time, s) + if self.cfg.model_name == 'cond_unet13': + for _ in range(1): + x0_aux = self.pred_x0_from_xt(img_aux, noise, aux_out, cur_time) + C_aux = -1 * x0_aux + img_aux = self.pred_xtms_from_xt(img_aux, noise, C_aux, cur_time, s) + cur_time = cur_time - s + if self.scale_by_softsign: + img.clamp_(-0.987654321, 0.987654321) + if unnormalize: + img = unnormalize_to_zero_to_one(img) + if self.cfg.model_name == 'cond_unet13': + aux_out = img_aux + if aux_out is not None: + return img, aux_out + return img + +class SpecifyGradient(torch.autograd.Function): + @staticmethod + @custom_fwd + def forward(ctx, input_tensor, gt_grad): + ctx.save_for_backward(gt_grad) + # we return a dummy value 1, which will be scaled by amp's scaler so we get the scale in backward. + return torch.ones(input_tensor.shape, device=input_tensor.device, dtype=input_tensor.dtype) + + @staticmethod + @custom_bwd + def backward(ctx, grad_scale): + (gt_grad,) = ctx.saved_tensors + gt_grad = gt_grad * grad_scale + return gt_grad, None + +if __name__ == "__main__": + ddconfig = {'double_z': True, + 'z_channels': 4, + 'resolution': (240, 960), + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [1, 2, 4, 4], # num_down = len(ch_mult)-1 + 'num_res_blocks': 2, + 'attn_resolutions': [], + 'dropout': 0.0} + lossconfig = {'disc_start': 50001, + 'kl_weight': 0.000001, + 'disc_weight': 0.5} + from encoder_decoder import AutoencoderKL + auto_encoder = AutoencoderKL(ddconfig, lossconfig, embed_dim=4, + ) + from mask_cond_unet import Unet + unet = Unet(dim=64, dim_mults=(1, 2, 4, 8), channels=4, cond_in_dim=1,) + ldm = LatentDiffusion(auto_encoder=auto_encoder, model=unet, image_size=ddconfig['resolution']) + image = torch.rand(1, 3, 128, 128) + mask = torch.rand(1, 1, 128, 128) + input = {'image': image, 'cond': mask} + time = torch.tensor([1]) + with torch.no_grad(): + y = ldm.training_step(input) + pass \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..ee49d7ab757febe8ef840020e0b4ae67f7311422 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/efficientnet.py @@ -0,0 +1,1130 @@ +import copy +import math +import warnings +from dataclasses import dataclass +from functools import partial +from typing import Any, Callable, Dict, Optional, List, Sequence, Tuple, Union + +import torch +from torch import nn, Tensor +from torchvision.ops import StochasticDepth + +from torchvision.ops.misc import Conv2dNormActivation, SqueezeExcitation +from torchvision.transforms._presets import ImageClassification, InterpolationMode +from torchvision.utils import _log_api_usage_once +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param, _make_divisible + + +__all__ = [ + "EfficientNet", + "EfficientNet_B0_Weights", + "EfficientNet_B1_Weights", + "EfficientNet_B2_Weights", + "EfficientNet_B3_Weights", + "EfficientNet_B4_Weights", + "EfficientNet_B5_Weights", + "EfficientNet_B6_Weights", + "EfficientNet_B7_Weights", + "EfficientNet_V2_S_Weights", + "EfficientNet_V2_M_Weights", + "EfficientNet_V2_L_Weights", + "efficientnet_b0", + "efficientnet_b1", + "efficientnet_b2", + "efficientnet_b3", + "efficientnet_b4", + "efficientnet_b5", + "efficientnet_b6", + "efficientnet_b7", + "efficientnet_v2_s", + "efficientnet_v2_m", + "efficientnet_v2_l", +] + + +@dataclass +class _MBConvConfig: + expand_ratio: float + kernel: int + stride: int + input_channels: int + out_channels: int + num_layers: int + block: Callable[..., nn.Module] + + @staticmethod + def adjust_channels(channels: int, width_mult: float, min_value: Optional[int] = None) -> int: + return _make_divisible(channels * width_mult, 8, min_value) + + +class MBConvConfig(_MBConvConfig): + # Stores information listed at Table 1 of the EfficientNet paper & Table 4 of the EfficientNetV2 paper + def __init__( + self, + expand_ratio: float, + kernel: int, + stride: int, + input_channels: int, + out_channels: int, + num_layers: int, + width_mult: float = 1.0, + depth_mult: float = 1.0, + block: Optional[Callable[..., nn.Module]] = None, + ) -> None: + input_channels = self.adjust_channels(input_channels, width_mult) + out_channels = self.adjust_channels(out_channels, width_mult) + num_layers = self.adjust_depth(num_layers, depth_mult) + if block is None: + block = MBConv + super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) + + @staticmethod + def adjust_depth(num_layers: int, depth_mult: float): + return int(math.ceil(num_layers * depth_mult)) + + +class FusedMBConvConfig(_MBConvConfig): + # Stores information listed at Table 4 of the EfficientNetV2 paper + def __init__( + self, + expand_ratio: float, + kernel: int, + stride: int, + input_channels: int, + out_channels: int, + num_layers: int, + block: Optional[Callable[..., nn.Module]] = None, + ) -> None: + if block is None: + block = FusedMBConv + super().__init__(expand_ratio, kernel, stride, input_channels, out_channels, num_layers, block) + + +class MBConv(nn.Module): + def __init__( + self, + cnf: MBConvConfig, + stochastic_depth_prob: float, + norm_layer: Callable[..., nn.Module], + se_layer: Callable[..., nn.Module] = SqueezeExcitation, + ) -> None: + super().__init__() + + if not (1 <= cnf.stride <= 2): + raise ValueError("illegal stride value") + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + # expand + expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) + if expanded_channels != cnf.input_channels: + layers.append( + Conv2dNormActivation( + cnf.input_channels, + expanded_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + # depthwise + layers.append( + Conv2dNormActivation( + expanded_channels, + expanded_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + groups=expanded_channels, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + # squeeze and excitation + squeeze_channels = max(1, cnf.input_channels // 4) + layers.append(se_layer(expanded_channels, squeeze_channels, activation=partial(nn.SiLU, inplace=True))) + + # project + layers.append( + Conv2dNormActivation( + expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None + ) + ) + + self.block = nn.Sequential(*layers) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result = self.stochastic_depth(result) + result += input + return result + + +class FusedMBConv(nn.Module): + def __init__( + self, + cnf: FusedMBConvConfig, + stochastic_depth_prob: float, + norm_layer: Callable[..., nn.Module], + ) -> None: + super().__init__() + + if not (1 <= cnf.stride <= 2): + raise ValueError("illegal stride value") + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.SiLU + + expanded_channels = cnf.adjust_channels(cnf.input_channels, cnf.expand_ratio) + if expanded_channels != cnf.input_channels: + # fused expand + layers.append( + Conv2dNormActivation( + cnf.input_channels, + expanded_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + # project + layers.append( + Conv2dNormActivation( + expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=None + ) + ) + else: + layers.append( + Conv2dNormActivation( + cnf.input_channels, + cnf.out_channels, + kernel_size=cnf.kernel, + stride=cnf.stride, + norm_layer=norm_layer, + activation_layer=activation_layer, + ) + ) + + self.block = nn.Sequential(*layers) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.out_channels = cnf.out_channels + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result = self.stochastic_depth(result) + result += input + return result + + +class EfficientNet(nn.Module): + def __init__( + self, + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], + dropout: float, + stochastic_depth_prob: float = 0.2, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + last_channel: Optional[int] = None, + **kwargs: Any, + ) -> None: + """ + EfficientNet V1 and V2 main class + + Args: + inverted_residual_setting (Sequence[Union[MBConvConfig, FusedMBConvConfig]]): Network structure + dropout (float): The droupout probability + stochastic_depth_prob (float): The stochastic depth probability + num_classes (int): Number of classes + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + last_channel (int): The number of channels on the penultimate layer + """ + super().__init__() + _log_api_usage_once(self) + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not ( + isinstance(inverted_residual_setting, Sequence) + and all([isinstance(s, _MBConvConfig) for s in inverted_residual_setting]) + ): + raise TypeError("The inverted_residual_setting should be List[MBConvConfig]") + + if "block" in kwargs: + warnings.warn( + "The parameter 'block' is deprecated since 0.13 and will be removed 0.15. " + "Please pass this information on 'MBConvConfig.block' instead." + ) + if kwargs["block"] is not None: + for s in inverted_residual_setting: + if isinstance(s, MBConvConfig): + s.block = kwargs["block"] + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + # layers.append( + # Conv2dNormActivation( + # 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU + # ) + # ) + self.first_coonv = Conv2dNormActivation( + 3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, activation_layer=nn.SiLU + ) + + # building inverted residual blocks + total_stage_blocks = sum(cnf.num_layers for cnf in inverted_residual_setting) + stage_block_id = 0 + for cnf in inverted_residual_setting: + stage: List[nn.Module] = [] + for _ in range(cnf.num_layers): + # copy to avoid modifications. shallow copy is enough + block_cnf = copy.copy(cnf) + + # overwrite info if not the first conv in the stage + if stage: + block_cnf.input_channels = block_cnf.out_channels + block_cnf.stride = 1 + + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / total_stage_blocks + + stage.append(block_cnf.block(block_cnf, sd_prob, norm_layer)) + stage_block_id += 1 + + layers.append(nn.Sequential(*stage)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = last_channel if last_channel is not None else 4 * lastconv_input_channels + layers.append( + Conv2dNormActivation( + lastconv_input_channels, + lastconv_output_channels, + kernel_size=1, + norm_layer=norm_layer, + activation_layer=nn.SiLU, + ) + ) + # self.last_conv = Conv2dNormActivation( + # lastconv_input_channels, + # lastconv_output_channels, + # kernel_size=1, + # norm_layer=norm_layer, + # activation_layer=nn.SiLU, + # ) + + # self.features = nn.Sequential(*layers) + self.features = nn.ModuleList(layers) + # self.avgpool = nn.AdaptiveAvgPool2d(1) + # self.classifier = nn.Sequential( + # nn.Dropout(p=dropout, inplace=True), + # nn.Linear(lastconv_output_channels, num_classes), + # ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out") + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + init_range = 1.0 / math.sqrt(m.out_features) + nn.init.uniform_(m.weight, -init_range, init_range) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor): + x = self.first_coonv(x) + # x = self.features(x) + feats = [] + for i, layer in enumerate(self.features): + x = layer(x) + if i in [1, 2, 4, 6]: + feats.append(x) + + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # + # x = self.classifier(x) + + return feats + + def forward(self, x: Tensor): + return self._forward_impl(x) + + +def _efficientnet( + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]], + dropout: float, + last_channel: Optional[int], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> EfficientNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = EfficientNet(inverted_residual_setting, dropout, last_channel=last_channel, **kwargs) + + if weights is not None: + ckpt1 = weights.get_state_dict(progress=progress) + ckpt2 = model.state_dict() + kl1 = list(ckpt1.keys()) + for i, k in enumerate(list(ckpt2.keys())): + ckpt2[k] = ckpt1[kl1[i]] + msg = model.load_state_dict(ckpt2, strict=False) + print(f'Load EfficientNet: {msg}') + else: + print('No pretrained weight loaded!') + return model + + +def _efficientnet_conf( + arch: str, + **kwargs: Any, +) -> Tuple[Sequence[Union[MBConvConfig, FusedMBConvConfig]], Optional[int]]: + inverted_residual_setting: Sequence[Union[MBConvConfig, FusedMBConvConfig]] + if arch.startswith("efficientnet_b"): + bneck_conf = partial(MBConvConfig, width_mult=kwargs.pop("width_mult"), depth_mult=kwargs.pop("depth_mult")) + inverted_residual_setting = [ + bneck_conf(1, 3, 1, 32, 16, 1), + bneck_conf(6, 3, 2, 16, 24, 2), + bneck_conf(6, 5, 2, 24, 40, 2), + bneck_conf(6, 3, 2, 40, 80, 3), + bneck_conf(6, 5, 1, 80, 112, 3), + bneck_conf(6, 5, 2, 112, 192, 4), + bneck_conf(6, 3, 1, 192, 320, 1), + ] + last_channel = None + elif arch.startswith("efficientnet_v2_s"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 24, 24, 2), + FusedMBConvConfig(4, 3, 2, 24, 48, 4), + FusedMBConvConfig(4, 3, 2, 48, 64, 4), + MBConvConfig(4, 3, 2, 64, 128, 6), + MBConvConfig(6, 3, 1, 128, 160, 9), + MBConvConfig(6, 3, 2, 160, 256, 15), + ] + last_channel = 1280 + elif arch.startswith("efficientnet_v2_m"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 24, 24, 3), + FusedMBConvConfig(4, 3, 2, 24, 48, 5), + FusedMBConvConfig(4, 3, 2, 48, 80, 5), + MBConvConfig(4, 3, 2, 80, 160, 7), + MBConvConfig(6, 3, 1, 160, 176, 14), + MBConvConfig(6, 3, 2, 176, 304, 18), + MBConvConfig(6, 3, 1, 304, 512, 5), + ] + last_channel = 1280 + elif arch.startswith("efficientnet_v2_l"): + inverted_residual_setting = [ + FusedMBConvConfig(1, 3, 1, 32, 32, 4), + FusedMBConvConfig(4, 3, 2, 32, 64, 7), + FusedMBConvConfig(4, 3, 2, 64, 96, 7), + MBConvConfig(4, 3, 2, 96, 192, 10), + MBConvConfig(6, 3, 1, 192, 224, 19), + MBConvConfig(6, 3, 2, 224, 384, 25), + MBConvConfig(6, 3, 1, 384, 640, 7), + ] + last_channel = 1280 + else: + raise ValueError(f"Unsupported model type {arch}") + + return inverted_residual_setting, last_channel + + +_COMMON_META: Dict[str, Any] = { + "categories": _IMAGENET_CATEGORIES, +} + + +_COMMON_META_V1 = { + **_COMMON_META, + "min_size": (1, 1), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v1", +} + + +_COMMON_META_V2 = { + **_COMMON_META, + "min_size": (33, 33), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#efficientnet-v2", +} + + +class EfficientNet_B0_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + url="https://download.pytorch.org/models/efficientnet_b0_rwightman-3dd342df.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 5288548, + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.692, + "acc@5": 93.532, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B1_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + url="https://download.pytorch.org/models/efficientnet_b1_rwightman-533bc792.pth", + transforms=partial( + ImageClassification, crop_size=240, resize_size=256, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.642, + "acc@5": 94.186, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/efficientnet_b1-c27df63c.pth", + transforms=partial( + ImageClassification, crop_size=240, resize_size=255, interpolation=InterpolationMode.BILINEAR + ), + meta={ + **_COMMON_META_V1, + "num_params": 7794184, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-lr-wd-crop-tuning", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.838, + "acc@5": 94.934, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class EfficientNet_B2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + url="https://download.pytorch.org/models/efficientnet_b2_rwightman-bcdf34b7.pth", + transforms=partial( + ImageClassification, crop_size=288, resize_size=288, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 9109994, + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.608, + "acc@5": 95.310, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B3_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + url="https://download.pytorch.org/models/efficientnet_b3_rwightman-cf984f9c.pth", + transforms=partial( + ImageClassification, crop_size=300, resize_size=320, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 12233232, + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.008, + "acc@5": 96.054, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B4_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/rwightman/pytorch-image-models/ + url="https://download.pytorch.org/models/efficientnet_b4_rwightman-7eb33cd5.pth", + transforms=partial( + ImageClassification, crop_size=380, resize_size=384, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 19341616, + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.384, + "acc@5": 96.594, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B5_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ + url="https://download.pytorch.org/models/efficientnet_b5_lukemelas-b6417697.pth", + transforms=partial( + ImageClassification, crop_size=456, resize_size=456, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 30389784, + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.444, + "acc@5": 96.628, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B6_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ + url="https://download.pytorch.org/models/efficientnet_b6_lukemelas-c76e70fd.pth", + transforms=partial( + ImageClassification, crop_size=528, resize_size=528, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 43040704, + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.008, + "acc@5": 96.916, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_B7_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + # Weights ported from https://github.com/lukemelas/EfficientNet-PyTorch/ + url="https://download.pytorch.org/models/efficientnet_b7_lukemelas-dcc49843.pth", + transforms=partial( + ImageClassification, crop_size=600, resize_size=600, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META_V1, + "num_params": 66347960, + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.122, + "acc@5": 96.908, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_s-dd5fe13b.pth", + transforms=partial( + ImageClassification, + crop_size=384, + resize_size=384, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 21458488, + "_metrics": { + "ImageNet-1K": { + "acc@1": 84.228, + "acc@5": 96.878, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_M_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_m-dc08266a.pth", + transforms=partial( + ImageClassification, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BILINEAR, + ), + meta={ + **_COMMON_META_V2, + "num_params": 54139356, + "_metrics": { + "ImageNet-1K": { + "acc@1": 85.112, + "acc@5": 97.156, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using a modified version of TorchVision's + `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class EfficientNet_V2_L_Weights(WeightsEnum): + # Weights ported from https://github.com/google/automl/tree/master/efficientnetv2 + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/efficientnet_v2_l-59c71312.pth", + transforms=partial( + ImageClassification, + crop_size=480, + resize_size=480, + interpolation=InterpolationMode.BICUBIC, + mean=(0.5, 0.5, 0.5), + std=(0.5, 0.5, 0.5), + ), + meta={ + **_COMMON_META_V2, + "num_params": 118515272, + "_metrics": { + "ImageNet-1K": { + "acc@1": 85.808, + "acc@5": 97.788, + } + }, + "_docs": """These weights are ported from the original paper.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B0_Weights.IMAGENET1K_V1)) +def efficientnet_b0( + *, weights: Optional[EfficientNet_B0_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B0 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B0_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B0_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B0_Weights + :members: + """ + weights = EfficientNet_B0_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b0", width_mult=1.0, depth_mult=1.0) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B1_Weights.IMAGENET1K_V1)) +def efficientnet_b1( + *, weights: Optional[EfficientNet_B1_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B1 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B1_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B1_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B1_Weights + :members: + """ + weights = EfficientNet_B1_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b1", width_mult=1.0, depth_mult=1.1) + return _efficientnet(inverted_residual_setting, 0.2, last_channel, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B2_Weights.IMAGENET1K_V1)) +def efficientnet_b2( + *, weights: Optional[EfficientNet_B2_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B2 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B2_Weights + :members: + """ + weights = EfficientNet_B2_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b2", width_mult=1.1, depth_mult=1.2) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B3_Weights.IMAGENET1K_V1)) +def efficientnet_b3( + *, weights: Optional[EfficientNet_B3_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B3 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B3_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B3_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B3_Weights + :members: + """ + weights = EfficientNet_B3_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b3", width_mult=1.2, depth_mult=1.4) + return _efficientnet(inverted_residual_setting, 0.3, last_channel, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B4_Weights.IMAGENET1K_V1)) +def efficientnet_b4( + *, weights: Optional[EfficientNet_B4_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B4 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B4_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B4_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B4_Weights + :members: + """ + weights = EfficientNet_B4_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b4", width_mult=1.4, depth_mult=1.8) + return _efficientnet(inverted_residual_setting, 0.4, last_channel, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B5_Weights.IMAGENET1K_V1)) +def efficientnet_b5( + *, weights: Optional[EfficientNet_B5_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B5 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B5_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B5_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B5_Weights + :members: + """ + weights = EfficientNet_B5_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b5", width_mult=1.6, depth_mult=2.2) + return _efficientnet( + inverted_residual_setting, + 0.4, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B6_Weights.IMAGENET1K_V1)) +def efficientnet_b6( + *, weights: Optional[EfficientNet_B6_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B6 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B6_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B6_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B6_Weights + :members: + """ + weights = EfficientNet_B6_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b6", width_mult=1.8, depth_mult=2.6) + return _efficientnet( + inverted_residual_setting, + 0.5, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_B7_Weights.IMAGENET1K_V1)) +def efficientnet_b7( + *, weights: Optional[EfficientNet_B7_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """EfficientNet B7 model architecture from the `EfficientNet: Rethinking Model Scaling for Convolutional + Neural Networks `_ paper. + + Args: + weights (:class:`~torchvision.models.EfficientNet_B7_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_B7_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_B7_Weights + :members: + """ + weights = EfficientNet_B7_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_b7", width_mult=2.0, depth_mult=3.1) + return _efficientnet( + inverted_residual_setting, + 0.5, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=0.001, momentum=0.01), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_S_Weights.IMAGENET1K_V1)) +def efficientnet_v2_s( + *, weights: Optional[EfficientNet_V2_S_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """ + Constructs an EfficientNetV2-S architecture from + `EfficientNetV2: Smaller Models and Faster Training `_. + + Args: + weights (:class:`~torchvision.models.EfficientNet_V2_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_V2_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_V2_S_Weights + :members: + """ + weights = EfficientNet_V2_S_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_s") + return _efficientnet( + inverted_residual_setting, + 0.2, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_M_Weights.IMAGENET1K_V1)) +def efficientnet_v2_m( + *, weights: Optional[EfficientNet_V2_M_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """ + Constructs an EfficientNetV2-M architecture from + `EfficientNetV2: Smaller Models and Faster Training `_. + + Args: + weights (:class:`~torchvision.models.EfficientNet_V2_M_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_V2_M_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_V2_M_Weights + :members: + """ + weights = EfficientNet_V2_M_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_m") + return _efficientnet( + inverted_residual_setting, + 0.3, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +@handle_legacy_interface(weights=("pretrained", EfficientNet_V2_L_Weights.IMAGENET1K_V1)) +def efficientnet_v2_l( + *, weights: Optional[EfficientNet_V2_L_Weights] = None, progress: bool = True, **kwargs: Any +) -> EfficientNet: + """ + Constructs an EfficientNetV2-L architecture from + `EfficientNetV2: Smaller Models and Faster Training `_. + + Args: + weights (:class:`~torchvision.models.EfficientNet_V2_L_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.EfficientNet_V2_L_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.efficientnet.EfficientNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.EfficientNet_V2_L_Weights + :members: + """ + weights = EfficientNet_V2_L_Weights.verify(weights) + + inverted_residual_setting, last_channel = _efficientnet_conf("efficientnet_v2_l") + return _efficientnet( + inverted_residual_setting, + 0.4, + last_channel, + weights, + progress, + norm_layer=partial(nn.BatchNorm2d, eps=1e-03), + **kwargs, + ) + + +# The dictionary below is internal implementation detail and will be removed in v0.15 +from torchvision.models._utils import _ModelURLs + + +model_urls = _ModelURLs( + { + "efficientnet_b0": EfficientNet_B0_Weights.IMAGENET1K_V1.url, + "efficientnet_b1": EfficientNet_B1_Weights.IMAGENET1K_V1.url, + "efficientnet_b2": EfficientNet_B2_Weights.IMAGENET1K_V1.url, + "efficientnet_b3": EfficientNet_B3_Weights.IMAGENET1K_V1.url, + "efficientnet_b4": EfficientNet_B4_Weights.IMAGENET1K_V1.url, + "efficientnet_b5": EfficientNet_B5_Weights.IMAGENET1K_V1.url, + "efficientnet_b6": EfficientNet_B6_Weights.IMAGENET1K_V1.url, + "efficientnet_b7": EfficientNet_B7_Weights.IMAGENET1K_V1.url, + } +) diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..0267f9d92635e19d2f32e76d53ff1a22227eb025 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/ema.py @@ -0,0 +1,191 @@ +import copy +import torch +from torch import nn + + +def exists(val): + return val is not None + + +def clamp(value, min_value=None, max_value=None): + assert exists(min_value) or exists(max_value) + if exists(min_value): + value = max(value, min_value) + + if exists(max_value): + value = min(value, max_value) + + return value + + +class EMA(nn.Module): + """ + Implements exponential moving average shadowing for your model. + + Utilizes an inverse decay schedule to manage longer term training runs. + By adjusting the power, you can control how fast EMA will ramp up to your specified beta. + + @crowsonkb's notes on EMA Warmup: + + If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are + good values for models you plan to train for a million or more steps (reaches decay + factor 0.999 at 31.6K steps, 0.9999 at 1M steps), gamma=1, power=3/4 for models + you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at + 215.4k steps). + + Args: + inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1. + power (float): Exponential factor of EMA warmup. Default: 1. + min_value (float): The minimum EMA decay rate. Default: 0. + """ + + def __init__( + self, + model, + ema_model=None, + # if your model has lazylinears or other types of non-deepcopyable modules, you can pass in your own ema model + beta=0.9999, + update_after_step=100, + update_every=10, + inv_gamma=1.0, + power=2 / 3, + min_value=0.0, + param_or_buffer_names_no_ema=set(), + ignore_names=set(), + ignore_startswith_names=set(), + include_online_model=True + # set this to False if you do not wish for the online model to be saved along with the ema model (managed externally) + ): + super().__init__() + self.beta = beta + + # whether to include the online model within the module tree, so that state_dict also saves it + + self.include_online_model = include_online_model + + if include_online_model: + self.online_model = model + else: + self.online_model = [model] # hack + + # ema model + + self.ema_model = ema_model + + if not exists(self.ema_model): + try: + self.ema_model = copy.deepcopy(model) + except: + print('Your model was not copyable. Please make sure you are not using any LazyLinear') + exit() + + self.ema_model.requires_grad_(False) + + self.parameter_names = {name for name, param in self.ema_model.named_parameters() if param.dtype == torch.float} + self.buffer_names = {name for name, buffer in self.ema_model.named_buffers() if buffer.dtype == torch.float} + + self.update_every = update_every + self.update_after_step = update_after_step + + self.inv_gamma = inv_gamma + self.power = power + self.min_value = min_value + + assert isinstance(param_or_buffer_names_no_ema, (set, list)) + self.param_or_buffer_names_no_ema = param_or_buffer_names_no_ema # parameter or buffer + + self.ignore_names = ignore_names + self.ignore_startswith_names = ignore_startswith_names + + self.register_buffer('initted', torch.Tensor([False])) + self.register_buffer('step', torch.tensor([0])) + + @property + def model(self): + return self.online_model if self.include_online_model else self.online_model[0] + + def restore_ema_model_device(self): + device = self.initted.device + self.ema_model.to(device) + + def get_params_iter(self, model): + for name, param in model.named_parameters(): + if name not in self.parameter_names: + continue + yield name, param + + def get_buffers_iter(self, model): + for name, buffer in model.named_buffers(): + if name not in self.buffer_names: + continue + yield name, buffer + + def copy_params_from_model_to_ema(self): + for (_, ma_params), (_, current_params) in zip(self.get_params_iter(self.ema_model), + self.get_params_iter(self.model)): + ma_params.data.copy_(current_params.data) + + for (_, ma_buffers), (_, current_buffers) in zip(self.get_buffers_iter(self.ema_model), + self.get_buffers_iter(self.model)): + ma_buffers.data.copy_(current_buffers.data) + + def get_current_decay(self): + epoch = clamp(self.step.item() - self.update_after_step - 1, min_value=0.) + value = 1 - (1 + epoch / self.inv_gamma) ** - self.power + + if epoch <= 0: + return 0. + + return clamp(value, min_value=self.min_value, max_value=self.beta) + + def update(self): + step = self.step.item() + self.step += 1 + + if (step % self.update_every) != 0: + return + + if step <= self.update_after_step: + self.copy_params_from_model_to_ema() + return + + if not self.initted.item(): + self.copy_params_from_model_to_ema() + self.initted.data.copy_(torch.Tensor([True])) + + self.update_moving_average(self.ema_model, self.model) + + @torch.no_grad() + def update_moving_average(self, ma_model, current_model): + current_decay = self.get_current_decay() + + for (name, current_params), (_, ma_params) in zip(self.get_params_iter(current_model), + self.get_params_iter(ma_model)): + if name in self.ignore_names: + continue + + if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): + continue + + if name in self.param_or_buffer_names_no_ema: + ma_params.data.copy_(current_params.data) + continue + + ma_params.data.lerp_(current_params.data, 1. - current_decay) + + for (name, current_buffer), (_, ma_buffer) in zip(self.get_buffers_iter(current_model), + self.get_buffers_iter(ma_model)): + if name in self.ignore_names: + continue + + if any([name.startswith(prefix) for prefix in self.ignore_startswith_names]): + continue + + if name in self.param_or_buffer_names_no_ema: + ma_buffer.data.copy_(current_buffer.data) + continue + + ma_buffer.data.lerp_(current_buffer.data, 1. - current_decay) + + def __call__(self, *args, **kwargs): + return self.ema_model(*args, **kwargs) diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d35e8da2cb64e2512ddb4031e120c8b7aed78832 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/encoder_decoder.py @@ -0,0 +1,1086 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange +from .loss import LPIPSWithDiscriminator + +# from ldm.util import instantiate_from_config +# from ldm.modules.attention import LinearAttention + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t=None, context=None): + #assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = (curr_res[0] // 2, curr_res[1] // 2) + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = (resolution[0] // 2**(self.num_resolutions-1), resolution[1] // 2**(self.num_resolutions-1)) + self.z_shape = (1,z_channels,curr_res[0],curr_res[1]) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = (curr_res[0] * 2, curr_res[1] * 2) + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d(in_channels, + mid_channels, + kernel_size=3, + stride=1, + padding=1) + self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0) for _ in range(depth)]) + + self.conv_out = nn.Conv2d(mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor)))) + x = self.attn(x) + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, + z_channels=intermediate_chn, double_z=False, resolution=resolution, + attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, + out_ch=None) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn, + mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8), + dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1): + super().__init__() + tmp_chn = z_channels*ch_mult[-1] + self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, + resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, + ch_mult=ch_mult, resolution=resolution, ch=ch) + self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, + out_channels=tmp_chn, depth=rescale_module_depth) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size//in_size))+1 + factor_up = 1.+ (out_size % in_size) + print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}") + self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels, + out_channels=in_channels) + self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, + attn_resolutions=[], in_channels=None, ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)]) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode") + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=4, + stride=2, + padding=1) + + def forward(self, x, scale_factor=1.0): + if scale_factor==1.0: + return x + else: + x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor) + return x + +class FirstStagePostProcessor(nn.Module): + + def __init__(self, ch_mult:list, in_channels, + pretrained_model:nn.Module=None, + reshape=False, + n_channels=None, + dropout=0., + pretrained_config=None): + super().__init__() + if pretrained_config is None: + assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels,num_groups=in_channels//2) + self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3, + stride=1,padding=1) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout)) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + + @torch.no_grad() + def encode_with_pretrained(self,x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self,x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model,self.downsampler): + z = submodel(z,temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z,'b c h w -> b (h w) c') + return z + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + +class AutoencoderKL(nn.Module): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.down_ratio = 2 ** (len(ddconfig['ch_mult']) - 1) + self.loss = LPIPSWithDiscriminator(**lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list(), use_ema=True): + sd = torch.load(path, map_location="cpu") + sd_keys = sd.keys() + if 'ema' in list(sd.keys()) and use_ema: + sd = sd['ema'] + new_sd = {} + for k in sd.keys(): + if k.startswith("ema_model."): + new_k = k[10:] # remove ema_model. + new_sd[new_k] = sd[k] + sd = new_sd + else: + if 'model' in sd_keys: + sd = sd["model"] + elif 'state_dict' in sd_keys: + sd = sd['state_dict'] + else: + sd = sd + # raise ValueError("") + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + msg = self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + print('==>Load AutoEncoder Info: ', msg) + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, inputs, optimizer_idx, global_step): + # inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, global_step, + last_layer=self.get_last_layer(), split="train") + # self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + # self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss, log_dict_ae + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, global_step, + last_layer=self.get_last_layer(), split="train") + + # self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + # self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss, log_dict_disc + + def validation_step(self, inputs, global_step): + # inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, global_step, + last_layer=self.get_last_layer(), split="val") + + # self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + # self.log_dict(log_dict_ae) + # self.log_dict(log_dict_disc) + return log_dict_ae, log_dict_disc + + def validate_img(self, inputs): + reconstructions, posterior = self(inputs) + return reconstructions + + # def configure_optimizers(self): + # lr = self.learning_rate + # opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + # list(self.decoder.parameters())+ + # list(self.quant_conv.parameters())+ + # list(self.post_quant_conv.parameters()), + # lr=lr, betas=(0.5, 0.9)) + # opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + # lr=lr, betas=(0.5, 0.9)) + # return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + ''' + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = nn.functional.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + ''' + +if __name__ == '__main__': + ddconfig = {'double_z': True, + 'z_channels': 4, + 'resolution': (240, 960), + 'in_channels': 3, + 'out_ch': 3, + 'ch': 128, + 'ch_mult': [ 1,2,4 ], # num_down = len(ch_mult)-1 + 'num_res_blocks': 2, + 'attn_resolutions': [ ], + 'dropout': 0.0} + lossconfig = {'disc_start': 50001, + 'kl_weight': 0.000001, + 'disc_weight': 0.5} + model = AutoencoderKL(ddconfig, lossconfig, embed_dim=4, + ckpt_path='/pretrain_weights/model-kl-f8.ckpt', ) + ''' + from torch.optim import AdamW + optimizer = AdamW(model.parameters(), lr=0.01) + lr_lambda = lambda iter: (1 - iter / 1000) ** 0.95 + lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) + for s in range(1000): + lr_scheduler.step() + cur_lr = optimizer.param_groups[0]['lr'] + print(cur_lr) + ''' + x = torch.rand(1, 3, 240, 960) + with torch.no_grad(): + y = model(x) + pass diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..f126d0cfdacf2c3d7f47c5713ef4082db98d6352 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/imagenet.py @@ -0,0 +1,395 @@ +import os, yaml, pickle, shutil, tarfile, glob +import cv2 +import custom_albumentations as albumentations +import PIL +import numpy as np +import torchvision.transforms.functional as TF +# from omegaconf import OmegaConf +from functools import partial +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset, Subset + +import taming.data.utils as tdu +from custom_controlnet_aux.diffusion_edge.taming.data.imagenet import str_to_indices, give_synsets_from_indices, download, retrieve +from custom_controlnet_aux.diffusion_edge.taming.data.imagenet import ImagePaths + +# from ldm.modules.image_degradation import degradation_fn_bsr, degradation_fn_bsr_light + + +def synset2idx(path_to_yaml="data/index_synset.yaml"): + with open(path_to_yaml) as f: + di2s = yaml.load(f) + return dict((v,k) for k,v in di2s.items()) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config + # if not type(self.config)==dict: + # self.config = OmegaConf.to_container(self.config) + self.keep_orig_class_label = self.config.get("keep_orig_class_label", False) + self.process_images = True # if False we skip loading & processing images and self.data contains filepaths + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._prepare_human_to_integer_label() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + self.synset2idx = synset2idx(path_to_yaml=self.idx2syn) + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _prepare_human_to_integer_label(self): + URL = "https://heibox.uni-heidelberg.de/f/2362b797d5be43b883f6/?dl=1" + self.human2integer = os.path.join(self.root, "imagenet1000_clsidx_to_labels.txt") + if (not os.path.exists(self.human2integer)): + download(URL, self.human2integer) + with open(self.human2integer, "r") as f: + lines = f.read().splitlines() + assert len(lines) == 1000 + self.human2integer_dict = dict() + for line in lines: + value, key = line.split(":") + self.human2integer_dict[key] = int(value) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + if not self.keep_orig_class_label: + self.class_labels = [class_dict[s] for s in self.synsets] + else: + self.class_labels = [self.synset2idx[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + + if self.process_images: + # self.size = retrieve(self.config, "size", default=256) + self.size = self.config.get("size", default=256) + self.data = ImagePaths(self.abspaths, + labels=labels, + size=self.size, + random_crop=self.random_crop, + ) + else: + self.data = self.abspaths + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.process_images = process_images + self.data_root = data_root + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + # self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", default=True) + self.random_crop = self.config.get("random_crop", default=True) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def __init__(self, process_images=True, data_root=None, **kwargs): + self.data_root = data_root + self.process_images = process_images + super().__init__(**kwargs) + + def _prepare(self): + if self.data_root: + self.root = os.path.join(self.data_root, self.NAME) + else: + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + # self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", default=False) + self.random_crop = self.config.get("random_crop", default=False) + if not tdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + tdu.mark_prepared(self.root) + + + +class ImageNetSR(Dataset): + def __init__(self, size=None, + degradation=None, downscale_f=4, min_crop_f=0.5, max_crop_f=1., + random_crop=True): + """ + Imagenet Superresolution Dataloader + Performs following ops in order: + 1. crops a crop of size s from image either as random or center crop + 2. resizes crop to size with cv2.area_interpolation + 3. degrades resized crop with degradation_fn + + :param size: resizing to size after cropping + :param degradation: degradation_fn, e.g. cv_bicubic or bsrgan_light + :param downscale_f: Low Resolution Downsample factor + :param min_crop_f: determines crop size s, + where s = c * min_img_side_len with c sampled from interval (min_crop_f, max_crop_f) + :param max_crop_f: "" + :param data_root: + :param random_crop: + """ + self.base = self.get_base() + assert size + assert (size / downscale_f).is_integer() + self.size = size + self.LR_size = int(size / downscale_f) + self.min_crop_f = min_crop_f + self.max_crop_f = max_crop_f + assert(max_crop_f <= 1.) + self.center_crop = not random_crop + + self.image_rescaler = albumentations.SmallestMaxSize(max_size=size, interpolation=cv2.INTER_AREA) + + self.pil_interpolation = False # gets reset later if incase interp_op is from pillow + + if degradation == "bsrgan": + self.degradation_process = partial(degradation_fn_bsr, sf=downscale_f) + + elif degradation == "bsrgan_light": + self.degradation_process = partial(degradation_fn_bsr_light, sf=downscale_f) + + else: + interpolation_fn = { + "cv_nearest": cv2.INTER_NEAREST, + "cv_bilinear": cv2.INTER_LINEAR, + "cv_bicubic": cv2.INTER_CUBIC, + "cv_area": cv2.INTER_AREA, + "cv_lanczos": cv2.INTER_LANCZOS4, + "pil_nearest": PIL.Image.NEAREST, + "pil_bilinear": PIL.Image.BILINEAR, + "pil_bicubic": PIL.Image.BICUBIC, + "pil_box": PIL.Image.BOX, + "pil_hamming": PIL.Image.HAMMING, + "pil_lanczos": PIL.Image.LANCZOS, + }[degradation] + + self.pil_interpolation = degradation.startswith("pil_") + + if self.pil_interpolation: + self.degradation_process = partial(TF.resize, size=self.LR_size, interpolation=interpolation_fn) + + else: + self.degradation_process = albumentations.SmallestMaxSize(max_size=self.LR_size, + interpolation=interpolation_fn) + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = Image.open(example["file_path_"]) + + if not image.mode == "RGB": + image = image.convert("RGB") + + image = np.array(image).astype(np.uint8) + + min_side_len = min(image.shape[:2]) + crop_side_len = min_side_len * np.random.uniform(self.min_crop_f, self.max_crop_f, size=None) + crop_side_len = int(crop_side_len) + + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=crop_side_len, width=crop_side_len) + + else: + self.cropper = albumentations.RandomCrop(height=crop_side_len, width=crop_side_len) + + image = self.cropper(image=image)["image"] + image = self.image_rescaler(image=image)["image"] + + if self.pil_interpolation: + image_pil = PIL.Image.fromarray(image) + LR_image = self.degradation_process(image_pil) + LR_image = np.array(LR_image).astype(np.uint8) + + else: + LR_image = self.degradation_process(image=image)["image"] + + example["image"] = (image/127.5 - 1.0).astype(np.float32) + example["LR_image"] = (LR_image/127.5 - 1.0).astype(np.float32) + + return example + + +class ImageNetSRTrain(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_train_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetTrain(process_images=False,) + return Subset(dset, indices) + + +class ImageNetSRValidation(ImageNetSR): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def get_base(self): + with open("data/imagenet_val_hr_indices.p", "rb") as f: + indices = pickle.load(f) + dset = ImageNetValidation(process_images=False,) + return Subset(dset, indices) diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..07f00750858bd9597a3e312c75fd81d6d2c30c1d --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/loss.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +import sys +# .path.append() +from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import * + + +class LPIPSWithDiscriminator(nn.Module): + def __init__(self, *, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_loss="hinge"): + + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.kl_weight = kl_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + # output log variance + self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm + ).apply(weights_init) + self.discriminator_iter_start = disc_start + self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, inputs, reconstructions, posteriors, optimizer_idx, + global_step, last_layer=None, cond=None, split="train", + weights=None): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + \ + F.mse_loss(inputs, reconstructions, reduction="none") + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + + nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar + weighted_nll_loss = nll_loss + if weights is not None: + weighted_nll_loss = weights*nll_loss + weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] + nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + kl_loss = posteriors.kl() + kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + if self.disc_factor > 0.0: + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + else: + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), + "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log + diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..2ac84c5d7788f049a2f8e82d35ab075f2065718d --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/mask_cond_unet.py @@ -0,0 +1,1009 @@ +import fvcore.common.config +import torch +import torch.nn as nn +import math +import torch.nn.functional as F +from functools import partial +from einops import rearrange, reduce +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.efficientnet import efficientnet_b7, EfficientNet_B7_Weights +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.resnet import resnet101, ResNet101_Weights +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.swin_transformer import swin_b, Swin_B_Weights +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.vgg import vgg16, VGG16_Weights + +from custom_controlnet_aux.util import custom_torch_download +# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.wcc import fft +### Compared to unet4: +# 1. add FFT-Conv on the mid feature. +######## Attention Layer ########## + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + # self.class_token_pos = nn.Parameter(torch.zeros(1, 1, num_pos_feats * 2)) + # self.class_token_pos + + def forward(self, x): + # x: b, h, w, d + num_feats = x.shape[3] + num_pos_feats = num_feats // 2 + # mask = tensor_list.mask + mask = torch.zeros(x.shape[0], x.shape[1], x.shape[2], device=x.device).to(torch.bool) + batch = mask.shape[0] + assert mask is not None + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-5 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + # pos = torch.cat((pos_y, pos_x), dim=3).flatten(1, 2) + pos = torch.cat((pos_y, pos_x), dim=3).contiguous() + ''' + pos_x: b ,h, w, d//2 + pos_y: b, h, w, d//2 + pos: b, h, w, d + ''' + return pos + +class PositionEmbeddingLearned(nn.Module): + """ + Absolute pos embedding, learned. + """ + def __init__(self, feature_size, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(feature_size[0], num_pos_feats) + self.col_embed = nn.Embedding(feature_size[1], num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + nn.init.uniform_(self.row_embed.weight) + nn.init.uniform_(self.col_embed.weight) + + def forward(self, x): + h, w = x.shape[-2:] + i = torch.arange(w, device=x.device) + j = torch.arange(h, device=x.device) + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + pos = torch.cat([ + x_emb.unsqueeze(0).repeat(h, 1, 1), + y_emb.unsqueeze(1).repeat(1, w, 1), + ], dim=-1).permute(2, 0, 1).unsqueeze(0).repeat(x.shape[0], 1, 1, 1) + return torch.cat([x, pos], dim=1) + + +class ChannelAttention(nn.Module): + def __init__(self, in_planes, ratio=8): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False) + self.relu1 = nn.ReLU() + self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False) + + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x)))) + max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) + out = avg_out + max_out + return self.sigmoid(out) * x + +class SpatialAtt(nn.Module): + def __init__(self, in_dim): + super(SpatialAtt, self).__init__() + self.map = nn.Conv2d(in_dim, 1, 1) + self.q_conv = nn.Conv2d(1, 1, 1) + self.k_conv = nn.Conv2d(1, 1, 1) + self.activation = nn.Softsign() + + def forward(self, x): + b, _, h, w = x.shape + att = self.map(x) # b, 1, h, w + q = self.q_conv(att) # b, 1, h, w + q = rearrange(q, 'b c h w -> b (h w) c') + k = self.k_conv(att) + k = rearrange(k, 'b c h w -> b c (h w)') + att = rearrange(att, 'b c h w -> b (h w) c') + att = F.softmax(q @ k, dim=-1) @ att # b, hw, 1 + att = att.reshape(b, 1, h, w) + return self.activation(att) * x + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + # self.fc1 = nn.Linear(in_features, hidden_features) + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) + # self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + + +class BasicAttetnionLayer(nn.Module): + def __init__(self, embed_dim=128, nhead=8, ffn_dim=512, window_size1=[4, 4], + window_size2=[1, 1], dropout=0.1): + super().__init__() + self.window_size1 = window_size1 + self.window_size2 = window_size2 + self.avgpool_q = nn.AvgPool2d(kernel_size=window_size1) + self.avgpool_k = nn.AvgPool2d(kernel_size=window_size2) + self.softmax = nn.Softmax(dim=-1) + self.nhead = nhead + + self.q_lin = nn.Linear(embed_dim, embed_dim) + self.k_lin = nn.Linear(embed_dim, embed_dim) + self.v_lin = nn.Linear(embed_dim, embed_dim) + + self.mlp = Mlp(in_features=embed_dim, hidden_features=ffn_dim, drop=dropout) + self.pos_enc = PositionEmbeddingSine(embed_dim) + self.concat_conv = nn.Conv2d(2 * embed_dim, embed_dim, 1) + self.gn = nn.GroupNorm(8, embed_dim) + + self.out_conv = nn.Conv2d(embed_dim, embed_dim, 1) + self.init_weights() + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0.) + + def forward(self, x1, x2): # x1 for q (conditional input), x2 for k,v + B, C1, H1, W1 = x1.shape + _, C2, H2, W2 = x2.shape + # x1 = x1.permute(0, 2, 3, 1).contiguous() # B, H1, W1, C1 + shortcut = x2 + self.concat_conv(torch.cat( + [F.interpolate(x1, size=(H2, W2), mode='bilinear', align_corners=True), + x2], dim=1)) + shortcut = self.gn(shortcut) + pad_l = pad_t = 0 + pad_r = (self.window_size1[1] - W1 % self.window_size1[1]) % self.window_size1[1] + pad_b = (self.window_size1[0] - H1 % self.window_size1[0]) % self.window_size1[0] + x1 = F.pad(x1, (pad_l, pad_r, pad_t, pad_b, 0, 0)) + _, _, H1p, W1p = x1.shape + # x2 = x2.permute(0, 2, 3, 1).contiguous() # B, H2, W2, C2 + pad_l = pad_t = 0 + pad_r = (self.window_size2[1] - W2 % self.window_size2[1]) % self.window_size2[1] + pad_b = (self.window_size2[0] - H2 % self.window_size2[0]) % self.window_size2[0] + x2 = F.pad(x2, (pad_l, pad_r, pad_t, pad_b, 0, 0)) + _, _, H2p, W2p = x2.shape + # x1g = x1 #B, C1, H1p, W1p + # x2g = x2 #B, C2, H2p, W2p + x1_s = self.avgpool_q(x1) + qg = self.avgpool_q(x1).permute(0, 2, 3, 1).contiguous() + qg = qg + self.pos_enc(qg) + qg= qg.view(B, -1, C2) + kg = self.avgpool_k(x2).permute(0, 2, 3, 1).contiguous() + kg = kg + self.pos_enc(kg) + kg = kg.view(B, -1, C1) + num_window_q = qg.shape[1] + num_window_k = kg.shape[1] + qg = self.q_lin(qg).reshape(B, num_window_q, self.nhead, C1 // self.nhead).permute(0, 2, 1, + 3).contiguous() + kg2 = self.k_lin(kg).reshape(B, num_window_k, self.nhead, C1 // self.nhead).permute(0, 2, 1, + 3).contiguous() + vg = self.v_lin(kg).reshape(B, num_window_k, self.nhead, C1 // self.nhead).permute(0, 2, 1, + 3).contiguous() + kg = kg2 + attn = (qg @ kg.transpose(-2, -1)) + attn = self.softmax(attn) + qg = (attn @ vg).transpose(1, 2).reshape(B, num_window_q, C1) + qg = qg.transpose(1, 2).reshape(B, C1, H1p // self.window_size1[0], W1p // self.window_size1[1]) + # qg = F.interpolate(qg, size=(H1p, W1p), mode='bilinear', align_corners=False) + x1_s = x1_s + qg + x1_s = x1_s + self.mlp(x1_s) + x1_s = F.interpolate(x1_s, size=(H2, W2), mode='bilinear', align_corners=True) + x1_s = shortcut + self.out_conv(x1_s) + # x1_s = self.out_norm(x1_s) + return x1_s + +class RelationNet(nn.Module): + def __init__(self, in_channel1=128, in_channel2=128, nhead=8, layers=3, embed_dim=128, ffn_dim=512, + window_size1= [4, 4], window_size2=[1, 1]): + # self.attention = BasicAttetnionLayer(embed_dim=embed_dim, nhead=nhead, ffn_dim=ffn_dim, + # window_size1=window_size1, window_size2=window_size2, dropout=0.1) + super().__init__() + self.layers = layers + self.input_conv1 = nn.Sequential( + nn.Conv2d(in_channel1, embed_dim, 1), + nn.BatchNorm2d(embed_dim, momentum=0.03, eps=0.001), + ) + self.input_conv2 = nn.Sequential( + nn.Conv2d(in_channel2, embed_dim, 1), + nn.BatchNorm2d(embed_dim, momentum=0.03, eps=0.001), + ) + # self.input_conv1 = ConvModule(in_channel1, + # embed_dim, + # 1, + # norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + # act_cfg=None) + # self.input_conv2 = ConvModule(in_channel2, + # embed_dim, + # 1, + # norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), + # act_cfg=None) + # self.input_conv2 = nn.Linear(in_channel2, embed_dim) + self.attentions = nn.ModuleList() + for i in range(layers): + self.attentions.append( + BasicAttetnionLayer(embed_dim=embed_dim, nhead=nhead, ffn_dim=ffn_dim, + window_size1=window_size1, window_size2=window_size2, dropout=0.1) + ) + + def forward(self, cond, feat): + # cluster = cluster.unsqueeze(0).repeat(feature.shape[0], 1, 1, 1) + cond = self.input_conv1(cond) + feat = self.input_conv2(feat) + for att in self.attentions: + feat = att(cond, feat) + return feat + + + +################# U-Net model defenition #################### + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# small helper modules + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + +def Downsample(dim, dim_out = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class GaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + def __init__(self, embedding_size=256, scale=1.0): + super().__init__() + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + + def forward(self, x): + x_proj = x[:, None] * self.W[None, :] * 2 * math.pi + return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + +class RandomOrLearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim, is_random = False): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class BlockFFT(nn.Module): + def __init__(self, dim, h, w, groups = 8): + super().__init__() + # self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.complex_weight = nn.Parameter(torch.randn(dim, h, w//2+1, 2, dtype=torch.float32) * 0.02) + # self.complex_weight = nn.Parameter(torch.normal(mean=0, std=0.01, size=(dim, h, w // 2 + 1, 2), dtype=torch.float32)) + # self.norm = nn.GroupNorm(groups, dim) + # self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + B, C, H, W = x.shape + x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho') + x = x * torch.view_as_complex(self.complex_weight) + x = torch.fft.irfft2(x, s=(H, W), dim=(2, 3), norm='ortho') + x = x.reshape(B, C, H, W) + + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class ResnetBlockFFT(nn.Module): + def __init__(self, dim, dim_out, h, w, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = BlockFFT(dim_out, h, w, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class ResnetDownsampleBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = nn.Sequential( + WeightStandardizedConv2d(dim_out, dim_out, 3, stride=2, padding=1), + nn.GroupNorm(groups, dim_out), + nn.SiLU() + ) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv( + F.interpolate(x, size=h.shape[-2:], mode="bilinear") + ) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv) + + q = q * self.scale + + sim = torch.einsum('b h d i, b h d j -> b h i j', q, k) + attn = sim.softmax(dim=-1) + out = torch.einsum('b h i j, b h d j -> b h i d', attn, v) + + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) + return self.to_out(out) + + +class ConditionEncoder(nn.Module): + def __init__(self, + down_dim_mults=(2, 4, 8), + dim=64, + in_dim=1, + out_dim=64): + super(ConditionEncoder, self).__init__() + self.init_conv = nn.Sequential( + nn.Conv2d(in_dim, dim, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(num_groups=min(dim // 4, 8), num_channels=dim), + ) + self.num_resolutions = len(down_dim_mults) + self.downs = nn.ModuleList() + in_mults = (1,) + tuple(down_dim_mults[:-1]) + in_dims = [mult*dim for mult in in_mults] + out_dims = [mult*dim for mult in down_dim_mults] + for i_level in range(self.num_resolutions): + block_in = in_dims[i_level] + block_out = out_dims[i_level] + self.downs.append(ResnetDownsampleBlock(dim=block_in, + dim_out=block_out)) + if self.num_resolutions < 1: + self.out_conv = nn.Conv2d(dim, out_dim, 1) + else: + self.out_conv = nn.Conv2d(out_dims[-1], out_dim, 1) + + def forward(self, x): + x = self.init_conv(x) + for down_layer in self.downs: + x = down_layer(x) + x = self.out_conv(x) + return x + + +class Unet(nn.Module): + def __init__( + self, + dim, + init_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + cond_in_dim=1, + cond_dim=64, + cond_dim_mults=(2, 4, 8), + channels=1, + out_mul=1, + self_condition=False, + resnet_block_groups=8, + learned_variance=False, + learned_sinusoidal_cond=False, + random_fourier_features=False, + learned_sinusoidal_dim=16, + window_sizes1=[[16, 16], [8, 8], [4, 4], [2, 2]], + window_sizes2=[[16, 16], [8, 8], [4, 4], [2, 2]], + fourier_scale=16, + ckpt_path=None, + ignore_keys=[], + cfg={}, + **kwargs + ): + super().__init__() + + # determine dimensions + self.cond_pe = cfg.get('cond_pe', False) + num_pos_feats = cfg.num_pos_feats if self.cond_pe else 0 + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + # self.init_conv_mask = nn.Sequential( + # nn.Conv2d(cond_in_dim, cond_dim, 3, padding=1), + # nn.GroupNorm(num_groups=min(init_dim // 4, 8), num_channels=init_dim), + # nn.SiLU(), + # nn.Conv2d(cond_dim, cond_dim, 3, padding=1), + # ) + # self.init_conv_mask = ConditionEncoder(down_dim_mults=cond_dim_mults, dim=cond_dim, + # in_dim=cond_in_dim, out_dim=init_dim) + + if cfg.cond_net == 'effnet': + f_condnet = 48 + if cfg.get('without_pretrain', False): + self.init_conv_mask = efficientnet_b7() + else: + self.init_conv_mask = efficientnet_b7(weights=EfficientNet_B7_Weights) + elif cfg.cond_net == 'resnet': + f_condnet = 256 + if cfg.get('without_pretrain', False): + self.init_conv_mask = resnet101() + else: + self.init_conv_mask = resnet101(weights=ResNet101_Weights) + elif cfg.cond_net == 'swin': + f_condnet = 128 + if cfg.get('without_pretrain', False): + self.init_conv_mask = swin_b() + else: + swin_b_model = swin_b(pretrained=False) + swin_b_model.load_state_dict(torch.load(custom_torch_download(filename="swin_b-68c6b09e.pth")), strict=False) + self.init_conv_mask = swin_b_model + elif cfg.cond_net == 'vgg': + f_condnet = 128 + if cfg.get('without_pretrain', False): + self.init_conv_mask = vgg16() + else: + self.init_conv_mask = vgg16(weights=VGG16_Weights) + else: + raise NotImplementedError + self.init_conv = nn.Sequential( + nn.Conv2d(input_channels + f_condnet, init_dim, 7, padding=3), + nn.GroupNorm(num_groups=min(init_dim // 4, 8), num_channels=init_dim), + ) + + if self.cond_pe: + self.cond_pos_embedding = nn.Sequential( + PositionEmbeddingLearned( + feature_size=cfg.cond_feature_size, num_pos_feats=cfg.num_pos_feats//2), + nn.Conv2d(num_pos_feats + init_dim, init_dim, 1) + ) + # self.init_conv_mask = nn.Conv2d(1, init_dim, 7, padding=3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + dims_rev = dims[::-1] + in_out = list(zip(dims[:-1], dims[1:])) + self.projects = nn.ModuleList() + print(cfg.cond_net) + if cfg.cond_net == 'effnet': + self.projects.append(nn.Conv2d(48, dims[0], 1)) + self.projects.append(nn.Conv2d(80, dims[1], 1)) + self.projects.append(nn.Conv2d(224, dims[2], 1)) + self.projects.append(nn.Conv2d(640, dims[3], 1)) + print(len(self.projects)) + elif cfg.cond_net == 'vgg': + self.projects.append(nn.Conv2d(128, dims[0], 1)) + self.projects.append(nn.Conv2d(256, dims[1], 1)) + self.projects.append(nn.Conv2d(512, dims[2], 1)) + self.projects.append(nn.Conv2d(512, dims[3], 1)) + else: + self.projects.append(nn.Conv2d(f_condnet, dims[0], 1)) + self.projects.append(nn.Conv2d(f_condnet*2, dims[1], 1)) + self.projects.append(nn.Conv2d(f_condnet*4, dims[2], 1)) + self.projects.append(nn.Conv2d(f_condnet*8, dims[3], 1)) + #print(len(self.projects)) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features + + if self.random_or_learned_sinusoidal_cond: + sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = GaussianFourierProjection(dim//2, scale=fourier_scale) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # layers + + self.downs = nn.ModuleList([]) + self.downs_mask = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + self.relation_layers_down = nn.ModuleList([]) + self.relation_layers_up = nn.ModuleList([]) + self.ups2 = nn.ModuleList([]) + self.relation_layers_up2 = nn.ModuleList([]) + num_resolutions = len(in_out) + input_size = cfg.get('input_size', [80, 80]) + feature_size_list = [[int(input_size[0]/2**k), int(input_size[1]/2**k)] for k in range(len(dim_mults))] + + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + # self.downs_mask.append(nn.ModuleList([ + # block_klass(dim_in, dim_in, time_emb_dim=time_dim), + # # block_klass(dim_in, dim_in, time_emb_dim=time_dim), + # Residual(PreNorm(dim_in, LinearAttention(dim_in))), + # Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding=1) + # ])) + self.relation_layers_down.append(RelationNet(in_channel1=dims[ind], in_channel2=dims[ind], nhead=8, + layers=1, embed_dim=dims[ind], ffn_dim=dims[ind]*2, + window_size1=window_sizes1[ind], window_size2=window_sizes2[ind]) + ) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.decouple1 = nn.Sequential( + nn.GroupNorm(num_groups=min(mid_dim // 4, 8), num_channels=mid_dim), + nn.Conv2d(mid_dim, mid_dim, 3, padding=1), + BlockFFT(mid_dim, input_size[0]//8, input_size[1]//8), + ) + self.decouple2 = nn.Sequential( + nn.GroupNorm(num_groups=min(mid_dim // 4, 8), num_channels=mid_dim), + nn.Conv2d(mid_dim, mid_dim, 3, padding=1), + BlockFFT(mid_dim, input_size[0]//8, input_size[1]//8), + ) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + self.relation_layers_up.append(RelationNet(in_channel1=dims_rev[ind+1], in_channel2=dims_rev[ind], + nhead=8, layers=1, embed_dim=dims_rev[ind], + ffn_dim=dims_rev[ind] * 2, + window_size1=window_sizes1[::-1][ind], + window_size2=window_sizes2[::-1][ind]) + ) + self.ups2.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding=1) + ])) + self.relation_layers_up2.append(RelationNet(in_channel1=dims_rev[ind + 1], in_channel2=dims_rev[ind], + nhead=8, layers=1, embed_dim=dims_rev[ind], + ffn_dim=dims_rev[ind] * 2, + window_size1=window_sizes1[::-1][ind], + window_size2=window_sizes2[::-1][ind]) + ) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim * out_mul, 1) + + self.final_res_block2 = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv2 = nn.Conv2d(dim, self.out_dim, 1) + + # self.init_weights() + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + fix_bb = cfg.get('fix_bb', True) + if fix_bb: + for n, p in self.init_conv_mask.named_parameters(): + p.requires_grad = False + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["model"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + msg = self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + print('==>Load Unet Info: ', msg) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.weight, 1.) + nn.init.constant_(m.bias, 0.) + elif isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + nn.init.constant_(m.bias, 0.) + + def forward(self, x, time, mask, x_self_cond = None, **kwargs): + if self.self_condition: + x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + sigma = time.reshape(-1, 1, 1, 1) + eps = 1e-4 + c_skip1 = 1 - sigma + c_skip2 = torch.sqrt(sigma) + # c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_out1 = sigma / torch.sqrt(sigma ** 2 + 1) + c_out2 = torch.sqrt(1 - sigma) / torch.sqrt(sigma ** 2 + 1) + c_in = 1 + + x_clone = x.clone() + x = c_in * x + # mask = torch.cat([], dim=1) + hm = self.init_conv_mask(mask) + # if self.cond_pe: + # m = self.cond_pos_embedding(m) + x = self.init_conv(torch.cat([x, F.interpolate(hm[0], size=x.shape[-2:], mode="bilinear")], dim=1)) + r = x.clone() + + t = self.time_mlp(torch.log(time)/4) + + h = [] + h2 = [] + for i, layer in enumerate(self.projects): + # print(hm[i].shape) + hm[i] = layer(hm[i]) + hm2 = [] + for i in range(len(hm)): + hm2.append(hm[i].clone()) + # hm = [] + # hm2 = [] + for i, ((block1, block2, attn, downsample), relation_layer) \ + in enumerate(zip(self.downs, self.relation_layers_down)): + x = block1(x, t) + h.append(x) + h2.append(x.clone()) + # m = m_block(m, t) + # hm.append(m) + # hm2.append(m.clone()) + + x = relation_layer(hm[i], x) + + x = block2(x, t) + x = attn(x) + h.append(x) + h2.append(x.clone()) + + x = downsample(x) + # m = m_downsample(m) + + + # x = x + F.interpolate(hm[-1], size=x.shape[2:], mode="bilinear", align_corners=True) + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + x1 = x + self.decouple1(x) + x2 = x + self.decouple2(x) + + x = x1 + for (block1, block2, attn, upsample), relation_layer in zip(self.ups, self.relation_layers_up): + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + x = relation_layer(hm.pop(), x) + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + x = upsample(x) + + x1 = torch.cat((x, r), dim=1) + x1 = self.final_res_block(x1, t) + x1 = self.final_conv(x1) + + x = x2 + for (block1, block2, attn, upsample), relation_layer in zip(self.ups2, self.relation_layers_up2): + x = torch.cat((x, h2.pop()), dim = 1) + x = block1(x, t) + x = relation_layer(hm2.pop(), x) + x = torch.cat((x, h2.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + x = upsample(x) + + x2 = torch.cat((x, r), dim=1) + x2 = self.final_res_block2(x2, t) + x2 = self.final_conv2(x2) + # sigma = time.reshape(x1.shape[0], *((1,) * (len(x1.shape) - 1))) + # scale_C = torch.exp(sigma) + x1 = c_skip1 * x_clone + c_out1 * x1 + x2 = c_skip2 * x_clone + c_out2 * x2 + return x1, x2 + + +if __name__ == "__main__": + # resnet = resnet101(weights=ResNet101_Weights) + # effnet = efficientnet_b7(weights=EfficientNet_B7_Weights) + # effnet = efficientnet_b7(weights=None) + # x = torch.rand(1, 3, 320, 320) + # y = effnet(x) + model = Unet(dim=128, dim_mults=(1, 2, 4, 4), + cond_dim=128, + cond_dim_mults=(2, 4, ), + channels=1, + window_sizes1=[[8, 8], [4, 4], [2, 2], [1, 1]], + window_sizes2=[[8, 8], [4, 4], [2, 2], [1, 1]], + cfg=fvcore.common.config.CfgNode({'cond_pe': False, 'input_size': [80, 80], + 'cond_feature_size': (32, 128), 'cond_net': 'vgg', + 'num_pos_feats': 96}) + ) + x = torch.rand(1, 1, 80, 80) + mask = torch.rand(1, 3, 320, 320) + time = torch.tensor([0.5124]) + with torch.no_grad(): + y = model(x, time, mask) + pass \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..1db3d379f0e44a885f7e4a249b1014172b95264d --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/quantization.py @@ -0,0 +1,103 @@ +import torch +from torch import nn as nn +from torch.nn import Parameter + + +def weight_quantization(b): + def uniform_quant(x, b): + xdiv = x.mul((2 ** b - 1)) + xhard = xdiv.round().div(2 ** b - 1) + return xhard + + class _pq(torch.autograd.Function): + @staticmethod + def forward(ctx, input, alpha): + input.div_(alpha) # weights are first divided by alpha + input_c = input.clamp(min=-1, max=1) # then clipped to [-1,1] + sign = input_c.sign() + input_abs = input_c.abs() + input_q = uniform_quant(input_abs, b).mul(sign) + ctx.save_for_backward(input, input_q) + input_q = input_q.mul(alpha) # rescale to the original range + return input_q + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() # grad for weights will not be clipped + input, input_q = ctx.saved_tensors + i = (input.abs() > 1.).float() + sign = input.sign() + grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() + return grad_input, grad_alpha + + return _pq().apply + + +class weight_quantize_fn(nn.Module): + def __init__(self, bit_w): + super(weight_quantize_fn, self).__init__() + assert bit_w > 0 + + self.bit_w = bit_w - 1 + self.weight_q = weight_quantization(b=self.bit_w) + self.register_parameter('w_alpha', Parameter(torch.tensor(3.0), requires_grad=True)) + + def forward(self, weight): + mean = weight.data.mean() + std = weight.data.std() + weight = weight.add(-mean).div(std) # weights normalization + weight_q = self.weight_q(weight, self.w_alpha) + return weight_q + + def change_bit(self, bit_w): + self.bit_w = bit_w - 1 + self.weight_q = weight_quantization(b=self.bit_w) + +def act_quantization(b, signed=False): + def uniform_quant(x, b=3): + xdiv = x.mul(2 ** b - 1) + xhard = xdiv.round().div(2 ** b - 1) + return xhard + + class _uq(torch.autograd.Function): + @staticmethod + def forward(ctx, input, alpha): + input = input.div(alpha) + input_c = input.clamp(min=-1, max=1) if signed else input.clamp(max=1) + input_q = uniform_quant(input_c, b) + ctx.save_for_backward(input, input_q) + input_q = input_q.mul(alpha) + return input_q + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + input, input_q = ctx.saved_tensors + i = (input.abs() > 1.).float() + sign = input.sign() + grad_alpha = (grad_output * (sign * i + (input_q - input) * (1 - i))).sum() + grad_input = grad_input * (1 - i) + return grad_input, grad_alpha + + return _uq().apply + +class act_quantize_fn(nn.Module): + def __init__(self, bit_a, signed=False): + super(act_quantize_fn, self).__init__() + self.bit_a = bit_a + self.signed = signed + if signed: + self.bit_a -= 1 + assert bit_a > 0 + + self.act_q = act_quantization(b=self.bit_a, signed=signed) + self.register_parameter('a_alpha', Parameter(torch.tensor(8.0), requires_grad=True)) + + def forward(self, x): + return self.act_q(x, self.a_alpha) + + def change_bit(self, bit_a): + self.bit_a = bit_a + if self.signed: + self.bit_a -= 1 + self.act_q = act_quantization(b=self.bit_a, signed=self.signed) \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..256a73a06e32b1cc07d35a8dfcfb3e4601ceeb38 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/resnet.py @@ -0,0 +1,963 @@ +from functools import partial +from typing import Type, Any, Callable, Union, List, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from torchvision.transforms._presets import ImageClassification +from torchvision.utils import _log_api_usage_once +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param + + +__all__ = [ + "ResNet", + "ResNet18_Weights", + "ResNet34_Weights", + "ResNet50_Weights", + "ResNet101_Weights", + "ResNet152_Weights", + "ResNeXt50_32X4D_Weights", + "ResNeXt101_32X8D_Weights", + "ResNeXt101_64X4D_Weights", + "Wide_ResNet50_2_Weights", + "Wide_ResNet101_2_Weights", + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x8d", + "resnext101_64x4d", + "wide_resnet50_2", + "wide_resnet101_2", +] + + +def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d: + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=dilation, + groups=groups, + bias=False, + dilation=dilation, + ) + + +def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d: + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion: int = 1 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError("BasicBlock only supports groups=1 and base_width=64") + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion: int = 4 + + def __init__( + self, + inplanes: int, + planes: int, + stride: int = 1, + downsample: Optional[nn.Module] = None, + groups: int = 1, + base_width: int = 64, + dilation: int = 1, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.0)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + def __init__( + self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + zero_init_residual: bool = False, + groups: int = 1, + width_per_group: int = 64, + replace_stride_with_dilation: Optional[List[bool]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None, + ) -> None: + super().__init__() + _log_api_usage_once(self) + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError( + "replace_stride_with_dilation should be None " + f"or a 3-element tuple, got {replace_stride_with_dilation}" + ) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=replace_stride_with_dilation[2]) + # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + # self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck) and m.bn3.weight is not None: + nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type] + elif isinstance(m, BasicBlock) and m.bn2.weight is not None: + nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type] + + def _make_layer( + self, + block: Type[Union[BasicBlock, Bottleneck]], + planes: int, + blocks: int, + stride: int = 1, + dilate: bool = False, + ) -> nn.Sequential: + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append( + block( + self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer + ) + ) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append( + block( + self.inplanes, + planes, + groups=self.groups, + base_width=self.base_width, + dilation=self.dilation, + norm_layer=norm_layer, + ) + ) + + return nn.Sequential(*layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + # See note [TorchScript super()] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + feats = [] + x = self.layer1(x) + feats.append(x) + x = self.layer2(x) + feats.append(x) + x = self.layer3(x) + feats.append(x) + x = self.layer4(x) + feats.append(x) + + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # x = self.fc(x) + + return feats + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _resnet( + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> ResNet: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = ResNet(block, layers, **kwargs) + + if weights is not None: + model.load_state_dict(weights.get_state_dict(progress=progress), strict=False) + + return model + + +_COMMON_META = { + "min_size": (1, 1), + "categories": _IMAGENET_CATEGORIES, +} + + +class ResNet18_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet18-f37072fd.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 11689512, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.758, + "acc@5": 89.078, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet34_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet34-b627a593.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 21797672, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.314, + "acc@5": 91.420, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class ResNet50_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet50-0676ba61.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 76.130, + "acc@5": 92.862, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet50-11ad3fa6.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25557032, + "recipe": "https://github.com/pytorch/vision/issues/3995#issuecomment-1013906621", + "_metrics": { + "ImageNet-1K": { + "acc@1": 80.858, + "acc@5": 95.434, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet101_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet101-63fe2227.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.374, + "acc@5": 93.546, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet101-cd907fc2.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 44549160, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.886, + "acc@5": 95.780, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNet152_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnet152-394f9c45.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnet", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.312, + "acc@5": 94.046, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnet152-f82ba261.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 60192808, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.284, + "acc@5": 96.002, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt50_32X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "_metrics": { + "ImageNet-1K": { + "acc@1": 77.618, + "acc@5": 93.698, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext50_32x4d-1a0047aa.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 25028904, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.198, + "acc@5": 95.340, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_32X8D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#resnext", + "_metrics": { + "ImageNet-1K": { + "acc@1": 79.312, + "acc@5": 94.526, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/resnext101_32x8d-110c445d.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 88791336, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.834, + "acc@5": 96.228, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class ResNeXt101_64X4D_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/resnext101_64x4d-173b62eb.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 83455272, + "recipe": "https://github.com/pytorch/vision/pull/5935", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.246, + "acc@5": 96.454, + } + }, + "_docs": """ + These weights were trained from scratch by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Wide_ResNet50_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.468, + "acc@5": 94.086, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet50_2-9ba9bcbe.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 68883240, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe-with-fixres", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.602, + "acc@5": 95.758, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +class Wide_ResNet101_2_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/pull/912#issue-445437439", + "_metrics": { + "ImageNet-1K": { + "acc@1": 78.848, + "acc@5": 94.284, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a simple training recipe.""", + }, + ) + IMAGENET1K_V2 = Weights( + url="https://download.pytorch.org/models/wide_resnet101_2-d733dc28.pth", + transforms=partial(ImageClassification, crop_size=224, resize_size=232), + meta={ + **_COMMON_META, + "num_params": 126886696, + "recipe": "https://github.com/pytorch/vision/issues/3995#new-recipe", + "_metrics": { + "ImageNet-1K": { + "acc@1": 82.510, + "acc@5": 96.020, + } + }, + "_docs": """ + These weights improve upon the results of the original paper by using TorchVision's `new training recipe + `_. + """, + }, + ) + DEFAULT = IMAGENET1K_V2 + + +@handle_legacy_interface(weights=("pretrained", ResNet18_Weights.IMAGENET1K_V1)) +def resnet18(*, weights: Optional[ResNet18_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-18 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet18_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet18_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet18_Weights + :members: + """ + weights = ResNet18_Weights.verify(weights) + + return _resnet(BasicBlock, [2, 2, 2, 2], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNet34_Weights.IMAGENET1K_V1)) +def resnet34(*, weights: Optional[ResNet34_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-34 from `Deep Residual Learning for Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.ResNet34_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet34_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet34_Weights + :members: + """ + weights = ResNet34_Weights.verify(weights) + + return _resnet(BasicBlock, [3, 4, 6, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNet50_Weights.IMAGENET1K_V1)) +def resnet50(*, weights: Optional[ResNet50_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-50 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet50_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet50_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet50_Weights + :members: + """ + weights = ResNet50_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNet101_Weights.IMAGENET1K_V1)) +def resnet101(*, weights: Optional[ResNet101_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-101 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet101_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet101_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet101_Weights + :members: + """ + weights = ResNet101_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNet152_Weights.IMAGENET1K_V1)) +def resnet152(*, weights: Optional[ResNet152_Weights] = None, progress: bool = True, **kwargs: Any) -> ResNet: + """ResNet-152 from `Deep Residual Learning for Image Recognition `__. + + .. note:: + The bottleneck of TorchVision places the stride for downsampling to the second 3x3 + convolution while the original paper places it to the first 1x1 convolution. + This variant improves the accuracy and is known as `ResNet V1.5 + `_. + + Args: + weights (:class:`~torchvision.models.ResNet152_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNet152_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.ResNet152_Weights + :members: + """ + weights = ResNet152_Weights.verify(weights) + + return _resnet(Bottleneck, [3, 8, 36, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNeXt50_32X4D_Weights.IMAGENET1K_V1)) +def resnext50_32x4d( + *, weights: Optional[ResNeXt50_32X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-50 32x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt50_32X4D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNext50_32X4D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt50_32X4D_Weights + :members: + """ + weights = ResNeXt50_32X4D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", ResNeXt101_32X8D_Weights.IMAGENET1K_V1)) +def resnext101_32x8d( + *, weights: Optional[ResNeXt101_32X8D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-101 32x8d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt101_32X8D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNeXt101_32X8D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt101_32X8D_Weights + :members: + """ + weights = ResNeXt101_32X8D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 32) + _ovewrite_named_param(kwargs, "width_per_group", 8) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +def resnext101_64x4d( + *, weights: Optional[ResNeXt101_64X4D_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """ResNeXt-101 64x4d model from + `Aggregated Residual Transformation for Deep Neural Networks `_. + + Args: + weights (:class:`~torchvision.models.ResNeXt101_64X4D_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.ResNeXt101_64X4D_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.ResNeXt101_64X4D_Weights + :members: + """ + weights = ResNeXt101_64X4D_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "groups", 64) + _ovewrite_named_param(kwargs, "width_per_group", 4) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", Wide_ResNet50_2_Weights.IMAGENET1K_V1)) +def wide_resnet50_2( + *, weights: Optional[Wide_ResNet50_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """Wide ResNet-50-2 model from + `Wide Residual Networks `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 + channels, and in Wide ResNet-50-2 has 2048-1024-2048. + + Args: + weights (:class:`~torchvision.models.Wide_ResNet50_2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Wide_ResNet50_2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.Wide_ResNet50_2_Weights + :members: + """ + weights = Wide_ResNet50_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 6, 3], weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", Wide_ResNet101_2_Weights.IMAGENET1K_V1)) +def wide_resnet101_2( + *, weights: Optional[Wide_ResNet101_2_Weights] = None, progress: bool = True, **kwargs: Any +) -> ResNet: + """Wide ResNet-101-2 model from + `Wide Residual Networks `_. + + The model is the same as ResNet except for the bottleneck number of channels + which is twice larger in every block. The number of channels in outer 1x1 + convolutions is the same, e.g. last block in ResNet-101 has 2048-512-2048 + channels, and in Wide ResNet-101-2 has 2048-1024-2048. + + Args: + weights (:class:`~torchvision.models.Wide_ResNet101_2_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Wide_ResNet101_2_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.resnet.ResNet`` + base class. Please refer to the `source code + `_ + for more details about this class. + .. autoclass:: torchvision.models.Wide_ResNet101_2_Weights + :members: + """ + weights = Wide_ResNet101_2_Weights.verify(weights) + + _ovewrite_named_param(kwargs, "width_per_group", 64 * 2) + return _resnet(Bottleneck, [3, 4, 23, 3], weights, progress, **kwargs) + + +# The dictionary below is internal implementation detail and will be removed in v0.15 +from torchvision.models._utils import _ModelURLs + + +model_urls = _ModelURLs( + { + "resnet18": ResNet18_Weights.IMAGENET1K_V1.url, + "resnet34": ResNet34_Weights.IMAGENET1K_V1.url, + "resnet50": ResNet50_Weights.IMAGENET1K_V1.url, + "resnet101": ResNet101_Weights.IMAGENET1K_V1.url, + "resnet152": ResNet152_Weights.IMAGENET1K_V1.url, + "resnext50_32x4d": ResNeXt50_32X4D_Weights.IMAGENET1K_V1.url, + "resnext101_32x8d": ResNeXt101_32X8D_Weights.IMAGENET1K_V1.url, + "wide_resnet50_2": Wide_ResNet50_2_Weights.IMAGENET1K_V1.url, + "wide_resnet101_2": Wide_ResNet101_2_Weights.IMAGENET1K_V1.url, + } +) diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..9915bfd6df1de65f8a84975afcf22a386d15958e --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/swin_transformer.py @@ -0,0 +1,651 @@ +from functools import partial +from typing import Optional, Callable, List, Any + +import torch +import torch.nn.functional as F +from torch import nn, Tensor + +from torchvision.ops.misc import MLP, Permute +from torchvision.ops.stochastic_depth import StochasticDepth +from torchvision.transforms._presets import ImageClassification, InterpolationMode +from torchvision.utils import _log_api_usage_once +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import _ovewrite_named_param + + +__all__ = [ + "SwinTransformer", + "Swin_T_Weights", + "Swin_S_Weights", + "Swin_B_Weights", + "swin_t", + "swin_s", + "swin_b", +] + + +def _patch_merging_pad(x): + H, W, _ = x.shape[-3:] + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + return x + + +torch.fx.wrap("_patch_merging_pad") + + +class PatchMerging(nn.Module): + """Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + """ + + def __init__(self, dim: int, norm_layer: Callable[..., nn.Module] = nn.LayerNorm): + super().__init__() + _log_api_usage_once(self) + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): input tensor with expected layout of [..., H, W, C] + Returns: + Tensor with layout of [..., H/2, W/2, 2*C] + """ + x = _patch_merging_pad(x) + + x0 = x[..., 0::2, 0::2, :] # ... H/2 W/2 C + x1 = x[..., 1::2, 0::2, :] # ... H/2 W/2 C + x2 = x[..., 0::2, 1::2, :] # ... H/2 W/2 C + x3 = x[..., 1::2, 1::2, :] # ... H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # ... H/2 W/2 4*C + + x = self.norm(x) + x = self.reduction(x) # ... H/2 W/2 2*C + return x + + +def shifted_window_attention( + input: Tensor, + qkv_weight: Tensor, + proj_weight: Tensor, + relative_position_bias: Tensor, + window_size: List[int], + num_heads: int, + shift_size: List[int], + attention_dropout: float = 0.0, + dropout: float = 0.0, + qkv_bias: Optional[Tensor] = None, + proj_bias: Optional[Tensor] = None, +): + """ + Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + input (Tensor[N, H, W, C]): The input tensor or 4-dimensions. + qkv_weight (Tensor[in_dim, out_dim]): The weight tensor of query, key, value. + proj_weight (Tensor[out_dim, out_dim]): The weight tensor of projection. + relative_position_bias (Tensor): The learned relative position bias added to attention. + window_size (List[int]): Window size. + num_heads (int): Number of attention heads. + shift_size (List[int]): Shift size for shifted window attention. + attention_dropout (float): Dropout ratio of attention weight. Default: 0.0. + dropout (float): Dropout ratio of output. Default: 0.0. + qkv_bias (Tensor[out_dim], optional): The bias tensor of query, key, value. Default: None. + proj_bias (Tensor[out_dim], optional): The bias tensor of projection. Default: None. + Returns: + Tensor[N, H, W, C]: The output tensor after shifted window attention. + """ + B, H, W, C = input.shape + # pad feature maps to multiples of window size + pad_r = (window_size[1] - W % window_size[1]) % window_size[1] + pad_b = (window_size[0] - H % window_size[0]) % window_size[0] + x = F.pad(input, (0, 0, 0, pad_r, 0, pad_b)) + _, pad_H, pad_W, _ = x.shape + + # If window size is larger than feature size, there is no need to shift window + if window_size[0] >= pad_H: + shift_size[0] = 0 + if window_size[1] >= pad_W: + shift_size[1] = 0 + + # cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(-shift_size[0], -shift_size[1]), dims=(1, 2)) + + # partition windows + num_windows = (pad_H // window_size[0]) * (pad_W // window_size[1]) + x = x.view(B, pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B * num_windows, window_size[0] * window_size[1], C) # B*nW, Ws*Ws, C + + # multi-head attention + qkv = F.linear(x, qkv_weight, qkv_bias) + qkv = qkv.reshape(x.size(0), x.size(1), 3, num_heads, C // num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + q = q * (C // num_heads) ** -0.5 + attn = q.matmul(k.transpose(-2, -1)) + # add relative position bias + attn = attn + relative_position_bias + + if sum(shift_size) > 0: + # generate attention mask + attn_mask = x.new_zeros((pad_H, pad_W)) + h_slices = ((0, -window_size[0]), (-window_size[0], -shift_size[0]), (-shift_size[0], None)) + w_slices = ((0, -window_size[1]), (-window_size[1], -shift_size[1]), (-shift_size[1], None)) + count = 0 + for h in h_slices: + for w in w_slices: + attn_mask[h[0] : h[1], w[0] : w[1]] = count + count += 1 + attn_mask = attn_mask.view(pad_H // window_size[0], window_size[0], pad_W // window_size[1], window_size[1]) + attn_mask = attn_mask.permute(0, 2, 1, 3).reshape(num_windows, window_size[0] * window_size[1]) + attn_mask = attn_mask.unsqueeze(1) - attn_mask.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + attn = attn.view(x.size(0) // num_windows, num_windows, num_heads, x.size(1), x.size(1)) + attn = attn + attn_mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, num_heads, x.size(1), x.size(1)) + + attn = F.softmax(attn, dim=-1) + attn = F.dropout(attn, p=attention_dropout) + + x = attn.matmul(v).transpose(1, 2).reshape(x.size(0), x.size(1), C) + x = F.linear(x, proj_weight, proj_bias) + x = F.dropout(x, p=dropout) + + # reverse windows + x = x.view(B, pad_H // window_size[0], pad_W // window_size[1], window_size[0], window_size[1], C) + x = x.permute(0, 1, 3, 2, 4, 5).reshape(B, pad_H, pad_W, C) + + # reverse cyclic shift + if sum(shift_size) > 0: + x = torch.roll(x, shifts=(shift_size[0], shift_size[1]), dims=(1, 2)) + + # unpad features + x = x[:, :H, :W, :].contiguous() + return x + + +torch.fx.wrap("shifted_window_attention") + + +class ShiftedWindowAttention(nn.Module): + """ + See :func:`shifted_window_attention`. + """ + + def __init__( + self, + dim: int, + window_size: List[int], + shift_size: List[int], + num_heads: int, + qkv_bias: bool = True, + proj_bias: bool = True, + attention_dropout: float = 0.0, + dropout: float = 0.0, + ): + super().__init__() + if len(window_size) != 2 or len(shift_size) != 2: + raise ValueError("window_size and shift_size must be of length 2") + self.window_size = window_size + self.shift_size = shift_size + self.num_heads = num_heads + self.attention_dropout = attention_dropout + self.dropout = dropout + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + # coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij")) # 2, Wh, Ww + coords = torch.stack(torch.meshgrid(coords_h, coords_w)) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1).view(-1) # Wh*Ww*Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + + def forward(self, x: Tensor): + """ + Args: + x (Tensor): Tensor with layout of [B, H, W, C] + Returns: + Tensor with same layout as input, i.e. [B, H, W, C] + """ + + N = self.window_size[0] * self.window_size[1] + relative_position_bias = self.relative_position_bias_table[self.relative_position_index] # type: ignore[index] + relative_position_bias = relative_position_bias.view(N, N, -1) + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().unsqueeze(0) + + return shifted_window_attention( + x, + self.qkv.weight, + self.proj.weight, + relative_position_bias, + self.window_size, + self.num_heads, + shift_size=self.shift_size, + attention_dropout=self.attention_dropout, + dropout=self.dropout, + qkv_bias=self.qkv.bias, + proj_bias=self.proj.bias, + ) + + +class SwinTransformerBlock(nn.Module): + """ + Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (List[int]): Window size. + shift_size (List[int]): Shift size for shifted window attention. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob: (float): Stochastic depth rate. Default: 0.0. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + attn_layer (nn.Module): Attention layer. Default: ShiftedWindowAttention + """ + + def __init__( + self, + dim: int, + num_heads: int, + window_size: List[int], + shift_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_layer: Callable[..., nn.Module] = ShiftedWindowAttention, + ): + super().__init__() + _log_api_usage_once(self) + + self.norm1 = norm_layer(dim) + self.attn = attn_layer( + dim, + window_size, + shift_size, + num_heads, + attention_dropout=attention_dropout, + dropout=dropout, + ) + self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row") + self.norm2 = norm_layer(dim) + self.mlp = MLP(dim, [int(dim * mlp_ratio), dim], activation_layer=nn.GELU, inplace=None, dropout=dropout) + + for m in self.mlp.modules(): + if isinstance(m, nn.Linear): + nn.init.xavier_uniform_(m.weight) + if m.bias is not None: + nn.init.normal_(m.bias, std=1e-6) + + def forward(self, x: Tensor): + x = x + self.stochastic_depth(self.attn(self.norm1(x))) + x = x + self.stochastic_depth(self.mlp(self.norm2(x))) + return x + + +class SwinTransformer(nn.Module): + """ + Implements Swin Transformer from the `"Swin Transformer: Hierarchical Vision Transformer using + Shifted Windows" `_ paper. + Args: + patch_size (List[int]): Patch size. + embed_dim (int): Patch embedding dimension. + depths (List(int)): Depth of each Swin Transformer layer. + num_heads (List(int)): Number of attention heads in different layers. + window_size (List[int]): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0. + dropout (float): Dropout rate. Default: 0.0. + attention_dropout (float): Attention dropout rate. Default: 0.0. + stochastic_depth_prob (float): Stochastic depth rate. Default: 0.0. + num_classes (int): Number of classes for classification head. Default: 1000. + block (nn.Module, optional): SwinTransformer Block. Default: None. + norm_layer (nn.Module, optional): Normalization layer. Default: None. + """ + + def __init__( + self, + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + mlp_ratio: float = 4.0, + dropout: float = 0.0, + attention_dropout: float = 0.0, + stochastic_depth_prob: float = 0.0, + num_classes: int = 1000, + norm_layer: Optional[Callable[..., nn.Module]] = None, + block: Optional[Callable[..., nn.Module]] = None, + ): + super().__init__() + _log_api_usage_once(self) + self.num_classes = num_classes + + if block is None: + block = SwinTransformerBlock + + if norm_layer is None: + norm_layer = partial(nn.LayerNorm, eps=1e-5) + + layers: List[nn.Module] = [] + # split image into non-overlapping patches + # layers.append( + # nn.Sequential( + # nn.Conv2d( + # 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + # ), + # Permute([0, 2, 3, 1]), + # norm_layer(embed_dim), + # ) + # ) + self.first_coonv = nn.Sequential( + nn.Conv2d( + 3, embed_dim, kernel_size=(patch_size[0], patch_size[1]), stride=(patch_size[0], patch_size[1]) + ), + Permute([0, 2, 3, 1]), + norm_layer(embed_dim), + ) + + total_stage_blocks = sum(depths) + stage_block_id = 0 + # build SwinTransformer blocks + for i_stage in range(len(depths)): + stage: List[nn.Module] = [] + dim = embed_dim * 2 ** i_stage + for i_layer in range(depths[i_stage]): + # adjust stochastic depth probability based on the depth of the stage block + sd_prob = stochastic_depth_prob * float(stage_block_id) / (total_stage_blocks - 1) + stage.append( + block( + dim, + num_heads[i_stage], + window_size=window_size, + shift_size=[0 if i_layer % 2 == 0 else w // 2 for w in window_size], + mlp_ratio=mlp_ratio, + dropout=dropout, + attention_dropout=attention_dropout, + stochastic_depth_prob=sd_prob, + norm_layer=norm_layer, + ) + ) + stage_block_id += 1 + layers.append(nn.Sequential(*stage)) + # add patch merging layer + if i_stage < (len(depths) - 1): + layers.append(PatchMerging(dim, norm_layer)) + # self.features = nn.Sequential(*layers) + self.features = nn.ModuleList(layers) + + num_features = embed_dim * 2 ** (len(depths) - 1) + self.norm = norm_layer(num_features) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.head = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.zeros_(m.bias) + + def forward(self, x): + feats = [] + x = self.first_coonv(x) + for i, layer in enumerate(self.features): + x = layer(x) + if i in [0, 2, 4, 6]: + feats.append(x.permute(0, 3, 1, 2).contiguous()) + # x = self.features(x) + # x = self.norm(x) + # x = x.permute(0, 3, 1, 2) + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # x = self.head(x) + return feats + + +def _swin_transformer( + patch_size: List[int], + embed_dim: int, + depths: List[int], + num_heads: List[int], + window_size: List[int], + stochastic_depth_prob: float, + weights: Optional[WeightsEnum], + progress: bool, + **kwargs: Any, +) -> SwinTransformer: + if weights is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + + model = SwinTransformer( + patch_size=patch_size, + embed_dim=embed_dim, + depths=depths, + num_heads=num_heads, + window_size=window_size, + stochastic_depth_prob=stochastic_depth_prob, + **kwargs, + ) + + if weights is not None: + ckpt1 = weights.get_state_dict(progress=progress) + ckpt2 = model.state_dict() + kl1 = list(ckpt1.keys()) + for i, k in enumerate(list(ckpt2.keys())): + ckpt2[k] = ckpt1[kl1[i]] + msg = model.load_state_dict(ckpt2, strict=False) + print(f'Load swin_transformer: {msg}') + + return model + + +_COMMON_META = { + "categories": _IMAGENET_CATEGORIES, +} + + +class Swin_T_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_t-704ceda3.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=232, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 28288354, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 81.474, + "acc@5": 95.776, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_S_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_s-5e29d889.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=246, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 49606258, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.196, + "acc@5": 96.360, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class Swin_B_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/swin_b-68c6b09e.pth", + transforms=partial( + ImageClassification, crop_size=224, resize_size=238, interpolation=InterpolationMode.BICUBIC + ), + meta={ + **_COMMON_META, + "num_params": 87768224, + "min_size": (224, 224), + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#swintransformer", + "_metrics": { + "ImageNet-1K": { + "acc@1": 83.582, + "acc@5": 96.640, + } + }, + "_docs": """These weights reproduce closely the results of the paper using a similar training recipe.""", + }, + ) + DEFAULT = IMAGENET1K_V1 + + +def swin_t(*, weights: Optional[Swin_T_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_tiny architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_T_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_T_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_T_Weights + :members: + """ + weights = Swin_T_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.2, + weights=weights, + progress=progress, + **kwargs, + ) + + +def swin_s(*, weights: Optional[Swin_S_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_small architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_S_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_S_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_S_Weights + :members: + """ + weights = Swin_S_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=96, + depths=[2, 2, 18, 2], + num_heads=[3, 6, 12, 24], + window_size=[7, 7], + stochastic_depth_prob=0.3, + weights=weights, + progress=progress, + **kwargs, + ) + + +from torchvision.models._utils import handle_legacy_interface +@handle_legacy_interface(weights=("pretrained", Swin_B_Weights.IMAGENET1K_V1)) +def swin_b(*, weights: Optional[Swin_B_Weights] = None, progress: bool = True, **kwargs: Any) -> SwinTransformer: + """ + Constructs a swin_base architecture from + `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows `_. + + Args: + weights (:class:`~torchvision.models.Swin_B_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.Swin_B_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.swin_transformer.SwinTransformer`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.Swin_B_Weights + :members: + """ + weights = Swin_B_Weights.verify(weights) + + return _swin_transformer( + patch_size=[4, 4], + embed_dim=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=[7, 7], + stochastic_depth_prob=0.5, + weights=weights, + progress=progress, + **kwargs, + ) + +if __name__ == '__main__': + model = swin_b(weights=Swin_B_Weights) + x = torch.rand(1, 3, 320, 320) + y = model(x) + pause = 0 \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py new file mode 100644 index 0000000000000000000000000000000000000000..c6e6dd3f7e1e8a69ba9e36009b5a1088cd82726f --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/uncond_unet.py @@ -0,0 +1,376 @@ +import math +import torch +from torch import nn, einsum +import torch.nn.functional as F +from einops import rearrange, reduce +from functools import partial + + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +# small helper modules + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + +def Upsample(dim, dim_out = None): + return nn.Sequential( + nn.Upsample(scale_factor = 2, mode = 'nearest'), + nn.Conv2d(dim, default(dim_out, dim), 3, padding = 1) + ) + +def Downsample(dim, dim_out = None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + +class WeightStandardizedConv2d(nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, 'o ... -> o 1 1 1', 'mean') + var = reduce(weight, 'o ... -> o 1 1 1', partial(torch.var, unbiased = False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d(x, normalized_weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim = 1, unbiased = False, keepdim = True) + mean = torch.mean(x, dim = 1, keepdim = True) + return (x - mean) * (var + eps).rsqrt() * self.g + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + +# sinusoidal positional embeds + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + +class RandomOrLearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with random (learned optional) sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim, is_random = False): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim), requires_grad = not is_random) + + def forward(self, x): + x = rearrange(x, 'b -> b 1') + freqs = x * rearrange(self.weights, 'd -> 1 d') * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim = -1) + fouriered = torch.cat((x, fouriered), dim = -1) + return fouriered + +# building block modules + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups = 8): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding = 1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift = None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim = None, groups = 8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups = groups) + self.block2 = Block(dim_out, dim_out, groups = groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb = None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1') + scale_shift = time_emb.chunk(2, dim = 1) + + h = self.block1(x, scale_shift = scale_shift) + + h = self.block2(h) + + return h + self.res_conv(x) + +class LinearAttention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + + self.to_out = nn.Sequential( + nn.Conv2d(hidden_dim, dim, 1), + LayerNorm(dim) + ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q.softmax(dim = -2) + k = k.softmax(dim = -1) + + q = q * self.scale + v = v / (h * w) + + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', h = self.heads, x = h, y = w) + return self.to_out(out) + +class Attention(nn.Module): + def __init__(self, dim, heads = 4, dim_head = 32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim = 1) + q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h = self.heads), qkv) + + q = q * self.scale + + sim = einsum('b h d i, b h d j -> b h i j', q, k) + attn = sim.softmax(dim = -1) + out = einsum('b h i j, b h d j -> b h i d', attn, v) + + out = rearrange(out, 'b h (x y) d -> b (h d) x y', x = h, y = w) + return self.to_out(out) + +# model + +class Unet(nn.Module): + def __init__( + self, + dim, + init_dim = None, + out_dim = None, + dim_mults=(1, 2, 4, 8), + channels = 3, + self_condition = False, + resnet_block_groups = 8, + heads=8, + learned_variance = False, + learned_sinusoidal_cond = False, + random_fourier_features = False, + learned_sinusoidal_dim = 16, + out_mul=1, + ): + super().__init__() + + # determine dimensions + + self.channels = channels + self.self_condition = self_condition + input_channels = channels * (2 if self_condition else 1) + + init_dim = default(init_dim, dim) + self.init_conv = nn.Conv2d(input_channels, init_dim, 7, padding = 3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups = resnet_block_groups) + + # time embeddings + + time_dim = dim * 4 + + self.random_or_learned_sinusoidal_cond = learned_sinusoidal_cond or random_fourier_features + + if self.random_or_learned_sinusoidal_cond: + sinu_pos_emb = RandomOrLearnedSinusoidalPosEmb(learned_sinusoidal_dim, random_fourier_features) + fourier_dim = learned_sinusoidal_dim + 1 + else: + sinu_pos_emb = SinusoidalPosEmb(dim) + fourier_dim = dim + + self.time_mlp = nn.Sequential( + sinu_pos_emb, + nn.Linear(fourier_dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + block_klass(dim_in, dim_in, time_emb_dim = time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in, heads=heads))), + Downsample(dim_in, dim_out) if not is_last else nn.Conv2d(dim_in, dim_out, 3, padding = 1) + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim, heads=heads))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim = time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + + self.ups.append(nn.ModuleList([ + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + block_klass(dim_out + dim_in, dim_out, time_emb_dim = time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Upsample(dim_out, dim_in) if not is_last else nn.Conv2d(dim_out, dim_in, 3, padding = 1) + ])) + + default_out_dim = channels * out_mul + self.out_dim = default(out_dim, default_out_dim) + + self.final_res_block = block_klass(dim * 2, dim, time_emb_dim = time_dim) + self.final_conv = nn.Conv2d(dim, self.out_dim, 1) + + def forward(self, x, time, cond=None, x_self_cond=None): ## cond is always None for unconditional model + if self.self_condition: + x_self_cond = default(x_self_cond, lambda: torch.zeros_like(x)) + x = torch.cat((x_self_cond, x), dim = 1) + + x = self.init_conv(x) + r = x.clone() + + t = self.time_mlp(time) + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t) + h.append(x) + + x = block2(x, t) + x = attn(x) + h.append(x) + + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim = 1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim = 1) + x = block2(x, t) + x = attn(x) + + x = upsample(x) + + x = torch.cat((x, r), dim = 1) + + x = self.final_res_block(x, t) + return self.final_conv(x) + +if __name__ == '__main__': + model = Unet(96, out_mul=2, dim_mults=[1,2,4,8], heads=8) + x = torch.rand(2, 3, 8, 8) + time = torch.tensor([2, 5]) + with torch.no_grad(): + y = model(x, time) + pass \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4975a5de594ac8fed9370d95fe978de0b36fc230 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/utils.py @@ -0,0 +1,68 @@ +import os +from pathlib import Path +import time +import logging +import math + +def create_logger(root_dir, des=''): + root_output_dir = Path(root_dir) + # set up logger + if not root_output_dir.exists(): + print('=> creating {}'.format(root_output_dir)) + root_output_dir.mkdir(exist_ok=True, parents=True) + time_str = time.strftime('%Y-%m-%d-%H-%M') + log_file = '{}_{}.log'.format(time_str, des) + final_log_file = root_output_dir / log_file + head = '%(asctime)-15s %(message)s' + logging.basicConfig(filename=str(final_log_file), format=head) + logger = logging.getLogger() + logger.setLevel(logging.INFO) + console = logging.StreamHandler() + logging.getLogger('').addHandler(console) + return logger + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + +def identity(t, *args, **kwargs): + return t + +def cycle(dl): + while True: + for data in dl: + yield data + +def has_int_squareroot(num): + return (math.sqrt(num) ** 2) == num + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + +def convert_image_to_fn(img_type, image): + if image.mode != img_type: + return image.convert(img_type) + return image + +# normalization functions + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + +def dict2str(dict): + s = '' + for k, v in dict.items(): + s += "{}: {:.5f}, ".format(k, v) + return s \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..e33c626885864748ea395b0db73aa500edde05d7 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/vgg.py @@ -0,0 +1,517 @@ +from functools import partial +from typing import Union, List, Dict, Any, Optional, cast + +import torch +import torch.nn as nn + +from torchvision.transforms._presets import ImageClassification +from torchvision.utils import _log_api_usage_once +from torchvision.models._api import WeightsEnum, Weights +from torchvision.models._meta import _IMAGENET_CATEGORIES +from torchvision.models._utils import handle_legacy_interface, _ovewrite_named_param + + +__all__ = [ + "VGG", + "VGG11_Weights", + "VGG11_BN_Weights", + "VGG13_Weights", + "VGG13_BN_Weights", + "VGG16_Weights", + "VGG16_BN_Weights", + "VGG19_Weights", + "VGG19_BN_Weights", + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19", + "vgg19_bn", +] + + +class VGG(nn.Module): + def __init__( + self, features: nn.Module, num_classes: int = 1000, init_weights: bool = True, dropout: float = 0.5 + ) -> None: + super().__init__() + _log_api_usage_once(self) + self.features = features + self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) + self.classifier = nn.Sequential( + nn.Linear(512 * 7 * 7, 4096), + nn.ReLU(True), + nn.Dropout(p=dropout), + nn.Linear(4096, 4096), + nn.ReLU(True), + nn.Dropout(p=dropout), + nn.Linear(4096, num_classes), + ) + if init_weights: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + feats = [] + # x = self.features(x) + # x = self.avgpool(x) + # x = torch.flatten(x, 1) + # x = self.classifier(x) + for i, layer in enumerate(self.features): + x = layer(x) + if i in [9, 16, 23, 30]: + feats.append(x) + return feats + + +def make_layers(cfg: List[Union[str, int]], batch_norm: bool = False) -> nn.Sequential: + layers: List[nn.Module] = [] + in_channels = 3 + for v in cfg: + if v == "M": + layers += [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + v = cast(int, v) + conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = v + return nn.ModuleList(layers) + + +cfgs: Dict[str, List[Union[str, int]]] = { + "A": [64, "M", 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "B": [64, 64, "M", 128, 128, "M", 256, 256, "M", 512, 512, "M", 512, 512, "M"], + "D": [64, 64, "M", 128, 128, "M", 256, 256, 256, "M", 512, 512, 512, "M", 512, 512, 512, "M"], + "E": [64, 64, "M", 128, 128, "M", 256, 256, 256, 256, "M", 512, 512, 512, 512, "M", 512, 512, 512, 512, "M"], +} + + +def _vgg(cfg: str, batch_norm: bool, weights: Optional[WeightsEnum], progress: bool, **kwargs: Any) -> VGG: + if weights is not None: + kwargs["init_weights"] = False + if weights.meta["categories"] is not None: + _ovewrite_named_param(kwargs, "num_classes", len(weights.meta["categories"])) + model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) + + if weights is not None: + ckpt1 = weights.get_state_dict(progress=progress) + ckpt2 = model.state_dict() + kl1 = list(ckpt1.keys()) + for i, k in enumerate(list(ckpt2.keys())): + ckpt2[k] = ckpt1[kl1[i]] + msg = model.load_state_dict(ckpt2, strict=False) + print(f'Load VGG: {msg}') + else: + print('No pretrained weight loaded!') + return model + + +_COMMON_META = { + "min_size": (32, 32), + "categories": _IMAGENET_CATEGORIES, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#alexnet-and-vgg", + "_docs": """These weights were trained from scratch by using a simplified training recipe.""", +} + + +class VGG11_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11-8a719046.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132863336, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.020, + "acc@5": 88.628, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG11_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg11_bn-6002323d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 132868840, + "_metrics": { + "ImageNet-1K": { + "acc@1": 70.370, + "acc@5": 89.810, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13-19584684.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133047848, + "_metrics": { + "ImageNet-1K": { + "acc@1": 69.928, + "acc@5": 89.246, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG13_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg13_bn-abd245e5.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 133053736, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.586, + "acc@5": 90.374, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16-397923af.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138357544, + "_metrics": { + "ImageNet-1K": { + "acc@1": 71.592, + "acc@5": 90.382, + } + }, + }, + ) + IMAGENET1K_FEATURES = Weights( + # Weights ported from https://github.com/amdegroot/ssd.pytorch/ + url="https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth", + transforms=partial( + ImageClassification, + crop_size=224, + mean=(0.48235, 0.45882, 0.40784), + std=(1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0), + ), + meta={ + **_COMMON_META, + "num_params": 138357544, + "categories": None, + "recipe": "https://github.com/amdegroot/ssd.pytorch#training-ssd", + "_metrics": { + "ImageNet-1K": { + "acc@1": float("nan"), + "acc@5": float("nan"), + } + }, + "_docs": """ + These weights can't be used for classification because they are missing values in the `classifier` + module. Only the `features` module has valid values and can be used for feature extraction. The weights + were trained using the original input standardization method as described in the paper. + """, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG16_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg16_bn-6c64b313.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 138365992, + "_metrics": { + "ImageNet-1K": { + "acc@1": 73.360, + "acc@5": 91.516, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19-dcbb9e9d.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143667240, + "_metrics": { + "ImageNet-1K": { + "acc@1": 72.376, + "acc@5": 90.876, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +class VGG19_BN_Weights(WeightsEnum): + IMAGENET1K_V1 = Weights( + url="https://download.pytorch.org/models/vgg19_bn-c79401a0.pth", + transforms=partial(ImageClassification, crop_size=224), + meta={ + **_COMMON_META, + "num_params": 143678248, + "_metrics": { + "ImageNet-1K": { + "acc@1": 74.218, + "acc@5": 91.842, + } + }, + }, + ) + DEFAULT = IMAGENET1K_V1 + + +@handle_legacy_interface(weights=("pretrained", VGG11_Weights.IMAGENET1K_V1)) +def vgg11(*, weights: Optional[VGG11_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-11 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG11_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG11_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG11_Weights + :members: + """ + weights = VGG11_Weights.verify(weights) + + return _vgg("A", False, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG11_BN_Weights.IMAGENET1K_V1)) +def vgg11_bn(*, weights: Optional[VGG11_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-11-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG11_BN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG11_BN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG11_BN_Weights + :members: + """ + weights = VGG11_BN_Weights.verify(weights) + + return _vgg("A", True, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG13_Weights.IMAGENET1K_V1)) +def vgg13(*, weights: Optional[VGG13_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-13 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG13_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG13_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG13_Weights + :members: + """ + weights = VGG13_Weights.verify(weights) + + return _vgg("B", False, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG13_BN_Weights.IMAGENET1K_V1)) +def vgg13_bn(*, weights: Optional[VGG13_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-13-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG13_BN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG13_BN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG13_BN_Weights + :members: + """ + weights = VGG13_BN_Weights.verify(weights) + + return _vgg("B", True, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG16_Weights.IMAGENET1K_V1)) +def vgg16(*, weights: Optional[VGG16_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-16 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG16_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG16_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG16_Weights + :members: + """ + weights = VGG16_Weights.verify(weights) + + return _vgg("D", False, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG16_BN_Weights.IMAGENET1K_V1)) +def vgg16_bn(*, weights: Optional[VGG16_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-16-BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG16_BN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG16_BN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG16_BN_Weights + :members: + """ + weights = VGG16_BN_Weights.verify(weights) + + return _vgg("D", True, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG19_Weights.IMAGENET1K_V1)) +def vgg19(*, weights: Optional[VGG19_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-19 from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG19_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG19_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG19_Weights + :members: + """ + weights = VGG19_Weights.verify(weights) + + return _vgg("E", False, weights, progress, **kwargs) + + +@handle_legacy_interface(weights=("pretrained", VGG19_BN_Weights.IMAGENET1K_V1)) +def vgg19_bn(*, weights: Optional[VGG19_BN_Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: + """VGG-19_BN from `Very Deep Convolutional Networks for Large-Scale Image Recognition `__. + + Args: + weights (:class:`~torchvision.models.VGG19_BN_Weights`, optional): The + pretrained weights to use. See + :class:`~torchvision.models.VGG19_BN_Weights` below for + more details, and possible values. By default, no pre-trained + weights are used. + progress (bool, optional): If True, displays a progress bar of the + download to stderr. Default is True. + **kwargs: parameters passed to the ``torchvision.models.vgg.VGG`` + base class. Please refer to the `source code + `_ + for more details about this class. + + .. autoclass:: torchvision.models.VGG19_BN_Weights + :members: + """ + weights = VGG19_BN_Weights.verify(weights) + + return _vgg("E", True, weights, progress, **kwargs) + + +# The dictionary below is internal implementation detail and will be removed in v0.15 +from torchvision.models._utils import _ModelURLs + + +model_urls = _ModelURLs( + { + "vgg11": VGG11_Weights.IMAGENET1K_V1.url, + "vgg13": VGG13_Weights.IMAGENET1K_V1.url, + "vgg16": VGG16_Weights.IMAGENET1K_V1.url, + "vgg19": VGG19_Weights.IMAGENET1K_V1.url, + "vgg11_bn": VGG11_BN_Weights.IMAGENET1K_V1.url, + "vgg13_bn": VGG13_BN_Weights.IMAGENET1K_V1.url, + "vgg16_bn": VGG16_BN_Weights.IMAGENET1K_V1.url, + "vgg19_bn": VGG19_BN_Weights.IMAGENET1K_V1.url, + } +) diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py new file mode 100644 index 0000000000000000000000000000000000000000..b051a3e10aaa33bdb8852a842229ffaaba68df0a --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wavelet.py @@ -0,0 +1,83 @@ +import pywt +import pywt.data +import torch +from torch import nn +from torch.autograd import Function +import torch.nn.functional as F + + +def create_wavelet_filter(wave, in_size, out_size, type=torch.float): + w = pywt.Wavelet(wave) + dec_hi = torch.tensor(w.dec_hi[::-1], dtype=type) + dec_lo = torch.tensor(w.dec_lo[::-1], dtype=type) + dec_filters = torch.stack([dec_lo.unsqueeze(0) * dec_lo.unsqueeze(1), + dec_lo.unsqueeze(0) * dec_hi.unsqueeze(1), + dec_hi.unsqueeze(0) * dec_lo.unsqueeze(1), + dec_hi.unsqueeze(0) * dec_hi.unsqueeze(1)], dim=0) + + dec_filters = dec_filters[:, None].repeat(in_size, 1, 1, 1) + + rec_hi = torch.tensor(w.rec_hi[::-1], dtype=type).flip(dims=[0]) + rec_lo = torch.tensor(w.rec_lo[::-1], dtype=type).flip(dims=[0]) + rec_filters = torch.stack([rec_lo.unsqueeze(0) * rec_lo.unsqueeze(1), + rec_lo.unsqueeze(0) * rec_hi.unsqueeze(1), + rec_hi.unsqueeze(0) * rec_lo.unsqueeze(1), + rec_hi.unsqueeze(0) * rec_hi.unsqueeze(1)], dim=0) + + rec_filters = rec_filters[:, None].repeat(out_size, 1, 1, 1) + + return dec_filters, rec_filters + + +def wt(x, filters, in_size, level): + _, _, h, w = x.shape + pad = (filters.shape[2] // 2 - 1, filters.shape[3] // 2 - 1) + res = F.conv2d(x, filters, stride=2, groups=in_size, padding=pad) + if level > 1: + res[:, ::4] = wt(res[:, ::4], filters, in_size, level - 1) + res = res.reshape(-1, 2, h // 2, w // 2).transpose(1, 2).reshape(-1, in_size, h, w) + return res + + +def iwt(x, inv_filters, in_size, level): + _, _, h, w = x.shape + pad = (inv_filters.shape[2] // 2 - 1, inv_filters.shape[3] // 2 - 1) + res = x.reshape(-1, h // 2, 2, w // 2).transpose(1, 2).reshape(-1, 4 * in_size, h // 2, w // 2) + if level > 1: + res[:, ::4] = iwt(res[:, ::4], inv_filters, in_size, level - 1) + res = F.conv_transpose2d(res, inv_filters, stride=2, groups=in_size, padding=pad) + return res + + +def get_inverse_transform(weights, in_size, level): + class InverseWaveletTransform(Function): + + @staticmethod + def forward(ctx, input): + with torch.no_grad(): + x = iwt(input, weights, in_size, level) + return x + + @staticmethod + def backward(ctx, grad_output): + grad = wt(grad_output, weights, in_size, level) + return grad, None + + return InverseWaveletTransform().apply + + +def get_transform(weights, in_size, level): + class WaveletTransform(Function): + + @staticmethod + def forward(ctx, input): + with torch.no_grad(): + x = wt(input, weights, in_size, level) + return x + + @staticmethod + def backward(ctx, grad_output): + grad = iwt(grad_output, weights, in_size, level) + return grad, None + + return WaveletTransform().apply \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py new file mode 100644 index 0000000000000000000000000000000000000000..5faeb37b717f47a9f526156ef77ac2711a2eb738 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/denoising_diffusion_pytorch/wcc.py @@ -0,0 +1,101 @@ +from typing import Union, Tuple + +import torch +from torch import nn as nn +from torch.nn import functional as F + +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.quantization import weight_quantize_fn, act_quantize_fn +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch import wavelet + + +class WCC(nn.Conv1d): + def __init__(self, in_channels: int, + out_channels: int, + stride: Union[int, Tuple] = 1, + padding: Union[int, Tuple] = 0, + dilation: Union[int, Tuple] = 1, + groups: int = 1, + bias: bool = False, + levels: int = 3, + compress_rate: float = 0.25, + bit_w: int = 8, + bit_a: int = 8, + wt_type: str = "db1"): + super(WCC, self).__init__(in_channels, out_channels, 1, stride, padding, dilation, groups, bias) + self.layer_type = 'WCC' + self.bit_w = bit_w + self.bit_a = bit_a + + self.weight_quant = weight_quantize_fn(self.bit_w) + self.act_quant = act_quantize_fn(self.bit_a, signed=True) + + self.levels = levels + self.wt_type = wt_type + self.compress_rate = compress_rate + + dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, + in_size=in_channels, + out_size=out_channels) + self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) + self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) + self.wt = wavelet.get_transform(self.wt_filters, in_channels, levels) + self.iwt = wavelet.get_inverse_transform(self.iwt_filters, out_channels, levels) + + self.get_pad = lambda n: ((2 ** levels) - n) % (2 ** levels) + + def forward(self, x): + in_shape = x.shape + pads = (0, self.get_pad(in_shape[2]), 0, self.get_pad(in_shape[3])) + x = F.pad(x, pads) # pad to match 2^(levels) + + weight_q = self.weight_quant(self.weight) # quantize weights + x = self.wt(x) # H + topk, ids = self.compress(x) # T + topk_q = self.act_quant(topk) # quantize activations + topk_q = F.conv1d(topk_q, weight_q, self.bias, self.stride, self.padding, self.dilation, self.groups) # K_1x1 + x = self.decompress(topk_q, ids, x.shape) # T^T + x = self.iwt(x) # H^T + + x = x[:, :, :in_shape[2], :in_shape[3]] # remove pads + return x + + def compress(self, x): + b, c, h, w = x.shape + acc = x.norm(dim=1).pow(2) + acc = acc.view(b, h * w) + k = int(h * w * self.compress_rate) + ids = acc.topk(k, dim=1, sorted=False)[1] + ids.unsqueeze_(dim=1) + topk = x.reshape((b, c, h * w)).gather(dim=2, index=ids.repeat(1, c, 1)) + return topk, ids + + def decompress(self, topk, ids, shape): + b, _, h, w = shape + ids = ids.repeat(1, self.out_channels, 1) + x = torch.zeros(size=(b, self.out_channels, h * w), requires_grad=True, device=topk.device) + x = x.scatter(dim=2, index=ids, src=topk) + x = x.reshape((b, self.out_channels, h, w)) + return x + + def change_wt_params(self, compress_rate, levels, wt_type="db1"): + self.compress_rate = compress_rate + self.levels = levels + dec_filters, rec_filters = wavelet.create_wavelet_filter(wave=self.wt_type, + in_size=self.in_channels, + out_size=self.out_channels) + self.wt_filters = nn.Parameter(dec_filters, requires_grad=False) + self.iwt_filters = nn.Parameter(rec_filters, requires_grad=False) + self.wt = wavelet.get_transform(self.wt_filters, self.in_channels, levels) + self.iwt = wavelet.get_inverse_transform(self.iwt_filters, self.out_channels, levels) + + def change_bit(self, bit_w, bit_a): + self.bit_w = bit_w + self.bit_a = bit_a + self.weight_quant.change_bit(bit_w) + self.act_quant.change_bit(bit_a) + +if __name__ == '__main__': + wcc = WCC(80, 80) + x = torch.rand(1, 80, 80, 80) + y = wcc(x) + pause = 0 \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/model.py b/custom_controlnet_aux/diffusion_edge/model.py new file mode 100644 index 0000000000000000000000000000000000000000..c3c86b512503ba681ed30cb894ec937ba188ee4e --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/model.py @@ -0,0 +1,197 @@ +import numpy as np +import yaml +import argparse +import math +import torch +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.utils import * +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.encoder_decoder import AutoencoderKL +# from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.transmodel import TransModel +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.uncond_unet import Unet +from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.data import * +from fvcore.common.config import CfgNode +from pathlib import Path + +def load_conf(config_file, conf={}): + with open(config_file) as f: + exp_conf = yaml.load(f, Loader=yaml.FullLoader) + for k, v in exp_conf.items(): + conf[k] = v + return conf + +def prepare_args(ckpt_path, sampling_timesteps=1): + return argparse.Namespace( + cfg=load_conf(Path(__file__).parent / "default.yaml"), + pre_weight=ckpt_path, + sampling_timesteps=sampling_timesteps + ) + +class DiffusionEdge: + def __init__(self, args) -> None: + self.cfg = CfgNode(args.cfg) + torch.manual_seed(42) + np.random.seed(42) + model_cfg = self.cfg.model + first_stage_cfg = model_cfg.first_stage + first_stage_model = AutoencoderKL( + ddconfig=first_stage_cfg.ddconfig, + lossconfig=first_stage_cfg.lossconfig, + embed_dim=first_stage_cfg.embed_dim, + ckpt_path=first_stage_cfg.ckpt_path, + ) + if model_cfg.model_name == 'cond_unet': + from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.mask_cond_unet import Unet + unet_cfg = model_cfg.unet + unet = Unet(dim=unet_cfg.dim, + channels=unet_cfg.channels, + dim_mults=unet_cfg.dim_mults, + learned_variance=unet_cfg.get('learned_variance', False), + out_mul=unet_cfg.out_mul, + cond_in_dim=unet_cfg.cond_in_dim, + cond_dim=unet_cfg.cond_dim, + cond_dim_mults=unet_cfg.cond_dim_mults, + window_sizes1=unet_cfg.window_sizes1, + window_sizes2=unet_cfg.window_sizes2, + fourier_scale=unet_cfg.fourier_scale, + cfg=unet_cfg, + ) + else: + raise NotImplementedError + if model_cfg.model_type == 'const_sde': + from custom_controlnet_aux.diffusion_edge.denoising_diffusion_pytorch.ddm_const_sde import LatentDiffusion + else: + raise NotImplementedError(f'{model_cfg.model_type} is not surportted !') + + self.model = LatentDiffusion( + model=unet, + auto_encoder=first_stage_model, + train_sample=model_cfg.train_sample, + image_size=model_cfg.image_size, + timesteps=model_cfg.timesteps, + sampling_timesteps=args.sampling_timesteps, + loss_type=model_cfg.loss_type, + objective=model_cfg.objective, + scale_factor=model_cfg.scale_factor, + scale_by_std=model_cfg.scale_by_std, + scale_by_softsign=model_cfg.scale_by_softsign, + default_scale=model_cfg.get('default_scale', False), + input_keys=model_cfg.input_keys, + ckpt_path=model_cfg.ckpt_path, + ignore_keys=model_cfg.ignore_keys, + only_model=model_cfg.only_model, + start_dist=model_cfg.start_dist, + perceptual_weight=model_cfg.perceptual_weight, + use_l1=model_cfg.get('use_l1', True), + cfg=model_cfg, + ) + self.cfg.sampler.ckpt_path = args.pre_weight + + data = torch.load(self.cfg.sampler.ckpt_path, map_location="cpu") + if self.cfg.sampler.use_ema: + sd = data['ema'] + new_sd = {} + for k in sd.keys(): + if k.startswith("ema_model."): + new_k = k[10:] # remove ema_model. + new_sd[new_k] = sd[k] + sd = new_sd + self.model.load_state_dict(sd) + else: + self.model.load_state_dict(data['model']) + if 'scale_factor' in data['model']: + self.model.scale_factor = data['model']['scale_factor'] + + self.model.eval() + self.device = "cpu" + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, image, batch_size=8): + image = normalize_to_neg_one_to_one(image).to(self.device) + mask = None + if self.cfg.sampler.sample_type == 'whole': + return self.whole_sample(image, raw_size=image.shape[2:], mask=mask) + elif self.cfg.sampler.sample_type == 'slide': + return self.slide_sample(image, crop_size=self.cfg.sampler.get('crop_size', [320, 320]), + stride=self.cfg.sampler.stride, mask=mask, bs=batch_size) + + def whole_sample(self, inputs, raw_size, mask=None): + inputs = F.interpolate(inputs, size=(416, 416), mode='bilinear', align_corners=True) + seg_logits = self.model.sample(batch_size=inputs.shape[0], cond=inputs, mask=mask) + seg_logits = F.interpolate(seg_logits, size=raw_size, mode='bilinear', align_corners=True) + return seg_logits + + def slide_sample(self, inputs, crop_size, stride, mask=None, bs=8): + """Inference by sliding-window with overlap. + + If h_crop > h_img or w_crop > w_img, the small patch will be used to + decode without padding. + + Args: + inputs (tensor): the tensor should have a shape NxCxHxW, + which contains all images in the batch. + batch_img_metas (List[dict]): List of image metainfo where each may + also contain: 'img_shape', 'scale_factor', 'flip', 'img_path', + 'ori_shape', and 'pad_shape'. + For details on the values of these keys see + `mmseg/datasets/pipelines/formatting.py:PackSegInputs`. + + Returns: + Tensor: The segmentation results, seg_logits from model of each + input image. + """ + + h_stride, w_stride = stride + h_crop, w_crop = crop_size + batch_size, _, h_img, w_img = inputs.size() + out_channels = 1 + h_grids = max(h_img - h_crop + h_stride - 1, 0) // h_stride + 1 + w_grids = max(w_img - w_crop + w_stride - 1, 0) // w_stride + 1 + preds = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + # aux_out1 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + # aux_out2 = inputs.new_zeros((batch_size, out_channels, h_img, w_img)) + count_mat = inputs.new_zeros((batch_size, 1, h_img, w_img)) + crop_imgs = [] + x1s = [] + x2s = [] + y1s = [] + y2s = [] + for h_idx in range(h_grids): + for w_idx in range(w_grids): + y1 = h_idx * h_stride + x1 = w_idx * w_stride + y2 = min(y1 + h_crop, h_img) + x2 = min(x1 + w_crop, w_img) + y1 = max(y2 - h_crop, 0) + x1 = max(x2 - w_crop, 0) + crop_img = inputs[:, :, y1:y2, x1:x2] + crop_imgs.append(crop_img) + x1s.append(x1) + x2s.append(x2) + y1s.append(y1) + y2s.append(y2) + crop_imgs = torch.cat(crop_imgs, dim=0) + crop_seg_logits_list = [] + num_windows = crop_imgs.shape[0] + bs = bs + length = math.ceil(num_windows / bs) + for i in range(length): + if i == length - 1: + crop_imgs_temp = crop_imgs[bs * i:num_windows, ...] + else: + crop_imgs_temp = crop_imgs[bs * i:bs * (i + 1), ...] + + crop_seg_logits = self.model.sample(batch_size=crop_imgs_temp.shape[0], cond=crop_imgs_temp, mask=mask) + crop_seg_logits_list.append(crop_seg_logits) + crop_seg_logits = torch.cat(crop_seg_logits_list, dim=0) + for crop_seg_logit, x1, x2, y1, y2 in zip(crop_seg_logits, x1s, x2s, y1s, y2s): + preds += F.pad(crop_seg_logit, + (int(x1), int(preds.shape[3] - x2), int(y1), + int(preds.shape[2] - y2))) + count_mat[:, :, y1:y2, x1:x2] += 1 + + assert (count_mat == 0).sum() == 0 + seg_logits = preds / count_mat + return seg_logits diff --git a/custom_controlnet_aux/diffusion_edge/requirement.txt b/custom_controlnet_aux/diffusion_edge/requirement.txt new file mode 100644 index 0000000000000000000000000000000000000000..333c5e3dd6f2662a9e887bc015895949eaf8d126 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/requirement.txt @@ -0,0 +1,9 @@ +#torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113 +einops +scikit-learn +scipy +tensorboard +fvcore +albumentations +omegaconf +numpy==1.23.5 \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/taming/__init__.py b/custom_controlnet_aux/diffusion_edge/taming/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py b/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..453832c4aa1fa5518398c791e141e0c382b3f504 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/ade20k.py @@ -0,0 +1,124 @@ +import os +import numpy as np +import cv2 +import custom_albumentations as albumentations +from PIL import Image +from torch.utils.data import Dataset + +from custom_controlnet_aux.diffusion_edge.taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/ade20k_examples.txt", + data_root="data/ade20k_images", + segmentation_root="data/ade20k_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=151, shift_segmentation=False) + + +# With semantic map and scene label +class ADE20kBase(Dataset): + def __init__(self, config=None, size=None, random_crop=False, interpolation="bicubic", crop_size=None): + self.split = self.get_split() + self.n_labels = 151 # unknown + 150 + self.data_csv = {"train": "data/ade20k_train.txt", + "validation": "data/ade20k_test.txt"}[self.split] + self.data_root = "data/ade20k_root" + with open(os.path.join(self.data_root, "sceneCategories.txt"), "r") as f: + self.scene_categories = f.read().splitlines() + self.scene_categories = dict(line.split() for line in self.scene_categories) + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, "images", l) + for l in self.image_paths], + "relative_segmentation_path_": [l.replace(".jpg", ".png") + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.data_root, "annotations", + l.replace(".jpg", ".png")) + for l in self.image_paths], + "scene_category": [self.scene_categories[l.split("/")[1].replace(".jpg", "")] + for l in self.image_paths], + } + + size = None if size is not None and size<=0 else size + self.size = size + if crop_size is None: + self.crop_size = size if size is not None else None + else: + self.crop_size = crop_size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + + if crop_size is not None: + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + segmentation = np.array(segmentation).astype(np.uint8) + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, mask=segmentation) + else: + processed = {"image": image, "mask": segmentation} + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class ADE20kTrain(ADE20kBase): + # default to random_crop=True + def __init__(self, config=None, size=None, random_crop=True, interpolation="bicubic", crop_size=None): + super().__init__(config=config, size=size, random_crop=random_crop, + interpolation=interpolation, crop_size=crop_size) + + def get_split(self): + return "train" + + +class ADE20kValidation(ADE20kBase): + def get_split(self): + return "validation" + + +if __name__ == "__main__": + dset = ADE20kValidation() + ex = dset[0] + for k in ["image", "scene_category", "segmentation"]: + print(type(ex[k])) + try: + print(ex[k].shape) + except: + print(ex[k]) diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..bb27bbfb56a9cb9e490f447e93d0d28fe351721b --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_coco.py @@ -0,0 +1,139 @@ +import json +from itertools import chain +from pathlib import Path +from typing import Iterable, Dict, List, Callable, Any +from collections import defaultdict + +from tqdm import tqdm + +from custom_controlnet_aux.diffusion_edge.taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation, ImageDescription, Category + +COCO_PATH_STRUCTURE = { + 'train': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_train2017.json', + 'stuff_annotations': 'annotations/stuff_train2017.json', + 'files': 'train2017' + }, + 'validation': { + 'top_level': '', + 'instances_annotations': 'annotations/instances_val2017.json', + 'stuff_annotations': 'annotations/stuff_val2017.json', + 'files': 'val2017' + } +} + + +def load_image_descriptions(description_json: List[Dict]) -> Dict[str, ImageDescription]: + return { + str(img['id']): ImageDescription( + id=img['id'], + license=img.get('license'), + file_name=img['file_name'], + coco_url=img['coco_url'], + original_size=(img['width'], img['height']), + date_captured=img.get('date_captured'), + flickr_url=img.get('flickr_url') + ) + for img in description_json + } + + +def load_categories(category_json: Iterable) -> Dict[str, Category]: + return {str(cat['id']): Category(id=str(cat['id']), super_category=cat['supercategory'], name=cat['name']) + for cat in category_json if cat['name'] != 'other'} + + +def load_annotations(annotations_json: List[Dict], image_descriptions: Dict[str, ImageDescription], + category_no_for_id: Callable[[str], int], split: str) -> Dict[str, List[Annotation]]: + annotations = defaultdict(list) + total = sum(len(a) for a in annotations_json) + for ann in tqdm(chain(*annotations_json), f'Loading {split} annotations', total=total): + image_id = str(ann['image_id']) + if image_id not in image_descriptions: + raise ValueError(f'image_id [{image_id}] has no image description.') + category_id = ann['category_id'] + try: + category_no = category_no_for_id(str(category_id)) + except KeyError: + continue + + width, height = image_descriptions[image_id].original_size + bbox = (ann['bbox'][0] / width, ann['bbox'][1] / height, ann['bbox'][2] / width, ann['bbox'][3] / height) + + annotations[image_id].append( + Annotation( + id=ann['id'], + area=bbox[2]*bbox[3], # use bbox area + is_group_of=ann['iscrowd'], + image_id=ann['image_id'], + bbox=bbox, + category_id=str(category_id), + category_no=category_no + ) + ) + return dict(annotations) + + +class AnnotatedObjectsCoco(AnnotatedObjectsDataset): + def __init__(self, use_things: bool = True, use_stuff: bool = True, **kwargs): + """ + @param data_path: is the path to the following folder structure: + coco/ + ├── annotations + │ ├── instances_train2017.json + │ ├── instances_val2017.json + │ ├── stuff_train2017.json + │ └── stuff_val2017.json + ├── train2017 + │ ├── 000000000009.jpg + │ ├── 000000000025.jpg + │ └── ... + ├── val2017 + │ ├── 000000000139.jpg + │ ├── 000000000285.jpg + │ └── ... + @param: split: one of 'train' or 'validation' + @param: desired image size (give square images) + """ + super().__init__(**kwargs) + self.use_things = use_things + self.use_stuff = use_stuff + + with open(self.paths['instances_annotations']) as f: + inst_data_json = json.load(f) + with open(self.paths['stuff_annotations']) as f: + stuff_data_json = json.load(f) + + category_jsons = [] + annotation_jsons = [] + if self.use_things: + category_jsons.append(inst_data_json['categories']) + annotation_jsons.append(inst_data_json['annotations']) + if self.use_stuff: + category_jsons.append(stuff_data_json['categories']) + annotation_jsons.append(stuff_data_json['annotations']) + + self.categories = load_categories(chain(*category_jsons)) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = load_image_descriptions(inst_data_json['images']) + annotations = load_annotations(annotation_jsons, self.image_descriptions, self.get_category_number, self.split) + self.annotations = self.filter_object_number(annotations, self.min_object_area, + self.min_objects_per_image, self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in COCO_PATH_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for COCO data.]') + return COCO_PATH_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(self.image_descriptions[str(image_id)].file_name) + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + # noinspection PyProtectedMember + return self.image_descriptions[image_id]._asdict() diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..af3af83a56abafa1c91c610b34c4ff9e333bcfd4 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_dataset.py @@ -0,0 +1,218 @@ +from pathlib import Path +from typing import Optional, List, Callable, Dict, Any, Union +import warnings + +import PIL.Image as pil_image +from torch import Tensor +from torch.utils.data import Dataset +from torchvision import transforms + +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_bbox import ObjectsBoundingBoxConditionalBuilder +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import load_object_from_string +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, CropMethodType, Image, Annotation, SplitType +from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import CenterCropReturnCoordinates, RandomCrop1dReturnCoordinates, \ + Random2dCropReturnCoordinates, RandomHorizontalFlipReturn, convert_pil_to_tensor + + +class AnnotatedObjectsDataset(Dataset): + def __init__(self, data_path: Union[str, Path], split: SplitType, keys: List[str], target_image_size: int, + min_object_area: float, min_objects_per_image: int, max_objects_per_image: int, + crop_method: CropMethodType, random_flip: bool, no_tokens: int, use_group_parameter: bool, + encode_crop: bool, category_allow_list_target: str = "", category_mapping_target: str = "", + no_object_classes: Optional[int] = None): + self.data_path = data_path + self.split = split + self.keys = keys + self.target_image_size = target_image_size + self.min_object_area = min_object_area + self.min_objects_per_image = min_objects_per_image + self.max_objects_per_image = max_objects_per_image + self.crop_method = crop_method + self.random_flip = random_flip + self.no_tokens = no_tokens + self.use_group_parameter = use_group_parameter + self.encode_crop = encode_crop + + self.annotations = None + self.image_descriptions = None + self.categories = None + self.category_ids = None + self.category_number = None + self.image_ids = None + self.transform_functions: List[Callable] = self.setup_transform(target_image_size, crop_method, random_flip) + self.paths = self.build_paths(self.data_path) + self._conditional_builders = None + self.category_allow_list = None + if category_allow_list_target: + allow_list = load_object_from_string(category_allow_list_target) + self.category_allow_list = {name for name, _ in allow_list} + self.category_mapping = {} + if category_mapping_target: + self.category_mapping = load_object_from_string(category_mapping_target) + self.no_object_classes = no_object_classes + + def build_paths(self, top_level: Union[str, Path]) -> Dict[str, Path]: + top_level = Path(top_level) + sub_paths = {name: top_level.joinpath(sub_path) for name, sub_path in self.get_path_structure().items()} + for path in sub_paths.values(): + if not path.exists(): + raise FileNotFoundError(f'{type(self).__name__} data structure error: [{path}] does not exist.') + return sub_paths + + @staticmethod + def load_image_from_disk(path: Path) -> Image: + return pil_image.open(path).convert('RGB') + + @staticmethod + def setup_transform(target_image_size: int, crop_method: CropMethodType, random_flip: bool): + transform_functions = [] + if crop_method == 'none': + transform_functions.append(transforms.Resize((target_image_size, target_image_size))) + elif crop_method == 'center': + transform_functions.extend([ + transforms.Resize(target_image_size), + CenterCropReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-1d': + transform_functions.extend([ + transforms.Resize(target_image_size), + RandomCrop1dReturnCoordinates(target_image_size) + ]) + elif crop_method == 'random-2d': + transform_functions.extend([ + Random2dCropReturnCoordinates(target_image_size), + transforms.Resize(target_image_size) + ]) + elif crop_method is None: + return None + else: + raise ValueError(f'Received invalid crop method [{crop_method}].') + if random_flip: + transform_functions.append(RandomHorizontalFlipReturn()) + transform_functions.append(transforms.Lambda(lambda x: x / 127.5 - 1.)) + return transform_functions + + def image_transform(self, x: Tensor) -> (Optional[BoundingBox], Optional[bool], Tensor): + crop_bbox = None + flipped = None + for t in self.transform_functions: + if isinstance(t, (RandomCrop1dReturnCoordinates, CenterCropReturnCoordinates, Random2dCropReturnCoordinates)): + crop_bbox, x = t(x) + elif isinstance(t, RandomHorizontalFlipReturn): + flipped, x = t(x) + else: + x = t(x) + return crop_bbox, flipped, x + + @property + def no_classes(self) -> int: + return self.no_object_classes if self.no_object_classes else len(self.categories) + + @property + def conditional_builders(self) -> ObjectsCenterPointsConditionalBuilder: + # cannot set this up in init because no_classes is only known after loading data in init of superclass + if self._conditional_builders is None: + self._conditional_builders = { + 'objects_center_points': ObjectsCenterPointsConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ), + 'objects_bbox': ObjectsBoundingBoxConditionalBuilder( + self.no_classes, + self.max_objects_per_image, + self.no_tokens, + self.encode_crop, + self.use_group_parameter, + getattr(self, 'use_additional_parameters', False) + ) + } + return self._conditional_builders + + def filter_categories(self) -> None: + if self.category_allow_list: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.name in self.category_allow_list} + if self.category_mapping: + self.categories = {id_: cat for id_, cat in self.categories.items() if cat.id not in self.category_mapping} + + def setup_category_id_and_number(self) -> None: + self.category_ids = list(self.categories.keys()) + self.category_ids.sort() + if '/m/01s55n' in self.category_ids: + self.category_ids.remove('/m/01s55n') + self.category_ids.append('/m/01s55n') + self.category_number = {category_id: i for i, category_id in enumerate(self.category_ids)} + if self.category_allow_list is not None and self.category_mapping is None \ + and len(self.category_ids) != len(self.category_allow_list): + warnings.warn('Unexpected number of categories: Mismatch with category_allow_list. ' + 'Make sure all names in category_allow_list exist.') + + def clean_up_annotations_and_image_descriptions(self) -> None: + image_id_set = set(self.image_ids) + self.annotations = {k: v for k, v in self.annotations.items() if k in image_id_set} + self.image_descriptions = {k: v for k, v in self.image_descriptions.items() if k in image_id_set} + + @staticmethod + def filter_object_number(all_annotations: Dict[str, List[Annotation]], min_object_area: float, + min_objects_per_image: int, max_objects_per_image: int) -> Dict[str, List[Annotation]]: + filtered = {} + for image_id, annotations in all_annotations.items(): + annotations_with_min_area = [a for a in annotations if a.area > min_object_area] + if min_objects_per_image <= len(annotations_with_min_area) <= max_objects_per_image: + filtered[image_id] = annotations_with_min_area + return filtered + + def __len__(self): + return len(self.image_ids) + + def __getitem__(self, n: int) -> Dict[str, Any]: + image_id = self.get_image_id(n) + sample = self.get_image_description(image_id) + sample['annotations'] = self.get_annotation(image_id) + + if 'image' in self.keys: + sample['image_path'] = str(self.get_image_path(image_id)) + sample['image'] = self.load_image_from_disk(sample['image_path']) + sample['image'] = convert_pil_to_tensor(sample['image']) + sample['crop_bbox'], sample['flipped'], sample['image'] = self.image_transform(sample['image']) + sample['image'] = sample['image'].permute(1, 2, 0) + + for conditional, builder in self.conditional_builders.items(): + if conditional in self.keys: + sample[conditional] = builder.build(sample['annotations'], sample['crop_bbox'], sample['flipped']) + + if self.keys: + # only return specified keys + sample = {key: sample[key] for key in self.keys} + return sample + + def get_image_id(self, no: int) -> str: + return self.image_ids[no] + + def get_annotation(self, image_id: str) -> str: + return self.annotations[image_id] + + def get_textual_label_for_category_id(self, category_id: str) -> str: + return self.categories[category_id].name + + def get_textual_label_for_category_no(self, category_no: int) -> str: + return self.categories[self.get_category_id(category_no)].name + + def get_category_number(self, category_id: str) -> int: + return self.category_number[category_id] + + def get_category_id(self, category_no: int) -> str: + return self.category_ids[category_no] + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + raise NotImplementedError() + + def get_path_structure(self): + raise NotImplementedError + + def get_image_path(self, image_id: str) -> Path: + raise NotImplementedError diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py new file mode 100644 index 0000000000000000000000000000000000000000..bc5c8185496ec1be6f16388dee8217f4cb686f3a --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/annotated_objects_open_images.py @@ -0,0 +1,137 @@ +from collections import defaultdict +from csv import DictReader, reader as TupleReader +from pathlib import Path +from typing import Dict, List, Any +import warnings + +from custom_controlnet_aux.diffusion_edge.taming.data.annotated_objects_dataset import AnnotatedObjectsDataset +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation, Category +from tqdm import tqdm + +OPEN_IMAGES_STRUCTURE = { + 'train': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'oidv6-train-annotations-bbox.csv', + 'file_list': 'train-images-boxable.csv', + 'files': 'train' + }, + 'validation': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'validation-annotations-bbox.csv', + 'file_list': 'validation-images.csv', + 'files': 'validation' + }, + 'test': { + 'top_level': '', + 'class_descriptions': 'class-descriptions-boxable.csv', + 'annotations': 'test-annotations-bbox.csv', + 'file_list': 'test-images.csv', + 'files': 'test' + } +} + + +def load_annotations(descriptor_path: Path, min_object_area: float, category_mapping: Dict[str, str], + category_no_for_id: Dict[str, int]) -> Dict[str, List[Annotation]]: + annotations: Dict[str, List[Annotation]] = defaultdict(list) + with open(descriptor_path) as file: + reader = DictReader(file) + for i, row in tqdm(enumerate(reader), total=14620000, desc='Loading OpenImages annotations'): + width = float(row['XMax']) - float(row['XMin']) + height = float(row['YMax']) - float(row['YMin']) + area = width * height + category_id = row['LabelName'] + if category_id in category_mapping: + category_id = category_mapping[category_id] + if area >= min_object_area and category_id in category_no_for_id: + annotations[row['ImageID']].append( + Annotation( + id=i, + image_id=row['ImageID'], + source=row['Source'], + category_id=category_id, + category_no=category_no_for_id[category_id], + confidence=float(row['Confidence']), + bbox=(float(row['XMin']), float(row['YMin']), width, height), + area=area, + is_occluded=bool(int(row['IsOccluded'])), + is_truncated=bool(int(row['IsTruncated'])), + is_group_of=bool(int(row['IsGroupOf'])), + is_depiction=bool(int(row['IsDepiction'])), + is_inside=bool(int(row['IsInside'])) + ) + ) + if 'train' in str(descriptor_path) and i < 14000000: + warnings.warn(f'Running with subset of Open Images. Train dataset has length [{len(annotations)}].') + return dict(annotations) + + +def load_image_ids(csv_path: Path) -> List[str]: + with open(csv_path) as file: + reader = DictReader(file) + return [row['image_name'] for row in reader] + + +def load_categories(csv_path: Path) -> Dict[str, Category]: + with open(csv_path) as file: + reader = TupleReader(file) + return {row[0]: Category(id=row[0], name=row[1], super_category=None) for row in reader} + + +class AnnotatedObjectsOpenImages(AnnotatedObjectsDataset): + def __init__(self, use_additional_parameters: bool, **kwargs): + """ + @param data_path: is the path to the following folder structure: + open_images/ + │ oidv6-train-annotations-bbox.csv + ├── class-descriptions-boxable.csv + ├── oidv6-train-annotations-bbox.csv + ├── test + │ ├── 000026e7ee790996.jpg + │ ├── 000062a39995e348.jpg + │ └── ... + ├── test-annotations-bbox.csv + ├── test-images.csv + ├── train + │ ├── 000002b66c9c498e.jpg + │ ├── 000002b97e5471a0.jpg + │ └── ... + ├── train-images-boxable.csv + ├── validation + │ ├── 0001eeaf4aed83f9.jpg + │ ├── 0004886b7d043cfd.jpg + │ └── ... + ├── validation-annotations-bbox.csv + └── validation-images.csv + @param: split: one of 'train', 'validation' or 'test' + @param: desired image size (returns square images) + """ + + super().__init__(**kwargs) + self.use_additional_parameters = use_additional_parameters + + self.categories = load_categories(self.paths['class_descriptions']) + self.filter_categories() + self.setup_category_id_and_number() + + self.image_descriptions = {} + annotations = load_annotations(self.paths['annotations'], self.min_object_area, self.category_mapping, + self.category_number) + self.annotations = self.filter_object_number(annotations, self.min_object_area, self.min_objects_per_image, + self.max_objects_per_image) + self.image_ids = list(self.annotations.keys()) + self.clean_up_annotations_and_image_descriptions() + + def get_path_structure(self) -> Dict[str, str]: + if self.split not in OPEN_IMAGES_STRUCTURE: + raise ValueError(f'Split [{self.split} does not exist for Open Images data.]') + return OPEN_IMAGES_STRUCTURE[self.split] + + def get_image_path(self, image_id: str) -> Path: + return self.paths['files'].joinpath(f'{image_id:0>16}.jpg') + + def get_image_description(self, image_id: str) -> Dict[str, Any]: + image_path = self.get_image_path(image_id) + return {'file_path': str(image_path), 'file_name': image_path.name} diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/base.py b/custom_controlnet_aux/diffusion_edge/taming/data/base.py new file mode 100644 index 0000000000000000000000000000000000000000..e43569ab0be34e4a91aa84478f56123d9d9e0429 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/base.py @@ -0,0 +1,70 @@ +import bisect +import numpy as np +import custom_albumentations as albumentations +from PIL import Image +from torch.utils.data import Dataset, ConcatDataset + + +class ConcatDatasetWithIndex(ConcatDataset): + """Modified from original pytorch code to return dataset idx""" + def __getitem__(self, idx): + if idx < 0: + if -idx > len(self): + raise ValueError("absolute value of index should not exceed dataset length") + idx = len(self) + idx + dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) + if dataset_idx == 0: + sample_idx = idx + else: + sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] + return self.datasets[dataset_idx][sample_idx], dataset_idx + + +class ImagePaths(Dataset): + def __init__(self, paths, size=None, random_crop=False, labels=None): + self.size = size + self.random_crop = random_crop + + self.labels = dict() if labels is None else labels + self.labels["file_path_"] = paths + self._length = len(paths) + + if self.size is not None and self.size > 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.size) + if not self.random_crop: + self.cropper = albumentations.CenterCrop(height=self.size,width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size,width=self.size) + self.preprocessor = albumentations.Compose([self.rescaler, self.cropper]) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return self._length + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def __getitem__(self, i): + example = dict() + example["image"] = self.preprocess_image(self.labels["file_path_"][i]) + for k in self.labels: + example[k] = self.labels[k][i] + return example + + +class NumpyPaths(ImagePaths): + def preprocess_image(self, image_path): + image = np.load(image_path).squeeze(0) # 3 x 1024 x 1024 + image = np.transpose(image, (1,2,0)) + image = Image.fromarray(image, mode="RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/coco.py b/custom_controlnet_aux/diffusion_edge/taming/data/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..3dff4fccad0d5467ccf3c6ca1ee97ff7efd2647b --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/coco.py @@ -0,0 +1,176 @@ +import os +import json +import custom_albumentations as albumentations +import numpy as np +from PIL import Image +from tqdm import tqdm +from torch.utils.data import Dataset + +from custom_controlnet_aux.diffusion_edge.taming.data.sflckr import SegmentationBase # for examples included in repo + + +class Examples(SegmentationBase): + def __init__(self, size=256, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/coco_examples.txt", + data_root="data/coco_images", + segmentation_root="data/coco_segmentations", + size=size, random_crop=random_crop, + interpolation=interpolation, + n_labels=183, shift_segmentation=True) + + +class CocoBase(Dataset): + """needed for (image, caption, segmentation) pairs""" + def __init__(self, size=None, dataroot="", datajson="", onehot_segmentation=False, use_stuffthing=False, + crop_size=None, force_no_crop=False, given_files=None): + self.split = self.get_split() + self.size = size + if crop_size is None: + self.crop_size = size + else: + self.crop_size = crop_size + + self.onehot = onehot_segmentation # return segmentation as rgb or one hot + self.stuffthing = use_stuffthing # include thing in segmentation + if self.onehot and not self.stuffthing: + raise NotImplemented("One hot mode is only supported for the " + "stuffthings version because labels are stored " + "a bit different.") + + data_json = datajson + with open(data_json) as json_file: + self.json_data = json.load(json_file) + self.img_id_to_captions = dict() + self.img_id_to_filepath = dict() + self.img_id_to_segmentation_filepath = dict() + + assert data_json.split("/")[-1] in ["captions_train2017.json", + "captions_val2017.json"] + if self.stuffthing: + self.segmentation_prefix = ( + "data/cocostuffthings/val2017" if + data_json.endswith("captions_val2017.json") else + "data/cocostuffthings/train2017") + else: + self.segmentation_prefix = ( + "data/coco/annotations/stuff_val2017_pixelmaps" if + data_json.endswith("captions_val2017.json") else + "data/coco/annotations/stuff_train2017_pixelmaps") + + imagedirs = self.json_data["images"] + self.labels = {"image_ids": list()} + for imgdir in tqdm(imagedirs, desc="ImgToPath"): + self.img_id_to_filepath[imgdir["id"]] = os.path.join(dataroot, imgdir["file_name"]) + self.img_id_to_captions[imgdir["id"]] = list() + pngfilename = imgdir["file_name"].replace("jpg", "png") + self.img_id_to_segmentation_filepath[imgdir["id"]] = os.path.join( + self.segmentation_prefix, pngfilename) + if given_files is not None: + if pngfilename in given_files: + self.labels["image_ids"].append(imgdir["id"]) + else: + self.labels["image_ids"].append(imgdir["id"]) + + capdirs = self.json_data["annotations"] + for capdir in tqdm(capdirs, desc="ImgToCaptions"): + # there are in average 5 captions per image + self.img_id_to_captions[capdir["image_id"]].append(np.array([capdir["caption"]])) + + self.rescaler = albumentations.SmallestMaxSize(max_size=self.size) + if self.split=="validation": + self.cropper = albumentations.CenterCrop(height=self.crop_size, width=self.crop_size) + else: + self.cropper = albumentations.RandomCrop(height=self.crop_size, width=self.crop_size) + self.preprocessor = albumentations.Compose( + [self.rescaler, self.cropper], + additional_targets={"segmentation": "image"}) + if force_no_crop: + self.rescaler = albumentations.Resize(height=self.size, width=self.size) + self.preprocessor = albumentations.Compose( + [self.rescaler], + additional_targets={"segmentation": "image"}) + + def __len__(self): + return len(self.labels["image_ids"]) + + def preprocess_image(self, image_path, segmentation_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + + segmentation = Image.open(segmentation_path) + if not self.onehot and not segmentation.mode == "RGB": + segmentation = segmentation.convert("RGB") + segmentation = np.array(segmentation).astype(np.uint8) + if self.onehot: + assert self.stuffthing + # stored in caffe format: unlabeled==255. stuff and thing from + # 0-181. to be compatible with the labels in + # https://github.com/nightrome/cocostuff/blob/master/labels.txt + # we shift stuffthing one to the right and put unlabeled in zero + # as long as segmentation is uint8 shifting to right handles the + # latter too + assert segmentation.dtype == np.uint8 + segmentation = segmentation + 1 + + processed = self.preprocessor(image=image, segmentation=segmentation) + image, segmentation = processed["image"], processed["segmentation"] + image = (image / 127.5 - 1.0).astype(np.float32) + + if self.onehot: + assert segmentation.dtype == np.uint8 + # make it one hot + n_labels = 183 + flatseg = np.ravel(segmentation) + onehot = np.zeros((flatseg.size, n_labels), dtype=np.bool) + onehot[np.arange(flatseg.size), flatseg] = True + onehot = onehot.reshape(segmentation.shape + (n_labels,)).astype(int) + segmentation = onehot + else: + segmentation = (segmentation / 127.5 - 1.0).astype(np.float32) + return image, segmentation + + def __getitem__(self, i): + img_path = self.img_id_to_filepath[self.labels["image_ids"][i]] + seg_path = self.img_id_to_segmentation_filepath[self.labels["image_ids"][i]] + image, segmentation = self.preprocess_image(img_path, seg_path) + captions = self.img_id_to_captions[self.labels["image_ids"][i]] + # randomly draw one of all available captions per image + caption = captions[np.random.randint(0, len(captions))] + example = {"image": image, + "caption": [str(caption[0])], + "segmentation": segmentation, + "img_path": img_path, + "seg_path": seg_path, + "filename_": img_path.split(os.sep)[-1] + } + return example + + +class CocoImagesAndCaptionsTrain(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False): + super().__init__(size=size, + dataroot="data/coco/train2017", + datajson="data/coco/annotations/captions_train2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop) + + def get_split(self): + return "train" + + +class CocoImagesAndCaptionsValidation(CocoBase): + """returns a pair of (image, caption)""" + def __init__(self, size, onehot_segmentation=False, use_stuffthing=False, crop_size=None, force_no_crop=False, + given_files=None): + super().__init__(size=size, + dataroot="data/coco/val2017", + datajson="data/coco/annotations/captions_val2017.json", + onehot_segmentation=onehot_segmentation, + use_stuffthing=use_stuffthing, crop_size=crop_size, force_no_crop=force_no_crop, + given_files=given_files) + + def get_split(self): + return "validation" diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py new file mode 100644 index 0000000000000000000000000000000000000000..44fa4867952d60cdd84060dd186b126cf7eefbe2 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_bbox.py @@ -0,0 +1,60 @@ +from itertools import cycle +from typing import List, Tuple, Callable, Optional + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.objects_center_points import ObjectsCenterPointsConditionalBuilder +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, additional_parameters_string, \ + pad_list, get_plot_font_size, absolute_bbox + + +class ObjectsBoundingBoxConditionalBuilder(ObjectsCenterPointsConditionalBuilder): + @property + def object_descriptor_length(self) -> int: + return 3 + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_triples = [ + (self.object_representation(ann), *self.token_pair_from_bbox(ann.bbox)) + for ann in annotations + ] + empty_triple = (self.none, self.none, self.none) + object_triples = pad_list(object_triples, empty_triple, self.no_max_objects) + return object_triples + + def inverse_build(self, conditional: LongTensor) -> Tuple[List[Tuple[int, BoundingBox]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + object_triples = grouper(conditional_list, 3) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_triple[0], self.bbox_from_token_pair(object_triple[1], object_triple[2])) + for object_triple in object_triples if object_triple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + font = ImageFont.truetype( + "/usr/share/fonts/truetype/lato/Lato-Regular.ttf", + size=get_plot_font_size(font_size, figure_size) + ) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, bbox), color in zip(description, cycle(COLOR_PALETTE)): + annotation = self.representation_to_annotation(representation) + class_label = label_for_category_no(annotation.category_no) + ' ' + additional_parameters_string(annotation) + bbox = absolute_bbox(bbox, width, height) + draw.rectangle(bbox, outline=color, width=line_width) + draw.text((bbox[0] + line_width, bbox[1] + line_width), class_label, anchor='la', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py new file mode 100644 index 0000000000000000000000000000000000000000..8fd1989c6647f08384d69b3cad7fe57288776d0b --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/objects_center_points.py @@ -0,0 +1,168 @@ +import math +import random +import warnings +from itertools import cycle +from typing import List, Optional, Tuple, Callable + +from PIL import Image as pil_image, ImageDraw as pil_img_draw, ImageFont +from more_itertools.recipes import grouper +from custom_controlnet_aux.diffusion_edge.taming.data.conditional_builder.utils import COLOR_PALETTE, WHITE, GRAY_75, BLACK, FULL_CROP, filter_annotations, \ + additional_parameters_string, horizontally_flip_bbox, pad_list, get_circle_size, get_plot_font_size, \ + absolute_bbox, rescale_annotations +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation +from custom_controlnet_aux.diffusion_edge.taming.data.image_transforms import convert_pil_to_tensor +from torch import LongTensor, Tensor + + +class ObjectsCenterPointsConditionalBuilder: + def __init__(self, no_object_classes: int, no_max_objects: int, no_tokens: int, encode_crop: bool, + use_group_parameter: bool, use_additional_parameters: bool): + self.no_object_classes = no_object_classes + self.no_max_objects = no_max_objects + self.no_tokens = no_tokens + self.encode_crop = encode_crop + self.no_sections = int(math.sqrt(self.no_tokens)) + self.use_group_parameter = use_group_parameter + self.use_additional_parameters = use_additional_parameters + + @property + def none(self) -> int: + return self.no_tokens - 1 + + @property + def object_descriptor_length(self) -> int: + return 2 + + @property + def embedding_dim(self) -> int: + extra_length = 2 if self.encode_crop else 0 + return self.no_max_objects * self.object_descriptor_length + extra_length + + def tokenize_coordinates(self, x: float, y: float) -> int: + """ + Express 2d coordinates with one number. + Example: assume self.no_tokens = 16, then no_sections = 4: + 0 0 0 0 + 0 0 # 0 + 0 0 0 0 + 0 0 0 x + Then the # position corresponds to token 6, the x position to token 15. + @param x: float in [0, 1] + @param y: float in [0, 1] + @return: discrete tokenized coordinate + """ + x_discrete = int(round(x * (self.no_sections - 1))) + y_discrete = int(round(y * (self.no_sections - 1))) + return y_discrete * self.no_sections + x_discrete + + def coordinates_from_token(self, token: int) -> (float, float): + x = token % self.no_sections + y = token // self.no_sections + return x / (self.no_sections - 1), y / (self.no_sections - 1) + + def bbox_from_token_pair(self, token1: int, token2: int) -> BoundingBox: + x0, y0 = self.coordinates_from_token(token1) + x1, y1 = self.coordinates_from_token(token2) + return x0, y0, x1 - x0, y1 - y0 + + def token_pair_from_bbox(self, bbox: BoundingBox) -> Tuple[int, int]: + return self.tokenize_coordinates(bbox[0], bbox[1]), \ + self.tokenize_coordinates(bbox[0] + bbox[2], bbox[1] + bbox[3]) + + def inverse_build(self, conditional: LongTensor) \ + -> Tuple[List[Tuple[int, Tuple[float, float]]], Optional[BoundingBox]]: + conditional_list = conditional.tolist() + crop_coordinates = None + if self.encode_crop: + crop_coordinates = self.bbox_from_token_pair(conditional_list[-2], conditional_list[-1]) + conditional_list = conditional_list[:-2] + table_of_content = grouper(conditional_list, self.object_descriptor_length) + assert conditional.shape[0] == self.embedding_dim + return [ + (object_tuple[0], self.coordinates_from_token(object_tuple[1])) + for object_tuple in table_of_content if object_tuple[0] != self.none + ], crop_coordinates + + def plot(self, conditional: LongTensor, label_for_category_no: Callable[[int], str], figure_size: Tuple[int, int], + line_width: int = 3, font_size: Optional[int] = None) -> Tensor: + plot = pil_image.new('RGB', figure_size, WHITE) + draw = pil_img_draw.Draw(plot) + circle_size = get_circle_size(figure_size) + font = ImageFont.truetype('/usr/share/fonts/truetype/lato/Lato-Regular.ttf', + size=get_plot_font_size(font_size, figure_size)) + width, height = plot.size + description, crop_coordinates = self.inverse_build(conditional) + for (representation, (x, y)), color in zip(description, cycle(COLOR_PALETTE)): + x_abs, y_abs = x * width, y * height + ann = self.representation_to_annotation(representation) + label = label_for_category_no(ann.category_no) + ' ' + additional_parameters_string(ann) + ellipse_bbox = [x_abs - circle_size, y_abs - circle_size, x_abs + circle_size, y_abs + circle_size] + draw.ellipse(ellipse_bbox, fill=color, width=0) + draw.text((x_abs, y_abs), label, anchor='md', fill=BLACK, font=font) + if crop_coordinates is not None: + draw.rectangle(absolute_bbox(crop_coordinates, width, height), outline=GRAY_75, width=line_width) + return convert_pil_to_tensor(plot) / 127.5 - 1. + + def object_representation(self, annotation: Annotation) -> int: + modifier = 0 + if self.use_group_parameter: + modifier |= 1 * (annotation.is_group_of is True) + if self.use_additional_parameters: + modifier |= 2 * (annotation.is_occluded is True) + modifier |= 4 * (annotation.is_depiction is True) + modifier |= 8 * (annotation.is_inside is True) + return annotation.category_no + self.no_object_classes * modifier + + def representation_to_annotation(self, representation: int) -> Annotation: + category_no = representation % self.no_object_classes + modifier = representation // self.no_object_classes + # noinspection PyTypeChecker + return Annotation( + area=None, image_id=None, bbox=None, category_id=None, id=None, source=None, confidence=None, + category_no=category_no, + is_group_of=bool((modifier & 1) * self.use_group_parameter), + is_occluded=bool((modifier & 2) * self.use_additional_parameters), + is_depiction=bool((modifier & 4) * self.use_additional_parameters), + is_inside=bool((modifier & 8) * self.use_additional_parameters) + ) + + def _crop_encoder(self, crop_coordinates: BoundingBox) -> List[int]: + return list(self.token_pair_from_bbox(crop_coordinates)) + + def _make_object_descriptors(self, annotations: List[Annotation]) -> List[Tuple[int, ...]]: + object_tuples = [ + (self.object_representation(a), + self.tokenize_coordinates(a.bbox[0] + a.bbox[2] / 2, a.bbox[1] + a.bbox[3] / 2)) + for a in annotations + ] + empty_tuple = (self.none, self.none) + object_tuples = pad_list(object_tuples, empty_tuple, self.no_max_objects) + return object_tuples + + def build(self, annotations: List, crop_coordinates: Optional[BoundingBox] = None, horizontal_flip: bool = False) \ + -> LongTensor: + if len(annotations) == 0: + warnings.warn('Did not receive any annotations.') + if len(annotations) > self.no_max_objects: + warnings.warn('Received more annotations than allowed.') + annotations = annotations[:self.no_max_objects] + + if not crop_coordinates: + crop_coordinates = FULL_CROP + + random.shuffle(annotations) + annotations = filter_annotations(annotations, crop_coordinates) + if self.encode_crop: + annotations = rescale_annotations(annotations, FULL_CROP, horizontal_flip) + if horizontal_flip: + crop_coordinates = horizontally_flip_bbox(crop_coordinates) + extra = self._crop_encoder(crop_coordinates) + else: + annotations = rescale_annotations(annotations, crop_coordinates, horizontal_flip) + extra = [] + + object_tuples = self._make_object_descriptors(annotations) + flattened = [token for tuple_ in object_tuples for token in tuple_] + extra + assert len(flattened) == self.embedding_dim + assert all(0 <= value < self.no_tokens for value in flattened) + return LongTensor(flattened) diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b97140e32b3d0772247ddb7121c22ce56cc90c4f --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/conditional_builder/utils.py @@ -0,0 +1,105 @@ +import importlib +from typing import List, Any, Tuple, Optional + +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Annotation + +# source: seaborn, color palette tab10 +COLOR_PALETTE = [(30, 118, 179), (255, 126, 13), (43, 159, 43), (213, 38, 39), (147, 102, 188), + (139, 85, 74), (226, 118, 193), (126, 126, 126), (187, 188, 33), (22, 189, 206)] +BLACK = (0, 0, 0) +GRAY_75 = (63, 63, 63) +GRAY_50 = (127, 127, 127) +GRAY_25 = (191, 191, 191) +WHITE = (255, 255, 255) +FULL_CROP = (0., 0., 1., 1.) + + +def intersection_area(rectangle1: BoundingBox, rectangle2: BoundingBox) -> float: + """ + Give intersection area of two rectangles. + @param rectangle1: (x0, y0, w, h) of first rectangle + @param rectangle2: (x0, y0, w, h) of second rectangle + """ + rectangle1 = rectangle1[0], rectangle1[1], rectangle1[0] + rectangle1[2], rectangle1[1] + rectangle1[3] + rectangle2 = rectangle2[0], rectangle2[1], rectangle2[0] + rectangle2[2], rectangle2[1] + rectangle2[3] + x_overlap = max(0., min(rectangle1[2], rectangle2[2]) - max(rectangle1[0], rectangle2[0])) + y_overlap = max(0., min(rectangle1[3], rectangle2[3]) - max(rectangle1[1], rectangle2[1])) + return x_overlap * y_overlap + + +def horizontally_flip_bbox(bbox: BoundingBox) -> BoundingBox: + return 1 - (bbox[0] + bbox[2]), bbox[1], bbox[2], bbox[3] + + +def absolute_bbox(relative_bbox: BoundingBox, width: int, height: int) -> Tuple[int, int, int, int]: + bbox = relative_bbox + bbox = bbox[0] * width, bbox[1] * height, (bbox[0] + bbox[2]) * width, (bbox[1] + bbox[3]) * height + return int(bbox[0]), int(bbox[1]), int(bbox[2]), int(bbox[3]) + + +def pad_list(list_: List, pad_element: Any, pad_to_length: int) -> List: + return list_ + [pad_element for _ in range(pad_to_length - len(list_))] + + +def rescale_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox, flip: bool) -> \ + List[Annotation]: + def clamp(x: float): + return max(min(x, 1.), 0.) + + def rescale_bbox(bbox: BoundingBox) -> BoundingBox: + x0 = clamp((bbox[0] - crop_coordinates[0]) / crop_coordinates[2]) + y0 = clamp((bbox[1] - crop_coordinates[1]) / crop_coordinates[3]) + w = min(bbox[2] / crop_coordinates[2], 1 - x0) + h = min(bbox[3] / crop_coordinates[3], 1 - y0) + if flip: + x0 = 1 - (x0 + w) + return x0, y0, w, h + + return [a._replace(bbox=rescale_bbox(a.bbox)) for a in annotations] + + +def filter_annotations(annotations: List[Annotation], crop_coordinates: BoundingBox) -> List: + return [a for a in annotations if intersection_area(a.bbox, crop_coordinates) > 0.0] + + +def additional_parameters_string(annotation: Annotation, short: bool = True) -> str: + sl = slice(1) if short else slice(None) + string = '' + if not (annotation.is_group_of or annotation.is_occluded or annotation.is_depiction or annotation.is_inside): + return string + if annotation.is_group_of: + string += 'group'[sl] + ',' + if annotation.is_occluded: + string += 'occluded'[sl] + ',' + if annotation.is_depiction: + string += 'depiction'[sl] + ',' + if annotation.is_inside: + string += 'inside'[sl] + return '(' + string.strip(",") + ')' + + +def get_plot_font_size(font_size: Optional[int], figure_size: Tuple[int, int]) -> int: + if font_size is None: + font_size = 10 + if max(figure_size) >= 256: + font_size = 12 + if max(figure_size) >= 512: + font_size = 15 + return font_size + + +def get_circle_size(figure_size: Tuple[int, int]) -> int: + circle_size = 2 + if max(figure_size) >= 256: + circle_size = 3 + if max(figure_size) >= 512: + circle_size = 4 + return circle_size + + +def load_object_from_string(object_string: str) -> Any: + """ + Source: https://stackoverflow.com/a/10773699 + """ + module_name, class_name = object_string.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/custom.py b/custom_controlnet_aux/diffusion_edge/taming/data/custom.py new file mode 100644 index 0000000000000000000000000000000000000000..2f5029298f76ee285821fe8a0d97053eed09d6f9 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/custom.py @@ -0,0 +1,38 @@ +import os +import numpy as np +import custom_albumentations as albumentations +from torch.utils.data import Dataset + +from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class CustomBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + return example + + + +class CustomTrain(CustomBase): + def __init__(self, size, training_images_list_file): + super().__init__() + with open(training_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + +class CustomTest(CustomBase): + def __init__(self, size, test_images_list_file): + super().__init__() + with open(test_images_list_file, "r") as f: + paths = f.read().splitlines() + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + + diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py b/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py new file mode 100644 index 0000000000000000000000000000000000000000..b0b00493407b13c829023a89fedac823b90f5881 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/faceshq.py @@ -0,0 +1,134 @@ +import os +import numpy as np +import custom_albumentations as albumentations +from torch.utils.data import Dataset + +from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex + + +class FacesBase(Dataset): + def __init__(self, *args, **kwargs): + super().__init__() + self.data = None + self.keys = None + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + example = self.data[i] + ex = {} + if self.keys is not None: + for k in self.keys: + ex[k] = example[k] + else: + ex = example + return ex + + +class CelebAHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class CelebAHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/celebahq" + with open("data/celebahqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = NumpyPaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQTrain(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqtrain.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FFHQValidation(FacesBase): + def __init__(self, size, keys=None): + super().__init__() + root = "data/ffhq" + with open("data/ffhqvalidation.txt", "r") as f: + relpaths = f.read().splitlines() + paths = [os.path.join(root, relpath) for relpath in relpaths] + self.data = ImagePaths(paths=paths, size=size, random_crop=False) + self.keys = keys + + +class FacesHQTrain(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQTrain(size=size, keys=keys) + d2 = FFHQTrain(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex + + +class FacesHQValidation(Dataset): + # CelebAHQ [0] + FFHQ [1] + def __init__(self, size, keys=None, crop_size=None, coord=False): + d1 = CelebAHQValidation(size=size, keys=keys) + d2 = FFHQValidation(size=size, keys=keys) + self.data = ConcatDatasetWithIndex([d1, d2]) + self.coord = coord + if crop_size is not None: + self.cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + if self.coord: + self.cropper = albumentations.Compose([self.cropper], + additional_targets={"coord": "image"}) + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + ex, y = self.data[i] + if hasattr(self, "cropper"): + if not self.coord: + out = self.cropper(image=ex["image"]) + ex["image"] = out["image"] + else: + h,w,_ = ex["image"].shape + coord = np.arange(h*w).reshape(h,w,1)/(h*w) + out = self.cropper(image=ex["image"], coord=coord) + ex["image"] = out["image"] + ex["coord"] = out["coord"] + ex["class"] = y + return ex diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py b/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py new file mode 100644 index 0000000000000000000000000000000000000000..fb51e301da08602cfead5961c4f7e1d89f6aba79 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/helper_types.py @@ -0,0 +1,49 @@ +from typing import Dict, Tuple, Optional, NamedTuple, Union +from PIL.Image import Image as pil_image +from torch import Tensor + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal + +Image = Union[Tensor, pil_image] +BoundingBox = Tuple[float, float, float, float] # x0, y0, w, h +CropMethodType = Literal['none', 'random', 'center', 'random-2d'] +SplitType = Literal['train', 'validation', 'test'] + + +class ImageDescription(NamedTuple): + id: int + file_name: str + original_size: Tuple[int, int] # w, h + url: Optional[str] = None + license: Optional[int] = None + coco_url: Optional[str] = None + date_captured: Optional[str] = None + flickr_url: Optional[str] = None + flickr_id: Optional[str] = None + coco_id: Optional[str] = None + + +class Category(NamedTuple): + id: str + super_category: Optional[str] + name: str + + +class Annotation(NamedTuple): + area: float + image_id: str + bbox: BoundingBox + category_no: int + category_id: str + id: Optional[int] = None + source: Optional[str] = None + confidence: Optional[float] = None + is_group_of: Optional[bool] = None + is_truncated: Optional[bool] = None + is_occluded: Optional[bool] = None + is_depiction: Optional[bool] = None + is_inside: Optional[bool] = None + segmentation: Optional[Dict] = None diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py b/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d1289f06c7df1a4806eaa2c58fc02e5bdcbc62f8 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/image_transforms.py @@ -0,0 +1,132 @@ +import random +import warnings +from typing import Union + +import torch +from torch import Tensor +from torchvision.transforms import RandomCrop, functional as F, CenterCrop, RandomHorizontalFlip, PILToTensor +from torchvision.transforms.functional import _get_image_size as get_image_size + +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import BoundingBox, Image + +pil_to_tensor = PILToTensor() + + +def convert_pil_to_tensor(image: Image) -> Tensor: + with warnings.catch_warnings(): + # to filter PyTorch UserWarning as described here: https://github.com/pytorch/vision/issues/2194 + warnings.simplefilter("ignore") + return pil_to_tensor(image) + + +class RandomCrop1dReturnCoordinates(RandomCrop): + def forward(self, img: Image) -> (BoundingBox, Image): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + if self.padding is not None: + img = F.pad(img, self.padding, self.fill, self.padding_mode) + + width, height = get_image_size(img) + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + img = F.pad(img, padding, self.fill, self.padding_mode) + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + img = F.pad(img, padding, self.fill, self.padding_mode) + + i, j, h, w = self.get_params(img, self.size) + bbox = (j / width, i / height, w / width, h / height) # x0, y0, w, h + return bbox, F.crop(img, i, j, h, w) + + +class Random2dCropReturnCoordinates(torch.nn.Module): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + + Based on: + torchvision.transforms.RandomCrop, torchvision 1.7.0 + """ + + def __init__(self, min_size: int): + super().__init__() + self.min_size = min_size + + def forward(self, img: Image) -> (BoundingBox, Image): + width, height = get_image_size(img) + max_size = min(width, height) + if max_size <= self.min_size: + size = max_size + else: + size = random.randint(self.min_size, max_size) + top = random.randint(0, height - size) + left = random.randint(0, width - size) + bbox = left / width, top / height, size / width, size / height + return bbox, F.crop(img, top, left, size, size) + + +class CenterCropReturnCoordinates(CenterCrop): + @staticmethod + def get_bbox_of_center_crop(width: int, height: int) -> BoundingBox: + if width > height: + w = height / width + h = 1.0 + x0 = 0.5 - w / 2 + y0 = 0. + else: + w = 1.0 + h = width / height + x0 = 0. + y0 = 0.5 - h / 2 + return x0, y0, w, h + + def forward(self, img: Union[Image, Tensor]) -> (BoundingBox, Union[Image, Tensor]): + """ + Additionally to cropping, returns the relative coordinates of the crop bounding box. + Args: + img (PIL Image or Tensor): Image to be cropped. + + Returns: + Bounding box: x0, y0, w, h + PIL Image or Tensor: Cropped image. + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + width, height = get_image_size(img) + return self.get_bbox_of_center_crop(width, height), F.center_crop(img, self.size) + + +class RandomHorizontalFlipReturn(RandomHorizontalFlip): + def forward(self, img: Image) -> (bool, Image): + """ + Additionally to flipping, returns a boolean whether it was flipped or not. + Args: + img (PIL Image or Tensor): Image to be flipped. + + Returns: + flipped: whether the image was flipped or not + PIL Image or Tensor: Randomly flipped image. + + Based on: + torchvision.transforms.RandomHorizontalFlip (version 1.7.0) + """ + if torch.rand(1) < self.p: + return True, F.hflip(img) + return False, img diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py b/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py new file mode 100644 index 0000000000000000000000000000000000000000..275d593ae35caf2b528e9326968e1389f463f7cd --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/imagenet.py @@ -0,0 +1,558 @@ +import os, tarfile, glob, shutil +import yaml +import numpy as np +from tqdm import tqdm +from PIL import Image +import custom_albumentations as albumentations +from omegaconf import OmegaConf +from torch.utils.data import Dataset + +from custom_controlnet_aux.diffusion_edge.taming.data.base import ImagePaths +from custom_controlnet_aux.diffusion_edge.taming.util import download, retrieve +import taming.data.utils as bdu + + +def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"): + synsets = [] + with open(path_to_yaml) as f: + di2s = yaml.load(f) + for idx in indices: + synsets.append(str(di2s[idx])) + print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets))) + return synsets + + +def str_to_indices(string): + """Expects a string in the format '32-123, 256, 280-321'""" + assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string) + subs = string.split(",") + indices = [] + for sub in subs: + subsubs = sub.split("-") + assert len(subsubs) > 0 + if len(subsubs) == 1: + indices.append(int(subsubs[0])) + else: + rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))] + indices.extend(rang) + return sorted(indices) + + +class ImageNetBase(Dataset): + def __init__(self, config=None): + self.config = config or OmegaConf.create() + if not type(self.config)==dict: + self.config = OmegaConf.to_container(self.config) + self._prepare() + self._prepare_synset_to_human() + self._prepare_idx_to_synset() + self._load() + + def __len__(self): + return len(self.data) + + def __getitem__(self, i): + return self.data[i] + + def _prepare(self): + raise NotImplementedError() + + def _filter_relpaths(self, relpaths): + ignore = set([ + "n06596364_9591.JPEG", + ]) + relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore] + if "sub_indices" in self.config: + indices = str_to_indices(self.config["sub_indices"]) + synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings + files = [] + for rpath in relpaths: + syn = rpath.split("/")[0] + if syn in synsets: + files.append(rpath) + return files + else: + return relpaths + + def _prepare_synset_to_human(self): + SIZE = 2655750 + URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1" + self.human_dict = os.path.join(self.root, "synset_human.txt") + if (not os.path.exists(self.human_dict) or + not os.path.getsize(self.human_dict)==SIZE): + download(URL, self.human_dict) + + def _prepare_idx_to_synset(self): + URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1" + self.idx2syn = os.path.join(self.root, "index_synset.yaml") + if (not os.path.exists(self.idx2syn)): + download(URL, self.idx2syn) + + def _load(self): + with open(self.txt_filelist, "r") as f: + self.relpaths = f.read().splitlines() + l1 = len(self.relpaths) + self.relpaths = self._filter_relpaths(self.relpaths) + print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths))) + + self.synsets = [p.split("/")[0] for p in self.relpaths] + self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths] + + unique_synsets = np.unique(self.synsets) + class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets)) + self.class_labels = [class_dict[s] for s in self.synsets] + + with open(self.human_dict, "r") as f: + human_dict = f.read().splitlines() + human_dict = dict(line.split(maxsplit=1) for line in human_dict) + + self.human_labels = [human_dict[s] for s in self.synsets] + + labels = { + "relpath": np.array(self.relpaths), + "synsets": np.array(self.synsets), + "class_label": np.array(self.class_labels), + "human_label": np.array(self.human_labels), + } + self.data = ImagePaths(self.abspaths, + labels=labels, + size=retrieve(self.config, "size", default=0), + random_crop=self.random_crop) + + +class ImageNetTrain(ImageNetBase): + NAME = "ILSVRC2012_train" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2" + FILES = [ + "ILSVRC2012_img_train.tar", + ] + SIZES = [ + 147897477120, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop", + default=True) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 1281167 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + print("Extracting sub-tars.") + subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar"))) + for subpath in tqdm(subpaths): + subdir = subpath[:-len(".tar")] + os.makedirs(subdir, exist_ok=True) + with tarfile.open(subpath, "r:") as tar: + tar.extractall(path=subdir) + + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +class ImageNetValidation(ImageNetBase): + NAME = "ILSVRC2012_validation" + URL = "http://www.image-net.org/challenges/LSVRC/2012/" + AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5" + VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1" + FILES = [ + "ILSVRC2012_img_val.tar", + "validation_synset.txt", + ] + SIZES = [ + 6744924160, + 1950000, + ] + + def _prepare(self): + self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop", + default=False) + cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache")) + self.root = os.path.join(cachedir, "autoencoders/data", self.NAME) + self.datadir = os.path.join(self.root, "data") + self.txt_filelist = os.path.join(self.root, "filelist.txt") + self.expected_length = 50000 + if not bdu.is_prepared(self.root): + # prep + print("Preparing dataset {} in {}".format(self.NAME, self.root)) + + datadir = self.datadir + if not os.path.exists(datadir): + path = os.path.join(self.root, self.FILES[0]) + if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]: + import academictorrents as at + atpath = at.get(self.AT_HASH, datastore=self.root) + assert atpath == path + + print("Extracting {} to {}".format(path, datadir)) + os.makedirs(datadir, exist_ok=True) + with tarfile.open(path, "r:") as tar: + tar.extractall(path=datadir) + + vspath = os.path.join(self.root, self.FILES[1]) + if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]: + download(self.VS_URL, vspath) + + with open(vspath, "r") as f: + synset_dict = f.read().splitlines() + synset_dict = dict(line.split() for line in synset_dict) + + print("Reorganizing into synset folders") + synsets = np.unique(list(synset_dict.values())) + for s in synsets: + os.makedirs(os.path.join(datadir, s), exist_ok=True) + for k, v in synset_dict.items(): + src = os.path.join(datadir, k) + dst = os.path.join(datadir, v) + shutil.move(src, dst) + + filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG")) + filelist = [os.path.relpath(p, start=datadir) for p in filelist] + filelist = sorted(filelist) + filelist = "\n".join(filelist)+"\n" + with open(self.txt_filelist, "w") as f: + f.write(filelist) + + bdu.mark_prepared(self.root) + + +def get_preprocessor(size=None, random_crop=False, additional_targets=None, + crop_size=None): + if size is not None and size > 0: + transforms = list() + rescaler = albumentations.SmallestMaxSize(max_size = size) + transforms.append(rescaler) + if not random_crop: + cropper = albumentations.CenterCrop(height=size,width=size) + transforms.append(cropper) + else: + cropper = albumentations.RandomCrop(height=size,width=size) + transforms.append(cropper) + flipper = albumentations.HorizontalFlip() + transforms.append(flipper) + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + elif crop_size is not None and crop_size > 0: + if not random_crop: + cropper = albumentations.CenterCrop(height=crop_size,width=crop_size) + else: + cropper = albumentations.RandomCrop(height=crop_size,width=crop_size) + transforms = [cropper] + preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + preprocessor = lambda **kwargs: kwargs + return preprocessor + + +def rgba_to_depth(x): + assert x.dtype == np.uint8 + assert len(x.shape) == 3 and x.shape[2] == 4 + y = x.copy() + y.dtype = np.float32 + y = y.reshape(x.shape[:2]) + return np.ascontiguousarray(y) + + +class BaseWithDepth(Dataset): + DEFAULT_DEPTH_ROOT="data/imagenet_depth" + + def __init__(self, config=None, size=None, random_crop=False, + crop_size=None, root=None): + self.config = config + self.base_dset = self.get_base_dset() + self.preprocessor = get_preprocessor( + size=size, + crop_size=crop_size, + random_crop=random_crop, + additional_targets={"depth": "image"}) + self.crop_size = crop_size + if self.crop_size is not None: + self.rescaler = albumentations.Compose( + [albumentations.SmallestMaxSize(max_size = self.crop_size)], + additional_targets={"depth": "image"}) + if root is not None: + self.DEFAULT_DEPTH_ROOT = root + + def __len__(self): + return len(self.base_dset) + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = self.base_dset[i] + e["depth"] = self.preprocess_depth(self.get_depth_path(e)) + # up if necessary + h,w,c = e["image"].shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + out = self.rescaler(image=e["image"], depth=e["depth"]) + e["image"] = out["image"] + e["depth"] = out["depth"] + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +class ImageNetTrainWithDepth(BaseWithDepth): + # default to random_crop=True + def __init__(self, random_crop=True, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(random_crop=random_crop, **kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetTrain() + else: + return ImageNetTrain({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid) + return fid + + +class ImageNetValidationWithDepth(BaseWithDepth): + def __init__(self, sub_indices=None, **kwargs): + self.sub_indices = sub_indices + super().__init__(**kwargs) + + def get_base_dset(self): + if self.sub_indices is None: + return ImageNetValidation() + else: + return ImageNetValidation({"sub_indices": self.sub_indices}) + + def get_depth_path(self, e): + fid = os.path.splitext(e["relpath"])[0]+".png" + fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid) + return fid + + +class RINTrainWithDepth(ImageNetTrainWithDepth): + def __init__(self, config=None, size=None, random_crop=True, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class RINValidationWithDepth(ImageNetValidationWithDepth): + def __init__(self, config=None, size=None, random_crop=False, crop_size=None): + sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319" + super().__init__(config=config, size=size, random_crop=random_crop, + sub_indices=sub_indices, crop_size=crop_size) + + +class DRINExamples(Dataset): + def __init__(self): + self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"}) + with open("data/drin_examples.txt", "r") as f: + relpaths = f.read().splitlines() + self.image_paths = [os.path.join("data/drin_images", + relpath) for relpath in relpaths] + self.depth_paths = [os.path.join("data/drin_depth", + relpath.replace(".JPEG", ".png")) for relpath in relpaths] + + def __len__(self): + return len(self.image_paths) + + def preprocess_image(self, image_path): + image = Image.open(image_path) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + image = self.preprocessor(image=image)["image"] + image = (image/127.5 - 1.0).astype(np.float32) + return image + + def preprocess_depth(self, path): + rgba = np.array(Image.open(path)) + depth = rgba_to_depth(rgba) + depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min()) + depth = 2.0*depth-1.0 + return depth + + def __getitem__(self, i): + e = dict() + e["image"] = self.preprocess_image(self.image_paths[i]) + e["depth"] = self.preprocess_depth(self.depth_paths[i]) + transformed = self.preprocessor(image=e["image"], depth=e["depth"]) + e["image"] = transformed["image"] + e["depth"] = transformed["depth"] + return e + + +def imscale(x, factor, keepshapes=False, keepmode="bicubic"): + if factor is None or factor==1: + return x + + dtype = x.dtype + assert dtype in [np.float32, np.float64] + assert x.min() >= -1 + assert x.max() <= 1 + + keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR, + "bicubic": Image.BICUBIC}[keepmode] + + lr = (x+1.0)*127.5 + lr = lr.clip(0,255).astype(np.uint8) + lr = Image.fromarray(lr) + + h, w, _ = x.shape + nh = h//factor + nw = w//factor + assert nh > 0 and nw > 0, (nh, nw) + + lr = lr.resize((nw,nh), Image.BICUBIC) + if keepshapes: + lr = lr.resize((w,h), keepmode) + lr = np.array(lr)/127.5-1.0 + lr = lr.astype(dtype) + + return lr + + +class ImageNetScale(Dataset): + def __init__(self, size=None, crop_size=None, random_crop=False, + up_factor=None, hr_factor=None, keep_mode="bicubic"): + self.base = self.get_base() + + self.size = size + self.crop_size = crop_size if crop_size is not None else self.size + self.random_crop = random_crop + self.up_factor = up_factor + self.hr_factor = hr_factor + self.keep_mode = keep_mode + + transforms = list() + + if self.size is not None and self.size > 0: + rescaler = albumentations.SmallestMaxSize(max_size = self.size) + self.rescaler = rescaler + transforms.append(rescaler) + + if self.crop_size is not None and self.crop_size > 0: + if len(transforms) == 0: + self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size) + + if not self.random_crop: + cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size) + else: + cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size) + transforms.append(cropper) + + if len(transforms) > 0: + if self.up_factor is not None: + additional_targets = {"lr": "image"} + else: + additional_targets = None + self.preprocessor = albumentations.Compose(transforms, + additional_targets=additional_targets) + else: + self.preprocessor = lambda **kwargs: kwargs + + def __len__(self): + return len(self.base) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + # adjust resolution + image = imscale(image, self.hr_factor, keepshapes=False) + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + if self.up_factor is None: + image = self.preprocessor(image=image)["image"] + example["image"] = image + else: + lr = imscale(image, self.up_factor, keepshapes=True, + keepmode=self.keep_mode) + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + +class ImageNetScaleTrain(ImageNetScale): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetScaleValidation(ImageNetScale): + def get_base(self): + return ImageNetValidation() + + +from skimage.feature import canny +from skimage.color import rgb2gray + + +class ImageNetEdges(ImageNetScale): + def __init__(self, up_factor=1, **kwargs): + super().__init__(up_factor=1, **kwargs) + + def __getitem__(self, i): + example = self.base[i] + image = example["image"] + h,w,c = image.shape + if self.crop_size and min(h,w) < self.crop_size: + # have to upscale to be able to crop - this just uses bilinear + image = self.rescaler(image=image)["image"] + + lr = canny(rgb2gray(image), sigma=2) + lr = lr.astype(np.float32) + lr = lr[:,:,None][:,:,[0,0,0]] + + out = self.preprocessor(image=image, lr=lr) + example["image"] = out["image"] + example["lr"] = out["lr"] + + return example + + +class ImageNetEdgesTrain(ImageNetEdges): + def __init__(self, random_crop=True, **kwargs): + super().__init__(random_crop=random_crop, **kwargs) + + def get_base(self): + return ImageNetTrain() + +class ImageNetEdgesValidation(ImageNetEdges): + def get_base(self): + return ImageNetValidation() diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py b/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..8feb7c6e705fc165d2983303192aaa88f579b243 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/open_images_helper.py @@ -0,0 +1,379 @@ +open_images_unify_categories_for_coco = { + '/m/03bt1vf': '/m/01g317', + '/m/04yx4': '/m/01g317', + '/m/05r655': '/m/01g317', + '/m/01bl7v': '/m/01g317', + '/m/0cnyhnx': '/m/01xq0k1', + '/m/01226z': '/m/018xm', + '/m/05ctyq': '/m/018xm', + '/m/058qzx': '/m/04ctx', + '/m/06pcq': '/m/0l515', + '/m/03m3pdh': '/m/02crq1', + '/m/046dlr': '/m/01x3z', + '/m/0h8mzrc': '/m/01x3z', +} + + +top_300_classes_plus_coco_compatibility = [ + ('Man', 1060962), + ('Clothing', 986610), + ('Tree', 748162), + ('Woman', 611896), + ('Person', 610294), + ('Human face', 442948), + ('Girl', 175399), + ('Building', 162147), + ('Car', 159135), + ('Plant', 155704), + ('Human body', 137073), + ('Flower', 133128), + ('Window', 127485), + ('Human arm', 118380), + ('House', 114365), + ('Wheel', 111684), + ('Suit', 99054), + ('Human hair', 98089), + ('Human head', 92763), + ('Chair', 88624), + ('Boy', 79849), + ('Table', 73699), + ('Jeans', 57200), + ('Tire', 55725), + ('Skyscraper', 53321), + ('Food', 52400), + ('Footwear', 50335), + ('Dress', 50236), + ('Human leg', 47124), + ('Toy', 46636), + ('Tower', 45605), + ('Boat', 43486), + ('Land vehicle', 40541), + ('Bicycle wheel', 34646), + ('Palm tree', 33729), + ('Fashion accessory', 32914), + ('Glasses', 31940), + ('Bicycle', 31409), + ('Furniture', 30656), + ('Sculpture', 29643), + ('Bottle', 27558), + ('Dog', 26980), + ('Snack', 26796), + ('Human hand', 26664), + ('Bird', 25791), + ('Book', 25415), + ('Guitar', 24386), + ('Jacket', 23998), + ('Poster', 22192), + ('Dessert', 21284), + ('Baked goods', 20657), + ('Drink', 19754), + ('Flag', 18588), + ('Houseplant', 18205), + ('Tableware', 17613), + ('Airplane', 17218), + ('Door', 17195), + ('Sports uniform', 17068), + ('Shelf', 16865), + ('Drum', 16612), + ('Vehicle', 16542), + ('Microphone', 15269), + ('Street light', 14957), + ('Cat', 14879), + ('Fruit', 13684), + ('Fast food', 13536), + ('Animal', 12932), + ('Vegetable', 12534), + ('Train', 12358), + ('Horse', 11948), + ('Flowerpot', 11728), + ('Motorcycle', 11621), + ('Fish', 11517), + ('Desk', 11405), + ('Helmet', 10996), + ('Truck', 10915), + ('Bus', 10695), + ('Hat', 10532), + ('Auto part', 10488), + ('Musical instrument', 10303), + ('Sunglasses', 10207), + ('Picture frame', 10096), + ('Sports equipment', 10015), + ('Shorts', 9999), + ('Wine glass', 9632), + ('Duck', 9242), + ('Wine', 9032), + ('Rose', 8781), + ('Tie', 8693), + ('Butterfly', 8436), + ('Beer', 7978), + ('Cabinetry', 7956), + ('Laptop', 7907), + ('Insect', 7497), + ('Goggles', 7363), + ('Shirt', 7098), + ('Dairy Product', 7021), + ('Marine invertebrates', 7014), + ('Cattle', 7006), + ('Trousers', 6903), + ('Van', 6843), + ('Billboard', 6777), + ('Balloon', 6367), + ('Human nose', 6103), + ('Tent', 6073), + ('Camera', 6014), + ('Doll', 6002), + ('Coat', 5951), + ('Mobile phone', 5758), + ('Swimwear', 5729), + ('Strawberry', 5691), + ('Stairs', 5643), + ('Goose', 5599), + ('Umbrella', 5536), + ('Cake', 5508), + ('Sun hat', 5475), + ('Bench', 5310), + ('Bookcase', 5163), + ('Bee', 5140), + ('Computer monitor', 5078), + ('Hiking equipment', 4983), + ('Office building', 4981), + ('Coffee cup', 4748), + ('Curtain', 4685), + ('Plate', 4651), + ('Box', 4621), + ('Tomato', 4595), + ('Coffee table', 4529), + ('Office supplies', 4473), + ('Maple', 4416), + ('Muffin', 4365), + ('Cocktail', 4234), + ('Castle', 4197), + ('Couch', 4134), + ('Pumpkin', 3983), + ('Computer keyboard', 3960), + ('Human mouth', 3926), + ('Christmas tree', 3893), + ('Mushroom', 3883), + ('Swimming pool', 3809), + ('Pastry', 3799), + ('Lavender (Plant)', 3769), + ('Football helmet', 3732), + ('Bread', 3648), + ('Traffic sign', 3628), + ('Common sunflower', 3597), + ('Television', 3550), + ('Bed', 3525), + ('Cookie', 3485), + ('Fountain', 3484), + ('Paddle', 3447), + ('Bicycle helmet', 3429), + ('Porch', 3420), + ('Deer', 3387), + ('Fedora', 3339), + ('Canoe', 3338), + ('Carnivore', 3266), + ('Bowl', 3202), + ('Human eye', 3166), + ('Ball', 3118), + ('Pillow', 3077), + ('Salad', 3061), + ('Beetle', 3060), + ('Orange', 3050), + ('Drawer', 2958), + ('Platter', 2937), + ('Elephant', 2921), + ('Seafood', 2921), + ('Monkey', 2915), + ('Countertop', 2879), + ('Watercraft', 2831), + ('Helicopter', 2805), + ('Kitchen appliance', 2797), + ('Personal flotation device', 2781), + ('Swan', 2739), + ('Lamp', 2711), + ('Boot', 2695), + ('Bronze sculpture', 2693), + ('Chicken', 2677), + ('Taxi', 2643), + ('Juice', 2615), + ('Cowboy hat', 2604), + ('Apple', 2600), + ('Tin can', 2590), + ('Necklace', 2564), + ('Ice cream', 2560), + ('Human beard', 2539), + ('Coin', 2536), + ('Candle', 2515), + ('Cart', 2512), + ('High heels', 2441), + ('Weapon', 2433), + ('Handbag', 2406), + ('Penguin', 2396), + ('Rifle', 2352), + ('Violin', 2336), + ('Skull', 2304), + ('Lantern', 2285), + ('Scarf', 2269), + ('Saucer', 2225), + ('Sheep', 2215), + ('Vase', 2189), + ('Lily', 2180), + ('Mug', 2154), + ('Parrot', 2140), + ('Human ear', 2137), + ('Sandal', 2115), + ('Lizard', 2100), + ('Kitchen & dining room table', 2063), + ('Spider', 1977), + ('Coffee', 1974), + ('Goat', 1926), + ('Squirrel', 1922), + ('Cello', 1913), + ('Sushi', 1881), + ('Tortoise', 1876), + ('Pizza', 1870), + ('Studio couch', 1864), + ('Barrel', 1862), + ('Cosmetics', 1841), + ('Moths and butterflies', 1841), + ('Convenience store', 1817), + ('Watch', 1792), + ('Home appliance', 1786), + ('Harbor seal', 1780), + ('Luggage and bags', 1756), + ('Vehicle registration plate', 1754), + ('Shrimp', 1751), + ('Jellyfish', 1730), + ('French fries', 1723), + ('Egg (Food)', 1698), + ('Football', 1697), + ('Musical keyboard', 1683), + ('Falcon', 1674), + ('Candy', 1660), + ('Medical equipment', 1654), + ('Eagle', 1651), + ('Dinosaur', 1634), + ('Surfboard', 1630), + ('Tank', 1628), + ('Grape', 1624), + ('Lion', 1624), + ('Owl', 1622), + ('Ski', 1613), + ('Waste container', 1606), + ('Frog', 1591), + ('Sparrow', 1585), + ('Rabbit', 1581), + ('Pen', 1546), + ('Sea lion', 1537), + ('Spoon', 1521), + ('Sink', 1512), + ('Teddy bear', 1507), + ('Bull', 1495), + ('Sofa bed', 1490), + ('Dragonfly', 1479), + ('Brassiere', 1478), + ('Chest of drawers', 1472), + ('Aircraft', 1466), + ('Human foot', 1463), + ('Pig', 1455), + ('Fork', 1454), + ('Antelope', 1438), + ('Tripod', 1427), + ('Tool', 1424), + ('Cheese', 1422), + ('Lemon', 1397), + ('Hamburger', 1393), + ('Dolphin', 1390), + ('Mirror', 1390), + ('Marine mammal', 1387), + ('Giraffe', 1385), + ('Snake', 1368), + ('Gondola', 1364), + ('Wheelchair', 1360), + ('Piano', 1358), + ('Cupboard', 1348), + ('Banana', 1345), + ('Trumpet', 1335), + ('Lighthouse', 1333), + ('Invertebrate', 1317), + ('Carrot', 1268), + ('Sock', 1260), + ('Tiger', 1241), + ('Camel', 1224), + ('Parachute', 1224), + ('Bathroom accessory', 1223), + ('Earrings', 1221), + ('Headphones', 1218), + ('Skirt', 1198), + ('Skateboard', 1190), + ('Sandwich', 1148), + ('Saxophone', 1141), + ('Goldfish', 1136), + ('Stool', 1104), + ('Traffic light', 1097), + ('Shellfish', 1081), + ('Backpack', 1079), + ('Sea turtle', 1078), + ('Cucumber', 1075), + ('Tea', 1051), + ('Toilet', 1047), + ('Roller skates', 1040), + ('Mule', 1039), + ('Bust', 1031), + ('Broccoli', 1030), + ('Crab', 1020), + ('Oyster', 1019), + ('Cannon', 1012), + ('Zebra', 1012), + ('French horn', 1008), + ('Grapefruit', 998), + ('Whiteboard', 997), + ('Zucchini', 997), + ('Crocodile', 992), + + ('Clock', 960), + ('Wall clock', 958), + + ('Doughnut', 869), + ('Snail', 868), + + ('Baseball glove', 859), + + ('Panda', 830), + ('Tennis racket', 830), + + ('Pear', 652), + + ('Bagel', 617), + ('Oven', 616), + ('Ladybug', 615), + ('Shark', 615), + ('Polar bear', 614), + ('Ostrich', 609), + + ('Hot dog', 473), + ('Microwave oven', 467), + ('Fire hydrant', 20), + ('Stop sign', 20), + ('Parking meter', 20), + ('Bear', 20), + ('Flying disc', 20), + ('Snowboard', 20), + ('Tennis ball', 20), + ('Kite', 20), + ('Baseball bat', 20), + ('Kitchen knife', 20), + ('Knife', 20), + ('Submarine sandwich', 20), + ('Computer mouse', 20), + ('Remote control', 20), + ('Toaster', 20), + ('Sink', 20), + ('Refrigerator', 20), + ('Alarm clock', 20), + ('Wall clock', 20), + ('Scissors', 20), + ('Hair dryer', 20), + ('Toothbrush', 20), + ('Suitcase', 20) +] diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py b/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py new file mode 100644 index 0000000000000000000000000000000000000000..c6bfea446bedab684e139df45252dc01734c6e9b --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/sflckr.py @@ -0,0 +1,91 @@ +import os +import numpy as np +import cv2 +import custom_albumentations as albumentations +from PIL import Image +from torch.utils.data import Dataset + + +class SegmentationBase(Dataset): + def __init__(self, + data_csv, data_root, segmentation_root, + size=None, random_crop=False, interpolation="bicubic", + n_labels=182, shift_segmentation=False, + ): + self.n_labels = n_labels + self.shift_segmentation = shift_segmentation + self.data_csv = data_csv + self.data_root = data_root + self.segmentation_root = segmentation_root + with open(self.data_csv, "r") as f: + self.image_paths = f.read().splitlines() + self._length = len(self.image_paths) + self.labels = { + "relative_file_path_": [l for l in self.image_paths], + "file_path_": [os.path.join(self.data_root, l) + for l in self.image_paths], + "segmentation_path_": [os.path.join(self.segmentation_root, l.replace(".jpg", ".png")) + for l in self.image_paths] + } + + size = None if size is not None and size<=0 else size + self.size = size + if self.size is not None: + self.interpolation = interpolation + self.interpolation = { + "nearest": cv2.INTER_NEAREST, + "bilinear": cv2.INTER_LINEAR, + "bicubic": cv2.INTER_CUBIC, + "area": cv2.INTER_AREA, + "lanczos": cv2.INTER_LANCZOS4}[self.interpolation] + self.image_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=self.interpolation) + self.segmentation_rescaler = albumentations.SmallestMaxSize(max_size=self.size, + interpolation=cv2.INTER_NEAREST) + self.center_crop = not random_crop + if self.center_crop: + self.cropper = albumentations.CenterCrop(height=self.size, width=self.size) + else: + self.cropper = albumentations.RandomCrop(height=self.size, width=self.size) + self.preprocessor = self.cropper + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = dict((k, self.labels[k][i]) for k in self.labels) + image = Image.open(example["file_path_"]) + if not image.mode == "RGB": + image = image.convert("RGB") + image = np.array(image).astype(np.uint8) + if self.size is not None: + image = self.image_rescaler(image=image)["image"] + segmentation = Image.open(example["segmentation_path_"]) + assert segmentation.mode == "L", segmentation.mode + segmentation = np.array(segmentation).astype(np.uint8) + if self.shift_segmentation: + # used to support segmentations containing unlabeled==255 label + segmentation = segmentation+1 + if self.size is not None: + segmentation = self.segmentation_rescaler(image=segmentation)["image"] + if self.size is not None: + processed = self.preprocessor(image=image, + mask=segmentation + ) + else: + processed = {"image": image, + "mask": segmentation + } + example["image"] = (processed["image"]/127.5 - 1.0).astype(np.float32) + segmentation = processed["mask"] + onehot = np.eye(self.n_labels)[segmentation] + example["segmentation"] = onehot + return example + + +class Examples(SegmentationBase): + def __init__(self, size=None, random_crop=False, interpolation="bicubic"): + super().__init__(data_csv="data/sflckr_examples.txt", + data_root="data/sflckr_images", + segmentation_root="data/sflckr_segmentations", + size=size, random_crop=random_crop, interpolation=interpolation) diff --git a/custom_controlnet_aux/diffusion_edge/taming/data/utils.py b/custom_controlnet_aux/diffusion_edge/taming/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8be640a37f4fcba5d7bb863b8913ba829ee09a85 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/data/utils.py @@ -0,0 +1,169 @@ +import collections +import os +import tarfile +import urllib +import zipfile +from pathlib import Path + +import numpy as np +import torch +from custom_controlnet_aux.diffusion_edge.taming.data.helper_types import Annotation +from torch._six import string_classes +from torch.utils.data._utils.collate import np_str_obj_array_pattern, default_collate_err_msg_format +from tqdm import tqdm + + +def unpack(path): + if path.endswith("tar.gz"): + with tarfile.open(path, "r:gz") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("tar"): + with tarfile.open(path, "r:") as tar: + tar.extractall(path=os.path.split(path)[0]) + elif path.endswith("zip"): + with zipfile.ZipFile(path, "r") as f: + f.extractall(path=os.path.split(path)[0]) + else: + raise NotImplementedError( + "Unknown file extension: {}".format(os.path.splitext(path)[1]) + ) + + +def reporthook(bar): + """tqdm progress bar for downloads.""" + + def hook(b=1, bsize=1, tsize=None): + if tsize is not None: + bar.total = tsize + bar.update(b * bsize - bar.n) + + return hook + + +def get_root(name): + base = "data/" + root = os.path.join(base, name) + os.makedirs(root, exist_ok=True) + return root + + +def is_prepared(root): + return Path(root).joinpath(".ready").exists() + + +def mark_prepared(root): + Path(root).joinpath(".ready").touch() + + +def prompt_download(file_, source, target_dir, content_dir=None): + targetpath = os.path.join(target_dir, file_) + while not os.path.exists(targetpath): + if content_dir is not None and os.path.exists( + os.path.join(target_dir, content_dir) + ): + break + print( + "Please download '{}' from '{}' to '{}'.".format(file_, source, targetpath) + ) + if content_dir is not None: + print( + "Or place its content into '{}'.".format( + os.path.join(target_dir, content_dir) + ) + ) + input("Press Enter when done...") + return targetpath + + +def download_url(file_, url, target_dir): + targetpath = os.path.join(target_dir, file_) + os.makedirs(target_dir, exist_ok=True) + with tqdm( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=file_ + ) as bar: + urllib.request.urlretrieve(url, targetpath, reporthook=reporthook(bar)) + return targetpath + + +def download_urls(urls, target_dir): + paths = dict() + for fname, url in urls.items(): + outpath = download_url(fname, url, target_dir) + paths[fname] = outpath + return paths + + +def quadratic_crop(x, bbox, alpha=1.0): + """bbox is xmin, ymin, xmax, ymax""" + im_h, im_w = x.shape[:2] + bbox = np.array(bbox, dtype=np.float32) + bbox = np.clip(bbox, 0, max(im_h, im_w)) + center = 0.5 * (bbox[0] + bbox[2]), 0.5 * (bbox[1] + bbox[3]) + w = bbox[2] - bbox[0] + h = bbox[3] - bbox[1] + l = int(alpha * max(w, h)) + l = max(l, 2) + + required_padding = -1 * min( + center[0] - l, center[1] - l, im_w - (center[0] + l), im_h - (center[1] + l) + ) + required_padding = int(np.ceil(required_padding)) + if required_padding > 0: + padding = [ + [required_padding, required_padding], + [required_padding, required_padding], + ] + padding += [[0, 0]] * (len(x.shape) - 2) + x = np.pad(x, padding, "reflect") + center = center[0] + required_padding, center[1] + required_padding + xmin = int(center[0] - l / 2) + ymin = int(center[1] - l / 2) + return np.array(x[ymin : ymin + l, xmin : xmin + l, ...]) + + +def custom_collate(batch): + r"""source: pytorch 1.9.0, only one modification to original code """ + + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + storage = elem.storage()._new_shared(numel) + out = elem.new(storage) + return torch.stack(batch, 0, out=out) + elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ + and elem_type.__name__ != 'string_': + if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + + return custom_collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: custom_collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple + return elem_type(*(custom_collate(samples) for samples in zip(*batch))) + if isinstance(elem, collections.abc.Sequence) and isinstance(elem[0], Annotation): # added + return batch # added + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if not all(len(elem) == elem_size for elem in it): + raise RuntimeError('each element in list of batch should be of equal size') + transposed = zip(*batch) + return [custom_collate(samples) for samples in transposed] + + raise TypeError(default_collate_err_msg_format.format(elem_type)) diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py b/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a5db6aa2ef915e270f1ae135e4a9918fdd884c --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/diffusionmodules/model.py @@ -0,0 +1,776 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0,1,0,0)) + return emb + + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Model(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, use_timestep=True): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, t=None): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x): + #assert x.shape[2] == x.shape[3] == self.resolution, "{}, {}, {}".format(x.shape[2], x.shape[3], self.resolution) + + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, **ignorekwargs): + super().__init__() + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class VUNet(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, + in_channels, c_channels, + resolution, z_channels, use_timestep=False, **ignore_kwargs): + super().__init__() + self.ch = ch + self.temb_ch = self.ch*4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList([ + torch.nn.Linear(self.ch, + self.temb_ch), + torch.nn.Linear(self.temb_ch, + self.temb_ch), + ]) + + # downsampling + self.conv_in = torch.nn.Conv2d(c_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + self.z_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=1, + stride=1, + padding=0) + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=2*block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = AttnBlock(block_in) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + skip_in = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + if i_block == self.num_res_blocks: + skip_in = ch*in_ch_mult[i_level] + block.append(ResnetBlock(in_channels=block_in+skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(AttnBlock(block_in)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + + def forward(self, x, z): + #assert x.shape[2] == x.shape[3] == self.resolution + + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + z = self.z_in(z) + h = torch.cat((h,z),dim=1) + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock(in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, dropout=0.0), + ResnetBlock(in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, dropout=0.0), + nn.Conv2d(2*in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True)]) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1,2,3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution, + ch_mult=(2,2), dropout=0.0): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py b/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab2cc89aa2d9478fcfd60c80b91a1f7aa829e4e --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/discriminator/model.py @@ -0,0 +1,131 @@ +import functools +import torch.nn as nn + + +from custom_controlnet_aux.diffusion_edge.taming.modules.util import ActNorm + + +def weights_init(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + nn.init.normal_(m.weight.data, 0.0, 0.02) + elif classname.find('BatchNorm') != -1: + nn.init.normal_(m.weight.data, 1.0, 0.02) + nn.init.constant_(m.bias.data, 0) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm2d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm2d + else: + use_bias = norm_layer != nn.BatchNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + +class NLayerDiscriminator2(nn.Module): + """Defines a PatchGAN discriminator as in Pix2Pix + --> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py + """ + def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): + """Construct a PatchGAN discriminator + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator2, self).__init__() + if not use_actnorm: + norm_layer = nn.BatchNorm3d + else: + norm_layer = ActNorm + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func != nn.BatchNorm3d + else: + use_bias = norm_layer != nn.BatchNorm3d + + kw = 4 + padw = 1 + sequence = [nn.Conv3d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, + padding=padw, bias=use_bias, groups=8), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv3d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, + padding=padw, bias=use_bias, groups=8), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [ + nn.Conv3d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw), + # nn.Sigmoid() + ] # output 1 channel prediction map + self.main = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.main(input) + +if __name__ == "__main__": + import torch + model = NLayerDiscriminator2(input_nc=3, ndf=64, n_layers=3) + x = torch.rand(1, 3, 64, 64, 64) + with torch.no_grad(): + y = model(x) + pause = 0 \ No newline at end of file diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c3d03a59762ab71d59bb7c3f8977ee0fd459d0bc --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/__init__.py @@ -0,0 +1,2 @@ +from custom_controlnet_aux.diffusion_edge.taming.modules.losses.vqperceptual import DummyLoss + diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..32666db18cb2e89e65b8e905028aaaed81ccb017 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/lpips.py @@ -0,0 +1,126 @@ +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + +import torch +import torch.nn as nn +from torchvision import models +from collections import namedtuple + +from .util import get_ckpt_path + +from custom_controlnet_aux.util import custom_torch_download + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=False, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips") + self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name != "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name) + model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=False): + super(vgg16, self).__init__() + vgg16_model = models.vgg16(pretrained=pretrained) + vgg16_model.load_state_dict(torch.load(custom_torch_download(filename="vgg16-397923af.pth")), strict=True) + vgg_pretrained_features = vgg16_model.features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x,eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2,3],keepdim=keepdim) + diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..4ba77deb5159a6307ed2acba9945e4764a4ff0a5 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/segmentation.py @@ -0,0 +1,22 @@ +import torch.nn as nn +import torch.nn.functional as F + + +class BCELoss(nn.Module): + def forward(self, prediction, target): + loss = F.binary_cross_entropy_with_logits(prediction,target) + return loss, {} + + +class BCELossWithQuant(nn.Module): + def __init__(self, codebook_weight=1.): + super().__init__() + self.codebook_weight = codebook_weight + + def forward(self, qloss, target, prediction, split): + bce_loss = F.binary_cross_entropy_with_logits(prediction,target) + loss = bce_loss + self.codebook_weight*qloss + return loss, {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/bce_loss".format(split): bce_loss.detach().mean(), + "{}/quant_loss".format(split): qloss.detach().mean() + } diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..67f01102c77e9793197059726c0165afb8d81ace --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/losses/vqperceptual.py @@ -0,0 +1,136 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_controlnet_aux.diffusion_edge.taming.modules.losses.lpips import LPIPS +from custom_controlnet_aux.diffusion_edge.taming.modules.discriminator.model import NLayerDiscriminator, weights_init, NLayerDiscriminator2 + + +class DummyLoss(nn.Module): + def __init__(self): + super().__init__() + + +def adopt_weight(weight, global_step, threshold=0, value=0.): + if global_step < threshold: + weight = value + return weight + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQLPIPSWithDiscriminator(nn.Module): + def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, + disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, + perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, + disc_ndf=64, disc_loss="hinge"): + super().__init__() + assert disc_loss in ["hinge", "vanilla"] + self.codebook_weight = codebook_weight + self.pixel_weight = pixelloss_weight + self.perceptual_loss = LPIPS().eval() + self.perceptual_weight = perceptual_weight + + self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, + n_layers=disc_num_layers, + use_actnorm=use_actnorm, + ndf=disc_ndf + ).apply(weights_init) + self.discriminator_iter_start = disc_start + if disc_loss == "hinge": + self.disc_loss = hinge_d_loss + elif disc_loss == "vanilla": + self.disc_loss = vanilla_d_loss + else: + raise ValueError(f"Unknown GAN loss '{disc_loss}'.") + print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") + self.disc_factor = disc_factor + self.discriminator_weight = disc_weight + self.disc_conditional = disc_conditional + + def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): + if last_layer is not None: + nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] + else: + nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] + g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] + + d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) + d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() + d_weight = d_weight * self.discriminator_weight + return d_weight + + def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, + global_step, last_layer=None, cond=None, split="train"): + rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) + if self.perceptual_weight > 0: + p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) + rec_loss = rec_loss + self.perceptual_weight * p_loss + else: + p_loss = torch.tensor([0.0]) + + nll_loss = rec_loss + #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] + nll_loss = torch.mean(nll_loss) + + # now the GAN part + if optimizer_idx == 0: + # generator update + if cond is None: + assert not self.disc_conditional + logits_fake = self.discriminator(reconstructions.contiguous()) + else: + assert self.disc_conditional + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) + g_loss = -torch.mean(logits_fake) + + try: + d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) + except RuntimeError: + assert not self.training + d_weight = torch.tensor(0.0) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() + + log = {"{}/total_loss".format(split): loss.clone().detach().mean(), + "{}/quant_loss".format(split): codebook_loss.detach().mean(), + "{}/nll_loss".format(split): nll_loss.detach().mean(), + "{}/rec_loss".format(split): rec_loss.detach().mean(), + "{}/p_loss".format(split): p_loss.detach().mean(), + "{}/d_weight".format(split): d_weight.detach(), + "{}/disc_factor".format(split): torch.tensor(disc_factor), + "{}/g_loss".format(split): g_loss.detach().mean(), + } + return loss, log + + if optimizer_idx == 1: + # second pass for discriminator update + if cond is None: + logits_real = self.discriminator(inputs.contiguous().detach()) + logits_fake = self.discriminator(reconstructions.contiguous().detach()) + else: + logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) + logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) + + disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) + d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) + + log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), + "{}/logits_real".format(split): logits_real.detach().mean(), + "{}/logits_fake".format(split): logits_fake.detach().mean() + } + return d_loss, log diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py b/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py new file mode 100644 index 0000000000000000000000000000000000000000..ee69b0c897b6b382ae673622e420f55e494f5b09 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/misc/coord.py @@ -0,0 +1,31 @@ +import torch + +class CoordStage(object): + def __init__(self, n_embed, down_factor): + self.n_embed = n_embed + self.down_factor = down_factor + + def eval(self): + return self + + def encode(self, c): + """fake vqmodel interface""" + assert 0.0 <= c.min() and c.max() <= 1.0 + b,ch,h,w = c.shape + assert ch == 1 + + c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, + mode="area") + c = c.clamp(0.0, 1.0) + c = self.n_embed*c + c_quant = c.round() + c_ind = c_quant.to(dtype=torch.long) + + info = None, None, c_ind + return c_quant, None, info + + def decode(self, c): + c = c/self.n_embed + c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, + mode="nearest") + return c diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/util.py b/custom_controlnet_aux/diffusion_edge/taming/modules/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9ee16385d8b1342a2d60a5f1aa5cadcfbe934bd8 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/util.py @@ -0,0 +1,130 @@ +import torch +import torch.nn as nn + + +def count_params(model): + total_params = sum(p.numel() for p in model.parameters()) + return total_params + + +class ActNorm(nn.Module): + def __init__(self, num_features, logdet=False, affine=True, + allow_reverse_init=False): + assert affine + super().__init__() + self.logdet = logdet + self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.allow_reverse_init = allow_reverse_init + + self.register_buffer('initialized', torch.tensor(0, dtype=torch.uint8)) + + def initialize(self, input): + with torch.no_grad(): + flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1) + mean = ( + flatten.mean(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + std = ( + flatten.std(1) + .unsqueeze(1) + .unsqueeze(2) + .unsqueeze(3) + .permute(1, 0, 2, 3) + ) + + self.loc.data.copy_(-mean) + self.scale.data.copy_(1 / (std + 1e-6)) + + def forward(self, input, reverse=False): + if reverse: + return self.reverse(input) + if len(input.shape) == 2: + input = input[:,:,None,None] + squeeze = True + else: + squeeze = False + + _, _, height, width = input.shape + + if self.training and self.initialized.item() == 0: + self.initialize(input) + self.initialized.fill_(1) + + h = self.scale * (input + self.loc) + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + + if self.logdet: + log_abs = torch.log(torch.abs(self.scale)) + logdet = height*width*torch.sum(log_abs) + logdet = logdet * torch.ones(input.shape[0]).to(input) + return h, logdet + + return h + + def reverse(self, output): + if self.training and self.initialized.item() == 0: + if not self.allow_reverse_init: + raise RuntimeError( + "Initializing ActNorm in reverse direction is " + "disabled by default. Use allow_reverse_init=True to enable." + ) + else: + self.initialize(output) + self.initialized.fill_(1) + + if len(output.shape) == 2: + output = output[:,:,None,None] + squeeze = True + else: + squeeze = False + + h = output / self.scale - self.loc + + if squeeze: + h = h.squeeze(-1).squeeze(-1) + return h + + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class Labelator(AbstractEncoder): + """Net2Net Interface for Class-Conditional Model""" + def __init__(self, n_classes, quantize_interface=True): + super().__init__() + self.n_classes = n_classes + self.quantize_interface = quantize_interface + + def encode(self, c): + c = c[:,None] + if self.quantize_interface: + return c, None, [None, None, c.long()] + return c + + +class SOSProvider(AbstractEncoder): + # for unconditional training + def __init__(self, sos_token, quantize_interface=True): + super().__init__() + self.sos_token = sos_token + self.quantize_interface = quantize_interface + + def encode(self, x): + # get batch size from data and replicate sos_token + c = torch.ones(x.shape[0], 1)*self.sos_token + c = c.long().to(x.device) + if self.quantize_interface: + return c, None, [None, None, c] + return c diff --git a/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py b/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..d75544e41fa01bce49dd822b1037963d62f79b51 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/modules/vqvae/quantize.py @@ -0,0 +1,445 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from torch import einsum +from einops import rearrange + + +class VectorQuantizer(nn.Module): + """ + see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py + ____________________________________________ + Discretization bottleneck part of the VQ-VAE. + Inputs: + - n_e : number of embeddings + - e_dim : dimension of embedding + - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + _____________________________________________ + """ + + # NOTE: this class contains a bug regarding beta; see VectorQuantizer2 for + # a fix and use legacy=False to apply that fix. VectorQuantizer2 can be + # used wherever VectorQuantizer has been used before and is additionally + # more efficient. + def __init__(self, n_e, e_dim, beta): + super(VectorQuantizer, self).__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + def forward(self, z): + """ + Inputs the output of the encoder network z and maps it to a discrete + one-hot vector that is the index of the closest embedding vector e_j + z (continuous) -> z_q (discrete) + z.shape = (batch, channel, height, width) + quantization pipeline: + 1. get encoder input (B,C,H,W) + 2. flatten input to (B*H*W,C) + """ + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.matmul(z_flattened, self.embedding.weight.t()) + + ## could possible replace this here + # #\start... + # find closest encodings + min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + + min_encodings = torch.zeros( + min_encoding_indices.shape[0], self.n_e).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # dtype min encodings: torch.float32 + # min_encodings shape: torch.Size([2048, 512]) + # min_encoding_indices.shape: torch.Size([2048, 1]) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + #.........\end + + # with: + # .........\start + #min_encoding_indices = torch.argmin(d, dim=1) + #z_q = self.embedding(min_encoding_indices) + # ......\end......... (TODO) + + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + # TODO: check for more easy handling with nn.Embedding + min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices) + min_encodings.scatter_(1, indices[:,None], 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: + z_q = z_q.view(shape) + + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantize(nn.Module): + """ + credit to @karpathy: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py (thanks!) + Gumbel Softmax trick quantizer + Categorical Reparameterization with Gumbel-Softmax, Jang et al. 2016 + https://arxiv.org/abs/1611.01144 + """ + def __init__(self, num_hiddens, embedding_dim, n_embed, straight_through=True, + kl_weight=5e-4, temp_init=1.0, use_vqinterface=True, + remap=None, unknown_index="random"): + super().__init__() + + self.embedding_dim = embedding_dim + self.n_embed = n_embed + + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + + self.proj = nn.Conv2d(num_hiddens, n_embed, 1) + self.embed = nn.Embedding(n_embed, embedding_dim) + + self.use_vqinterface = use_vqinterface + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, return_logits=False): + # force hard = True when we are in eval mode, as we must quantize. actually, always true seems to work + hard = self.straight_through if self.training else True + temp = self.temperature if temp is None else temp + + logits = self.proj(z) + if self.remap is not None: + # continue only with used logits + full_zeros = torch.zeros_like(logits) + logits = logits[:,self.used,...] + + soft_one_hot = F.gumbel_softmax(logits, tau=temp, dim=1, hard=hard) + if self.remap is not None: + # go back to all entries but unused set to zero + full_zeros[:,self.used,...] = soft_one_hot + soft_one_hot = full_zeros + z_q = einsum('b n h w, n d -> b d h w', soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.n_embed + 1e-10), dim=1).mean() + + ind = soft_one_hot.argmax(dim=1) + if self.remap is not None: + ind = self.remap_to_used(ind) + if self.use_vqinterface: + if return_logits: + return z_q, diff, (None, None, ind), logits + return z_q, diff, (None, None, ind) + return z_q, diff, ind + + def get_codebook_entry(self, indices, shape): + b, h, w, c = shape + assert b*h*w == indices.shape[0] + indices = rearrange(indices, '(b h w) -> b h w', b=b, h=h, w=w) + if self.remap is not None: + indices = self.unmap_to_all(indices) + one_hot = F.one_hot(indices, num_classes=self.n_embed).permute(0, 3, 1, 2).float() + z_q = einsum('b n h w, n d -> b d h w', one_hot, self.embed.weight) + return z_q + + +class VectorQuantizer2(nn.Module): + """ + Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly + avoids costly matrix multiplications and allows for post-hoc remapping of indices. + """ + # NOTE: due to a bug the beta term was applied to the wrong term. for + # backwards compatibility we use the buggy version by default, but you can + # specify legacy=False to fix it. + def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", + sane_index_shape=False, legacy=True): + super().__init__() + self.n_e = n_e + self.e_dim = e_dim + self.beta = beta + self.legacy = legacy + + self.embedding = nn.Embedding(self.n_e, self.e_dim) + self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_e} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_e + + self.sane_index_shape = sane_index_shape + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z, temp=None, rescale_logits=False, return_logits=False): + assert temp is None or temp==1.0, "Only for interface compatible with Gumbel" + assert rescale_logits==False, "Only for interface compatible with Gumbel" + assert return_logits==False, "Only for interface compatible with Gumbel" + # reshape z -> (batch, height, width, channel) and flatten + z = rearrange(z, 'b c h w -> b h w c').contiguous() + z_flattened = z.view(-1, self.e_dim) + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + + d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \ + torch.sum(self.embedding.weight**2, dim=1) - 2 * \ + torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n')) + + min_encoding_indices = torch.argmin(d, dim=1) + z_q = self.embedding(min_encoding_indices).view(z.shape) + perplexity = None + min_encodings = None + + # compute loss for embedding + if not self.legacy: + loss = self.beta * torch.mean((z_q.detach()-z)**2) + \ + torch.mean((z_q - z.detach()) ** 2) + else: + loss = torch.mean((z_q.detach()-z)**2) + self.beta * \ + torch.mean((z_q - z.detach()) ** 2) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous() + + if self.remap is not None: + min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis + min_encoding_indices = self.remap_to_used(min_encoding_indices) + min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten + + if self.sane_index_shape: + min_encoding_indices = min_encoding_indices.reshape( + z_q.shape[0], z_q.shape[2], z_q.shape[3]) + + return z_q, loss, (perplexity, min_encodings, min_encoding_indices) + + def get_codebook_entry(self, indices, shape): + # shape specifying (batch, height, width, channel) + if self.remap is not None: + indices = indices.reshape(shape[0],-1) # add batch axis + indices = self.unmap_to_all(indices) + indices = indices.reshape(-1) # flatten again + + # get quantized latent vectors + z_q = self.embedding(indices) + + if shape is not None: + z_q = z_q.view(shape) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q + +class EmbeddingEMA(nn.Module): + def __init__(self, num_tokens, codebook_dim, decay=0.99, eps=1e-5): + super().__init__() + self.decay = decay + self.eps = eps + weight = torch.randn(num_tokens, codebook_dim) + self.weight = nn.Parameter(weight, requires_grad = False) + self.cluster_size = nn.Parameter(torch.zeros(num_tokens), requires_grad = False) + self.embed_avg = nn.Parameter(weight.clone(), requires_grad = False) + self.update = True + + def forward(self, embed_id): + return F.embedding(embed_id, self.weight) + + def cluster_size_ema_update(self, new_cluster_size): + self.cluster_size.data.mul_(self.decay).add_(new_cluster_size, alpha=1 - self.decay) + + def embed_avg_ema_update(self, new_embed_avg): + self.embed_avg.data.mul_(self.decay).add_(new_embed_avg, alpha=1 - self.decay) + + def weight_update(self, num_tokens): + n = self.cluster_size.sum() + smoothed_cluster_size = ( + (self.cluster_size + self.eps) / (n + num_tokens * self.eps) * n + ) + #normalize embedding average with smoothed cluster size + embed_normalized = self.embed_avg / smoothed_cluster_size.unsqueeze(1) + self.weight.data.copy_(embed_normalized) + + +class EMAVectorQuantizer(nn.Module): + def __init__(self, n_embed, embedding_dim, beta, decay=0.99, eps=1e-5, + remap=None, unknown_index="random"): + super().__init__() + self.codebook_dim = codebook_dim + self.num_tokens = num_tokens + self.beta = beta + self.embedding = EmbeddingEMA(self.num_tokens, self.codebook_dim, decay, eps) + + self.remap = remap + if self.remap is not None: + self.register_buffer("used", torch.tensor(np.load(self.remap))) + self.re_embed = self.used.shape[0] + self.unknown_index = unknown_index # "random" or "extra" or integer + if self.unknown_index == "extra": + self.unknown_index = self.re_embed + self.re_embed = self.re_embed+1 + print(f"Remapping {self.n_embed} indices to {self.re_embed} indices. " + f"Using {self.unknown_index} for unknown indices.") + else: + self.re_embed = n_embed + + def remap_to_used(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + match = (inds[:,:,None]==used[None,None,...]).long() + new = match.argmax(-1) + unknown = match.sum(2)<1 + if self.unknown_index == "random": + new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device) + else: + new[unknown] = self.unknown_index + return new.reshape(ishape) + + def unmap_to_all(self, inds): + ishape = inds.shape + assert len(ishape)>1 + inds = inds.reshape(ishape[0],-1) + used = self.used.to(inds) + if self.re_embed > self.used.shape[0]: # extra token + inds[inds>=self.used.shape[0]] = 0 # simply set to zero + back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds) + return back.reshape(ishape) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + #z, 'b c h w -> b h w c' + z = rearrange(z, 'b c h w -> b h w c') + z_flattened = z.reshape(-1, self.codebook_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = z_flattened.pow(2).sum(dim=1, keepdim=True) + \ + self.embedding.weight.pow(2).sum(dim=1) - 2 * \ + torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight) # 'n d -> d n' + + + encoding_indices = torch.argmin(d, dim=1) + + z_q = self.embedding(encoding_indices).view(z.shape) + encodings = F.one_hot(encoding_indices, self.num_tokens).type(z.dtype) + avg_probs = torch.mean(encodings, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10))) + + if self.training and self.embedding.update: + #EMA cluster size + encodings_sum = encodings.sum(0) + self.embedding.cluster_size_ema_update(encodings_sum) + #EMA embedding average + embed_sum = encodings.transpose(0,1) @ z_flattened + self.embedding.embed_avg_ema_update(embed_sum) + #normalize embed_avg and update weight + self.embedding.weight_update(self.num_tokens) + + # compute loss for embedding + loss = self.beta * F.mse_loss(z_q.detach(), z) + + # preserve gradients + z_q = z + (z_q - z).detach() + + # reshape back to match original input shape + #z_q, 'b h w c -> b c h w' + z_q = rearrange(z_q, 'b h w c -> b c h w') + return z_q, loss, (perplexity, encodings, encoding_indices) diff --git a/custom_controlnet_aux/diffusion_edge/taming/util.py b/custom_controlnet_aux/diffusion_edge/taming/util.py new file mode 100644 index 0000000000000000000000000000000000000000..06053e5defb87977f9ab07e69bf4da12201de9b7 --- /dev/null +++ b/custom_controlnet_aux/diffusion_edge/taming/util.py @@ -0,0 +1,157 @@ +import os, hashlib +import requests +from tqdm import tqdm + +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class KeyNotFoundError(Exception): + def __init__(self, cause, keys=None, visited=None): + self.cause = cause + self.keys = keys + self.visited = visited + messages = list() + if keys is not None: + messages.append("Key not found: {}".format(keys)) + if visited is not None: + messages.append("Visited: {}".format(visited)) + messages.append("Cause:\n{}".format(cause)) + message = "\n".join(messages) + super().__init__(message) + + +def retrieve( + list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False +): + """Given a nested list or dict return the desired value at key expanding + callable nodes if necessary and :attr:`expand` is ``True``. The expansion + is done in-place. + + Parameters + ---------- + list_or_dict : list or dict + Possibly nested list or dictionary. + key : str + key/to/value, path like string describing all keys necessary to + consider to get to the desired value. List indices can also be + passed here. + splitval : str + String that defines the delimiter between keys of the + different depth levels in `key`. + default : obj + Value returned if :attr:`key` is not found. + expand : bool + Whether to expand callable nodes on the path or not. + + Returns + ------- + The desired value or if :attr:`default` is not ``None`` and the + :attr:`key` is not found returns ``default``. + + Raises + ------ + Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is + ``None``. + """ + + keys = key.split(splitval) + + success = True + try: + visited = [] + parent = None + last_key = None + for key in keys: + if callable(list_or_dict): + if not expand: + raise KeyNotFoundError( + ValueError( + "Trying to get past callable node with expand=False." + ), + keys=keys, + visited=visited, + ) + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + + last_key = key + parent = list_or_dict + + try: + if isinstance(list_or_dict, dict): + list_or_dict = list_or_dict[key] + else: + list_or_dict = list_or_dict[int(key)] + except (KeyError, IndexError, ValueError) as e: + raise KeyNotFoundError(e, keys=keys, visited=visited) + + visited += [key] + # final expansion of retrieved value + if expand and callable(list_or_dict): + list_or_dict = list_or_dict() + parent[last_key] = list_or_dict + except KeyNotFoundError as e: + if default is None: + raise e + else: + list_or_dict = default + success = False + + if not pass_success: + return list_or_dict + else: + return list_or_dict, success + + +if __name__ == "__main__": + config = {"keya": "a", + "keyb": "b", + "keyc": + {"cc1": 1, + "cc2": 2, + } + } + from omegaconf import OmegaConf + config = OmegaConf.create(config) + print(config) + retrieve(config, "keya") + diff --git a/custom_controlnet_aux/dsine/LICENSE b/custom_controlnet_aux/dsine/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..2b677c1a1ec5f491a189df741896bc7207093c8f --- /dev/null +++ b/custom_controlnet_aux/dsine/LICENSE @@ -0,0 +1,230 @@ +DSINE SOFTWARE + +LICENCE AGREEMENT + +WE (Imperial College of Science, Technology and Medicine, (“Imperial College +London”)) ARE WILLING TO LICENSE THIS SOFTWARE TO YOU (a licensee “You”) ONLY +ON THE CONDITION THAT YOU ACCEPT ALL OF THE TERMS CONTAINED IN THE FOLLOWING +AGREEMENT. PLEASE READ THE AGREEMENT CAREFULLY BEFORE DOWNLOADING THE SOFTWARE. +BY EXERCISING THE OPTION TO DOWNLOAD THE SOFTWARE YOU AGREE TO BE BOUND BY THE +TERMS OF THE AGREEMENT. + +SOFTWARE LICENCE AGREEMENT (EXCLUDING BSD COMPONENTS) + +1. This Agreement pertains to a worldwide, non-exclusive, temporary, fully +paid-up, royalty free, non-transferable, non-sub- licensable licence (the +“Licence”) to use the elastic fusion source code, including any modification, +part or derivative (the “Software”). + +Ownership and Licence. Your rights to use and download the Software onto your +computer, and all other copies that You are authorised to make, are specified +in this Agreement. However, we (or our licensors) retain all rights, including +but not limited to all copyright and other intellectual property rights +anywhere in the world, in the Software not expressly granted to You in this +Agreement. + +2. Permitted use of the Licence: + +(a) You may download and install the Software onto one computer or server for +use in accordance with Clause 2(b) of this Agreement provided that You ensure +that the Software is not accessible by other users unless they have themselves +accepted the terms of this licence agreement. + +(b) You may use the Software solely for non-commercial, internal or academic +research purposes and only in accordance with the terms of this Agreement. You +may not use the Software for commercial purposes, including but not limited to +(1) integration of all or part of the source code or the Software into a +product for sale or licence by or on behalf of You to third parties or (2) use +of the Software or any derivative of it for research to develop software +products for sale or licence to a third party or (3) use of the Software or any +derivative of it for research to develop non-software products for sale or +licence to a third party, or (4) use of the Software to provide any service to +an external organisation for which payment is received. + +Should You wish to use the Software for commercial purposes, You shall +email researchcontracts.engineering@imperial.ac.uk . + +(c) Right to Copy. You may copy the Software for back-up and archival purposes, +provided that each copy is kept in your possession and provided You reproduce +our copyright notice (set out in Schedule 1) on each copy. + +(d) Transfer and sub-licensing. You may not rent, lend, or lease the Software +and You may not transmit, transfer or sub-license this licence to use the +Software or any of your rights or obligations under this Agreement to another +party. + +(e) Identity of Licensee. The licence granted herein is personal to You. You +shall not permit any third party to access, modify or otherwise use the +Software nor shall You access modify or otherwise use the Software on behalf of +any third party. If You wish to obtain a licence for mutiple users or a site +licence for the Software please contact us +at researchcontracts.engineering@imperial.ac.uk . + +(f) Publications and presentations. You may make public, results or data +obtained from, dependent on or arising from research carried out using the +Software, provided that any such presentation or publication identifies the +Software as the source of the results or the data, including the Copyright +Notice given in each element of the Software, and stating that the Software has +been made available for use by You under licence from Imperial College London +and You provide a copy of any such publication to Imperial College London. + +3. Prohibited Uses. You may not, without written permission from us +at researchcontracts.engineering@imperial.ac.uk : + +(a) Use, copy, modify, merge, or transfer copies of the Software or any +documentation provided by us which relates to the Software except as provided +in this Agreement; + +(b) Use any back-up or archival copies of the Software (or allow anyone else to +use such copies) for any purpose other than to replace the original copy in the +event it is destroyed or becomes defective; or + +(c) Disassemble, decompile or "unlock", reverse translate, or in any manner +decode the Software for any reason. + +4. Warranty Disclaimer + +(a) Disclaimer. The Software has been developed for research purposes only. You +acknowledge that we are providing the Software to You under this licence +agreement free of charge and on condition that the disclaimer set out below +shall apply. We do not represent or warrant that the Software as to: (i) the +quality, accuracy or reliability of the Software; (ii) the suitability of the +Software for any particular use or for use under any specific conditions; and +(iii) whether use of the Software will infringe third-party rights. + +You acknowledge that You have reviewed and evaluated the Software to determine +that it meets your needs and that You assume all responsibility and liability +for determining the suitability of the Software as fit for your particular +purposes and requirements. Subject to Clause 4(b), we exclude and expressly +disclaim all express and implied representations, warranties, conditions and +terms not stated herein (including the implied conditions or warranties of +satisfactory quality, merchantable quality, merchantability and fitness for +purpose). + +(b) Savings. Some jurisdictions may imply warranties, conditions or terms or +impose obligations upon us which cannot, in whole or in part, be excluded, +restricted or modified or otherwise do not allow the exclusion of implied +warranties, conditions or terms, in which case the above warranty disclaimer +and exclusion will only apply to You to the extent permitted in the relevant +jurisdiction and does not in any event exclude any implied warranties, +conditions or terms which may not under applicable law be excluded. + +(c) Imperial College London disclaims all responsibility for the use which is +made of the Software and any liability for the outcomes arising from using the +Software. + +5. Limitation of Liability + +(a) You acknowledge that we are providing the Software to You under this +licence agreement free of charge and on condition that the limitation of +liability set out below shall apply. Accordingly, subject to Clause 5(b), we +exclude all liability whether in contract, tort, negligence or otherwise, in +respect of the Software and/or any related documentation provided to You by us +including, but not limited to, liability for loss or corruption of data, loss +of contracts, loss of income, loss of profits, loss of cover and any +consequential or indirect loss or damage of any kind arising out of or in +connection with this licence agreement, however caused. This exclusion shall +apply even if we have been advised of the possibility of such loss or damage. + +(b) You agree to indemnify Imperial College London and hold it harmless from +and against any and all claims, damages and liabilities asserted by third +parties (including claims for negligence) which arise directly or indirectly +from the use of the Software or any derivative of it or the sale of any +products based on the Software. You undertake to make no liability claim +against any employee, student, agent or appointee of Imperial College London, +in connection with this Licence or the Software. + +(c) Nothing in this Agreement shall have the effect of excluding or limiting +our statutory liability. + +(d) Some jurisdictions do not allow these limitations or exclusions either +wholly or in part, and, to that extent, they may not apply to you. Nothing in +this licence agreement will affect your statutory rights or other relevant +statutory provisions which cannot be excluded, restricted or modified, and its +terms and conditions must be read and construed subject to any such statutory +rights and/or provisions. + +6. Confidentiality. You agree not to disclose any confidential information +provided to You by us pursuant to this Agreement to any third party without our +prior written consent. The obligations in this Clause 6 shall survive the +termination of this Agreement for any reason. + +7. Termination. + +(a) We may terminate this licence agreement and your right to use the Software +at any time with immediate effect upon written notice to You. + +(b) This licence agreement and your right to use the Software automatically +terminate if You: + + (i) fail to comply with any provisions of this Agreement; or + + (ii) destroy the copies of the Software in your possession, or voluntarily + return the Software to us. + +(c) Upon termination You will destroy all copies of the Software. + +(d) Otherwise, the restrictions on your rights to use the Software will expire +10 (ten) years after first use of the Software under this licence agreement. + +8. Miscellaneous Provisions. + +(a) This Agreement will be governed by and construed in accordance with the +substantive laws of England and Wales whose courts shall have exclusive +jurisdiction over all disputes which may arise between us. + +(b) This is the entire agreement between us relating to the Software, and +supersedes any prior purchase order, communications, advertising or +representations concerning the Software. + +(c) No change or modification of this Agreement will be valid unless it is in +writing, and is signed by us. + +(d) The unenforceability or invalidity of any part of this Agreement will not +affect the enforceability or validity of the remaining parts. + +BSD Elements of the Software + +For BSD elements of the Software, the following terms shall apply: +Copyright as indicated in the header of the individual element of the Software. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this +list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, +this list of conditions and the following disclaimer in the documentation +and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors +may be used to endorse or promote products derived from this software without +specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +SCHEDULE 1 + +The Software + +DSINE is a framework for estimating surface normals from a single image. It is based on the techniques described in the following publication: + + • Gwangbin Bae, Andrew J. Davison. Rethinking Inductive Biases for Surface Normal Estimation. CVPR, 2024 +_________________________ + +Acknowledgments + +If you use the software, you should reference the following paper in any publication: + + • Gwangbin Bae, Andrew J. Davison. Rethinking Inductive Biases for Surface Normal Estimation. CVPR, 2024 \ No newline at end of file diff --git a/custom_controlnet_aux/dsine/__init__.py b/custom_controlnet_aux/dsine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..06b3a3625212241e6dc35786d1411aa48a3b5f8e --- /dev/null +++ b/custom_controlnet_aux/dsine/__init__.py @@ -0,0 +1,100 @@ +import os +import types +import warnings + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, DIFFUSION_EDGE_MODEL_NAME +from .models.dsine_arch import DSINE +from custom_controlnet_aux.dsine.utils.utils import get_intrins_from_fov + +# load model +def load_checkpoint(fpath, model): + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + return model + +def get_pad(orig_H, orig_W): + if orig_W % 64 == 0: + l = 0 + r = 0 + else: + new_W = 64 * ((orig_W // 64) + 1) + l = (new_W - orig_W) // 2 + r = (new_W - orig_W) - l + + if orig_H % 64 == 0: + t = 0 + b = 0 + else: + new_H = 64 * ((orig_H // 64) + 1) + t = (new_H - orig_H) // 2 + b = (new_H - orig_H) - t + return l, r, t, b + +class DsineDetector: + def __init__(self, model): + self.model = model + self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=DIFFUSION_EDGE_MODEL_NAME, filename="dsine.pt"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + model = DSINE() + model = load_checkpoint(model_path, model) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + self.model.pixel_coords = self.model.pixel_coords.to(device) + self.device = device + return self + + + def __call__(self, input_image, fov=60.0, iterations=5, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + self.model.num_iter = iterations + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + orig_H, orig_W = input_image.shape[:2] + l, r, t, b = get_pad(orig_H, orig_W) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method, mode="constant") + with torch.no_grad(): + input_image = torch.from_numpy(input_image).float().to(self.device) + input_image = input_image / 255.0 + input_image = rearrange(input_image, 'h w c -> 1 c h w') + input_image = self.norm(input_image) + + intrins = get_intrins_from_fov(new_fov=fov, H=orig_H, W=orig_W, device=self.device).unsqueeze(0) + intrins[:, 0, 2] += l + intrins[:, 1, 2] += t + + normal = self.model(input_image, intrins) + normal = normal[-1][0] + normal = ((normal + 1) * 0.5).clip(0, 1) + + normal = rearrange(normal, 'c h w -> h w c').cpu().numpy() + normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = HWC3(normal_image) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + \ No newline at end of file diff --git a/custom_controlnet_aux/dsine/models/dsine_arch.py b/custom_controlnet_aux/dsine/models/dsine_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..13becb2f576a81b58dbf6d5e4667a51e92433c57 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/dsine_arch.py @@ -0,0 +1,232 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from custom_controlnet_aux.dsine.models.submodules import Encoder, ConvGRU, UpSampleBN, UpSampleGN, RayReLU, \ + convex_upsampling, get_unfold, get_prediction_head, \ + INPUT_CHANNELS_DICT +from custom_controlnet_aux.dsine.utils.rotation import axis_angle_to_matrix + + +class Decoder(nn.Module): + def __init__(self, output_dims, B=5, NF=2048, BN=False, downsample_ratio=8): + super(Decoder, self).__init__() + input_channels = INPUT_CHANNELS_DICT[B] + output_dim, feature_dim, hidden_dim = output_dims + features = bottleneck_features = NF + self.downsample_ratio = downsample_ratio + + UpSample = UpSampleBN if BN else UpSampleGN + self.conv2 = nn.Conv2d(bottleneck_features + 2, features, kernel_size=1, stride=1, padding=0) + self.up1 = UpSample(skip_input=features // 1 + input_channels[1] + 2, output_features=features // 2, align_corners=False) + self.up2 = UpSample(skip_input=features // 2 + input_channels[2] + 2, output_features=features // 4, align_corners=False) + + # prediction heads + i_dim = features // 4 + h_dim = 128 + self.normal_head = get_prediction_head(i_dim+2, h_dim, output_dim) + self.feature_head = get_prediction_head(i_dim+2, h_dim, feature_dim) + self.hidden_head = get_prediction_head(i_dim+2, h_dim, hidden_dim) + + def forward(self, features, uvs): + _, _, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + uv_32, uv_16, uv_8 = uvs + + x_d0 = self.conv2(torch.cat([x_block4, uv_32], dim=1)) + x_d1 = self.up1(x_d0, torch.cat([x_block3, uv_16], dim=1)) + x_feat = self.up2(x_d1, torch.cat([x_block2, uv_8], dim=1)) + x_feat = torch.cat([x_feat, uv_8], dim=1) + + normal = self.normal_head(x_feat) + normal = F.normalize(normal, dim=1) + f = self.feature_head(x_feat) + h = self.hidden_head(x_feat) + return normal, f, h + + +class DSINE(nn.Module): + def __init__(self): + super(DSINE, self).__init__() + self.downsample_ratio = 8 + self.ps = 5 # patch size + self.num_iter = 5 # num iterations + + # define encoder + self.encoder = Encoder(B=5, pretrained=True) + + # define decoder + self.output_dim = output_dim = 3 + self.feature_dim = feature_dim = 64 + self.hidden_dim = hidden_dim = 64 + self.decoder = Decoder([output_dim, feature_dim, hidden_dim], B=5, NF=2048, BN=False) + + # ray direction-based ReLU + self.ray_relu = RayReLU(eps=1e-2) + + # pixel_coords (1, 3, H, W) + # NOTE: this is set to some arbitrarily high number, + # if your input is 2000+ pixels wide/tall, increase these values + h = 2000 + w = 2000 + pixel_coords = np.ones((3, h, w)).astype(np.float32) + x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0) + y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1) + pixel_coords[0, :, :] = x_range + 0.5 + pixel_coords[1, :, :] = y_range + 0.5 + self.pixel_coords = torch.from_numpy(pixel_coords).unsqueeze(0) + + # define ConvGRU cell + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=feature_dim+2, ks=self.ps) + + # padding used during NRN + self.pad = (self.ps - 1) // 2 + + # prediction heads + self.prob_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # weights assigned for each nghbr pixel + self.xy_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps*2) # rotation axis for each nghbr pixel + self.angle_head = get_prediction_head(self.hidden_dim+2, 64, self.ps*self.ps) # rotation angle for each nghbr pixel + + # prediction heads - weights used for upsampling the coarse resolution output + self.up_prob_head = get_prediction_head(self.hidden_dim+2, 64, 9 * self.downsample_ratio * self.downsample_ratio) + + def get_ray(self, intrins, H, W, orig_H, orig_W, return_uv=False): + B, _, _ = intrins.shape + fu = intrins[:, 0, 0][:,None,None] * (W / orig_W) + cu = intrins[:, 0, 2][:,None,None] * (W / orig_W) + fv = intrins[:, 1, 1][:,None,None] * (H / orig_H) + cv = intrins[:, 1, 2][:,None,None] * (H / orig_H) + + # (B, 2, H, W) + ray = self.pixel_coords[:, :, :H, :W].repeat(B, 1, 1, 1) + ray[:, 0, :, :] = (ray[:, 0, :, :] - cu) / fu + ray[:, 1, :, :] = (ray[:, 1, :, :] - cv) / fv + + if return_uv: + return ray[:, :2, :, :] + else: + return F.normalize(ray, dim=1) + + def upsample(self, h, pred_norm, uv_8): + up_mask = self.up_prob_head(torch.cat([h, uv_8], dim=1)) + up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio) + up_pred_norm = F.normalize(up_pred_norm, dim=1) + return up_pred_norm + + def refine(self, h, feat_map, pred_norm, intrins, orig_H, orig_W, uv_8, ray_8): + B, C, H, W = pred_norm.shape + fu = intrins[:, 0, 0][:,None,None,None] * (W / orig_W) # (B, 1, 1, 1) + cu = intrins[:, 0, 2][:,None,None,None] * (W / orig_W) + fv = intrins[:, 1, 1][:,None,None,None] * (H / orig_H) + cv = intrins[:, 1, 2][:,None,None,None] * (H / orig_H) + + h_new = self.gru(h, feat_map) + + # get nghbr prob (B, 1, ps*ps, h, w) + nghbr_prob = self.prob_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1) + nghbr_prob = torch.sigmoid(nghbr_prob) + + # get nghbr normals (B, 3, ps*ps, h, w) + nghbr_normals = get_unfold(pred_norm, ps=self.ps, pad=self.pad) + + # get nghbr xy (B, 2, ps*ps, h, w) + nghbr_xys = self.xy_head(torch.cat([h_new, uv_8], dim=1)) + nghbr_xs, nghbr_ys = torch.split(nghbr_xys, [self.ps*self.ps, self.ps*self.ps], dim=1) + nghbr_xys = torch.cat([nghbr_xs.unsqueeze(1), nghbr_ys.unsqueeze(1)], dim=1) + nghbr_xys = F.normalize(nghbr_xys, dim=1) + + # get nghbr theta (B, 1, ps*ps, h, w) + nghbr_angle = self.angle_head(torch.cat([h_new, uv_8], dim=1)).unsqueeze(1) + nghbr_angle = torch.sigmoid(nghbr_angle) * np.pi + + # get nghbr pixel coord (1, 3, ps*ps, h, w) + nghbr_pixel_coord = get_unfold(self.pixel_coords[:, :, :H, :W], ps=self.ps, pad=self.pad) + + # nghbr axes (B, 3, ps*ps, h, w) + nghbr_axes = torch.zeros_like(nghbr_normals) + + du_over_fu = nghbr_xys[:, 0, ...] / fu # (B, ps*ps, h, w) + dv_over_fv = nghbr_xys[:, 1, ...] / fv # (B, ps*ps, h, w) + + term_u = (nghbr_pixel_coord[:, 0, ...] + nghbr_xys[:, 0, ...] - cu) / fu # (B, ps*ps, h, w) + term_v = (nghbr_pixel_coord[:, 1, ...] + nghbr_xys[:, 1, ...] - cv) / fv # (B, ps*ps, h, w) + + nx = nghbr_normals[:, 0, ...] # (B, ps*ps, h, w) + ny = nghbr_normals[:, 1, ...] # (B, ps*ps, h, w) + nz = nghbr_normals[:, 2, ...] # (B, ps*ps, h, w) + + nghbr_delta_z_num = - (du_over_fu * nx + dv_over_fv * ny) + nghbr_delta_z_denom = (term_u * nx + term_v * ny + nz) + nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8] = 1e-8 * torch.sign(nghbr_delta_z_denom[torch.abs(nghbr_delta_z_denom) < 1e-8]) + nghbr_delta_z = nghbr_delta_z_num / nghbr_delta_z_denom + + nghbr_axes[:, 0, ...] = du_over_fu + nghbr_delta_z * term_u + nghbr_axes[:, 1, ...] = dv_over_fv + nghbr_delta_z * term_v + nghbr_axes[:, 2, ...] = nghbr_delta_z + nghbr_axes = F.normalize(nghbr_axes, dim=1) # (B, 3, ps*ps, h, w) + + # make sure axes are all valid + invalid = torch.sum(torch.logical_or(torch.isnan(nghbr_axes), torch.isinf(nghbr_axes)).float(), dim=1) > 0.5 # (B, ps*ps, h, w) + nghbr_axes[:, 0, ...][invalid] = 0.0 + nghbr_axes[:, 1, ...][invalid] = 0.0 + nghbr_axes[:, 2, ...][invalid] = 0.0 + + # nghbr_axes_angle (B, 3, ps*ps, h, w) + nghbr_axes_angle = nghbr_axes * nghbr_angle + nghbr_axes_angle = nghbr_axes_angle.permute(0, 2, 3, 4, 1) # (B, ps*ps, h, w, 3) + nghbr_R = axis_angle_to_matrix(nghbr_axes_angle) # (B, ps*ps, h, w, 3, 3) + + # (B, 3, ps*ps, h, w) + nghbr_normals_rot = torch.bmm( + nghbr_R.reshape(B * self.ps * self.ps * H * W, 3, 3), + nghbr_normals.permute(0, 2, 3, 4, 1).reshape(B * self.ps * self.ps * H * W, 3).unsqueeze(-1) + ).reshape(B, self.ps*self.ps, H, W, 3, 1).squeeze(-1).permute(0, 4, 1, 2, 3) # (B, 3, ps*ps, h, w) + nghbr_normals_rot = F.normalize(nghbr_normals_rot, dim=1) + + # ray ReLU + nghbr_normals_rot = torch.cat([ + self.ray_relu(nghbr_normals_rot[:, :, i, :, :], ray_8).unsqueeze(2) + for i in range(nghbr_normals_rot.size(2)) + ], dim=2) + + # (B, 1, ps*ps, h, w) * (B, 3, ps*ps, h, w) + pred_norm = torch.sum(nghbr_prob * nghbr_normals_rot, dim=2) # (B, C, H, W) + pred_norm = F.normalize(pred_norm, dim=1) + + up_mask = self.up_prob_head(torch.cat([h_new, uv_8], dim=1)) + up_pred_norm = convex_upsampling(pred_norm, up_mask, self.downsample_ratio) + up_pred_norm = F.normalize(up_pred_norm, dim=1) + + return h_new, pred_norm, up_pred_norm + + + def forward(self, img, intrins=None): + # Step 1. encoder + features = self.encoder(img) + + # Step 2. get uv encoding + B, _, orig_H, orig_W = img.shape + intrins[:, 0, 2] += 0.5 + intrins[:, 1, 2] += 0.5 + uv_32 = self.get_ray(intrins, orig_H//32, orig_W//32, orig_H, orig_W, return_uv=True) + uv_16 = self.get_ray(intrins, orig_H//16, orig_W//16, orig_H, orig_W, return_uv=True) + uv_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W, return_uv=True) + ray_8 = self.get_ray(intrins, orig_H//8, orig_W//8, orig_H, orig_W) + + # Step 3. decoder - initial prediction + pred_norm, feat_map, h = self.decoder(features, uvs=(uv_32, uv_16, uv_8)) + pred_norm = self.ray_relu(pred_norm, ray_8) + + # Step 4. add ray direction encoding + feat_map = torch.cat([feat_map, uv_8], dim=1) + + # iterative refinement + up_pred_norm = self.upsample(h, pred_norm, uv_8) + pred_list = [up_pred_norm] + for i in range(self.num_iter): + h, pred_norm, up_pred_norm = self.refine(h, feat_map, + pred_norm.detach(), + intrins, orig_H, orig_W, uv_8, ray_8) + pred_list.append(up_pred_norm) + return pred_list + diff --git a/custom_controlnet_aux/dsine/models/submodules/__init__.py b/custom_controlnet_aux/dsine/models/submodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..664ba06cc2f87dfadf0f2c6cd8d1f2c70b149b05 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/__init__.py @@ -0,0 +1,199 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os + + +INPUT_CHANNELS_DICT = { + 0: [1280, 112, 40, 24, 16], + 1: [1280, 112, 40, 24, 16], + 2: [1408, 120, 48, 24, 16], + 3: [1536, 136, 48, 32, 24], + 4: [1792, 160, 56, 32, 24], + 5: [2048, 176, 64, 40, 24], + 6: [2304, 200, 72, 40, 32], + 7: [2560, 224, 80, 48, 32] +} + + +class Encoder(nn.Module): + def __init__(self, B=5, pretrained=True): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b%s_ap' % B + print('Loading base model ()...'.format(basemodel_name), end='') + repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') + basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') + print('Done.') + + # Remove last layer + print('Removing last two layers (global_pool & classifier).') + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, ks=3): + super(ConvGRU, self).__init__() + p = (ks - 1) // 2 + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, ks, padding=p) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + return h + + +class RayReLU(nn.Module): + def __init__(self, eps=1e-2): + super(RayReLU, self).__init__() + self.eps = eps + + def forward(self, pred_norm, ray): + # angle between the predicted normal and ray direction + cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze(1) # (B, 1, H, W) + + # component of pred_norm along view + norm_along_view = ray * cos + + # cos should be bigger than eps + norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps) + + # difference + diff = norm_along_view_relu - norm_along_view + + # updated pred_norm + new_pred_norm = pred_norm + diff + new_pred_norm = F.normalize(new_pred_norm, dim=1) + + return new_pred_norm + + +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleBN, self).__init__() + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +class Conv2d_WS(nn.Conv2d): + """ weight standardization + """ + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d_WS, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class UpSampleGN(nn.Module): + """ UpSample with GroupNorm + """ + def __init__(self, skip_input, output_features, align_corners=True): + super(UpSampleGN, self).__init__() + self._net = nn.Sequential(Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU(), + Conv2d_WS(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU()) + self.align_corners = align_corners + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=self.align_corners) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +def upsample_via_bilinear(out, up_mask, downsample_ratio): + """ bilinear upsampling (up_mask is a dummy variable) + """ + return F.interpolate(out, scale_factor=downsample_ratio, mode='bilinear', align_corners=True) + + +def upsample_via_mask(out, up_mask, downsample_ratio): + """ convex upsampling + """ + # out: low-resolution output (B, o_dim, H, W) + # up_mask: (B, 9*k*k, H, W) + k = downsample_ratio + + N, o_dim, H, W = out.shape + up_mask = up_mask.view(N, 1, 9, k, k, H, W) + up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) + + up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W) + up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W) + up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W) + + up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k) + return up_out.reshape(N, o_dim, k*H, k*W) # (B, 2, kH, kW) + + +def convex_upsampling(out, up_mask, k): + # out: low-resolution output (B, C, H, W) + # up_mask: (B, 9*k*k, H, W) + B, C, H, W = out.shape + up_mask = up_mask.view(B, 1, 9, k, k, H, W) + up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) + + out = F.pad(out, pad=(1,1,1,1), mode='replicate') + up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W) + up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W) + + up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W) + up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k) + return up_out.reshape(B, C, k*H, k*W) # (B, C, kH, kW) + + +def get_unfold(pred_norm, ps, pad): + B, C, H, W = pred_norm.shape + pred_norm = F.pad(pred_norm, pad=(pad,pad,pad,pad), mode='replicate') # (B, C, h, w) + pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w) + pred_norm_unfold = pred_norm_unfold.view(B, C, ps*ps, H, W) # (B, C, ps*ps, h, w) + return pred_norm_unfold + + +def get_prediction_head(input_dim, hidden_dim, output_dim): + return nn.Sequential( + nn.Conv2d(input_dim, hidden_dim, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, hidden_dim, 1), + nn.ReLU(inplace=True), + nn.Conv2d(hidden_dim, output_dim, 1), + ) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/BENCHMARK.md b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/BENCHMARK.md new file mode 100644 index 0000000000000000000000000000000000000000..6ead7171ce5a5bbd2702f6b5c825dc9808ba5658 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/BENCHMARK.md @@ -0,0 +1,555 @@ +# Model Performance Benchmarks + +All benchmarks run as per: + +``` +python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx +python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx +python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 +python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt +python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb +python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb +``` + +## EfficientNet-B0 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 +Time per operator type: + 29.7378 ms. 60.5145%. Conv + 12.1785 ms. 24.7824%. Sigmoid + 3.62811 ms. 7.38297%. SpatialBN + 2.98444 ms. 6.07314%. Mul + 0.326902 ms. 0.665225%. AveragePool + 0.197317 ms. 0.401528%. FC + 0.0852877 ms. 0.173555%. Add + 0.0032607 ms. 0.00663532%. Squeeze + 49.1416 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 95.2696%. Conv + 0.0269508 GFLOP. 3.33857%. SpatialBN + 0.00846444 GFLOP. 1.04855%. Mul + 0.002561 GFLOP. 0.317248%. FC + 0.000210112 GFLOP. 0.0260279%. Add + 0.807256 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 43.0891%. Mul + 43.2015 MB. 31.807%. Conv + 27.2869 MB. 20.0899%. SpatialBN + 5.12912 MB. 3.77631%. FC + 1.6809 MB. 1.23756%. Add + 135.824 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 38.1965%. Mul + 26.9881 MB. 30.4465%. Conv + 26.9508 MB. 30.4044%. SpatialBN + 0.840448 MB. 0.948147%. Add + 0.004 MB. 0.00451258%. FC + 88.6412 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 74.9391%. Conv + 5.124 MB. 24.265%. FC + 0.168064 MB. 0.795877%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 21.1168 MB in Total +``` +### Optimized +``` +Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 +Time per operator type: + 29.776 ms. 65.002%. Conv + 12.2803 ms. 26.8084%. Sigmoid + 3.15073 ms. 6.87815%. Mul + 0.328651 ms. 0.717456%. AveragePool + 0.186237 ms. 0.406563%. FC + 0.0832429 ms. 0.181722%. Add + 0.0026184 ms. 0.00571606%. Squeeze + 45.8078 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 98.5601%. Conv + 0.00846444 GFLOP. 1.08476%. Mul + 0.002561 GFLOP. 0.328205%. FC + 0.000210112 GFLOP. 0.0269269%. Add + 0.780305 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 53.8803%. Mul + 43.2855 MB. 39.8501%. Conv + 5.12912 MB. 4.72204%. FC + 1.6809 MB. 1.54749%. Add + 108.621 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 54.8834%. Mul + 26.9881 MB. 43.7477%. Conv + 0.840448 MB. 1.36237%. Add + 0.004 MB. 0.00648399%. FC + 61.6904 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 75.5403%. Conv + 5.124 MB. 24.4597%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 20.9488 MB in Total +``` + +## EfficientNet-B1 +### Optimized +``` +Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 +Time per operator type: + 45.7915 ms. 66.3206%. Conv + 17.8718 ms. 25.8841%. Sigmoid + 4.44132 ms. 6.43244%. Mul + 0.51001 ms. 0.738658%. AveragePool + 0.233283 ms. 0.337868%. Add + 0.194986 ms. 0.282402%. FC + 0.00268255 ms. 0.00388519%. Squeeze + 69.0456 ms in Total +FLOP per operator type: + 1.37105 GFLOP. 98.7673%. Conv + 0.0138759 GFLOP. 0.99959%. Mul + 0.002561 GFLOP. 0.184489%. FC + 0.000674432 GFLOP. 0.0485847%. Add + 1.38816 GFLOP in Total +Feature Memory Read per operator type: + 94.624 MB. 54.0789%. Mul + 69.8255 MB. 39.9062%. Conv + 5.39546 MB. 3.08357%. Add + 5.12912 MB. 2.93136%. FC + 174.974 MB in Total +Feature Memory Written per operator type: + 55.5035 MB. 54.555%. Mul + 43.5333 MB. 42.7894%. Conv + 2.69773 MB. 2.65163%. Add + 0.004 MB. 0.00393165%. FC + 101.739 MB in Total +Parameter Memory per operator type: + 25.7479 MB. 83.4024%. Conv + 5.124 MB. 16.5976%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 30.8719 MB in Total +``` + +## EfficientNet-B2 +### Optimized +``` +Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 +Time per operator type: + 61.4627 ms. 67.5845%. Conv + 22.7458 ms. 25.0113%. Sigmoid + 5.59931 ms. 6.15701%. Mul + 0.642567 ms. 0.706568%. AveragePool + 0.272795 ms. 0.299965%. Add + 0.216178 ms. 0.237709%. FC + 0.00268895 ms. 0.00295677%. Squeeze + 90.942 ms in Total +FLOP per operator type: + 1.98431 GFLOP. 98.9343%. Conv + 0.0177039 GFLOP. 0.882686%. Mul + 0.002817 GFLOP. 0.140451%. FC + 0.000853984 GFLOP. 0.0425782%. Add + 2.00568 GFLOP in Total +Feature Memory Read per operator type: + 120.609 MB. 54.9637%. Mul + 86.3512 MB. 39.3519%. Conv + 6.83187 MB. 3.11341%. Add + 5.64163 MB. 2.571%. FC + 219.433 MB in Total +Feature Memory Written per operator type: + 70.8155 MB. 54.6573%. Mul + 55.3273 MB. 42.7031%. Conv + 3.41594 MB. 2.63651%. Add + 0.004 MB. 0.00308731%. FC + 129.563 MB in Total +Parameter Memory per operator type: + 30.4721 MB. 84.3913%. Conv + 5.636 MB. 15.6087%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 36.1081 MB in Total +``` + +## MixNet-M +### Optimized +``` +Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 +Time per operator type: + 48.1139 ms. 75.2052%. Conv + 7.1341 ms. 11.1511%. Sigmoid + 2.63706 ms. 4.12189%. SpatialBN + 1.73186 ms. 2.70701%. Mul + 1.38707 ms. 2.16809%. Split + 1.29322 ms. 2.02139%. Concat + 1.00093 ms. 1.56452%. Relu + 0.235309 ms. 0.367803%. Add + 0.221579 ms. 0.346343%. FC + 0.219315 ms. 0.342803%. AveragePool + 0.00250145 ms. 0.00390993%. Squeeze + 63.9768 ms in Total +FLOP per operator type: + 0.675273 GFLOP. 95.5827%. Conv + 0.0221072 GFLOP. 3.12921%. SpatialBN + 0.00538445 GFLOP. 0.762152%. Mul + 0.003073 GFLOP. 0.434973%. FC + 0.000642488 GFLOP. 0.0909421%. Add + 0 GFLOP. 0%. Concat + 0 GFLOP. 0%. Relu + 0.70648 GFLOP in Total +Feature Memory Read per operator type: + 46.8424 MB. 30.502%. Conv + 36.8626 MB. 24.0036%. Mul + 22.3152 MB. 14.5309%. SpatialBN + 22.1074 MB. 14.3955%. Concat + 14.1496 MB. 9.21372%. Relu + 6.15414 MB. 4.00735%. FC + 5.1399 MB. 3.34692%. Add + 153.571 MB in Total +Feature Memory Written per operator type: + 32.7672 MB. 28.4331%. Conv + 22.1072 MB. 19.1831%. Concat + 22.1072 MB. 19.1831%. SpatialBN + 21.5378 MB. 18.689%. Mul + 14.1496 MB. 12.2781%. Relu + 2.56995 MB. 2.23003%. Add + 0.004 MB. 0.00347092%. FC + 115.243 MB in Total +Parameter Memory per operator type: + 13.7059 MB. 68.674%. Conv + 6.148 MB. 30.8049%. FC + 0.104 MB. 0.521097%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Concat + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 19.9579 MB in Total +``` + +## TF MobileNet-V3 Large 1.0 + +### Optimized +``` +Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 +Time per operator type: + 17.437 ms. 80.0087%. Conv + 1.27662 ms. 5.8577%. Add + 1.12759 ms. 5.17387%. Div + 0.701155 ms. 3.21721%. Mul + 0.562654 ms. 2.58171%. Relu + 0.431144 ms. 1.97828%. Clip + 0.156902 ms. 0.719936%. FC + 0.0996858 ms. 0.457402%. AveragePool + 0.00112455 ms. 0.00515993%. Flatten + 21.7939 ms in Total +FLOP per operator type: + 0.43062 GFLOP. 98.1484%. Conv + 0.002561 GFLOP. 0.583713%. FC + 0.00210867 GFLOP. 0.480616%. Mul + 0.00193868 GFLOP. 0.441871%. Add + 0.00151532 GFLOP. 0.345377%. Div + 0 GFLOP. 0%. Relu + 0.438743 GFLOP in Total +Feature Memory Read per operator type: + 34.7967 MB. 43.9391%. Conv + 14.496 MB. 18.3046%. Mul + 9.44828 MB. 11.9307%. Add + 9.26157 MB. 11.6949%. Relu + 6.0614 MB. 7.65395%. Div + 5.12912 MB. 6.47673%. FC + 79.193 MB in Total +Feature Memory Written per operator type: + 17.6247 MB. 35.8656%. Conv + 9.26157 MB. 18.847%. Relu + 8.43469 MB. 17.1643%. Mul + 7.75472 MB. 15.7806%. Add + 6.06128 MB. 12.3345%. Div + 0.004 MB. 0.00813985%. FC + 49.1409 MB in Total +Parameter Memory per operator type: + 16.6851 MB. 76.5052%. Conv + 5.124 MB. 23.4948%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8091 MB in Total +``` + +## MobileNet-V3 (RW) + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 +Time per operator type: + 15.9266 ms. 69.2624%. Conv + 2.36551 ms. 10.2873%. SpatialBN + 1.39102 ms. 6.04936%. Add + 1.30327 ms. 5.66773%. Div + 0.737014 ms. 3.20517%. Mul + 0.639697 ms. 2.78195%. Relu + 0.375681 ms. 1.63378%. Clip + 0.153126 ms. 0.665921%. FC + 0.0993787 ms. 0.432184%. AveragePool + 0.0032632 ms. 0.0141912%. Squeeze + 22.9946 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 94.4041%. Conv + 0.0175992 GFLOP. 3.85829%. SpatialBN + 0.002561 GFLOP. 0.561449%. FC + 0.00210961 GFLOP. 0.46249%. Mul + 0.00173891 GFLOP. 0.381223%. Add + 0.00151626 GFLOP. 0.33241%. Div + 0 GFLOP. 0%. Relu + 0.456141 GFLOP in Total +Feature Memory Read per operator type: + 34.7354 MB. 36.4363%. Conv + 17.7944 MB. 18.6658%. SpatialBN + 14.5035 MB. 15.2137%. Mul + 9.25778 MB. 9.71113%. Relu + 7.84641 MB. 8.23064%. Add + 6.06516 MB. 6.36216%. Div + 5.12912 MB. 5.38029%. FC + 95.3317 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 26.7264%. Conv + 17.5992 MB. 26.6878%. SpatialBN + 9.25778 MB. 14.0387%. Relu + 8.43843 MB. 12.7962%. Mul + 6.95565 MB. 10.5477%. Add + 6.06502 MB. 9.19713%. Div + 0.004 MB. 0.00606568%. FC + 65.9447 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.1564%. Conv + 5.124 MB. 23.3979%. FC + 0.0976 MB. 0.445674%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8994 MB in Total + +``` +### Optimized + +``` +Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 +Time per operator type: + 17.146 ms. 78.8965%. Conv + 1.38453 ms. 6.37084%. Add + 1.30991 ms. 6.02749%. Div + 0.685417 ms. 3.15391%. Mul + 0.532589 ms. 2.45068%. Relu + 0.418263 ms. 1.92461%. Clip + 0.15128 ms. 0.696106%. FC + 0.102065 ms. 0.469648%. AveragePool + 0.0022143 ms. 0.010189%. Squeeze + 21.7323 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 98.1927%. Conv + 0.002561 GFLOP. 0.583981%. FC + 0.00210961 GFLOP. 0.481051%. Mul + 0.00173891 GFLOP. 0.396522%. Add + 0.00151626 GFLOP. 0.34575%. Div + 0 GFLOP. 0%. Relu + 0.438542 GFLOP in Total +Feature Memory Read per operator type: + 34.7842 MB. 44.833%. Conv + 14.5035 MB. 18.6934%. Mul + 9.25778 MB. 11.9323%. Relu + 7.84641 MB. 10.1132%. Add + 6.06516 MB. 7.81733%. Div + 5.12912 MB. 6.61087%. FC + 77.5861 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 36.4556%. Conv + 9.25778 MB. 19.1492%. Relu + 8.43843 MB. 17.4544%. Mul + 6.95565 MB. 14.3874%. Add + 6.06502 MB. 12.5452%. Div + 0.004 MB. 0.00827378%. FC + 48.3455 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.4973%. Conv + 5.124 MB. 23.5027%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8018 MB in Total + +``` + +## MnasNet-A1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 +Time per operator type: + 24.4656 ms. 79.0905%. Conv + 4.14958 ms. 13.4144%. SpatialBN + 1.60598 ms. 5.19169%. Relu + 0.295219 ms. 0.95436%. Mul + 0.187609 ms. 0.606486%. FC + 0.120556 ms. 0.389724%. AveragePool + 0.09036 ms. 0.292109%. Add + 0.015727 ms. 0.050841%. Sigmoid + 0.00306205 ms. 0.00989875%. Squeeze + 30.9337 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 95.6434%. Conv + 0.0248873 GFLOP. 3.8355%. SpatialBN + 0.002561 GFLOP. 0.394688%. FC + 0.000597408 GFLOP. 0.0920695%. Mul + 0.000222656 GFLOP. 0.0343146%. Add + 0 GFLOP. 0%. Relu + 0.648867 GFLOP in Total +Feature Memory Read per operator type: + 35.5457 MB. 38.4109%. Conv + 25.1552 MB. 27.1829%. SpatialBN + 22.5235 MB. 24.339%. Relu + 5.12912 MB. 5.54256%. FC + 2.40586 MB. 2.59978%. Mul + 1.78125 MB. 1.92483%. Add + 92.5406 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 32.9424%. Conv + 24.8873 MB. 32.92%. SpatialBN + 22.5235 MB. 29.7932%. Relu + 2.38963 MB. 3.16092%. Mul + 0.890624 MB. 1.17809%. Add + 0.004 MB. 0.00529106%. FC + 75.5993 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.1459%. Conv + 5.124 MB. 32.9917%. FC + 0.133952 MB. 0.86247%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.5312 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 +Time per operator type: + 22.0547 ms. 91.1375%. Conv + 1.49096 ms. 6.16116%. Relu + 0.253417 ms. 1.0472%. Mul + 0.18506 ms. 0.76473%. FC + 0.112942 ms. 0.466717%. AveragePool + 0.086769 ms. 0.358559%. Add + 0.0127889 ms. 0.0528479%. Sigmoid + 0.0027346 ms. 0.0113003%. Squeeze + 24.1994 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 99.4581%. Conv + 0.002561 GFLOP. 0.41043%. FC + 0.000597408 GFLOP. 0.0957417%. Mul + 0.000222656 GFLOP. 0.0356832%. Add + 0 GFLOP. 0%. Relu + 0.623979 GFLOP in Total +Feature Memory Read per operator type: + 35.6127 MB. 52.7968%. Conv + 22.5235 MB. 33.3917%. Relu + 5.12912 MB. 7.60406%. FC + 2.40586 MB. 3.56675%. Mul + 1.78125 MB. 2.64075%. Add + 67.4524 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 49.1092%. Conv + 22.5235 MB. 44.4145%. Relu + 2.38963 MB. 4.71216%. Mul + 0.890624 MB. 1.75624%. Add + 0.004 MB. 0.00788768%. FC + 50.712 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.7213%. Conv + 5.124 MB. 33.2787%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.3972 MB in Total +``` +## MnasNet-B1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 +Time per operator type: + 29.1121 ms. 83.3081%. Conv + 4.14959 ms. 11.8746%. SpatialBN + 1.35823 ms. 3.88675%. Relu + 0.186188 ms. 0.532802%. FC + 0.116244 ms. 0.332647%. Add + 0.018641 ms. 0.0533437%. AveragePool + 0.0040904 ms. 0.0117052%. Squeeze + 34.9451 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 96.2088%. Conv + 0.0218266 GFLOP. 3.35303%. SpatialBN + 0.002561 GFLOP. 0.393424%. FC + 0.000291648 GFLOP. 0.0448034%. Add + 0 GFLOP. 0%. Relu + 0.650951 GFLOP in Total +Feature Memory Read per operator type: + 34.4354 MB. 41.3788%. Conv + 22.1299 MB. 26.5921%. SpatialBN + 19.1923 MB. 23.0622%. Relu + 5.12912 MB. 6.16333%. FC + 2.33318 MB. 2.80364%. Add + 83.2199 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 34.0955%. Conv + 21.8266 MB. 34.0955%. SpatialBN + 19.1923 MB. 29.9805%. Relu + 1.16659 MB. 1.82234%. Add + 0.004 MB. 0.00624844%. FC + 64.016 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 69.9104%. Conv + 5.124 MB. 29.2245%. FC + 0.15168 MB. 0.865099%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.5332 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 +Time per operator type: + 24.9888 ms. 94.0962%. Conv + 1.26147 ms. 4.75011%. Relu + 0.176234 ms. 0.663619%. FC + 0.113309 ms. 0.426672%. Add + 0.0138708 ms. 0.0522311%. AveragePool + 0.00295685 ms. 0.0111341%. Squeeze + 26.5566 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 99.5466%. Conv + 0.002561 GFLOP. 0.407074%. FC + 0.000291648 GFLOP. 0.0463578%. Add + 0 GFLOP. 0%. Relu + 0.629124 GFLOP in Total +Feature Memory Read per operator type: + 34.5112 MB. 56.4224%. Conv + 19.1923 MB. 31.3775%. Relu + 5.12912 MB. 8.3856%. FC + 2.33318 MB. 3.81452%. Add + 61.1658 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 51.7346%. Conv + 19.1923 MB. 45.4908%. Relu + 1.16659 MB. 2.76513%. Add + 0.004 MB. 0.00948104%. FC + 42.1895 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 70.5205%. Conv + 5.124 MB. 29.4795%. FC + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.3816 MB in Total +``` diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/LICENSE b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..80e7d15508202f3262a50db27f5198460d7f509f --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Ross Wightman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/README.md b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..463368280d6a5015060eb73d20fe6512f8e04c50 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/README.md @@ -0,0 +1,323 @@ +# (Generic) EfficientNets for PyTorch + +A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. + +All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) + +## What's New + +### Aug 19, 2020 +* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) +* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) +* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX +* ONNX runtime based validation script added +* activations (mostly) brought in sync with `timm` equivalents + + +### April 5, 2020 +* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite + * 3.5M param MobileNet-V2 100 @ 73% + * 4.5M param MobileNet-V2 110d @ 75% + * 6.1M param MobileNet-V2 140 @ 76.5% + * 5.8M param MobileNet-V2 120d @ 77.3% + +### March 23, 2020 + * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 + * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior + +### Feb 12, 2020 + * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) + * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. + * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) + +### Jan 22, 2020 + * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) + * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict + * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues + +### Nov 22, 2019 + * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different + preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. + +### Nov 15, 2019 + * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights + * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine + +### Oct 30, 2019 + * Many of the models will now work with torch.jit.script, MixNet being the biggest exception + * Improved interface for enabling torchscript or ONNX export compatible modes (via config) + * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn + * Activation factory to select best version of activation by name or override one globally + * Add pretrained checkpoint load helper that handles input conv and classifier changes + +### Oct 27, 2019 + * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base + * Switch activations and global pooling to modules + * Add memory-efficient Swish/Mish impl + * Add as_sequential() method to all models and allow as an argument in entrypoint fns + * Move MobileNetV3 into own file since it has a different head + * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here + +## Models + +Implemented models include: + * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) + * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) + * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) + * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) + * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * MixNet (https://arxiv.org/abs/1907.09595) + * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) + * MobileNet-V3 (https://arxiv.org/abs/1905.02244) + * FBNet-C (https://arxiv.org/abs/1812.03443) + * Single-Path NAS (https://arxiv.org/abs/1904.02877) + +I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. + +## Pretrained + +I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models + + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | +|---|---|---|---|---|---|---|---| +| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | +| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | +| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | +| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | +| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | +| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | +| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | +| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | +| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | +| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | +| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | +| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | +| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | +| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | +| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | +| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | +| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | +| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | + + +More pretrained models to come... + + +## Ported Weights + +The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. + +**IMPORTANT:** +* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. +* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. + +To run validation for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` + +To run validation w/ TF preprocessing for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` + +To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | +|---|---|---|---|---|---|---| +| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | +| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | +| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | +| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | +| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | +| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | +| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | +| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | +| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | +| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | +| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | +| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | +| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | +| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | +| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | +| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | +| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | +| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | +| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | +| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | +| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | +| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | +| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | +| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | +| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | +| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | +| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | +| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | +| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | +| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | + + +*tfp models validated with `tf-preprocessing` pipeline + +Google tf and tflite weights ported from official Tensorflow repositories +* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet +* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet +* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet + +## Usage + +### Environment + +All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. + +Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. + +PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. + +I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: +``` +conda create -n torch-env +conda activate torch-env +conda install -c pytorch pytorch torchvision cudatoolkit=10.2 +``` + +### PyTorch Hub + +Models can be accessed via the PyTorch Hub API + +``` +>>> torch.hub.list('rwightman/gen-efficientnet-pytorch') +['efficientnet_b0', ...] +>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) +>>> model.eval() +>>> output = model(torch.randn(1,3,224,224)) +``` + +### Pip +This package can be installed via pip. + +Install (after conda env/install): +``` +pip install geffnet +``` + +Eval use: +``` +>>> import geffnet +>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) +>>> m.eval() +``` + +Train use: +``` +>>> import geffnet +>>> # models can also be created by using the entrypoint directly +>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) +>>> m.train() +``` + +Create in a nn.Sequential container, for fast.ai, etc: +``` +>>> import geffnet +>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) +``` + +### Exporting + +Scripts are included to +* export models to ONNX (`onnx_export.py`) +* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) +* validate with ONNX runtime (`onnx_validate.py`) +* convert ONNX model to Caffe2 (`onnx_to_caffe.py`) +* validate in Caffe2 (`caffe2_validate.py`) +* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) + +As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: +``` +python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx +python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx +``` + +These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible +export now requires additional args mentioned in the export script (not needed in earlier versions). + +#### Export Notes +1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. +2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. +3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. +3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. + + diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/__init__.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_benchmark.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..93f28a1e63d9f7287ca02997c7991fe66dd0aeb9 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_benchmark.py @@ -0,0 +1,65 @@ +""" Caffe2 validation script + +This script runs Caffe2 benchmark on exported ONNX model. +It is a useful tool for reporting model FLOPS. + +Copyright 2020 Ross Wightman +""" +import argparse +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 + + +parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=224, type=int, + metavar='N', help='Input image dimension, uses model default if empty') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="le_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + # CUDA performance not impressive + #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + input_blob = model.net.external_inputs[0] + model.param_init_net.GaussianFill( + [], + input_blob.GetUnscopedName(), + shape=(args.batch_size, 3, args.img_size, args.img_size), + mean=0.0, + std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_validate.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfaab38c095663fe32e4addbdf06b57bcb53614 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/caffe2_validate.py @@ -0,0 +1,138 @@ +""" Caffe2 validation script + +This script is created to verify exported ONNX models running in Caffe2 +It utilizes the same PyTorch dataloader/processing pipeline for a +fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="validation_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + # this is so obvious, wonderful interface + input_blob = model.net.external_inputs[0] + output_blob = model.net.external_outputs[0] + + if True: + device_opts = None + else: + # CUDA is crashing, no idea why, awesome error message, give it a try for kicks + device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + model.param_init_net.GaussianFill( + [], input_blob.GetUnscopedName(), + shape=(1,) + data_config['input_size'], mean=0.0, std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + caffe2_in = input.data.numpy() + workspace.FeedBlob(input_blob, caffe2_in, device_opts) + workspace.RunNet(model.net, num_iter=1) + output = workspace.FetchBlob(output_blob) + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output.data, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/__init__.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e441a5838d1e972823b9668ac8d459445f6f6ce --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/__init__.py @@ -0,0 +1,5 @@ +from .gen_efficientnet import * +from .mobilenetv3 import * +from .model_factory import create_model +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable +from .activations import * \ No newline at end of file diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/__init__.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..813421a743ffc33b8eb53ebf62dd4a03d831b654 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/__init__.py @@ -0,0 +1,137 @@ +from geffnet import config +from geffnet.activations.activations_me import * +from geffnet.activations.activations_jit import * +from geffnet.activations.activations import * +import torch + +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_swish=hard_swish_me, + hard_sigmoid_jit=hard_sigmoid_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_swish=HardSwishMe, + hard_sigmoid=HardSigmoidMe +) + +_OVERRIDE_FN = dict() +_OVERRIDE_LAYER = dict() + + +def add_override_act_fn(name, fn): + global _OVERRIDE_FN + _OVERRIDE_FN[name] = fn + + +def update_override_act_fn(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_FN + _OVERRIDE_FN.update(overrides) + + +def clear_override_act_fn(): + global _OVERRIDE_FN + _OVERRIDE_FN = dict() + + +def add_override_act_layer(name, fn): + _OVERRIDE_LAYER[name] = fn + + +def update_override_act_layer(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_LAYER + _OVERRIDE_LAYER.update(overrides) + + +def clear_override_act_layer(): + global _OVERRIDE_LAYER + _OVERRIDE_LAYER = dict() + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_FN: + return _OVERRIDE_FN[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_FN_ME: + # If not exporting or scripting the model, first look for a memory optimized version + # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin + return _ACT_FN_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_LAYER: + return _OVERRIDE_LAYER[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..bdea692d1397673b2513d898c33edbcb37d94240 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations.py @@ -0,0 +1,102 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..7176b05e779787528a47f20d55d64d4a0f219360 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_jit.py @@ -0,0 +1,79 @@ +""" Activations (jit) + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', + 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df5a50fdbe40bc386e2541a4fda743ad95e9a --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/activations/activations_me.py @@ -0,0 +1,174 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', + 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + + Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/config.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..27d5307fd9ee0246f1e35f41520f17385d23f1dd --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/config.py @@ -0,0 +1,123 @@ +""" Global layer config state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def layer_config_kwargs(kwargs): + """ Consume config kwargs and return contextmgr obj """ + return set_layer_config( + scriptable=kwargs.pop('scriptable', None), + exportable=kwargs.pop('exportable', None), + no_jit=kwargs.pop('no_jit', None)) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d8467460c4b36e54c83ce2dcd3ebe91d3432cad2 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/conv2d_layers.py @@ -0,0 +1,304 @@ +""" Conv2D w/ SAME padding, CondConv, MixedConv + +A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and +MobileNetV3 models that maintain weight compatibility with original Tensorflow models. + +Copyright 2020 Ross Wightman +""" +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import * + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _same_pad_arg(input_size, kernel_size, stride, dilation): + ih, iw = input_size + kh, kw = kernel_size + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + if self.pad is not None: + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + if is_exportable(): + assert not is_scriptable() + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95dd63d400e70d70664c5a433a2772363f865e61 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/efficientnet_builder.py @@ -0,0 +1,683 @@ +""" EfficientNet / MobileNetV3 Blocks and Builder + +Copyright 2020 Ross Wightman +""" +import re +from copy import deepcopy + +from .conv2d_layers import * +from geffnet.activations import * + +__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', + 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', + 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', + 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' +] + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +# +# PyTorch defaults are momentum = .1, eps = 1e-5 +# +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, # None == use containing block's activation layer + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v: int, divisor: int = 8, min_value: int = None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class SqueezeExcite(nn.Module): + + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion + factor of 1.0. This is an alternative to having a IR with optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs: int = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() # for jit.script compat + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class EfficientNetBuilder: + """ Build Trunk Blocks for Efficient/Mobile Networks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = EdgeResidual(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stack_idx, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return blocks + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +def initialize_weight_goog(m, n='', fix_group_fanout=True): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def initialize_weight_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cd170d4cc5bed6ca82b61539902b470d3320c691 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/gen_efficientnet.py @@ -0,0 +1,1450 @@ +""" Generic Efficient Networks + +A generic MobileNet class with building blocks to support a variety of models: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* EfficientNet-Lite + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .config import layer_config_kwargs, is_scriptable +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', + 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', + 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', + 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', + 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', + 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', + 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', + 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', + 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', + 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', + 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', + 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', + 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', + 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', + 'tf_efficientnet_lite4', + 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] + + +model_urls = { + 'mnasnet_050': None, + 'mnasnet_075': None, + 'mnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + 'mnasnet_140': None, + 'mnasnet_small': None, + + 'semnasnet_050': None, + 'semnasnet_075': None, + 'semnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + 'semnasnet_140': None, + + 'mobilenetv2_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + 'mobilenetv2_110d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + 'mobilenetv2_120d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + 'mobilenetv2_140': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + + 'fbnetc_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + 'spnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + + 'efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + 'efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + 'efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + 'efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + 'efficientnet_b4': None, + 'efficientnet_b5': None, + 'efficientnet_b6': None, + 'efficientnet_b7': None, + 'efficientnet_b8': None, + 'efficientnet_l2': None, + + 'efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + 'efficientnet_em': None, + 'efficientnet_el': None, + + 'efficientnet_cc_b0_4e': None, + 'efficientnet_cc_b0_8e': None, + 'efficientnet_cc_b1_8e': None, + + 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + 'efficientnet_lite1': None, + 'efficientnet_lite2': None, + 'efficientnet_lite3': None, + 'efficientnet_lite4': None, + + 'tf_efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + 'tf_efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + 'tf_efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + 'tf_efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + 'tf_efficientnet_b4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + 'tf_efficientnet_b5': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b6': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + 'tf_efficientnet_b7': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b8': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + + 'tf_efficientnet_b0_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + 'tf_efficientnet_b1_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + 'tf_efficientnet_b2_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + 'tf_efficientnet_b3_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + 'tf_efficientnet_b4_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + 'tf_efficientnet_b5_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + 'tf_efficientnet_b6_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + 'tf_efficientnet_b7_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + 'tf_efficientnet_b8_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + + 'tf_efficientnet_b0_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + 'tf_efficientnet_b1_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + 'tf_efficientnet_b2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + 'tf_efficientnet_b3_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + 'tf_efficientnet_b4_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + 'tf_efficientnet_b5_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + 'tf_efficientnet_b6_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + 'tf_efficientnet_b7_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + 'tf_efficientnet_l2_ns_475': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + 'tf_efficientnet_l2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + + 'tf_efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + 'tf_efficientnet_em': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + 'tf_efficientnet_el': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + + 'tf_efficientnet_cc_b0_4e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + 'tf_efficientnet_cc_b0_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + 'tf_efficientnet_cc_b1_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + + 'tf_efficientnet_lite0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + 'tf_efficientnet_lite1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + 'tf_efficientnet_lite2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + 'tf_efficientnet_lite3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + 'tf_efficientnet_lite4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + + 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + + 'tf_mixnet_s': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + 'tf_mixnet_m': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + 'tf_mixnet_l': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', +} + + +class GenEfficientNet(nn.Module): + """ Generic EfficientNets + + An implementation of mobile optimized networks that covers: + * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) + * MixNet (Small, Medium, and Large, XL) + * MNASNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + weight_init='goog'): + super(GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, + pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) + self.bn2 = norm_layer(num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(num_features, num_classes) + + for n, m in self.named_modules(): + if weight_init == 'goog': + initialize_weight_goog(m, n) + else: + initialize_weight_default(m, n) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.conv_head, self.bn2, self.act2, + self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = GenEfficientNet(**model_kwargs) + if pretrained: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2. """ + # NOTE for train, drop_rate should be 0.5 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + # NOTE for train set drop_rate=0.2 + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/helpers.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f83a07d690c7ad681c777c19b1e7a5bb95da007 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/helpers.py @@ -0,0 +1,71 @@ +""" Checkpoint loading / state_dict helpers +Copyright 2020 Ross Wightman +""" +import torch +import os +from collections import OrderedDict +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, url, filter_fn=None, strict=True): + if not url: + print("=> Warning: Pretrained model URL is empty, using random initialization.") + return + + state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') + + input_conv = 'conv_stem' + classifier = 'classifier' + in_chans = getattr(model, input_conv).weight.shape[1] + num_classes = getattr(model, classifier).weight.shape[0] + + input_conv_weight = input_conv + '.weight' + pretrained_in_chans = state_dict[input_conv_weight].shape[1] + if in_chans != pretrained_in_chans: + if in_chans == 1: + print('=> Converting pretrained input conv {} from {} to 1 channel'.format( + input_conv_weight, pretrained_in_chans)) + conv1_weight = state_dict[input_conv_weight] + state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) + else: + print('=> Discarding pretrained input conv {} since input channel count != {}'.format( + input_conv_weight, pretrained_in_chans)) + del state_dict[input_conv_weight] + strict = False + + classifier_weight = classifier + '.weight' + pretrained_num_classes = state_dict[classifier_weight].shape[0] + if num_classes != pretrained_num_classes: + print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) + del state_dict[classifier_weight] + del state_dict[classifier + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..b5966c28f7207e98ee50745b1bc8f3663c650f9d --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/mobilenetv3.py @@ -0,0 +1,364 @@ +""" MobileNet-V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_act_fn, get_act_layer, HardSwish +from .config import layer_config_kwargs +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', + 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', + 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] + +model_urls = { + 'mobilenetv3_rw': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + 'mobilenetv3_large_075': None, + 'mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + 'mobilenetv3_large_minimal_100': None, + 'mobilenetv3_small_075': None, + 'mobilenetv3_small_100': None, + 'mobilenetv3_small_minimal_100': None, + 'tf_mobilenetv3_large_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + 'tf_mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + 'tf_mobilenetv3_large_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + 'tf_mobilenetv3_small_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + 'tf_mobilenetv3_small_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + 'tf_mobilenetv3_small_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', +} + + +class MobileNetV3(nn.Module): + """ MobileNet-V3 + + A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the + head convolution without a final batch-norm layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3, self).__init__() + self.drop_rate = drop_rate + + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if weight_init == 'goog': + initialize_weight_goog(m) + else: + initialize_weight_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.global_pool, self.conv_head, self.act2, + nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = MobileNetV3(**model_kwargs) + if pretrained and model_urls[variant]: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model (RW variant). + + Paper: https://arxiv.org/abs/1905.02244 + + This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the + eventual Tensorflow reference impl but has a few differences: + 1. This model has no bias on the head convolution + 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet + 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer + from their parent block + 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count + + Overall the changes are fairly minor and result in a very small parameter count difference and no + top-1/5 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, # one of my mistakes + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 large/small/minimal models. + + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, act_layer), + se_kwargs=dict( + act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet-V3 RW + Attn: See note in gen function for this variant. + """ + # NOTE for train set drop_rate=0.2 + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75""" + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large (Minimalistic) 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small (Minimalistic) 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/model_factory.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46ea8baedaf3d787826eb3bb314b4230514647 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/model_factory.py @@ -0,0 +1,27 @@ +from .config import set_layer_config +from .helpers import load_checkpoint + +from .gen_efficientnet import * +from .mobilenetv3 import * + + +def create_model( + model_name='mnasnet_100', + pretrained=None, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + + model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**model_kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path and not pretrained: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/version.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a6221b3de7b1490c5e712e8b5fcc94c3d9d04295 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/geffnet/version.py @@ -0,0 +1 @@ +__version__ = '1.0.2' diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/hubconf.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..45b17b99bbeba34596569e6e50f6e8a2ebc45c54 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/hubconf.py @@ -0,0 +1,84 @@ +dependencies = ['torch', 'math'] + +from geffnet import efficientnet_b0 +from geffnet import efficientnet_b1 +from geffnet import efficientnet_b2 +from geffnet import efficientnet_b3 + +from geffnet import efficientnet_es + +from geffnet import efficientnet_lite0 + +from geffnet import mixnet_s +from geffnet import mixnet_m +from geffnet import mixnet_l +from geffnet import mixnet_xl + +from geffnet import mobilenetv2_100 +from geffnet import mobilenetv2_110d +from geffnet import mobilenetv2_120d +from geffnet import mobilenetv2_140 + +from geffnet import mobilenetv3_large_100 +from geffnet import mobilenetv3_rw +from geffnet import mnasnet_a1 +from geffnet import mnasnet_b1 +from geffnet import fbnetc_100 +from geffnet import spnasnet_100 + +from geffnet import tf_efficientnet_b0 +from geffnet import tf_efficientnet_b1 +from geffnet import tf_efficientnet_b2 +from geffnet import tf_efficientnet_b3 +from geffnet import tf_efficientnet_b4 +from geffnet import tf_efficientnet_b5 +from geffnet import tf_efficientnet_b6 +from geffnet import tf_efficientnet_b7 +from geffnet import tf_efficientnet_b8 + +from geffnet import tf_efficientnet_b0_ap +from geffnet import tf_efficientnet_b1_ap +from geffnet import tf_efficientnet_b2_ap +from geffnet import tf_efficientnet_b3_ap +from geffnet import tf_efficientnet_b4_ap +from geffnet import tf_efficientnet_b5_ap +from geffnet import tf_efficientnet_b6_ap +from geffnet import tf_efficientnet_b7_ap +from geffnet import tf_efficientnet_b8_ap + +from geffnet import tf_efficientnet_b0_ns +from geffnet import tf_efficientnet_b1_ns +from geffnet import tf_efficientnet_b2_ns +from geffnet import tf_efficientnet_b3_ns +from geffnet import tf_efficientnet_b4_ns +from geffnet import tf_efficientnet_b5_ns +from geffnet import tf_efficientnet_b6_ns +from geffnet import tf_efficientnet_b7_ns +from geffnet import tf_efficientnet_l2_ns_475 +from geffnet import tf_efficientnet_l2_ns + +from geffnet import tf_efficientnet_es +from geffnet import tf_efficientnet_em +from geffnet import tf_efficientnet_el + +from geffnet import tf_efficientnet_cc_b0_4e +from geffnet import tf_efficientnet_cc_b0_8e +from geffnet import tf_efficientnet_cc_b1_8e + +from geffnet import tf_efficientnet_lite0 +from geffnet import tf_efficientnet_lite1 +from geffnet import tf_efficientnet_lite2 +from geffnet import tf_efficientnet_lite3 +from geffnet import tf_efficientnet_lite4 + +from geffnet import tf_mixnet_s +from geffnet import tf_mixnet_m +from geffnet import tf_mixnet_l + +from geffnet import tf_mobilenetv3_large_075 +from geffnet import tf_mobilenetv3_large_100 +from geffnet import tf_mobilenetv3_large_minimal_100 +from geffnet import tf_mobilenetv3_small_075 +from geffnet import tf_mobilenetv3_small_100 +from geffnet import tf_mobilenetv3_small_minimal_100 + diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_export.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5162ce214830df501bdb81edb66c095122f69d --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_export.py @@ -0,0 +1,120 @@ +""" ONNX export script + +Export PyTorch models as ONNX graphs. + +This export script originally started as an adaptation of code snippets found at +https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html + +The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph +for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible +with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback +flags are currently required. + +Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for +caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime. + +Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models. +Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks. + +Copyright 2020 Ross Wightman +""" +import argparse +import torch +import numpy as np + +import onnx +import geffnet + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('output', metavar='ONNX_FILE', + help='output model filename') +parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100', + help='model architecture (default: mobilenetv3_large_100)') +parser.add_argument('--opset', type=int, default=10, + help='ONNX opset to use (default: 10)') +parser.add_argument('--keep-init', action='store_true', default=False, + help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') +parser.add_argument('--aten-fallback', action='store_true', default=False, + help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') +parser.add_argument('--dynamic-size', action='store_true', default=False, + help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to checkpoint (default: none)') + + +def main(): + args = parser.parse_args() + + args.pretrained = True + if args.checkpoint: + args.pretrained = False + + print("==> Creating PyTorch {} model".format(args.model)) + # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers + # for models using SAME padding + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + exportable=True) + + model.eval() + + example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True) + + # Run model once before export trace, sets padding for models with Conv2dSameExport. This means + # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for + # the input img_size specified in this script. + # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to + # issues in the tracing of the dynamic padding or errors attempting to export the model after jit + # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... + model(example_input) + + print("==> Exporting model to ONNX format at '{}'".format(args.output)) + input_names = ["input0"] + output_names = ["output0"] + dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} + if args.dynamic_size: + dynamic_axes['input0'][2] = 'height' + dynamic_axes['input0'][3] = 'width' + if args.aten_fallback: + export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + else: + export_type = torch.onnx.OperatorExportTypes.ONNX + + torch_out = torch.onnx._export( + model, example_input, args.output, export_params=True, verbose=True, input_names=input_names, + output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes, + opset_version=args.opset, operator_export_type=export_type) + + print("==> Loading and checking exported model from '{}'".format(args.output)) + onnx_model = onnx.load(args.output) + onnx.checker.check_model(onnx_model) # assuming throw on error + print("==> Passed") + + if args.keep_init and args.aten_fallback: + import caffe2.python.onnx.backend as onnx_caffe2 + # Caffe2 loading only works properly in newer PyTorch/ONNX combos when + # keep_initializers_as_inputs and aten_fallback are set to True. + print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output)) + caffe2_backend = onnx_caffe2.prepare(onnx_model) + B = {onnx_model.graph.input[0].name: x.data.numpy()} + c2_out = caffe2_backend.run(B)[0] + np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5) + print("==> Passed") + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_optimize.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..ee20bbf9f0f9473370489512eb96ca0b570b5388 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_optimize.py @@ -0,0 +1,84 @@ +""" ONNX optimization script + +Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. + +NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), +it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). + +Copyright 2020 Ross Wightman +""" +import argparse +import warnings + +import onnx +from onnx import optimizer + + +parser = argparse.ArgumentParser(description="Optimize ONNX model") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--output", required=True, help="The optimized model output filename") + + +def traverse_graph(graph, prefix=''): + content = [] + indent = prefix + ' ' + graphs = [] + num_nodes = 0 + for node in graph.node: + pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) + assert isinstance(gs, list) + content.append(pn) + graphs.extend(gs) + num_nodes += 1 + for g in graphs: + g_count, g_str = traverse_graph(g) + content.append('\n' + g_str) + num_nodes += g_count + return num_nodes, '\n'.join(content) + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) + + # Optimizer passes to perform + passes = [ + #'eliminate_deadend', + 'eliminate_identity', + 'eliminate_nop_dropout', + 'eliminate_nop_pad', + 'eliminate_nop_transpose', + 'eliminate_unused_initializer', + 'extract_constant_to_initializer', + 'fuse_add_bias_into_conv', + 'fuse_bn_into_conv', + 'fuse_consecutive_concats', + 'fuse_consecutive_reduce_unsqueeze', + 'fuse_consecutive_squeezes', + 'fuse_consecutive_transposes', + #'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv', + #'fuse_transpose_into_gemm', + #'lift_lexical_references', + ] + + # Apply the optimization on the original serialized model + # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing + # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 + # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. + warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." + "Try onnxruntime optimization if this doesn't work.") + optimized_model = optimizer.optimize(onnx_model, passes) + + num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) + print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) + print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) + + # Save the ONNX model + onnx.save(optimized_model, args.output) + + +if __name__ == "__main__": + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_to_caffe.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_to_caffe.py new file mode 100644 index 0000000000000000000000000000000000000000..44399aafababcdf6b84147a0613eb0909730db4b --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_to_caffe.py @@ -0,0 +1,27 @@ +import argparse + +import onnx +from caffe2.python.onnx.backend import Caffe2Backend + + +parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--c2-prefix", required=True, + help="The output file prefix for the caffe2 model init and predict file. ") + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) + caffe2_init_str = caffe2_init.SerializeToString() + with open(args.c2_prefix + '.init.pb', "wb") as f: + f.write(caffe2_init_str) + caffe2_predict_str = caffe2_predict.SerializeToString() + with open(args.c2_prefix + '.predict.pb', "wb") as f: + f.write(caffe2_predict_str) + + +if __name__ == "__main__": + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_validate.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3e4fb141b6ef660dcc5b447fd9f368a2ea19a0 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/onnx_validate.py @@ -0,0 +1,112 @@ +""" ONNX-runtime validation script + +This script was created to verify accuracy and performance of exported ONNX +models running with the onnxruntime. It utilizes the PyTorch dataloader/processing +pipeline for a fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +import onnxruntime +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', + help='path to onnx model/weights file') +parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', + help='path to output optimized onnx graph') +parser.add_argument('--profile', action='store_true', default=False, + help='Enable profiler output.') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + + # Set graph optimization level + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.profile: + sess_options.enable_profiling = True + if args.onnx_output_opt: + sess_options.optimized_model_filepath = args.onnx_output_opt + + session = onnxruntime.InferenceSession(args.onnx_input, sess_options) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + input_name = session.get_inputs()[0].name + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + output = session.run([], {input_name: input.data.numpy()}) + output = output[0] + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/requirements.txt b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac3ffc13bae15f9b11f7cbe3705760056ecd7f13 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.2.0 +torchvision>=0.4.0 diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/setup.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..023e4c30f98164595964423e3a83eefaf7ffdad6 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/setup.py @@ -0,0 +1,47 @@ +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('geffnet/version.py').read()) +setup( + name='geffnet', + version=__version__, + description='(Generic) EfficientNets for PyTorch', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/rwightman/gen-efficientnet-pytorch', + author='Ross Wightman', + author_email='hello@rwightman.com', + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + + # Note that this is a string of words separated by whitespace, not a list. + keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', + packages=find_packages(exclude=['data']), + install_requires=['torch >= 1.4', 'torchvision'], + python_requires='>=3.6', +) diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/utils.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d327e8bd8120c5cd09ae6c15c3991ccbe27f6c1f --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/utils.py @@ -0,0 +1,52 @@ +import os + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + diff --git a/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/validate.py b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd44fbb3165ef81ef81251b6299f6aaa80bf2c2 --- /dev/null +++ b/custom_controlnet_aux/dsine/models/submodules/efficientnet_repo/validate.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import time +import torch +import torch.nn as nn +import torch.nn.parallel +from contextlib import suppress + +import geffnet +from data import Dataset, create_loader, resolve_data_config +from utils import accuracy, AverageMeter + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', + help='') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use native Torch AMP mixed precision.') + + +def main(): + args = parser.parse_args() + + if not args.checkpoint and not args.pretrained: + args.pretrained = True + + amp_autocast = suppress # do nothing + if args.amp: + if not has_native_amp: + print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") + else: + amp_autocast = torch.cuda.amp.autocast + + # create model + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + scriptable=args.torchscript) + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + data_config = resolve_data_config(model, args) + + criterion = nn.CrossEntropyLoss() + + if not args.no_cuda: + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=not args.no_cuda, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + end = time.time() + with torch.no_grad(): + for i, (input, target) in enumerate(loader): + if not args.no_cuda: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/dsine/utils/rotation.py b/custom_controlnet_aux/dsine/utils/rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..8123f900cee404d10dd96e03ce31a681c594e284 --- /dev/null +++ b/custom_controlnet_aux/dsine/utils/rotation.py @@ -0,0 +1,85 @@ +import torch +import numpy as np + + +# NOTE: from PyTorch3D +def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to quaternions. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) + half_angles = angles * 0.5 + eps = 1e-6 + small_angles = angles.abs() < eps + sin_half_angles_over_angles = torch.empty_like(angles) + sin_half_angles_over_angles[~small_angles] = ( + torch.sin(half_angles[~small_angles]) / angles[~small_angles] + ) + # for x small, sin(x/2) is about x/2 - (x/2)^3/6 + # so sin(x/2)/x is about 1/2 - (x*x)/48 + sin_half_angles_over_angles[small_angles] = ( + 0.5 - (angles[small_angles] * angles[small_angles]) / 48 + ) + quaternions = torch.cat( + [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 + ) + return quaternions + + +# NOTE: from PyTorch3D +def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +# NOTE: from PyTorch3D +def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as axis/angle to rotation matrices. + + Args: + axis_angle: Rotations given as a vector in axis angle form, + as a tensor of shape (..., 3), where the magnitude is + the angle turned anticlockwise in radians around the + vector's direction. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) \ No newline at end of file diff --git a/custom_controlnet_aux/dsine/utils/utils.py b/custom_controlnet_aux/dsine/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..61cd4cca0dd2f1ba05a77d139c90f226fec5af04 --- /dev/null +++ b/custom_controlnet_aux/dsine/utils/utils.py @@ -0,0 +1,105 @@ +""" utils +""" +import os +import torch +import numpy as np + + +def load_checkpoint(fpath, model): + print('loading checkpoint... {}'.format(fpath)) + + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + print('loading checkpoint... / done') + return model + + +def compute_normal_error(pred_norm, gt_norm): + pred_error = torch.cosine_similarity(pred_norm, gt_norm, dim=1) + pred_error = torch.clamp(pred_error, min=-1.0, max=1.0) + pred_error = torch.acos(pred_error) * 180.0 / np.pi + pred_error = pred_error.unsqueeze(1) # (B, 1, H, W) + return pred_error + + +def compute_normal_metrics(total_normal_errors): + total_normal_errors = total_normal_errors.detach().cpu().numpy() + num_pixels = total_normal_errors.shape[0] + + metrics = { + 'mean': np.average(total_normal_errors), + 'median': np.median(total_normal_errors), + 'rmse': np.sqrt(np.sum(total_normal_errors * total_normal_errors) / num_pixels), + 'a1': 100.0 * (np.sum(total_normal_errors < 5) / num_pixels), + 'a2': 100.0 * (np.sum(total_normal_errors < 7.5) / num_pixels), + 'a3': 100.0 * (np.sum(total_normal_errors < 11.25) / num_pixels), + 'a4': 100.0 * (np.sum(total_normal_errors < 22.5) / num_pixels), + 'a5': 100.0 * (np.sum(total_normal_errors < 30) / num_pixels) + } + + return metrics + + +def pad_input(orig_H, orig_W): + if orig_W % 32 == 0: + l = 0 + r = 0 + else: + new_W = 32 * ((orig_W // 32) + 1) + l = (new_W - orig_W) // 2 + r = (new_W - orig_W) - l + + if orig_H % 32 == 0: + t = 0 + b = 0 + else: + new_H = 32 * ((orig_H // 32) + 1) + t = (new_H - orig_H) // 2 + b = (new_H - orig_H) - t + return l, r, t, b + + +def get_intrins_from_fov(new_fov, H, W, device): + # NOTE: top-left pixel should be (0,0) + if W >= H: + new_fu = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + new_fv = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + else: + new_fu = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + new_fv = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) + + new_cu = (W / 2.0) - 0.5 + new_cv = (H / 2.0) - 0.5 + + new_intrins = torch.tensor([ + [new_fu, 0, new_cu ], + [0, new_fv, new_cv ], + [0, 0, 1 ] + ], dtype=torch.float32, device=device) + + return new_intrins + + +def get_intrins_from_txt(intrins_path, device): + # NOTE: top-left pixel should be (0,0) + with open(intrins_path, 'r') as f: + intrins_ = f.readlines()[0].split()[0].split(',') + intrins_ = [float(i) for i in intrins_] + fx, fy, cx, cy = intrins_ + + intrins = torch.tensor([ + [fx, 0,cx], + [ 0,fy,cy], + [ 0, 0, 1] + ], dtype=torch.float32, device=device) + + return intrins \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/LICENSE b/custom_controlnet_aux/dwpose/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd --- /dev/null +++ b/custom_controlnet_aux/dwpose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/__init__.py b/custom_controlnet_aux/dwpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05cc7ad6ac19221571b6fd9b55d46427f5421c99 --- /dev/null +++ b/custom_controlnet_aux/dwpose/__init__.py @@ -0,0 +1,328 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) +# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) +# This preprocessor is licensed by CMU for non-commercial use only. + +import os +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import json +import torch +import numpy as np +from . import util +from .body import Body, BodyResult, Keypoint +from .hand import Hand +from .face import Face +from .types import PoseResult, HandResult, FaceResult, AnimalPoseResult +from huggingface_hub import hf_hub_download +from .wholebody import Wholebody +import warnings +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download +import cv2 +from PIL import Image +from .animalpose import AnimalPoseImage + +from typing import Tuple, List, Callable, Union, Optional + + +def draw_animalposes(animals: list[list[Keypoint]], H: int, W: int) -> np.ndarray: + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + for animal_pose in animals: + canvas = draw_animalpose(canvas, animal_pose) + return canvas + + +def draw_animalpose(canvas: np.ndarray, keypoints: list[Keypoint]) -> np.ndarray: + # order of the keypoints for AP10k and a standardized list of colors for limbs + keypointPairsList = [ + (1, 2), + (2, 3), + (1, 3), + (3, 4), + (4, 9), + (9, 10), + (10, 11), + (4, 6), + (6, 7), + (7, 8), + (4, 5), + (5, 15), + (15, 16), + (16, 17), + (5, 12), + (12, 13), + (13, 14), + ] + colorsList = [ + (255, 255, 255), + (100, 255, 100), + (150, 255, 255), + (100, 50, 255), + (50, 150, 200), + (0, 255, 255), + (0, 150, 0), + (0, 0, 255), + (0, 0, 150), + (255, 50, 255), + (255, 0, 255), + (255, 0, 0), + (150, 0, 0), + (255, 255, 100), + (0, 150, 0), + (255, 255, 0), + (150, 150, 150), + ] # 16 colors needed + + for ind, (i, j) in enumerate(keypointPairsList): + p1 = keypoints[i - 1] + p2 = keypoints[j - 1] + + if p1 is not None and p2 is not None: + cv2.line( + canvas, + (int(p1.x), int(p1.y)), + (int(p2.x), int(p2.y)), + colorsList[ind], + 5, + ) + return canvas + + +def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): + """ + Draw the detected poses on an empty canvas. + + Args: + poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. + H (int): The height of the canvas. + W (int): The width of the canvas. + draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. + draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. + draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. + + Returns: + numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. + """ + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + for pose in poses: + if draw_body: + canvas = util.draw_bodypose(canvas, pose.body.keypoints) + + if draw_hand: + canvas = util.draw_handpose(canvas, pose.left_hand) + canvas = util.draw_handpose(canvas, pose.right_hand) + + if draw_face: + canvas = util.draw_facepose(canvas, pose.face) + + return canvas + + +def decode_json_as_poses( + pose_json: dict, +) -> Tuple[List[PoseResult], List[AnimalPoseResult], int, int]: + """Decode the json_string complying with the openpose JSON output format + to poses that controlnet recognizes. + https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md + + Args: + json_string: The json string to decode. + + Returns: + human_poses + animal_poses + canvas_height + canvas_width + """ + height = pose_json["canvas_height"] + width = pose_json["canvas_width"] + + def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + def decompress_keypoints( + numbers: Optional[List[float]], + ) -> Optional[List[Optional[Keypoint]]]: + if not numbers: + return None + + assert len(numbers) % 3 == 0 + + def create_keypoint(x, y, c): + if c < 1.0: + return None + keypoint = Keypoint(x, y) + return keypoint + + return [create_keypoint(x, y, c) for x, y, c in chunks(numbers, n=3)] + + return ( + [ + PoseResult( + body=BodyResult( + keypoints=decompress_keypoints(pose.get("pose_keypoints_2d")) + ), + left_hand=decompress_keypoints(pose.get("hand_left_keypoints_2d")), + right_hand=decompress_keypoints(pose.get("hand_right_keypoints_2d")), + face=decompress_keypoints(pose.get("face_keypoints_2d")), + ) + for pose in pose_json.get("people", []) + ], + [decompress_keypoints(pose) for pose in pose_json.get("animals", [])], + height, + width, + ) + + +def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str: + """ Encode the pose as a dict following openpose JSON output format: + https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md + """ + def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]: + if not keypoints: + return None + + return [ + value + for keypoint in keypoints + for value in ( + [float(keypoint.x), float(keypoint.y), 1.0] + if keypoint is not None + else [0.0, 0.0, 0.0] + ) + ] + + return { + 'people': [ + { + 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints), + "face_keypoints_2d": compress_keypoints(pose.face), + "hand_left_keypoints_2d": compress_keypoints(pose.left_hand), + "hand_right_keypoints_2d":compress_keypoints(pose.right_hand), + } + for pose in poses + ], + 'canvas_height': canvas_height, + 'canvas_width': canvas_width, + } + +global_cached_dwpose = Wholebody() + +class DwposeDetector: + """ + A class for detecting human poses in images using the Dwpose model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, dw_pose_estimation): + self.dw_pose_estimation = dw_pose_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename=None, pose_filename=None, torchscript_device="cuda"): + global global_cached_dwpose + pretrained_det_model_or_path = pretrained_det_model_or_path or pretrained_model_or_path + det_filename = det_filename or "yolox_l.onnx" + pose_filename = pose_filename or "dw-ll_ucoco_384.onnx" + det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename) + pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename) + + print(f"\nDWPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation") + if global_cached_dwpose.det is None or global_cached_dwpose.det_filename != det_filename: + t = Wholebody(det_model_path, None, torchscript_device=torchscript_device) + t.pose = global_cached_dwpose.pose + t.pose_filename = global_cached_dwpose.pose + global_cached_dwpose = t + + if global_cached_dwpose.pose is None or global_cached_dwpose.pose_filename != pose_filename: + t = Wholebody(None, pose_model_path, torchscript_device=torchscript_device) + t.det = global_cached_dwpose.det + t.det_filename = global_cached_dwpose.det_filename + global_cached_dwpose = t + return cls(global_cached_dwpose) + + def detect_poses(self, oriImg) -> List[PoseResult]: + with torch.no_grad(): + keypoints_info = self.dw_pose_estimation(oriImg.copy()) + return Wholebody.format_result(keypoints_info) + + def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs): + if hand_and_face is not None: + warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + include_hand = hand_and_face + include_face = hand_and_face + + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + poses = self.detect_poses(input_image) + + canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + canvas, remove_pad = resize_image_with_pad(canvas, detect_resolution, upscale_method) + detected_map = HWC3(remove_pad(canvas)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + if image_and_json: + return (detected_map, encode_poses_as_dict(poses, input_image.shape[0], input_image.shape[1])) + + return detected_map + +global_cached_animalpose = AnimalPoseImage() +class AnimalposeDetector: + """ + A class for detecting animal poses in images using the RTMPose AP10k model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, animal_pose_estimation): + self.animal_pose_estimation = animal_pose_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path, pretrained_det_model_or_path=None, det_filename="yolox_l.onnx", pose_filename="dw-ll_ucoco_384.onnx", torchscript_device="cuda"): + global global_cached_animalpose + det_model_path = custom_hf_download(pretrained_det_model_or_path, det_filename) + pose_model_path = custom_hf_download(pretrained_model_or_path, pose_filename) + + print(f"\nAnimalPose: Using {det_filename} for bbox detection and {pose_filename} for pose estimation") + if global_cached_animalpose.det is None or global_cached_animalpose.det_filename != det_filename: + t = AnimalPoseImage(det_model_path, None, torchscript_device=torchscript_device) + t.pose = global_cached_animalpose.pose + t.pose_filename = global_cached_animalpose.pose + global_cached_animalpose = t + + if global_cached_animalpose.pose is None or global_cached_animalpose.pose_filename != pose_filename: + t = AnimalPoseImage(None, pose_model_path, torchscript_device=torchscript_device) + t.det = global_cached_animalpose.det + t.det_filename = global_cached_animalpose.det_filename + global_cached_animalpose = t + return cls(global_cached_animalpose) + + def __call__(self, input_image, detect_resolution=512, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + result = self.animal_pose_estimation(input_image) + if result is None: + detected_map = np.zeros_like(input_image) + openpose_dict = { + 'version': 'ap10k', + 'animals': [], + 'canvas_height': input_image.shape[0], + 'canvas_width': input_image.shape[1] + } + else: + detected_map, openpose_dict = result + detected_map = remove_pad(detected_map) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + if image_and_json: + return (detected_map, openpose_dict) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/animalpose.py b/custom_controlnet_aux/dwpose/animalpose.py new file mode 100644 index 0000000000000000000000000000000000000000..d61e0c152182c133e2a2d5bd3ab789876ac9b47e --- /dev/null +++ b/custom_controlnet_aux/dwpose/animalpose.py @@ -0,0 +1,271 @@ +import numpy as np +import cv2 +import os +import cv2 +from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox +from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas +from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose + +from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox +from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose +from typing import List, Optional +from .types import PoseResult, BodyResult, Keypoint +from custom_controlnet_aux.dwpose.util import guess_onnx_input_shape_dtype, get_ort_providers, get_model_type, is_model_torchscript +from timeit import default_timer +import torch + +def drawBetweenKeypoints(pose_img, keypoints, indexes, color, scaleFactor): + ind0 = indexes[0] - 1 + ind1 = indexes[1] - 1 + + point1 = (keypoints[ind0][0], keypoints[ind0][1]) + point2 = (keypoints[ind1][0], keypoints[ind1][1]) + + thickness = int(5 // scaleFactor) + + + cv2.line(pose_img, (int(point1[0]), int(point1[1])), (int(point2[0]), int(point2[1])), color, thickness) + + +def drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor): + for ind, keypointPair in enumerate(keypointPairsList): + drawBetweenKeypoints(pose_img, keypoints, keypointPair, colorsList[ind], scaleFactor) + +def drawBetweenSetofKeypointLists(pose_img, keypoints_set, keypointPairsList, colorsList, scaleFactor): + for keypoints in keypoints_set: + drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor) + + +def padImg(img, size, blackBorder=True): + left, right, top, bottom = 0, 0, 0, 0 + + # pad x + if img.shape[1] < size[1]: + sidePadding = int((size[1] - img.shape[1]) // 2) + left = sidePadding + right = sidePadding + + # pad extra on right if padding needed is an odd number + if img.shape[1] % 2 == 1: + right += 1 + + # pad y + if img.shape[0] < size[0]: + topBottomPadding = int((size[0] - img.shape[0]) // 2) + top = topBottomPadding + bottom = topBottomPadding + + # pad extra on bottom if padding needed is an odd number + if img.shape[0] % 2 == 1: + bottom += 1 + + if blackBorder: + paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_CONSTANT, value=(0,0,0)) + else: + paddedImg = cv2.copyMakeBorder(src=img, top=top, bottom=bottom, left=left, right=right, borderType=cv2.BORDER_REPLICATE) + + return paddedImg + +def smartCrop(img, size, center): + + width = img.shape[1] + height = img.shape[0] + xSize = size[1] + ySize = size[0] + xCenter = center[0] + yCenter = center[1] + + if img.shape[0] > size[0] or img.shape[1] > size[1]: + + + leftMargin = xCenter - xSize//2 + rightMargin = xCenter + xSize//2 + upMargin = yCenter - ySize//2 + downMargin = yCenter + ySize//2 + + + if(leftMargin < 0): + xCenter += (-leftMargin) + if(rightMargin > width): + xCenter -= (rightMargin - width) + + if(upMargin < 0): + yCenter -= -upMargin + if(downMargin > height): + yCenter -= (downMargin - height) + + + img = cv2.getRectSubPix(img, size, (xCenter, yCenter)) + + + + return img + + + +def calculateScaleFactor(img, size, poseSpanX, poseSpanY): + + poseSpanX = max(poseSpanX, size[0]) + + scaleFactorX = 1 + + + if poseSpanX > size[0]: + scaleFactorX = size[0] / poseSpanX + + scaleFactorY = 1 + if poseSpanY > size[1]: + scaleFactorY = size[1] / poseSpanY + + scaleFactor = min(scaleFactorX, scaleFactorY) + + + return scaleFactor + + + +def scaleImg(img, size, poseSpanX, poseSpanY, scaleFactor): + scaledImg = img + + scaledImg = cv2.resize(img, (0, 0), fx=scaleFactor, fy=scaleFactor) + + return scaledImg, scaleFactor + +class AnimalPoseImage: + def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"): + self.det_filename = det_model_path and os.path.basename(det_model_path) + self.pose_filename = pose_model_path and os.path.basename(pose_model_path) + self.det, self.pose = None, None + # return type: None ort cv2 torchscript + self.det_model_type = get_model_type("AnimalPose",self.det_filename) + self.pose_model_type = get_model_type("AnimalPose",self.pose_filename) + # Always loads to CPU to avoid building OpenCV. + cv2_device = 'cpu' + cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA + # You need to manually build OpenCV through cmake to work with your GPU. + cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA + ort_providers = get_ort_providers() + + if self.det_model_type is None: + pass + elif self.det_model_type == "ort": + try: + import onnxruntime as ort + self.det = ort.InferenceSession(det_model_path, providers=ort_providers) + except: + print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI") + self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"]) + elif self.det_model_type == "cv2": + try: + self.det = cv2.dnn.readNetFromONNX(det_model_path) + self.det.setPreferableBackend(cv2_backend) + self.det.setPreferableTarget(cv2_providers) + except: + print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider") + try: + import onnxruntime as ort + self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"]) + except: + print(f"Failed to load {det_model_path}, you can use other models instead") + else: + self.det = torch.jit.load(det_model_path) + self.det.to(torchscript_device) + + if self.pose_model_type is None: + pass + elif self.pose_model_type == "ort": + try: + import onnxruntime as ort + self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers) + except: + print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI") + self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"]) + elif self.pose_model_type == "cv2": + self.pose = cv2.dnn.readNetFromONNX(pose_model_path) + self.pose.setPreferableBackend(cv2_backend) + self.pose.setPreferableTarget(cv2_providers) + else: + self.pose = torch.jit.load(pose_model_path) + self.pose.to(torchscript_device) + + if self.pose_filename is not None: + self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename) + + def __call__(self, oriImg): + detect_classes = list(range(14, 23 + 1)) #https://github.com/ultralytics/ultralytics/blob/main/ultralytics/cfg/datasets/coco.yaml + + #Sacrifice accurate time measurement for compatibility + det_start = default_timer() + if is_model_torchscript(self.det): + det_result = inference_jit_yolox(self.det, oriImg, detect_classes=detect_classes) + else: + det_start = default_timer() + det_onnx_dtype = np.float32 if "yolox" in self.det_filename else np.uint8 + if "yolox" in self.det_filename: + det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype) + else: + #FP16 and INT8 YOLO NAS accept uint8 input + det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=detect_classes, dtype=det_onnx_dtype) + print(f"AnimalPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms") + + if (det_result is None) or (det_result.shape[0] == 0): + openpose_dict = { + 'version': 'ap10k', + 'animals': [], + 'canvas_height': oriImg.shape[0], + 'canvas_width': oriImg.shape[1] + } + return np.zeros_like(oriImg), openpose_dict + + pose_start = default_timer() + if is_model_torchscript(self.pose): + keypoint_sets, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size) + else: + pose_start = default_timer() + _, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename) + keypoint_sets, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype) + print(f"AnimalPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} animals\n") + + animal_kps_scores = [] + pose_img = np.zeros((oriImg.shape[0], oriImg.shape[1], 3), dtype = np.uint8) + for (idx, keypoints) in enumerate(keypoint_sets): + # don't use keypoints that go outside the frame in calculations for the center + interorKeypoints = keypoints[((keypoints[:,0] > 0) & (keypoints[:,0] < oriImg.shape[1])) & ((keypoints[:,1] > 0) & (keypoints[:,1] < oriImg.shape[0]))] + + xVals = interorKeypoints[:,0] + yVals = interorKeypoints[:,1] + + minX = np.amin(xVals) + minY = np.amin(yVals) + maxX = np.amax(xVals) + maxY = np.amax(yVals) + + poseSpanX = maxX - minX + poseSpanY = maxY - minY + + # find mean center + + xSum = np.sum(xVals) + ySum = np.sum(yVals) + + xCenter = xSum // xVals.shape[0] + yCenter = ySum // yVals.shape[0] + center_of_keypoints = (xCenter,yCenter) + + # order of the keypoints for AP10k and a standardized list of colors for limbs + keypointPairsList = [(1,2), (2,3), (1,3), (3,4), (4,9), (9,10), (10,11), (4,6), (6,7), (7,8), (4,5), (5,15), (15,16), (16,17), (5,12), (12,13), (13,14)] + colorsList = [(255,255,255), (100,255,100), (150,255,255), (100,50,255), (50,150,200), (0,255,255), (0,150,0), (0,0,255), (0,0,150), (255,50,255), (255,0,255), (255,0,0), (150,0,0), (255,255,100), (0,150,0), (255,255,0), (150,150,150)] # 16 colors needed + + drawBetweenKeypointsList(pose_img, keypoints, keypointPairsList, colorsList, scaleFactor=1.0) + score = scores[idx, ..., None] + score[score > 1.0] = 1.0 + score[score < 0.0] = 0.0 + animal_kps_scores.append(np.concatenate((keypoints, score), axis=-1)) + + openpose_dict = { + 'version': 'ap10k', + 'animals': [keypoints.tolist() for keypoints in animal_kps_scores], + 'canvas_height': oriImg.shape[0], + 'canvas_width': oriImg.shape[1] + } + return pose_img, openpose_dict \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/body.py b/custom_controlnet_aux/dwpose/body.py new file mode 100644 index 0000000000000000000000000000000000000000..32934f19eba4b7e762678fd1fcd6b2bd193811d6 --- /dev/null +++ b/custom_controlnet_aux/dwpose/body.py @@ -0,0 +1,261 @@ +import cv2 +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from torchvision import transforms +from typing import NamedTuple, List, Union + +from . import util +from .model import bodypose_model +from .types import Keypoint, BodyResult + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + # if torch.cuda.is_available(): + # self.model = self.model.cuda() + # print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImg): + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + data = data.to(self.cn_device) + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + + @staticmethod + def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: + """ + Format the body results from the candidate and subset arrays into a list of BodyResult objects. + + Args: + candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id + for each body part. + subset (np.ndarray): An array of subsets containing indices to the candidate array for each + person detected. The last two columns of each row hold the total score and total parts + of the person. + + Returns: + List[BodyResult]: A list of BodyResult objects, where each object represents a person with + detected keypoints, total score, and total parts. + """ + return [ + BodyResult( + keypoints=[ + Keypoint( + x=candidate[candidate_index][0], + y=candidate[candidate_index][1], + score=candidate[candidate_index][2], + id=candidate[candidate_index][3] + ) if candidate_index != -1 else None + for candidate_index in person[:18].astype(int) + ], + total_score=person[18], + total_parts=person[19] + ) + for person in subset + ] + + +if __name__ == "__main__": + body_estimation = Body('../model/body_pose_model.pth') + + test_image = '../images/ski.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + candidate, subset = body_estimation(oriImg) + bodies = body_estimation.format_body_result(candidate, subset) + + canvas = oriImg + for body in bodies: + canvas = util.draw_bodypose(canvas, body) + + plt.imshow(canvas[:, :, [2, 1, 0]]) + plt.show() \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_onnx/__init__.py b/custom_controlnet_aux/dwpose/dw_onnx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_onnx/__init__.py @@ -0,0 +1 @@ +#Dummy file ensuring this package will be recognized \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py new file mode 100644 index 0000000000000000000000000000000000000000..0365234c2caef3b98fc01304ba5365da2115ba65 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_det.py @@ -0,0 +1,129 @@ +import cv2 +import numpy as np + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg, detect_classes=[0], dtype=np.float32): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + input = img[None, :, :, :] + input = input.astype(dtype) + if "InferenceSession" in type(session).__name__: + input_name = session.get_inputs()[0].name + output = session.run(None, {input_name: input}) + else: + outNames = session.getUnconnectedOutLayersNames() + session.setInput(input) + output = session.forward(outNames) + + predictions = demo_postprocess(output[0], input_shape)[0] + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is None: + return None + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = np.isin(final_cls_inds, detect_classes) + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + return final_boxes \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..956c4bc715214bcc2e6228166032418294df46bc --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_pose.py @@ -0,0 +1,363 @@ +from typing import List, Tuple + +import cv2 +import numpy as np + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for DWPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + + +def inference(sess, img, dtype=np.float32): + """Inference DWPose model. Processing all image segments at once to take advantage of GPU's parallelism ability if onnxruntime is installed + + Args: + sess : ONNXRuntime session. + img : Input image in shape. + + Returns: + outputs : Output of DWPose model. + """ + all_out = [] + # build input + input = np.stack(img, axis=0).transpose(0, 3, 1, 2) + input = input.astype(dtype) + if "InferenceSession" in type(sess).__name__: + input_name = sess.get_inputs()[0].name + all_outputs = sess.run(None, {input_name: input}) + for batch_idx in range(len(all_outputs[0])): + outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))] + all_out.append(outputs) + return all_out + + #OpenCV doesn't support batch processing sadly + for i in range(len(img)): + input = img[i].transpose(2, 0, 1) + input = input[None, :, :, :] + + outNames = sess.getUnconnectedOutLayersNames() + sess.setInput(input) + outputs = sess.forward(outNames) + all_out.append(outputs) + + return all_out + +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for DWPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + + +def inference_pose(session, out_bbox, oriImg, model_input_size=(288, 384), dtype=np.float32): + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + outputs = inference(session, resized_img, dtype) + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py new file mode 100644 index 0000000000000000000000000000000000000000..67ff249be283b11e0eb7d95ef7c0adc024c48285 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_onnx/cv_ox_yolo_nas.py @@ -0,0 +1,60 @@ +# Source: https://github.com/Hyuto/yolo-nas-onnx/tree/master/yolo-nas-py +# Inspired from: https://github.com/Deci-AI/super-gradients/blob/3.1.1/src/super_gradients/training/processing/processing.py + +import numpy as np +import cv2 + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(session, oriImg, detect_classes=[0], dtype=np.uint8): + """ + This function is only compatible with onnx models exported from the new API with built-in NMS + ```py + from super_gradients.conversion.conversion_enums import ExportQuantizationMode + from super_gradients.common.object_names import Models + from super_gradients.training import models + + model = models.get(Models.YOLO_NAS_L, pretrained_weights="coco") + + export_result = model.export( + "yolo_nas/yolo_nas_l_fp16.onnx", + quantization_mode=ExportQuantizationMode.FP16, + device="cuda" + ) + ``` + """ + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + input = img[None, :, :, :] + input = input.astype(dtype) + if "InferenceSession" in type(session).__name__: + input_name = session.get_inputs()[0].name + output = session.run(None, {input_name: input}) + else: + outNames = session.getUnconnectedOutLayersNames() + session.setInput(input) + output = session.forward(outNames) + num_preds, pred_boxes, pred_scores, pred_classes = output + num_preds = num_preds[0,0] + if num_preds == 0: + return None + idxs = np.where((np.isin(pred_classes[0, :num_preds], detect_classes)) & (pred_scores[0, :num_preds] > 0.3)) + if (len(idxs) == 0) or (idxs[0].size == 0): + return None + return pred_boxes[0, idxs].squeeze(axis=0) / ratio diff --git a/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py b/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..33e7a7f594ef441479257c788e4c0d6e08657fc8 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_torchscript/__init__.py @@ -0,0 +1 @@ +#Dummy file ensuring this package will be recognized \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py b/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py new file mode 100644 index 0000000000000000000000000000000000000000..2ae863509df37386365791df8b4a7635c05d9344 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_torchscript/jit_det.py @@ -0,0 +1,125 @@ +import cv2 +import numpy as np +import torch + +def nms(boxes, scores, nms_thr): + """Single class NMS implemented in Numpy.""" + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + + areas = (x2 - x1 + 1) * (y2 - y1 + 1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1 + 1) + h = np.maximum(0.0, yy2 - yy1 + 1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= nms_thr)[0] + order = order[inds + 1] + + return keep + +def multiclass_nms(boxes, scores, nms_thr, score_thr): + """Multiclass NMS implemented in Numpy. Class-aware version.""" + final_dets = [] + num_classes = scores.shape[1] + for cls_ind in range(num_classes): + cls_scores = scores[:, cls_ind] + valid_score_mask = cls_scores > score_thr + if valid_score_mask.sum() == 0: + continue + else: + valid_scores = cls_scores[valid_score_mask] + valid_boxes = boxes[valid_score_mask] + keep = nms(valid_boxes, valid_scores, nms_thr) + if len(keep) > 0: + cls_inds = np.ones((len(keep), 1)) * cls_ind + dets = np.concatenate( + [valid_boxes[keep], valid_scores[keep, None], cls_inds], 1 + ) + final_dets.append(dets) + if len(final_dets) == 0: + return None + return np.concatenate(final_dets, 0) + +def demo_postprocess(outputs, img_size, p6=False): + grids = [] + expanded_strides = [] + strides = [8, 16, 32] if not p6 else [8, 16, 32, 64] + + hsizes = [img_size[0] // stride for stride in strides] + wsizes = [img_size[1] // stride for stride in strides] + + for hsize, wsize, stride in zip(hsizes, wsizes, strides): + xv, yv = np.meshgrid(np.arange(wsize), np.arange(hsize)) + grid = np.stack((xv, yv), 2).reshape(1, -1, 2) + grids.append(grid) + shape = grid.shape[:2] + expanded_strides.append(np.full((*shape, 1), stride)) + + grids = np.concatenate(grids, 1) + expanded_strides = np.concatenate(expanded_strides, 1) + outputs[..., :2] = (outputs[..., :2] + grids) * expanded_strides + outputs[..., 2:4] = np.exp(outputs[..., 2:4]) * expanded_strides + + return outputs + +def preprocess(img, input_size, swap=(2, 0, 1)): + if len(img.shape) == 3: + padded_img = np.ones((input_size[0], input_size[1], 3), dtype=np.uint8) * 114 + else: + padded_img = np.ones(input_size, dtype=np.uint8) * 114 + + r = min(input_size[0] / img.shape[0], input_size[1] / img.shape[1]) + resized_img = cv2.resize( + img, + (int(img.shape[1] * r), int(img.shape[0] * r)), + interpolation=cv2.INTER_LINEAR, + ).astype(np.uint8) + padded_img[: int(img.shape[0] * r), : int(img.shape[1] * r)] = resized_img + + padded_img = padded_img.transpose(swap) + padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) + return padded_img, r + +def inference_detector(model, oriImg, detect_classes=[0]): + input_shape = (640,640) + img, ratio = preprocess(oriImg, input_shape) + + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + input = img[None, :, :, :] + input = torch.from_numpy(input).to(device, dtype) + + output = model(input).float().cpu().detach().numpy() + predictions = demo_postprocess(output[0], input_shape) + + boxes = predictions[:, :4] + scores = predictions[:, 4:5] * predictions[:, 5:] + + boxes_xyxy = np.ones_like(boxes) + boxes_xyxy[:, 0] = boxes[:, 0] - boxes[:, 2]/2. + boxes_xyxy[:, 1] = boxes[:, 1] - boxes[:, 3]/2. + boxes_xyxy[:, 2] = boxes[:, 0] + boxes[:, 2]/2. + boxes_xyxy[:, 3] = boxes[:, 1] + boxes[:, 3]/2. + boxes_xyxy /= ratio + dets = multiclass_nms(boxes_xyxy, scores, nms_thr=0.45, score_thr=0.1) + if dets is None: + return None + final_boxes, final_scores, final_cls_inds = dets[:, :4], dets[:, 4], dets[:, 5] + isscore = final_scores>0.3 + iscat = np.isin(final_cls_inds, detect_classes) + isbbox = [ i and j for (i, j) in zip(isscore, iscat)] + final_boxes = final_boxes[isbbox] + return final_boxes \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py b/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..7a7b6acee1cd39657cec719aef1aaa09648da4e2 --- /dev/null +++ b/custom_controlnet_aux/dwpose/dw_torchscript/jit_pose.py @@ -0,0 +1,363 @@ +from typing import List, Tuple + +import cv2 +import numpy as np +import torch + +def preprocess( + img: np.ndarray, out_bbox, input_size: Tuple[int, int] = (192, 256) +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Do preprocessing for DWPose model inference. + + Args: + img (np.ndarray): Input image in shape. + input_size (tuple): Input image size in shape (w, h). + + Returns: + tuple: + - resized_img (np.ndarray): Preprocessed image. + - center (np.ndarray): Center of image. + - scale (np.ndarray): Scale of image. + """ + # get shape of image + img_shape = img.shape[:2] + out_img, out_center, out_scale = [], [], [] + if len(out_bbox) == 0: + out_bbox = [[0, 0, img_shape[1], img_shape[0]]] + for i in range(len(out_bbox)): + x0 = out_bbox[i][0] + y0 = out_bbox[i][1] + x1 = out_bbox[i][2] + y1 = out_bbox[i][3] + bbox = np.array([x0, y0, x1, y1]) + + # get center and scale + center, scale = bbox_xyxy2cs(bbox, padding=1.25) + + # do affine transformation + resized_img, scale = top_down_affine(input_size, scale, center, img) + + # normalize image + mean = np.array([123.675, 116.28, 103.53]) + std = np.array([58.395, 57.12, 57.375]) + resized_img = (resized_img - mean) / std + + out_img.append(resized_img) + out_center.append(center) + out_scale.append(scale) + + return out_img, out_center, out_scale + +def inference(model, img, bs=5): + """Inference DWPose model implemented in TorchScript. + + Args: + model : TorchScript Model. + img : Input image in shape. + + Returns: + outputs : Output of DWPose model. + """ + all_out = [] + # build input + orig_img_count = len(img) + #Pad zeros to fit batch size + for _ in range(bs - (orig_img_count % bs)): + img.append(np.zeros_like(img[0])) + input = np.stack(img, axis=0).transpose(0, 3, 1, 2) + device, dtype = next(model.parameters()).device, next(model.parameters()).dtype + input = torch.from_numpy(input).to(device, dtype) + + out1, out2 = [], [] + for i in range(input.shape[0] // bs): + curr_batch_output = model(input[i*bs:(i+1)*bs]) + out1.append(curr_batch_output[0].float()) + out2.append(curr_batch_output[1].float()) + out1, out2 = torch.cat(out1, dim=0)[:orig_img_count], torch.cat(out2, dim=0)[:orig_img_count] + out1, out2 = out1.float().cpu().detach().numpy(), out2.float().cpu().detach().numpy() + all_outputs = out1, out2 + + for batch_idx in range(len(all_outputs[0])): + outputs = [all_outputs[i][batch_idx:batch_idx+1,...] for i in range(len(all_outputs))] + all_out.append(outputs) + return all_out +def postprocess(outputs: List[np.ndarray], + model_input_size: Tuple[int, int], + center: Tuple[int, int], + scale: Tuple[int, int], + simcc_split_ratio: float = 2.0 + ) -> Tuple[np.ndarray, np.ndarray]: + """Postprocess for DWPose model output. + + Args: + outputs (np.ndarray): Output of RTMPose model. + model_input_size (tuple): RTMPose model Input image size. + center (tuple): Center of bbox in shape (x, y). + scale (tuple): Scale of bbox in shape (w, h). + simcc_split_ratio (float): Split ratio of simcc. + + Returns: + tuple: + - keypoints (np.ndarray): Rescaled keypoints. + - scores (np.ndarray): Model predict scores. + """ + all_key = [] + all_score = [] + for i in range(len(outputs)): + # use simcc to decode + simcc_x, simcc_y = outputs[i] + keypoints, scores = decode(simcc_x, simcc_y, simcc_split_ratio) + + # rescale keypoints + keypoints = keypoints / model_input_size * scale[i] + center[i] - scale[i] / 2 + all_key.append(keypoints[0]) + all_score.append(scores[0]) + + return np.array(all_key), np.array(all_score) + + +def bbox_xyxy2cs(bbox: np.ndarray, + padding: float = 1.) -> Tuple[np.ndarray, np.ndarray]: + """Transform the bbox format from (x,y,w,h) into (center, scale) + + Args: + bbox (ndarray): Bounding box(es) in shape (4,) or (n, 4), formatted + as (left, top, right, bottom) + padding (float): BBox padding factor that will be multilied to scale. + Default: 1.0 + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: Center (x, y) of the bbox in shape (2,) or + (n, 2) + - np.ndarray[float32]: Scale (w, h) of the bbox in shape (2,) or + (n, 2) + """ + # convert single bbox from (4, ) to (1, 4) + dim = bbox.ndim + if dim == 1: + bbox = bbox[None, :] + + # get bbox center and scale + x1, y1, x2, y2 = np.hsplit(bbox, [1, 2, 3]) + center = np.hstack([x1 + x2, y1 + y2]) * 0.5 + scale = np.hstack([x2 - x1, y2 - y1]) * padding + + if dim == 1: + center = center[0] + scale = scale[0] + + return center, scale + + +def _fix_aspect_ratio(bbox_scale: np.ndarray, + aspect_ratio: float) -> np.ndarray: + """Extend the scale to match the given aspect ratio. + + Args: + scale (np.ndarray): The image scale (w, h) in shape (2, ) + aspect_ratio (float): The ratio of ``w/h`` + + Returns: + np.ndarray: The reshaped image scale in (2, ) + """ + w, h = np.hsplit(bbox_scale, [1]) + bbox_scale = np.where(w > h * aspect_ratio, + np.hstack([w, w / aspect_ratio]), + np.hstack([h * aspect_ratio, h])) + return bbox_scale + + +def _rotate_point(pt: np.ndarray, angle_rad: float) -> np.ndarray: + """Rotate a point by an angle. + + Args: + pt (np.ndarray): 2D point coordinates (x, y) in shape (2, ) + angle_rad (float): rotation angle in radian + + Returns: + np.ndarray: Rotated point in shape (2, ) + """ + sn, cs = np.sin(angle_rad), np.cos(angle_rad) + rot_mat = np.array([[cs, -sn], [sn, cs]]) + return rot_mat @ pt + + +def _get_3rd_point(a: np.ndarray, b: np.ndarray) -> np.ndarray: + """To calculate the affine matrix, three pairs of points are required. This + function is used to get the 3rd point, given 2D points a & b. + + The 3rd point is defined by rotating vector `a - b` by 90 degrees + anticlockwise, using b as the rotation center. + + Args: + a (np.ndarray): The 1st point (x,y) in shape (2, ) + b (np.ndarray): The 2nd point (x,y) in shape (2, ) + + Returns: + np.ndarray: The 3rd point. + """ + direction = a - b + c = b + np.r_[-direction[1], direction[0]] + return c + + +def get_warp_matrix(center: np.ndarray, + scale: np.ndarray, + rot: float, + output_size: Tuple[int, int], + shift: Tuple[float, float] = (0., 0.), + inv: bool = False) -> np.ndarray: + """Calculate the affine transformation matrix that can warp the bbox area + in the input image to the output size. + + Args: + center (np.ndarray[2, ]): Center of the bounding box (x, y). + scale (np.ndarray[2, ]): Scale of the bounding box + wrt [width, height]. + rot (float): Rotation angle (degree). + output_size (np.ndarray[2, ] | list(2,)): Size of the + destination heatmaps. + shift (0-100%): Shift translation ratio wrt the width/height. + Default (0., 0.). + inv (bool): Option to inverse the affine transform direction. + (inv=False: src->dst or inv=True: dst->src) + + Returns: + np.ndarray: A 2x3 transformation matrix + """ + shift = np.array(shift) + src_w = scale[0] + dst_w = output_size[0] + dst_h = output_size[1] + + # compute transformation matrix + rot_rad = np.deg2rad(rot) + src_dir = _rotate_point(np.array([0., src_w * -0.5]), rot_rad) + dst_dir = np.array([0., dst_w * -0.5]) + + # get four corners of the src rectangle in the original image + src = np.zeros((3, 2), dtype=np.float32) + src[0, :] = center + scale * shift + src[1, :] = center + src_dir + scale * shift + src[2, :] = _get_3rd_point(src[0, :], src[1, :]) + + # get four corners of the dst rectangle in the input image + dst = np.zeros((3, 2), dtype=np.float32) + dst[0, :] = [dst_w * 0.5, dst_h * 0.5] + dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir + dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :]) + + if inv: + warp_mat = cv2.getAffineTransform(np.float32(dst), np.float32(src)) + else: + warp_mat = cv2.getAffineTransform(np.float32(src), np.float32(dst)) + + return warp_mat + + +def top_down_affine(input_size: dict, bbox_scale: dict, bbox_center: dict, + img: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get the bbox image as the model input by affine transform. + + Args: + input_size (dict): The input size of the model. + bbox_scale (dict): The bbox scale of the img. + bbox_center (dict): The bbox center of the img. + img (np.ndarray): The original image. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: img after affine transform. + - np.ndarray[float32]: bbox scale after affine transform. + """ + w, h = input_size + warp_size = (int(w), int(h)) + + # reshape bbox to fixed aspect ratio + bbox_scale = _fix_aspect_ratio(bbox_scale, aspect_ratio=w / h) + + # get the affine matrix + center = bbox_center + scale = bbox_scale + rot = 0 + warp_mat = get_warp_matrix(center, scale, rot, output_size=(w, h)) + + # do affine transform + img = cv2.warpAffine(img, warp_mat, warp_size, flags=cv2.INTER_LINEAR) + + return img, bbox_scale + + +def get_simcc_maximum(simcc_x: np.ndarray, + simcc_y: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """Get maximum response location and value from simcc representations. + + Note: + instance number: N + num_keypoints: K + heatmap height: H + heatmap width: W + + Args: + simcc_x (np.ndarray): x-axis SimCC in shape (K, Wx) or (N, K, Wx) + simcc_y (np.ndarray): y-axis SimCC in shape (K, Wy) or (N, K, Wy) + + Returns: + tuple: + - locs (np.ndarray): locations of maximum heatmap responses in shape + (K, 2) or (N, K, 2) + - vals (np.ndarray): values of maximum heatmap responses in shape + (K,) or (N, K) + """ + N, K, Wx = simcc_x.shape + simcc_x = simcc_x.reshape(N * K, -1) + simcc_y = simcc_y.reshape(N * K, -1) + + # get maximum value locations + x_locs = np.argmax(simcc_x, axis=1) + y_locs = np.argmax(simcc_y, axis=1) + locs = np.stack((x_locs, y_locs), axis=-1).astype(np.float32) + max_val_x = np.amax(simcc_x, axis=1) + max_val_y = np.amax(simcc_y, axis=1) + + # get maximum value across x and y axis + mask = max_val_x > max_val_y + max_val_x[mask] = max_val_y[mask] + vals = max_val_x + locs[vals <= 0.] = -1 + + # reshape + locs = locs.reshape(N, K, 2) + vals = vals.reshape(N, K) + + return locs, vals + + +def decode(simcc_x: np.ndarray, simcc_y: np.ndarray, + simcc_split_ratio) -> Tuple[np.ndarray, np.ndarray]: + """Modulate simcc distribution with Gaussian. + + Args: + simcc_x (np.ndarray[K, Wx]): model predicted simcc in x. + simcc_y (np.ndarray[K, Wy]): model predicted simcc in y. + simcc_split_ratio (int): The split ratio of simcc. + + Returns: + tuple: A tuple containing center and scale. + - np.ndarray[float32]: keypoints in shape (K, 2) or (n, K, 2) + - np.ndarray[float32]: scores in shape (K,) or (n, K) + """ + keypoints, scores = get_simcc_maximum(simcc_x, simcc_y) + keypoints /= simcc_split_ratio + + return keypoints, scores + +def inference_pose(model, out_bbox, oriImg, model_input_size=(288, 384)): + resized_img, center, scale = preprocess(oriImg, out_bbox, model_input_size) + #outputs = inference(session, resized_img, dtype) + outputs = inference(model, resized_img) + + keypoints, scores = postprocess(outputs, model_input_size, center, scale) + + return keypoints, scores \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/face.py b/custom_controlnet_aux/dwpose/face.py new file mode 100644 index 0000000000000000000000000000000000000000..f3c46d77664aa9fa91c63785a1485a396f05cacc --- /dev/null +++ b/custom_controlnet_aux/dwpose/face.py @@ -0,0 +1,362 @@ +import logging +import numpy as np +from torchvision.transforms import ToTensor, ToPILImage +import torch +import torch.nn.functional as F +import cv2 + +from . import util +from torch.nn import Conv2d, Module, ReLU, MaxPool2d, init + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +LOG = logging.getLogger(__name__) +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + # if torch.cuda.is_available(): + # self.model = self.model.cuda() + # print('cuda') + self.model.eval() + + def __call__(self, face_img): + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + x_data = x_data.to(self.cn_device) + + with torch.no_grad(): + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/hand.py b/custom_controlnet_aux/dwpose/hand.py new file mode 100644 index 0000000000000000000000000000000000000000..74767def506c72612954fe3b79056d17a83b1e16 --- /dev/null +++ b/custom_controlnet_aux/dwpose/hand.py @@ -0,0 +1,94 @@ +import cv2 +import json +import numpy as np +import math +import time +from scipy.ndimage.filters import gaussian_filter +import matplotlib.pyplot as plt +import matplotlib +import torch +from skimage.measure import label + +from .model import handpose_model +from . import util + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + # if torch.cuda.is_available(): + # self.model = self.model.cuda() + # print('cuda') + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + + def __call__(self, oriImgRaw): + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + if torch.cuda.is_available(): + data = data.cuda() + + with torch.no_grad(): + data = data.to(self.cn_device) + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/custom_controlnet_aux/dwpose/model.py b/custom_controlnet_aux/dwpose/model.py new file mode 100644 index 0000000000000000000000000000000000000000..72dc79ad857933a7c108d21494d6395572b816e6 --- /dev/null +++ b/custom_controlnet_aux/dwpose/model.py @@ -0,0 +1,218 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 + diff --git a/custom_controlnet_aux/dwpose/types.py b/custom_controlnet_aux/dwpose/types.py new file mode 100644 index 0000000000000000000000000000000000000000..6a5fa889e401bcd85e50977a674ce601bb02a46a --- /dev/null +++ b/custom_controlnet_aux/dwpose/types.py @@ -0,0 +1,30 @@ +from typing import NamedTuple, List, Optional + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +class BodyResult(NamedTuple): + # Note: Using `Optional` instead of `|` operator as the ladder is a Python + # 3.10 feature. + # Annotator code should be Python 3.8 Compatible, as controlnet repo uses + # Python 3.8 environment. + # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 + keypoints: List[Optional[Keypoint]] + total_score: float = 0.0 + total_parts: int = 0 + + +HandResult = List[Keypoint] +FaceResult = List[Keypoint] +AnimalPoseResult = List[Keypoint] + + +class PoseResult(NamedTuple): + body: BodyResult + left_hand: Optional[HandResult] + right_hand: Optional[HandResult] + face: Optional[FaceResult] diff --git a/custom_controlnet_aux/dwpose/util.py b/custom_controlnet_aux/dwpose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..cce0dc28a0989af7ae1f04ca6f782d694e27a4bd --- /dev/null +++ b/custom_controlnet_aux/dwpose/util.py @@ -0,0 +1,457 @@ +import math +import numpy as np +import matplotlib +import cv2 +import os +from typing import List, Tuple, Union, Optional + +from .body import BodyResult, Keypoint + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def is_normalized(keypoints: List[Optional[Keypoint]]) -> bool: + point_normalized = [ + 0 <= abs(k.x) <= 1 and 0 <= abs(k.y) <= 1 + for k in keypoints + if k is not None + ] + if not point_normalized: + return False + return all(point_normalized) + + +def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: + """ + Draw keypoints and limbs representing body pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. + keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not is_normalized(keypoints): + H, W = 1.0, 1.0 + else: + H, W, _ = canvas.shape + + stickwidth = 4 + + limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], + [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], + [2, 1], [1, 15], [15, 17], [1, 16], + [16, 18], + ] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for (k1_index, k2_index), color in zip(limbSeq, colors): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + continue + + Y = np.array([keypoint1.x, keypoint2.x]) * float(W) + X = np.array([keypoint1.y, keypoint2.y]) * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) + + for keypoint, color in zip(keypoints, colors): + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + + return canvas + + +def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + if not is_normalized(keypoints): + H, W = 1.0, 1.0 + else: + H, W, _ = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + + x1 = int(k1.x * W) + y1 = int(k1.y * H) + x2 = int(k2.x * W) + y2 = int(k2.y * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for keypoint in keypoints: + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints representing face pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + if not is_normalized(keypoints): + H, W = 1.0, 1.0 + else: + H, W, _ = canvas.shape + + for keypoint in keypoints: + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: + """ + Detect hands in the input body pose keypoints and calculate the bounding box for each hand. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left + corner of the bounding box, the width (height) of the bounding box, and + a boolean flag indicating whether the hand is a left hand (True) or a + right hand (False). + + Notes: + - The width and height of the bounding boxes are equal since the network requires squared input. + - The minimum bounding box size is 20 pixels. + """ + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + left_shoulder = keypoints[5] + left_elbow = keypoints[6] + left_wrist = keypoints[7] + right_shoulder = keypoints[2] + right_elbow = keypoints[3] + right_wrist = keypoints[4] + + # if any of three not detected + has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) + has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) + if not (has_left or has_right): + return [] + + hands = [] + #left hand + if has_left: + hands.append([ + left_shoulder.x, left_shoulder.y, + left_elbow.x, left_elbow.y, + left_wrist.x, left_wrist.y, + True + ]) + # right hand + if has_right: + hands.append([ + right_shoulder.x, right_shoulder.y, + right_elbow.x, right_elbow.y, + right_wrist.x, right_wrist.y, + False + ]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append((int(x), int(y), int(width), is_left)) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: + """ + Detect the face in the input body pose keypoints and calculate the bounding box for the face. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the + bounding box and the width (height) of the bounding box, or None if the + face is not detected or the bounding box width is less than 20 pixels. + + Notes: + - The width and height of the bounding box are equal. + - The minimum bounding box size is 20 pixels. + """ + # left right eye ear 14 15 16 17 + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + head = keypoints[0] + left_eye = keypoints[14] + right_eye = keypoints[15] + left_ear = keypoints[16] + right_ear = keypoints[17] + + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): + return None + + width = 0.0 + x0, y0 = head.x, head.y + + if left_eye is not None: + x1, y1 = left_eye.x, left_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if right_eye is not None: + x1, y1 = right_eye.x, right_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if left_ear is not None: + x1, y1 = left_ear.x, left_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if right_ear is not None: + x1, y1 = right_ear.x, right_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + return int(x), int(y), int(width) + else: + return None + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j + +def guess_onnx_input_shape_dtype(filename): + dtype = np.float32 + if "fp16" in filename: + dtype = np.float16 + elif "int8" in filename: + dtype = np.uint8 + input_size = (640, 640) if "yolo" in filename else (192, 256) + if "384" in filename: + input_size = (288, 384) + elif "256" in filename: + input_size = (256, 256) + return input_size, dtype + +if os.getenv('AUX_ORT_PROVIDERS'): + ONNX_PROVIDERS = os.getenv('AUX_ORT_PROVIDERS').split(',') +else: + ONNX_PROVIDERS = ["CUDAExecutionProvider", "DirectMLExecutionProvider", "OpenVINOExecutionProvider", "ROCMExecutionProvider", "CPUExecutionProvider"] +def get_ort_providers() -> List[str]: + providers = [] + try: + import onnxruntime as ort + for provider in ONNX_PROVIDERS: + if provider in ort.get_available_providers(): + providers.append(provider) + return providers + except: + return [] + +def is_model_torchscript(model) -> bool: + return bool(type(model).__name__ == "RecursiveScriptModule") + +def get_model_type(Nodesname, filename) -> str: + ort_providers = list(filter(lambda x : x != "CPUExecutionProvider", get_ort_providers())) + if filename is None: + return None + elif ("onnx" in filename) and ort_providers: + print(f"{Nodesname}: Caching ONNXRuntime session {filename}...") + return "ort" + elif ("onnx" in filename): + print(f"{Nodesname}: Caching OpenCV DNN module {filename} on cv2.DNN...") + return "cv2" + else: + print(f"{Nodesname}: Caching TorchScript module {filename} on ...") + return "torchscript" diff --git a/custom_controlnet_aux/dwpose/wholebody.py b/custom_controlnet_aux/dwpose/wholebody.py new file mode 100644 index 0000000000000000000000000000000000000000..03b68cd6da1e8c1c95c630b2508f7988cd0cfae6 --- /dev/null +++ b/custom_controlnet_aux/dwpose/wholebody.py @@ -0,0 +1,167 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import cv2 +import numpy as np + +from .dw_onnx.cv_ox_det import inference_detector as inference_onnx_yolox +from .dw_onnx.cv_ox_yolo_nas import inference_detector as inference_onnx_yolo_nas +from .dw_onnx.cv_ox_pose import inference_pose as inference_onnx_pose + +from .dw_torchscript.jit_det import inference_detector as inference_jit_yolox +from .dw_torchscript.jit_pose import inference_pose as inference_jit_pose + +from typing import List, Optional +from .types import PoseResult, BodyResult, Keypoint +from timeit import default_timer +import os +from custom_controlnet_aux.dwpose.util import guess_onnx_input_shape_dtype, get_model_type, get_ort_providers, is_model_torchscript +import torch + +class Wholebody: + def __init__(self, det_model_path: Optional[str] = None, pose_model_path: Optional[str] = None, torchscript_device="cuda"): + self.det_filename = det_model_path and os.path.basename(det_model_path) + self.pose_filename = pose_model_path and os.path.basename(pose_model_path) + self.det, self.pose = None, None + # return type: None ort cv2 torchscript + self.det_model_type = get_model_type("DWPose",self.det_filename) + self.pose_model_type = get_model_type("DWPose",self.pose_filename) + # Always loads to CPU to avoid building OpenCV. + cv2_device = 'cpu' + cv2_backend = cv2.dnn.DNN_BACKEND_OPENCV if cv2_device == 'cpu' else cv2.dnn.DNN_BACKEND_CUDA + # You need to manually build OpenCV through cmake to work with your GPU. + cv2_providers = cv2.dnn.DNN_TARGET_CPU if cv2_device == 'cpu' else cv2.dnn.DNN_TARGET_CUDA + ort_providers = get_ort_providers() + + if self.det_model_type is None: + pass + elif self.det_model_type == "ort": + try: + import onnxruntime as ort + self.det = ort.InferenceSession(det_model_path, providers=ort_providers) + except: + print(f"Failed to load onnxruntime with {self.det.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI") + self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"]) + elif self.det_model_type == "cv2": + try: + self.det = cv2.dnn.readNetFromONNX(det_model_path) + self.det.setPreferableBackend(cv2_backend) + self.det.setPreferableTarget(cv2_providers) + except: + print("TopK operators may not work on your OpenCV, try use onnxruntime with CPUExecutionProvider") + try: + import onnxruntime as ort + self.det = ort.InferenceSession(det_model_path, providers=["CPUExecutionProvider"]) + except: + print(f"Failed to load {det_model_path}, you can use other models instead") + else: + self.det = torch.jit.load(det_model_path) + self.det.to(torchscript_device) + + if self.pose_model_type is None: + pass + elif self.pose_model_type == "ort": + try: + import onnxruntime as ort + self.pose = ort.InferenceSession(pose_model_path, providers=ort_providers) + except: + print(f"Failed to load onnxruntime with {self.pose.get_providers()}.\nPlease change EP_list in the config.yaml and restart ComfyUI") + self.pose = ort.InferenceSession(pose_model_path, providers=["CPUExecutionProvider"]) + elif self.pose_model_type == "cv2": + self.pose = cv2.dnn.readNetFromONNX(pose_model_path) + self.pose.setPreferableBackend(cv2_backend) + self.pose.setPreferableTarget(cv2_providers) + else: + self.pose = torch.jit.load(pose_model_path) + self.pose.to(torchscript_device) + + if self.pose_filename is not None: + self.pose_input_size, _ = guess_onnx_input_shape_dtype(self.pose_filename) + + def __call__(self, oriImg) -> Optional[np.ndarray]: + #Sacrifice accurate time measurement for compatibility + det_start = default_timer() + if is_model_torchscript(self.det): + det_result = inference_jit_yolox(self.det, oriImg, detect_classes=[0]) + else: + if "yolox" in self.det_filename: + det_result = inference_onnx_yolox(self.det, oriImg, detect_classes=[0], dtype=np.float32) + else: + #FP16 and INT8 YOLO NAS accept uint8 input + det_result = inference_onnx_yolo_nas(self.det, oriImg, detect_classes=[0], dtype=np.uint8) + print(f"DWPose: Bbox {((default_timer() - det_start) * 1000):.2f}ms") + if (det_result is None) or (det_result.shape[0] == 0): + return None + + pose_start = default_timer() + if is_model_torchscript(self.pose): + keypoints, scores = inference_jit_pose(self.pose, det_result, oriImg, self.pose_input_size) + else: + _, pose_onnx_dtype = guess_onnx_input_shape_dtype(self.pose_filename) + keypoints, scores = inference_onnx_pose(self.pose, det_result, oriImg, self.pose_input_size, dtype=pose_onnx_dtype) + print(f"DWPose: Pose {((default_timer() - pose_start) * 1000):.2f}ms on {det_result.shape[0]} people\n") + + keypoints_info = np.concatenate( + (keypoints, scores[..., None]), axis=-1) + # compute neck joint + neck = np.mean(keypoints_info[:, [5, 6]], axis=1) + # neck score when visualizing pred + neck[:, 2:4] = np.logical_and( + keypoints_info[:, 5, 2:4] > 0.3, + keypoints_info[:, 6, 2:4] > 0.3).astype(int) + new_keypoints_info = np.insert( + keypoints_info, 17, neck, axis=1) + mmpose_idx = [ + 17, 6, 8, 10, 7, 9, 12, 14, 16, 13, 15, 2, 1, 4, 3 + ] + openpose_idx = [ + 1, 2, 3, 4, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17 + ] + new_keypoints_info[:, openpose_idx] = \ + new_keypoints_info[:, mmpose_idx] + keypoints_info = new_keypoints_info + + return keypoints_info + + @staticmethod + def format_result(keypoints_info: Optional[np.ndarray]) -> List[PoseResult]: + def format_keypoint_part( + part: np.ndarray, + ) -> Optional[List[Optional[Keypoint]]]: + keypoints = [ + Keypoint(x, y, score, i) if score >= 0.3 else None + for i, (x, y, score) in enumerate(part) + ] + return ( + None if all(keypoint is None for keypoint in keypoints) else keypoints + ) + + def total_score(keypoints: Optional[List[Optional[Keypoint]]]) -> float: + return ( + sum(keypoint.score for keypoint in keypoints if keypoint is not None) + if keypoints is not None + else 0.0 + ) + + pose_results = [] + if keypoints_info is None: + return pose_results + + for instance in keypoints_info: + body_keypoints = format_keypoint_part(instance[:18]) or ([None] * 18) + left_hand = format_keypoint_part(instance[92:113]) + right_hand = format_keypoint_part(instance[113:134]) + face = format_keypoint_part(instance[24:92]) + + # Openpose face consists of 70 points in total, while DWPose only + # provides 68 points. Padding the last 2 points. + if face is not None: + # left eye + face.append(body_keypoints[14]) + # right eye + face.append(body_keypoints[15]) + + body = BodyResult( + body_keypoints, total_score(body_keypoints), len(body_keypoints) + ) + pose_results.append(PoseResult(body, left_hand, right_hand, face)) + + return pose_results \ No newline at end of file diff --git a/custom_controlnet_aux/hed/__init__.py b/custom_controlnet_aux/hed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..022c4f999303b8496dbb64bce9457ecebf728efb --- /dev/null +++ b/custom_controlnet_aux/hed/__init__.py @@ -0,0 +1,110 @@ +# This is an improved version and model of HED edge detection with Apache License, Version 2.0. +# Please use this implementation in your products +# This implementation may produce slightly different results from Saining Xie's official implementations, +# but it generates smoother edges and is more suitable for ControlNet as well as other image-to-image translations. +# Different from official models and other implementations, this is an RGB-input model (rather than BGR) +# and in this way it works better for gradio's RGB protocol + +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, nms, resize_image_with_pad, safe_step, common_input_validate, custom_hf_download, HF_MODEL_NAME + + +class DoubleConvBlock(torch.nn.Module): + def __init__(self, input_channel, output_channel, layer_number): + super().__init__() + self.convs = torch.nn.Sequential() + self.convs.append(torch.nn.Conv2d(in_channels=input_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + for i in range(1, layer_number): + self.convs.append(torch.nn.Conv2d(in_channels=output_channel, out_channels=output_channel, kernel_size=(3, 3), stride=(1, 1), padding=1)) + self.projection = torch.nn.Conv2d(in_channels=output_channel, out_channels=1, kernel_size=(1, 1), stride=(1, 1), padding=0) + + def __call__(self, x, down_sampling=False): + h = x + if down_sampling: + h = torch.nn.functional.max_pool2d(h, kernel_size=(2, 2), stride=(2, 2)) + for conv in self.convs: + h = conv(h) + h = torch.nn.functional.relu(h) + return h, self.projection(h) + + +class ControlNetHED_Apache2(torch.nn.Module): + def __init__(self): + super().__init__() + self.norm = torch.nn.Parameter(torch.zeros(size=(1, 3, 1, 1))) + self.block1 = DoubleConvBlock(input_channel=3, output_channel=64, layer_number=2) + self.block2 = DoubleConvBlock(input_channel=64, output_channel=128, layer_number=2) + self.block3 = DoubleConvBlock(input_channel=128, output_channel=256, layer_number=3) + self.block4 = DoubleConvBlock(input_channel=256, output_channel=512, layer_number=3) + self.block5 = DoubleConvBlock(input_channel=512, output_channel=512, layer_number=3) + + def __call__(self, x): + h = x - self.norm + h, projection1 = self.block1(h) + h, projection2 = self.block2(h, down_sampling=True) + h, projection3 = self.block3(h, down_sampling=True) + h, projection4 = self.block4(h, down_sampling=True) + h, projection5 = self.block5(h, down_sampling=True) + return projection1, projection2, projection3, projection4, projection5 + +class HEDdetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="ControlNetHED.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + netNetwork = ControlNetHED_Apache2() + netNetwork.load_state_dict(torch.load(model_path, map_location='cpu')) + netNetwork.float().eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + self.device = device + return self + + + def __call__(self, input_image, detect_resolution=512, safe=False, output_type="pil", scribble=False, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + assert input_image.ndim == 3 + H, W, C = input_image.shape + with torch.no_grad(): + image_hed = torch.from_numpy(input_image).float().to(self.device) + image_hed = rearrange(image_hed, 'h w c -> 1 c h w') + edges = self.netNetwork(image_hed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + detected_map = HWC3(remove_pad(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/leres/__init__.py b/custom_controlnet_aux/leres/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e348228730393dc3d2f79f571cba0c0ed95972d --- /dev/null +++ b/custom_controlnet_aux/leres/__init__.py @@ -0,0 +1,93 @@ +import os + +import cv2 +import numpy as np +import torch +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +from .leres.depthmap import estimateboost, estimateleres +from .leres.multi_depth_model_woauxi import RelDepthModel +from .leres.net_tools import strip_prefix_if_present +from .pix2pix.models.pix2pix4depth_model import Pix2Pix4DepthModel +from .pix2pix.options.test_options import TestOptions + + +class LeresDetector: + def __init__(self, model, pix2pixmodel): + self.model = model + self.pix2pixmodel = pix2pixmodel + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="res101.pth", pix2pix_filename="latest_net_G.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + checkpoint = torch.load(model_path, map_location=torch.device('cpu')) + + model = RelDepthModel(backbone='resnext101') + model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), strict=True) + del checkpoint + + pix2pix_model_path = custom_hf_download(pretrained_model_or_path, pix2pix_filename) + + opt = TestOptions().parse() + if not torch.cuda.is_available(): + opt.gpu_ids = [] # cpu mode + pix2pixmodel = Pix2Pix4DepthModel(opt) + pix2pixmodel.save_dir = os.path.dirname(pix2pix_model_path) + pix2pixmodel.load_networks('latest') + pix2pixmodel.eval() + + return cls(model, pix2pixmodel) + + def to(self, device): + self.model.to(device) + # TODO - refactor pix2pix implementation to support device migration + # self.pix2pixmodel.to(device) + return self + + def __call__(self, input_image, thr_a=0, thr_b=0, boost=False, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + with torch.no_grad(): + if boost: + depth = estimateboost(detected_map, self.model, 0, self.pix2pixmodel, max(detected_map.shape[1], detected_map.shape[0])) + else: + depth = estimateleres(detected_map, self.model, detected_map.shape[1], detected_map.shape[0]) + + numbytes=2 + depth_min = depth.min() + depth_max = depth.max() + max_val = (2**(8*numbytes))-1 + + # check output before normalizing and mapping to 16 bit + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape) + + # single channel, 16 bit image + depth_image = out.astype("uint16") + + # convert to uint8 + depth_image = cv2.convertScaleAbs(depth_image, alpha=(255.0/65535.0)) + + # remove near + if thr_a != 0: + thr_a = ((thr_a/100)*255) + depth_image = cv2.threshold(depth_image, thr_a, 255, cv2.THRESH_TOZERO)[1] + + # invert image + depth_image = cv2.bitwise_not(depth_image) + + # remove bg + if thr_b != 0: + thr_b = ((thr_b/100)*255) + depth_image = cv2.threshold(depth_image, thr_b, 255, cv2.THRESH_TOZERO)[1] + + detected_map = HWC3(remove_pad(depth_image)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/leres/leres/LICENSE b/custom_controlnet_aux/leres/leres/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e0f1d07d98d4e85e684734d058dfe2515d215405 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/LICENSE @@ -0,0 +1,23 @@ +https://github.com/thygate/stable-diffusion-webui-depthmap-script + +MIT License + +Copyright (c) 2023 Bob Thiry + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/leres/leres/Resnet.py b/custom_controlnet_aux/leres/leres/Resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..f12c9975c1aa05401269be3ca3dbaa56bde55581 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/Resnet.py @@ -0,0 +1,199 @@ +import torch.nn as nn +import torch.nn as NN + +__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', + 'resnet152'] + + +model_urls = { + 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', + 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', + 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', + 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', + 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000): + self.inplanes = 64 + super(ResNet, self).__init__() + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + #self.avgpool = nn.AvgPool2d(7, stride=1) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.Sequential(*layers) + + def forward(self, x): + features = [] + + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + x = self.layer2(x) + features.append(x) + x = self.layer3(x) + features.append(x) + x = self.layer4(x) + features.append(x) + + return features + + +def resnet18(pretrained=True, **kwargs): + """Constructs a ResNet-18 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) + return model + + +def resnet34(pretrained=True, **kwargs): + """Constructs a ResNet-34 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) + return model + + +def resnet50(pretrained=True, **kwargs): + """Constructs a ResNet-50 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) + + return model + + +def resnet101(pretrained=True, **kwargs): + """Constructs a ResNet-101 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + + return model + + +def resnet152(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) + return model diff --git a/custom_controlnet_aux/leres/leres/Resnext_torch.py b/custom_controlnet_aux/leres/leres/Resnext_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..9af54fcc3e5b363935ef60c8aaf269110c0d6611 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/Resnext_torch.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python +# coding: utf-8 +import torch.nn as nn + +try: + from urllib import urlretrieve +except ImportError: + from urllib.request import urlretrieve + +__all__ = ['resnext101_32x8d'] + + +model_urls = { + 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', + 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', +} + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) + # while original implementation places the stride at the first 1x1 convolution(self.conv1) + # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. + # This variant is also known as ResNet V1.5 and improves accuracy according to + # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Module): + + def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, + groups=1, width_per_group=64, replace_stride_with_dilation=None, + norm_layer=None): + super(ResNet, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self._norm_layer = norm_layer + + self.inplanes = 64 + self.dilation = 1 + if replace_stride_with_dilation is None: + # each element in the tuple indicates if we should replace + # the 2x2 stride with a dilated convolution instead + replace_stride_with_dilation = [False, False, False] + if len(replace_stride_with_dilation) != 3: + raise ValueError("replace_stride_with_dilation should be None " + "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) + self.groups = groups + self.base_width = width_per_group + self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, + bias=False) + self.bn1 = norm_layer(self.inplanes) + self.relu = nn.ReLU(inplace=True) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2, + dilate=replace_stride_with_dilation[0]) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2, + dilate=replace_stride_with_dilation[1]) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2, + dilate=replace_stride_with_dilation[2]) + #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + #self.fc = nn.Linear(512 * block.expansion, num_classes) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + # Zero-initialize the last BN in each residual branch, + # so that the residual branch starts with zeros, and each residual block behaves like an identity. + # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 + if zero_init_residual: + for m in self.modules(): + if isinstance(m, Bottleneck): + nn.init.constant_(m.bn3.weight, 0) + elif isinstance(m, BasicBlock): + nn.init.constant_(m.bn2.weight, 0) + + def _make_layer(self, block, planes, blocks, stride=1, dilate=False): + norm_layer = self._norm_layer + downsample = None + previous_dilation = self.dilation + if dilate: + self.dilation *= stride + stride = 1 + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + conv1x1(self.inplanes, planes * block.expansion, stride), + norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample, self.groups, + self.base_width, previous_dilation, norm_layer)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes, groups=self.groups, + base_width=self.base_width, dilation=self.dilation, + norm_layer=norm_layer)) + + return nn.Sequential(*layers) + + def _forward_impl(self, x): + # See note [TorchScript super()] + features = [] + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.maxpool(x) + + x = self.layer1(x) + features.append(x) + + x = self.layer2(x) + features.append(x) + + x = self.layer3(x) + features.append(x) + + x = self.layer4(x) + features.append(x) + + #x = self.avgpool(x) + #x = torch.flatten(x, 1) + #x = self.fc(x) + + return features + + def forward(self, x): + return self._forward_impl(x) + + + +def resnext101_32x8d(pretrained=True, **kwargs): + """Constructs a ResNet-152 model. + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + """ + kwargs['groups'] = 32 + kwargs['width_per_group'] = 8 + + model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) + return model + diff --git a/custom_controlnet_aux/leres/leres/__init__.py b/custom_controlnet_aux/leres/leres/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/leres/leres/depthmap.py b/custom_controlnet_aux/leres/leres/depthmap.py new file mode 100644 index 0000000000000000000000000000000000000000..fc743bf4946b514a53f8d286a395e33c7b612582 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/depthmap.py @@ -0,0 +1,548 @@ +# Author: thygate +# https://github.com/thygate/stable-diffusion-webui-depthmap-script + +import gc +from operator import getitem + +import cv2 +import numpy as np +import skimage.measure +import torch +from torchvision.transforms import transforms + +from ...util import torch_gc + +whole_size_threshold = 1600 # R_max from the paper +pix2pixsize = 1024 + +def scale_torch(img): + """ + Scale the image and output it in torch.tensor. + :param img: input rgb is in shape [H, W, C], input depth/disp is in shape [H, W] + :param scale: the scale factor. float + :return: img. [C, H, W] + """ + if len(img.shape) == 2: + img = img[np.newaxis, :, :] + if img.shape[2] == 3: + transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225) )]) + img = transform(img.astype(np.float32)) + else: + img = img.astype(np.float32) + img = torch.from_numpy(img) + return img + +def estimateleres(img, model, w, h): + device = next(iter(model.parameters())).device + # leres transform input + rgb_c = img[:, :, ::-1].copy() + A_resize = cv2.resize(rgb_c, (w, h)) + img_torch = scale_torch(A_resize)[None, :, :, :] + + # compute + with torch.no_grad(): + img_torch = img_torch.to(device) + prediction = model.depth_model(img_torch) + + prediction = prediction.squeeze().cpu().numpy() + prediction = cv2.resize(prediction, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_CUBIC) + + return prediction + +def generatemask(size): + # Generates a Guassian mask + mask = np.zeros(size, dtype=np.float32) + sigma = int(size[0]/16) + k_size = int(2 * np.ceil(2 * int(size[0]/16)) + 1) + mask[int(0.15*size[0]):size[0] - int(0.15*size[0]), int(0.15*size[1]): size[1] - int(0.15*size[1])] = 1 + mask = cv2.GaussianBlur(mask, (int(k_size), int(k_size)), sigma) + mask = (mask - mask.min()) / (mask.max() - mask.min()) + mask = mask.astype(np.float32) + return mask + +def resizewithpool(img, size): + i_size = img.shape[0] + n = int(np.floor(i_size/size)) + + out = skimage.measure.block_reduce(img, (n, n), np.max) + return out + +def rgb2gray(rgb): + # Converts rgb to gray + return np.dot(rgb[..., :3], [0.2989, 0.5870, 0.1140]) + +def calculateprocessingres(img, basesize, confidence=0.1, scale_threshold=3, whole_size_threshold=3000): + # Returns the R_x resolution described in section 5 of the main paper. + + # Parameters: + # img :input rgb image + # basesize : size the dilation kernel which is equal to receptive field of the network. + # confidence: value of x in R_x; allowed percentage of pixels that are not getting any contextual cue. + # scale_threshold: maximum allowed upscaling on the input image ; it has been set to 3. + # whole_size_threshold: maximum allowed resolution. (R_max from section 6 of the main paper) + + # Returns: + # outputsize_scale*speed_scale :The computed R_x resolution + # patch_scale: K parameter from section 6 of the paper + + # speed scale parameter is to process every image in a smaller size to accelerate the R_x resolution search + speed_scale = 32 + image_dim = int(min(img.shape[0:2])) + + gray = rgb2gray(img) + grad = np.abs(cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)) + np.abs(cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)) + grad = cv2.resize(grad, (image_dim, image_dim), cv2.INTER_AREA) + + # thresholding the gradient map to generate the edge-map as a proxy of the contextual cues + m = grad.min() + M = grad.max() + middle = m + (0.4 * (M - m)) + grad[grad < middle] = 0 + grad[grad >= middle] = 1 + + # dilation kernel with size of the receptive field + kernel = np.ones((int(basesize/speed_scale), int(basesize/speed_scale)), float) + # dilation kernel with size of the a quarter of receptive field used to compute k + # as described in section 6 of main paper + kernel2 = np.ones((int(basesize / (4*speed_scale)), int(basesize / (4*speed_scale))), float) + + # Output resolution limit set by the whole_size_threshold and scale_threshold. + threshold = min(whole_size_threshold, scale_threshold * max(img.shape[:2])) + + outputsize_scale = basesize / speed_scale + for p_size in range(int(basesize/speed_scale), int(threshold/speed_scale), int(basesize / (2*speed_scale))): + grad_resized = resizewithpool(grad, p_size) + grad_resized = cv2.resize(grad_resized, (p_size, p_size), cv2.INTER_NEAREST) + grad_resized[grad_resized >= 0.5] = 1 + grad_resized[grad_resized < 0.5] = 0 + + dilated = cv2.dilate(grad_resized, kernel, iterations=1) + meanvalue = (1-dilated).mean() + if meanvalue > confidence: + break + else: + outputsize_scale = p_size + + grad_region = cv2.dilate(grad_resized, kernel2, iterations=1) + patch_scale = grad_region.mean() + + return int(outputsize_scale*speed_scale), patch_scale + +# Generate a double-input depth estimation +def doubleestimate(img, size1, size2, pix2pixsize, model, net_type, pix2pixmodel): + # Generate the low resolution estimation + estimate1 = singleestimate(img, size1, model, net_type) + # Resize to the inference size of merge network. + estimate1 = cv2.resize(estimate1, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Generate the high resolution estimation + estimate2 = singleestimate(img, size2, model, net_type) + # Resize to the inference size of merge network. + estimate2 = cv2.resize(estimate2, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Inference on the merge model + pix2pixmodel.set_input(estimate1, estimate2) + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = (prediction_mapped - torch.min(prediction_mapped)) / ( + torch.max(prediction_mapped) - torch.min(prediction_mapped)) + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + return prediction_mapped + +# Generate a single-input depth estimation +def singleestimate(img, msize, model, net_type): + # if net_type == 0: + return estimateleres(img, model, msize, msize) + # else: + # return estimatemidasBoost(img, model, msize, msize) + +def applyGridpatch(blsize, stride, img, box): + # Extract a simple grid patch. + counter1 = 0 + patch_bound_list = {} + for k in range(blsize, img.shape[1] - blsize, stride): + for j in range(blsize, img.shape[0] - blsize, stride): + patch_bound_list[str(counter1)] = {} + patchbounds = [j - blsize, k - blsize, j - blsize + 2 * blsize, k - blsize + 2 * blsize] + patch_bound = [box[0] + patchbounds[1], box[1] + patchbounds[0], patchbounds[3] - patchbounds[1], + patchbounds[2] - patchbounds[0]] + patch_bound_list[str(counter1)]['rect'] = patch_bound + patch_bound_list[str(counter1)]['size'] = patch_bound[2] + counter1 = counter1 + 1 + return patch_bound_list + +# Generating local patches to perform the local refinement described in section 6 of the main paper. +def generatepatchs(img, base_size): + + # Compute the gradients as a proxy of the contextual cues. + img_gray = rgb2gray(img) + whole_grad = np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)) +\ + np.abs(cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)) + + threshold = whole_grad[whole_grad > 0].mean() + whole_grad[whole_grad < threshold] = 0 + + # We use the integral image to speed-up the evaluation of the amount of gradients for each patch. + gf = whole_grad.sum()/len(whole_grad.reshape(-1)) + grad_integral_image = cv2.integral(whole_grad) + + # Variables are selected such that the initial patch size would be the receptive field size + # and the stride is set to 1/3 of the receptive field size. + blsize = int(round(base_size/2)) + stride = int(round(blsize*0.75)) + + # Get initial Grid + patch_bound_list = applyGridpatch(blsize, stride, img, [0, 0, 0, 0]) + + # Refine initial Grid of patches by discarding the flat (in terms of gradients of the rgb image) ones. Refine + # each patch size to ensure that there will be enough depth cues for the network to generate a consistent depth map. + print("Selecting patches ...") + patch_bound_list = adaptiveselection(grad_integral_image, patch_bound_list, gf) + + # Sort the patch list to make sure the merging operation will be done with the correct order: starting from biggest + # patch + patchset = sorted(patch_bound_list.items(), key=lambda x: getitem(x[1], 'size'), reverse=True) + return patchset + +def getGF_fromintegral(integralimage, rect): + # Computes the gradient density of a given patch from the gradient integral image. + x1 = rect[1] + x2 = rect[1]+rect[3] + y1 = rect[0] + y2 = rect[0]+rect[2] + value = integralimage[x2, y2]-integralimage[x1, y2]-integralimage[x2, y1]+integralimage[x1, y1] + return value + +# Adaptively select patches +def adaptiveselection(integral_grad, patch_bound_list, gf): + patchlist = {} + count = 0 + height, width = integral_grad.shape + + search_step = int(32/factor) + + # Go through all patches + for c in range(len(patch_bound_list)): + # Get patch + bbox = patch_bound_list[str(c)]['rect'] + + # Compute the amount of gradients present in the patch from the integral image. + cgf = getGF_fromintegral(integral_grad, bbox)/(bbox[2]*bbox[3]) + + # Check if patching is beneficial by comparing the gradient density of the patch to + # the gradient density of the whole image + if cgf >= gf: + bbox_test = bbox.copy() + patchlist[str(count)] = {} + + # Enlarge each patch until the gradient density of the patch is equal + # to the whole image gradient density + while True: + + bbox_test[0] = bbox_test[0] - int(search_step/2) + bbox_test[1] = bbox_test[1] - int(search_step/2) + + bbox_test[2] = bbox_test[2] + search_step + bbox_test[3] = bbox_test[3] + search_step + + # Check if we are still within the image + if bbox_test[0] < 0 or bbox_test[1] < 0 or bbox_test[1] + bbox_test[3] >= height \ + or bbox_test[0] + bbox_test[2] >= width: + break + + # Compare gradient density + cgf = getGF_fromintegral(integral_grad, bbox_test)/(bbox_test[2]*bbox_test[3]) + if cgf < gf: + break + bbox = bbox_test.copy() + + # Add patch to selected patches + patchlist[str(count)]['rect'] = bbox + patchlist[str(count)]['size'] = bbox[2] + count = count + 1 + + # Return selected patches + return patchlist + +def impatch(image, rect): + # Extract the given patch pixels from a given image. + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + image_patch = image[h1:h2, w1:w2] + return image_patch + +class ImageandPatchs: + def __init__(self, root_dir, name, patchsinfo, rgb_image, scale=1): + self.root_dir = root_dir + self.patchsinfo = patchsinfo + self.name = name + self.patchs = patchsinfo + self.scale = scale + + self.rgb_image = cv2.resize(rgb_image, (round(rgb_image.shape[1]*scale), round(rgb_image.shape[0]*scale)), + interpolation=cv2.INTER_CUBIC) + + self.do_have_estimate = False + self.estimation_updated_image = None + self.estimation_base_image = None + + def __len__(self): + return len(self.patchs) + + def set_base_estimate(self, est): + self.estimation_base_image = est + if self.estimation_updated_image is not None: + self.do_have_estimate = True + + def set_updated_estimate(self, est): + self.estimation_updated_image = est + if self.estimation_base_image is not None: + self.do_have_estimate = True + + def __getitem__(self, index): + patch_id = int(self.patchs[index][0]) + rect = np.array(self.patchs[index][1]['rect']) + msize = self.patchs[index][1]['size'] + + ## applying scale to rect: + rect = np.round(rect * self.scale) + rect = rect.astype('int') + msize = round(msize * self.scale) + + patch_rgb = impatch(self.rgb_image, rect) + if self.do_have_estimate: + patch_whole_estimate_base = impatch(self.estimation_base_image, rect) + patch_whole_estimate_updated = impatch(self.estimation_updated_image, rect) + return {'patch_rgb': patch_rgb, 'patch_whole_estimate_base': patch_whole_estimate_base, + 'patch_whole_estimate_updated': patch_whole_estimate_updated, 'rect': rect, + 'size': msize, 'id': patch_id} + else: + return {'patch_rgb': patch_rgb, 'rect': rect, 'size': msize, 'id': patch_id} + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + """ + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + """ + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt + + +def estimateboost(img, model, model_type, pix2pixmodel, max_res=512, depthmap_script_boost_rmax=None): + global whole_size_threshold + + # get settings + if depthmap_script_boost_rmax: + whole_size_threshold = depthmap_script_boost_rmax + + if model_type == 0: #leres + net_receptive_field_size = 448 + patch_netsize = 2 * net_receptive_field_size + elif model_type == 1: #dpt_beit_large_512 + net_receptive_field_size = 512 + patch_netsize = 2 * net_receptive_field_size + else: #other midas + net_receptive_field_size = 384 + patch_netsize = 2 * net_receptive_field_size + + gc.collect() + torch_gc() + + # Generate mask used to smoothly blend the local pathc estimations to the base estimate. + # It is arbitrarily large to avoid artifacts during rescaling for each crop. + mask_org = generatemask((3000, 3000)) + mask = mask_org.copy() + + # Value x of R_x defined in the section 5 of the main paper. + r_threshold_value = 0.2 + #if R0: + # r_threshold_value = 0 + + input_resolution = img.shape + scale_threshold = 3 # Allows up-scaling with a scale up to 3 + + # Find the best input resolution R-x. The resolution search described in section 5-double estimation of the main paper and section B of the + # supplementary material. + whole_image_optimal_size, patch_scale = calculateprocessingres(img, net_receptive_field_size, r_threshold_value, scale_threshold, whole_size_threshold) + + # print('wholeImage being processed in :', whole_image_optimal_size) + + # Generate the base estimate using the double estimation. + whole_estimate = doubleestimate(img, net_receptive_field_size, whole_image_optimal_size, pix2pixsize, model, model_type, pix2pixmodel) + + # Compute the multiplier described in section 6 of the main paper to make sure our initial patch can select + # small high-density regions of the image. + global factor + factor = max(min(1, 4 * patch_scale * whole_image_optimal_size / whole_size_threshold), 0.2) + # print('Adjust factor is:', 1/factor) + + # Check if Local boosting is beneficial. + if max_res < whole_image_optimal_size: + # print("No Local boosting. Specified Max Res is smaller than R20, Returning doubleestimate result") + return cv2.resize(whole_estimate, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) + + # Compute the default target resolution. + if img.shape[0] > img.shape[1]: + a = 2 * whole_image_optimal_size + b = round(2 * whole_image_optimal_size * img.shape[1] / img.shape[0]) + else: + a = round(2 * whole_image_optimal_size * img.shape[0] / img.shape[1]) + b = 2 * whole_image_optimal_size + b = int(round(b / factor)) + a = int(round(a / factor)) + + """ + # recompute a, b and saturate to max res. + if max(a,b) > max_res: + print('Default Res is higher than max-res: Reducing final resolution') + if img.shape[0] > img.shape[1]: + a = max_res + b = round(max_res * img.shape[1] / img.shape[0]) + else: + a = round(max_res * img.shape[0] / img.shape[1]) + b = max_res + b = int(b) + a = int(a) + """ + + img = cv2.resize(img, (b, a), interpolation=cv2.INTER_CUBIC) + + # Extract selected patches for local refinement + base_size = net_receptive_field_size * 2 + patchset = generatepatchs(img, base_size) + + # print('Target resolution: ', img.shape) + + # Computing a scale in case user prompted to generate the results as the same resolution of the input. + # Notice that our method output resolution is independent of the input resolution and this parameter will only + # enable a scaling operation during the local patch merge implementation to generate results with the same resolution + # as the input. + """ + if output_resolution == 1: + mergein_scale = input_resolution[0] / img.shape[0] + print('Dynamicly change merged-in resolution; scale:', mergein_scale) + else: + mergein_scale = 1 + """ + # always rescale to input res for now + mergein_scale = input_resolution[0] / img.shape[0] + + imageandpatchs = ImageandPatchs('', '', patchset, img, mergein_scale) + whole_estimate_resized = cv2.resize(whole_estimate, (round(img.shape[1]*mergein_scale), + round(img.shape[0]*mergein_scale)), interpolation=cv2.INTER_CUBIC) + imageandpatchs.set_base_estimate(whole_estimate_resized.copy()) + imageandpatchs.set_updated_estimate(whole_estimate_resized.copy()) + + print('Resulting depthmap resolution will be :', whole_estimate_resized.shape[:2]) + print('Patches to process: '+str(len(imageandpatchs))) + + # Enumerate through all patches, generate their estimations and refining the base estimate. + for patch_ind in range(len(imageandpatchs)): + + # Get patch information + patch = imageandpatchs[patch_ind] # patch object + patch_rgb = patch['patch_rgb'] # rgb patch + patch_whole_estimate_base = patch['patch_whole_estimate_base'] # corresponding patch from base + rect = patch['rect'] # patch size and location + patch_id = patch['id'] # patch ID + org_size = patch_whole_estimate_base.shape # the original size from the unscaled input + print('\t Processing patch', patch_ind, '/', len(imageandpatchs)-1, '|', rect) + + # We apply double estimation for patches. The high resolution value is fixed to twice the receptive + # field size of the network for patches to accelerate the process. + patch_estimation = doubleestimate(patch_rgb, net_receptive_field_size, patch_netsize, pix2pixsize, model, model_type, pix2pixmodel) + patch_estimation = cv2.resize(patch_estimation, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + patch_whole_estimate_base = cv2.resize(patch_whole_estimate_base, (pix2pixsize, pix2pixsize), interpolation=cv2.INTER_CUBIC) + + # Merging the patch estimation into the base estimate using our merge network: + # We feed the patch estimation and the same region from the updated base estimate to the merge network + # to generate the target estimate for the corresponding region. + pix2pixmodel.set_input(patch_whole_estimate_base, patch_estimation) + + # Run merging network + pix2pixmodel.test() + visuals = pix2pixmodel.get_current_visuals() + + prediction_mapped = visuals['fake_B'] + prediction_mapped = (prediction_mapped+1)/2 + prediction_mapped = prediction_mapped.squeeze().cpu().numpy() + + mapped = prediction_mapped + + # We use a simple linear polynomial to make sure the result of the merge network would match the values of + # base estimate + p_coef = np.polyfit(mapped.reshape(-1), patch_whole_estimate_base.reshape(-1), deg=1) + merged = np.polyval(p_coef, mapped.reshape(-1)).reshape(mapped.shape) + + merged = cv2.resize(merged, (org_size[1],org_size[0]), interpolation=cv2.INTER_CUBIC) + + # Get patch size and location + w1 = rect[0] + h1 = rect[1] + w2 = w1 + rect[2] + h2 = h1 + rect[3] + + # To speed up the implementation, we only generate the Gaussian mask once with a sufficiently large size + # and resize it to our needed size while merging the patches. + if mask.shape != org_size: + mask = cv2.resize(mask_org, (org_size[1],org_size[0]), interpolation=cv2.INTER_LINEAR) + + tobemergedto = imageandpatchs.estimation_updated_image + + # Update the whole estimation: + # We use a simple Gaussian mask to blend the merged patch region with the base estimate to ensure seamless + # blending at the boundaries of the patch region. + tobemergedto[h1:h2, w1:w2] = np.multiply(tobemergedto[h1:h2, w1:w2], 1 - mask) + np.multiply(merged, mask) + imageandpatchs.set_updated_estimate(tobemergedto) + + # output + return cv2.resize(imageandpatchs.estimation_updated_image, (input_resolution[1], input_resolution[0]), interpolation=cv2.INTER_CUBIC) diff --git a/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py b/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py new file mode 100644 index 0000000000000000000000000000000000000000..fdf35d7843e00be5d3c831d72b9ab5d64d130f93 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/multi_depth_model_woauxi.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from . import network_auxi as network +from .net_tools import get_func + + +class RelDepthModel(nn.Module): + def __init__(self, backbone='resnet50'): + super(RelDepthModel, self).__init__() + if backbone == 'resnet50': + encoder = 'resnet50_stride32' + elif backbone == 'resnext101': + encoder = 'resnext101_stride32x8d' + self.depth_model = DepthModel(encoder) + + def inference(self, rgb): + with torch.no_grad(): + input = rgb.to(self.depth_model.device) + depth = self.depth_model(input) + #pred_depth_out = depth - depth.min() + 0.01 + return depth #pred_depth_out + + +class DepthModel(nn.Module): + def __init__(self, encoder): + super(DepthModel, self).__init__() + backbone = network.__name__.split('.')[-1] + '.' + encoder + self.encoder_modules = get_func(backbone)() + self.decoder_modules = network.Decoder() + + def forward(self, x): + lateral_out = self.encoder_modules(x) + out_logit = self.decoder_modules(lateral_out) + return out_logit \ No newline at end of file diff --git a/custom_controlnet_aux/leres/leres/net_tools.py b/custom_controlnet_aux/leres/leres/net_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d44c794d6a81cb0309fb4873f83489de377e30a8 --- /dev/null +++ b/custom_controlnet_aux/leres/leres/net_tools.py @@ -0,0 +1,54 @@ +import importlib +import torch +import os +from collections import OrderedDict + + +def get_func(func_name): + """Helper to return a function object by name. func_name must identify a + function in this module or the path to a function relative to the base + 'modeling' module. + """ + if func_name == '': + return None + try: + parts = func_name.split('.') + # Refers to a function in this module + if len(parts) == 1: + return globals()[parts[0]] + # Otherwise, assume we're referencing a module under modeling + module_name = 'custom_controlnet_aux.leres.leres.' + '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + return getattr(module, parts[-1]) + except Exception: + print('Failed to f1ind function: %s', func_name) + raise + +def load_ckpt(args, depth_model, shift_model, focal_model): + """ + Load checkpoint. + """ + if os.path.isfile(args.load_ckpt): + print("loading checkpoint %s" % args.load_ckpt) + checkpoint = torch.load(args.load_ckpt) + if shift_model is not None: + shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), + strict=True) + if focal_model is not None: + focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), + strict=True) + depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), + strict=True) + del checkpoint + if torch.cuda.is_available(): + torch.cuda.empty_cache() + + +def strip_prefix_if_present(state_dict, prefix): + keys = sorted(state_dict.keys()) + if not all(key.startswith(prefix) for key in keys): + return state_dict + stripped_state_dict = OrderedDict() + for key, value in state_dict.items(): + stripped_state_dict[key.replace(prefix, "")] = value + return stripped_state_dict \ No newline at end of file diff --git a/custom_controlnet_aux/leres/leres/network_auxi.py b/custom_controlnet_aux/leres/leres/network_auxi.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd87011a5339aca632d1a10b217c8737bdc794f --- /dev/null +++ b/custom_controlnet_aux/leres/leres/network_auxi.py @@ -0,0 +1,417 @@ +import torch +import torch.nn as nn +import torch.nn.init as init + +from . import Resnet, Resnext_torch + + +def resnet50_stride32(): + return DepthNet(backbone='resnet', depth=50, upfactors=[2, 2, 2, 2]) + +def resnext101_stride32x8d(): + return DepthNet(backbone='resnext101_32x8d', depth=101, upfactors=[2, 2, 2, 2]) + + +class Decoder(nn.Module): + def __init__(self): + super(Decoder, self).__init__() + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = [2,2,2,2] + self.outchannels = 1 + + self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) + self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) + self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) + + self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) + self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) + self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) + + self.outconv = AO(inchannels=self.midchannels[0], outchannels=self.outchannels, upfactor=2) + self._init_params() + + def _init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): #NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + def forward(self, features): + x_32x = self.conv(features[3]) # 1/32 + x_32 = self.conv1(x_32x) + x_16 = self.upsample(x_32) # 1/16 + + x_8 = self.ffm2(features[2], x_16) # 1/8 + x_4 = self.ffm1(features[1], x_8) # 1/4 + x_2 = self.ffm0(features[0], x_4) # 1/2 + #----------------------------------------- + x = self.outconv(x_2) # original size + return x + +class DepthNet(nn.Module): + __factory = { + 18: Resnet.resnet18, + 34: Resnet.resnet34, + 50: Resnet.resnet50, + 101: Resnet.resnet101, + 152: Resnet.resnet152 + } + def __init__(self, + backbone='resnet', + depth=50, + upfactors=[2, 2, 2, 2]): + super(DepthNet, self).__init__() + self.backbone = backbone + self.depth = depth + self.pretrained = False + self.inchannels = [256, 512, 1024, 2048] + self.midchannels = [256, 256, 256, 512] + self.upfactors = upfactors + self.outchannels = 1 + + # Build model + if self.backbone == 'resnet': + if self.depth not in DepthNet.__factory: + raise KeyError("Unsupported depth:", self.depth) + self.encoder = DepthNet.__factory[depth](pretrained=self.pretrained) + elif self.backbone == 'resnext101_32x8d': + self.encoder = Resnext_torch.resnext101_32x8d(pretrained=self.pretrained) + else: + self.encoder = Resnext_torch.resnext101(pretrained=self.pretrained) + + def forward(self, x): + x = self.encoder(x) # 1/32, 1/16, 1/8, 1/4 + return x + + +class FTB(nn.Module): + def __init__(self, inchannels, midchannels=512): + super(FTB, self).__init__() + self.in1 = inchannels + self.mid = midchannels + self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, + bias=True) + # NN.BatchNorm2d + self.conv_branch = nn.Sequential(nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.mid), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, + padding=1, stride=1, bias=True)) + self.relu = nn.ReLU(inplace=True) + + self.init_params() + + def forward(self, x): + x = self.conv1(x) + x = x + self.conv_branch(x) + x = self.relu(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class ATA(nn.Module): + def __init__(self, inchannels, reduction=8): + super(ATA, self).__init__() + self.inchannels = inchannels + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential(nn.Linear(self.inchannels * 2, self.inchannels // reduction), + nn.ReLU(inplace=True), + nn.Linear(self.inchannels // reduction, self.inchannels), + nn.Sigmoid()) + self.init_params() + + def forward(self, low_x, high_x): + n, c, _, _ = low_x.size() + x = torch.cat([low_x, high_x], 1) + x = self.avg_pool(x) + x = x.view(n, -1) + x = self.fc(x).view(n, c, 1, 1) + x = low_x * x + high_x + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + # init.normal_(m.weight, std=0.01) + init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FFM(nn.Module): + def __init__(self, inchannels, midchannels, outchannels, upfactor=2): + super(FFM, self).__init__() + self.inchannels = inchannels + self.midchannels = midchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) + # self.ata = ATA(inchannels = self.midchannels) + self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) + + self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) + + self.init_params() + + def forward(self, low_x, high_x): + x = self.ftb1(low_x) + x = x + high_x + x = self.ftb2(x) + x = self.upsample(x) + + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class AO(nn.Module): + # Adaptive output module + def __init__(self, inchannels, outchannels, upfactor=2): + super(AO, self).__init__() + self.inchannels = inchannels + self.outchannels = outchannels + self.upfactor = upfactor + + self.adapt_conv = nn.Sequential( + nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels // 2, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.BatchNorm2d(num_features=self.inchannels // 2), \ + nn.ReLU(inplace=True), \ + nn.Conv2d(in_channels=self.inchannels // 2, out_channels=self.outchannels, kernel_size=3, padding=1, + stride=1, bias=True), \ + nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) + + self.init_params() + + def forward(self, x): + x = self.adapt_conv(x) + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.Batchnorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + + +# ============================================================================================================== + + +class ResidualConv(nn.Module): + def __init__(self, inchannels): + super(ResidualConv, self).__init__() + # NN.BatchNorm2d + self.conv = nn.Sequential( + # nn.BatchNorm2d(num_features=inchannels), + nn.ReLU(inplace=False), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True), + # nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True) + nn.Conv2d(in_channels=inchannels, out_channels=inchannels / 2, kernel_size=3, padding=1, stride=1, + bias=False), + nn.BatchNorm2d(num_features=inchannels / 2), + nn.ReLU(inplace=False), + nn.Conv2d(in_channels=inchannels / 2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, + bias=False) + ) + self.init_params() + + def forward(self, x): + x = self.conv(x) + x + return x + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class FeatureFusion(nn.Module): + def __init__(self, inchannels, outchannels): + super(FeatureFusion, self).__init__() + self.conv = ResidualConv(inchannels=inchannels) + # NN.BatchNorm2d + self.up = nn.Sequential(ResidualConv(inchannels=inchannels), + nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3, + stride=2, padding=1, output_padding=1), + nn.BatchNorm2d(num_features=outchannels), + nn.ReLU(inplace=True)) + + def forward(self, lowfeat, highfeat): + return self.up(highfeat + self.conv(lowfeat)) + + def init_params(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.ConvTranspose2d): + # init.kaiming_normal_(m.weight, mode='fan_out') + init.normal_(m.weight, std=0.01) + # init.xavier_normal_(m.weight) + if m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): # NN.BatchNorm2d + init.constant_(m.weight, 1) + init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + init.normal_(m.weight, std=0.01) + if m.bias is not None: + init.constant_(m.bias, 0) + + +class SenceUnderstand(nn.Module): + def __init__(self, channels): + super(SenceUnderstand, self).__init__() + self.channels = channels + self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1), + nn.ReLU(inplace=True)) + self.pool = nn.AdaptiveAvgPool2d(8) + self.fc = nn.Sequential(nn.Linear(512 * 8 * 8, self.channels), + nn.ReLU(inplace=True)) + self.conv2 = nn.Sequential( + nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0), + nn.ReLU(inplace=True)) + self.initial_params() + + def forward(self, x): + n, c, h, w = x.size() + x = self.conv1(x) + x = self.pool(x) + x = x.view(n, -1) + x = self.fc(x) + x = x.view(n, self.channels, 1, 1) + x = self.conv2(x) + x = x.repeat(1, 1, h, w) + return x + + def initial_params(self, dev=0.01): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.ConvTranspose2d): + # print torch.sum(m.weight) + m.weight.data.normal_(0, dev) + if m.bias is not None: + m.bias.data.fill_(0) + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, dev) + + +if __name__ == '__main__': + net = DepthNet(depth=50, pretrained=True) + print(net) + inputs = torch.ones(4,3,128,128) + out = net(inputs) + print(out.size()) + diff --git a/custom_controlnet_aux/leres/pix2pix/LICENSE b/custom_controlnet_aux/leres/pix2pix/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..38b1a24fd389a138b930dcf1ee606ef97a0186c8 --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/LICENSE @@ -0,0 +1,19 @@ +https://github.com/compphoto/BoostingMonocularDepth + +Copyright 2021, Seyed Mahdi Hosseini Miangoleh, Sebastian Dille, Computational Photography Laboratory. All rights reserved. + +This software is for academic use only. A redistribution of this +software, with or without modifications, has to be for academic +use only, while giving the appropriate credit to the original +authors of the software. The methods implemented as a part of +this software may be covered under patents or patent applications. + +THIS SOFTWARE IS PROVIDED BY THE AUTHOR ''AS IS'' AND ANY EXPRESS OR IMPLIED +WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND +FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR +CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF +ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/custom_controlnet_aux/leres/pix2pix/__init__.py b/custom_controlnet_aux/leres/pix2pix/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/leres/pix2pix/models/__init__.py b/custom_controlnet_aux/leres/pix2pix/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cb83d333a655103146dde012f119f05c05635c3e --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/models/__init__.py @@ -0,0 +1,67 @@ +"""This package contains modules related to objective functions, optimizations, and network architectures. + +To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. +You need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate loss, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + +In the function <__init__>, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. + +Now you can use the model class by specifying flag '--model dummy'. +See our template model class 'template_model.py' for more details. +""" + +import importlib +from .base_model import BaseModel + + +def find_model_using_name(model_name): + """Import the module "models/[model_name]_model.py". + + In the file, the class called DatasetNameModel() will + be instantiated. It has to be a subclass of BaseModel, + and it is case-insensitive. + """ + model_filename = "custom_controlnet_aux.leres.pix2pix.models." + model_name + "_model" + modellib = importlib.import_module(model_filename) + model = None + target_model_name = model_name.replace('_', '') + 'model' + for name, cls in modellib.__dict__.items(): + if name.lower() == target_model_name.lower() \ + and issubclass(cls, BaseModel): + model = cls + + if model is None: + print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) + exit(0) + + return model + + +def get_option_setter(model_name): + """Return the static method of the model class.""" + model_class = find_model_using_name(model_name) + return model_class.modify_commandline_options + + +def create_model(opt): + """Create a model given the option. + + This function warps the class CustomDatasetDataLoader. + This is the main interface between this package and 'train.py'/'test.py' + + Example: + >>> from models import create_model + >>> model = create_model(opt) + """ + model = find_model_using_name(opt.model) + instance = model(opt) + print("model [%s] was created" % type(instance).__name__) + return instance diff --git a/custom_controlnet_aux/leres/pix2pix/models/base_model.py b/custom_controlnet_aux/leres/pix2pix/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..66ec298f77cf769e39da38d1107e0b6dc38d519d --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/models/base_model.py @@ -0,0 +1,244 @@ +import gc +import os +from abc import ABC, abstractmethod +from collections import OrderedDict + +import torch + +from ....util import torch_gc +from . import networks + + +class BaseModel(ABC): + """This class is an abstract base class (ABC) for models. + To create a subclass, you need to implement the following five functions: + -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). + -- : unpack data from dataset and apply preprocessing. + -- : produce intermediate results. + -- : calculate losses, gradients, and update network weights. + -- : (optionally) add model-specific options and set default options. + """ + + def __init__(self, opt): + """Initialize the BaseModel class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + + When creating your custom class, you need to implement your own initialization. + In this function, you should first call + Then, you need to define four lists: + -- self.loss_names (str list): specify the training losses that you want to plot and save. + -- self.model_names (str list): define networks used in our training. + -- self.visual_names (str list): specify the images that you want to display and save. + -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. + """ + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir + if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. + torch.backends.cudnn.benchmark = True + self.loss_names = [] + self.model_names = [] + self.visual_names = [] + self.optimizers = [] + self.image_paths = [] + self.metric = 0 # used for learning rate policy 'plateau' + + @staticmethod + def modify_commandline_options(parser, is_train): + """Add new model-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + """ + return parser + + @abstractmethod + def set_input(self, input): + """Unpack input data from the dataloader and perform necessary pre-processing steps. + + Parameters: + input (dict): includes the data itself and its metadata information. + """ + pass + + @abstractmethod + def forward(self): + """Run forward pass; called by both functions and .""" + pass + + @abstractmethod + def optimize_parameters(self): + """Calculate losses, gradients, and update network weights; called in every training iteration""" + pass + + def setup(self, opt): + """Load and print networks; create schedulers + + Parameters: + opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + if self.isTrain: + self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] + if not self.isTrain or opt.continue_train: + load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch + self.load_networks(load_suffix) + self.print_networks(opt.verbose) + + def eval(self): + """Make models eval mode during test time""" + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + net.eval() + + def test(self): + """Forward function used in test time. + + This function wraps function in no_grad() so we don't save intermediate steps for backprop + It also calls to produce additional visualization results + """ + with torch.no_grad(): + self.forward() + self.compute_visuals() + + def compute_visuals(self): + """Calculate additional output images for visdom and HTML visualization""" + pass + + def get_image_paths(self): + """ Return image paths that are used to load current data""" + return self.image_paths + + def update_learning_rate(self): + """Update learning rates for all the networks; called at the end of every epoch""" + old_lr = self.optimizers[0].param_groups[0]['lr'] + for scheduler in self.schedulers: + if self.opt.lr_policy == 'plateau': + scheduler.step(self.metric) + else: + scheduler.step() + + lr = self.optimizers[0].param_groups[0]['lr'] + print('learning rate %.7f -> %.7f' % (old_lr, lr)) + + def get_current_visuals(self): + """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" + visual_ret = OrderedDict() + for name in self.visual_names: + if isinstance(name, str): + visual_ret[name] = getattr(self, name) + return visual_ret + + def get_current_losses(self): + """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" + errors_ret = OrderedDict() + for name in self.loss_names: + if isinstance(name, str): + errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number + return errors_ret + + def save_networks(self, epoch): + """Save all the networks to the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + save_filename = '%s_net_%s.pth' % (epoch, name) + save_path = os.path.join(self.save_dir, save_filename) + net = getattr(self, 'net' + name) + + if len(self.gpu_ids) > 0 and torch.cuda.is_available(): + torch.save(net.module.cpu().state_dict(), save_path) + net.cuda(self.gpu_ids[0]) + else: + torch.save(net.cpu().state_dict(), save_path) + + def unload_network(self, name): + """Unload network and gc. + """ + if isinstance(name, str): + net = getattr(self, 'net' + name) + del net + gc.collect() + torch_gc() + return None + + def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): + """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" + key = keys[i] + if i + 1 == len(keys): # at the end, pointing to a parameter/buffer + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'running_mean' or key == 'running_var'): + if getattr(module, key) is None: + state_dict.pop('.'.join(keys)) + if module.__class__.__name__.startswith('InstanceNorm') and \ + (key == 'num_batches_tracked'): + state_dict.pop('.'.join(keys)) + else: + self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) + + def load_networks(self, epoch): + """Load all the networks from the disk. + + Parameters: + epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) + """ + for name in self.model_names: + if isinstance(name, str): + load_filename = '%s_net_%s.pth' % (epoch, name) + load_path = os.path.join(self.save_dir, load_filename) + net = getattr(self, 'net' + name) + if isinstance(net, torch.nn.DataParallel): + net = net.module + # print('Loading depth boost model from %s' % load_path) + # if you are using PyTorch newer than 0.4 (e.g., built from + # GitHub source), you can remove str() on self.device + state_dict = torch.load(load_path, map_location=str(self.device)) + if hasattr(state_dict, '_metadata'): + del state_dict._metadata + + # patch InstanceNorm checkpoints prior to 0.4 + for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop + self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) + net.load_state_dict(state_dict) + + def print_networks(self, verbose): + """Print the total number of parameters in the network and (if verbose) network architecture + + Parameters: + verbose (bool) -- if verbose: print the network architecture + """ + print('---------- Networks initialized -------------') + for name in self.model_names: + if isinstance(name, str): + net = getattr(self, 'net' + name) + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + if verbose: + print(net) + print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) + print('-----------------------------------------------') + + def set_requires_grad(self, nets, requires_grad=False): + """Set requies_grad=Fasle for all the networks to avoid unnecessary computations + Parameters: + nets (network list) -- a list of networks + requires_grad (bool) -- whether the networks require gradients or not + """ + if not isinstance(nets, list): + nets = [nets] + for net in nets: + if net is not None: + for param in net.parameters(): + param.requires_grad = requires_grad diff --git a/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py b/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py new file mode 100644 index 0000000000000000000000000000000000000000..1709accdf0b048b3793dfd1f58d1b06c35f7b907 --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/models/base_model_hg.py @@ -0,0 +1,58 @@ +import os +import torch + +class BaseModelHG(): + def name(self): + return 'BaseModel' + + def initialize(self, opt): + self.opt = opt + self.gpu_ids = opt.gpu_ids + self.isTrain = opt.isTrain + self.Tensor = torch.cuda.FloatTensor if self.gpu_ids else torch.Tensor + self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) + + def set_input(self, input): + self.input = input + + def forward(self): + pass + + # used in test time, no backprop + def test(self): + pass + + def get_image_paths(self): + pass + + def optimize_parameters(self): + pass + + def get_current_visuals(self): + return self.input + + def get_current_errors(self): + return {} + + def save(self, label): + pass + + # helper saving function that can be used by subclasses + def save_network(self, network, network_label, epoch_label, gpu_ids): + save_filename = '_%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + torch.save(network.cpu().state_dict(), save_path) + if len(gpu_ids) and torch.cuda.is_available(): + network.cuda(device_id=gpu_ids[0]) + + # helper loading function that can be used by subclasses + def load_network(self, network, network_label, epoch_label): + save_filename = '%s_net_%s.pth' % (epoch_label, network_label) + save_path = os.path.join(self.save_dir, save_filename) + print(save_path) + model = torch.load(save_path) + return model + # network.load_state_dict(torch.load(save_path)) + + def update_learning_rate(): + pass diff --git a/custom_controlnet_aux/leres/pix2pix/models/networks.py b/custom_controlnet_aux/leres/pix2pix/models/networks.py new file mode 100644 index 0000000000000000000000000000000000000000..0cf912b2973721a02deefd042af621e732bad59f --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/models/networks.py @@ -0,0 +1,623 @@ +import torch +import torch.nn as nn +from torch.nn import init +import functools +from torch.optim import lr_scheduler + + +############################################################################### +# Helper Functions +############################################################################### + + +class Identity(nn.Module): + def forward(self, x): + return x + + +def get_norm_layer(norm_type='instance'): + """Return a normalization layer + + Parameters: + norm_type (str) -- the name of the normalization layer: batch | instance | none + + For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). + For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + """ + if norm_type == 'batch': + norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True) + elif norm_type == 'instance': + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + elif norm_type == 'none': + def norm_layer(x): return Identity() + else: + raise NotImplementedError('normalization layer [%s] is not found' % norm_type) + return norm_layer + + +def get_scheduler(optimizer, opt): + """Return a learning rate scheduler + + Parameters: + optimizer -- the optimizer of the network + opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  + opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine + + For 'linear', we keep the same learning rate for the first epochs + and linearly decay the rate to zero over the next epochs. + For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. + See https://pytorch.org/docs/stable/optim.html for more details. + """ + if opt.lr_policy == 'linear': + def lambda_rule(epoch): + lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) + return lr_l + scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) + elif opt.lr_policy == 'step': + scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) + elif opt.lr_policy == 'plateau': + scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) + elif opt.lr_policy == 'cosine': + scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) + else: + return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) + return scheduler + + +def init_weights(net, init_type='normal', init_gain=0.02): + """Initialize network weights. + + Parameters: + net (network) -- network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + + We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might + work better for some applications. Feel free to try yourself. + """ + def init_func(m): # define the initialization function + classname = m.__class__.__name__ + if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): + if init_type == 'normal': + init.normal_(m.weight.data, 0.0, init_gain) + elif init_type == 'xavier': + init.xavier_normal_(m.weight.data, gain=init_gain) + elif init_type == 'kaiming': + init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + init.orthogonal_(m.weight.data, gain=init_gain) + else: + raise NotImplementedError('initialization method [%s] is not implemented' % init_type) + if hasattr(m, 'bias') and m.bias is not None: + init.constant_(m.bias.data, 0.0) + elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. + init.normal_(m.weight.data, 1.0, init_gain) + init.constant_(m.bias.data, 0.0) + + # print('initialize network with %s' % init_type) + net.apply(init_func) # apply the initialization function + + +def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights + Parameters: + net (network) -- the network to be initialized + init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal + gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Return an initialized network. + """ + if len(gpu_ids) > 0: + assert(torch.cuda.is_available()) + net.to(gpu_ids[0]) + net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs + init_weights(net, init_type, init_gain=init_gain) + return net + + +def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[]): + """Create a generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm (str) -- the name of normalization layers used in the network: batch | instance | none + use_dropout (bool) -- if use dropout layers. + init_type (str) -- the name of our initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a generator + + Our current implementation provides two types of generators: + U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) + The original U-Net paper: https://arxiv.org/abs/1505.04597 + + Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) + Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. + We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). + + + The generator has been initialized by . It uses RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netG == 'resnet_9blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9) + elif netG == 'resnet_6blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6) + elif netG == 'resnet_12blocks': + net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=12) + elif netG == 'unet_128': + net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_256': + net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_672': + net = UnetGenerator(input_nc, output_nc, 5, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_960': + net = UnetGenerator(input_nc, output_nc, 6, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + elif netG == 'unet_1024': + net = UnetGenerator(input_nc, output_nc, 10, ngf, norm_layer=norm_layer, use_dropout=use_dropout) + else: + raise NotImplementedError('Generator model name [%s] is not recognized' % netG) + return init_net(net, init_type, init_gain, gpu_ids) + + +def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]): + """Create a discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the first conv layer + netD (str) -- the architecture's name: basic | n_layers | pixel + n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str) -- the type of normalization layers used in the network. + init_type (str) -- the name of the initialization method. + init_gain (float) -- scaling factor for normal, xavier and orthogonal. + gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 + + Returns a discriminator + + Our current implementation provides three types of discriminators: + [basic]: 'PatchGAN' classifier described in the original pix2pix paper. + It can classify whether 70×70 overlapping patches are real or fake. + Such a patch-level discriminator architecture has fewer parameters + than a full-image discriminator and can work on arbitrarily-sized images + in a fully convolutional fashion. + + [n_layers]: With this mode, you can specify the number of conv layers in the discriminator + with the parameter (default=3 as used in [basic] (PatchGAN).) + + [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. + It encourages greater color diversity but has no effect on spatial statistics. + + The discriminator has been initialized by . It uses Leakly RELU for non-linearity. + """ + net = None + norm_layer = get_norm_layer(norm_type=norm) + + if netD == 'basic': # default PatchGAN classifier + net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) + elif netD == 'n_layers': # more options + net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) + elif netD == 'pixel': # classify if each pixel is real or fake + net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) + else: + raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) + return init_net(net, init_type, init_gain, gpu_ids) + + +############################################################################## +# Classes +############################################################################## +class GANLoss(nn.Module): + """Define different GAN objectives. + + The GANLoss class abstracts away the need to create the target label tensor + that has the same size as the input. + """ + + def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): + """ Initialize the GANLoss class. + + Parameters: + gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool) - - label for a real image + target_fake_label (bool) - - label of a fake image + + Note: Do not use sigmoid as the last layer of Discriminator. + LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + """ + super(GANLoss, self).__init__() + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + self.gan_mode = gan_mode + if gan_mode == 'lsgan': + self.loss = nn.MSELoss() + elif gan_mode == 'vanilla': + self.loss = nn.BCEWithLogitsLoss() + elif gan_mode in ['wgangp']: + self.loss = None + else: + raise NotImplementedError('gan mode %s not implemented' % gan_mode) + + def get_target_tensor(self, prediction, target_is_real): + """Create label tensors with the same size as the input. + + Parameters: + prediction (tensor) - - tpyically the prediction from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + A label tensor filled with ground truth label, and with the size of the input + """ + + if target_is_real: + target_tensor = self.real_label + else: + target_tensor = self.fake_label + return target_tensor.expand_as(prediction) + + def __call__(self, prediction, target_is_real): + """Calculate loss given Discriminator's output and grount truth labels. + + Parameters: + prediction (tensor) - - tpyically the prediction output from a discriminator + target_is_real (bool) - - if the ground truth label is for real images or fake images + + Returns: + the calculated loss. + """ + if self.gan_mode in ['lsgan', 'vanilla']: + target_tensor = self.get_target_tensor(prediction, target_is_real) + loss = self.loss(prediction, target_tensor) + elif self.gan_mode == 'wgangp': + if target_is_real: + loss = -prediction.mean() + else: + loss = prediction.mean() + return loss + + +def cal_gradient_penalty(netD, real_data, fake_data, device, type='mixed', constant=1.0, lambda_gp=10.0): + """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 + + Arguments: + netD (network) -- discriminator network + real_data (tensor array) -- real images + fake_data (tensor array) -- generated images from the generator + device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str) -- if we mix real and fake data or not [real | fake | mixed]. + constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 + lambda_gp (float) -- weight for this loss + + Returns the gradient penalty loss + """ + if lambda_gp > 0.0: + if type == 'real': # either use real images, fake images, or a linear interpolation of two. + interpolatesv = real_data + elif type == 'fake': + interpolatesv = fake_data + elif type == 'mixed': + alpha = torch.rand(real_data.shape[0], 1, device=device) + alpha = alpha.expand(real_data.shape[0], real_data.nelement() // real_data.shape[0]).contiguous().view(*real_data.shape) + interpolatesv = alpha * real_data + ((1 - alpha) * fake_data) + else: + raise NotImplementedError('{} not implemented'.format(type)) + interpolatesv.requires_grad_(True) + disc_interpolates = netD(interpolatesv) + gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolatesv, + grad_outputs=torch.ones(disc_interpolates.size()).to(device), + create_graph=True, retain_graph=True, only_inputs=True) + gradients = gradients[0].view(real_data.size(0), -1) # flat the data + gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - constant) ** 2).mean() * lambda_gp # added eps + return gradient_penalty, gradients + else: + return 0.0, None + + +class ResnetGenerator(nn.Module): + """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. + + We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + """ + + def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='reflect'): + """Construct a Resnet-based generator + + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers + n_blocks (int) -- the number of ResNet blocks + padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + """ + assert(n_blocks >= 0) + super(ResnetGenerator, self).__init__() + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + model = [nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias), + norm_layer(ngf), + nn.ReLU(True)] + + n_downsampling = 2 + for i in range(n_downsampling): # add downsampling layers + mult = 2 ** i + model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias), + norm_layer(ngf * mult * 2), + nn.ReLU(True)] + + mult = 2 ** n_downsampling + for i in range(n_blocks): # add ResNet blocks + + model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)] + + for i in range(n_downsampling): # add upsampling layers + mult = 2 ** (n_downsampling - i) + model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), + kernel_size=3, stride=2, + padding=1, output_padding=1, + bias=use_bias), + norm_layer(int(ngf * mult / 2)), + nn.ReLU(True)] + model += [nn.ReflectionPad2d(3)] + model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)] + model += [nn.Tanh()] + + self.model = nn.Sequential(*model) + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class ResnetBlock(nn.Module): + """Define a Resnet block""" + + def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Initialize the Resnet block + + A resnet block is a conv block with skip connections + We construct a conv block with build_conv_block function, + and implement skip connections in function. + Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + """ + super(ResnetBlock, self).__init__() + self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias) + + def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias): + """Construct a convolutional block. + + Parameters: + dim (int) -- the number of channels in the conv layer. + padding_type (str) -- the name of padding layer: reflect | replicate | zero + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + use_bias (bool) -- if the conv layer uses bias or not + + Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + """ + conv_block = [] + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)] + if use_dropout: + conv_block += [nn.Dropout(0.5)] + + p = 0 + if padding_type == 'reflect': + conv_block += [nn.ReflectionPad2d(1)] + elif padding_type == 'replicate': + conv_block += [nn.ReplicationPad2d(1)] + elif padding_type == 'zero': + p = 1 + else: + raise NotImplementedError('padding [%s] is not implemented' % padding_type) + conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)] + + return nn.Sequential(*conv_block) + + def forward(self, x): + """Forward function (with skip connections)""" + out = x + self.conv_block(x) # add skip connections + return out + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class NLayerDiscriminator(nn.Module): + """Defines a PatchGAN discriminator""" + + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): + """Construct a PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + n_layers (int) -- the number of conv layers in the discriminator + norm_layer -- normalization layer + """ + super(NLayerDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 4 + padw = 1 + sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True) + ] + + sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map + self.model = nn.Sequential(*sequence) + + def forward(self, input): + """Standard forward.""" + return self.model(input) + + +class PixelDiscriminator(nn.Module): + """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" + + def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): + """Construct a 1x1 PatchGAN discriminator + + Parameters: + input_nc (int) -- the number of channels in input images + ndf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + """ + super(PixelDiscriminator, self).__init__() + if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + self.net = [ + nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), + norm_layer(ndf * 2), + nn.LeakyReLU(0.2, True), + nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] + + self.net = nn.Sequential(*self.net) + + def forward(self, input): + """Standard forward.""" + return self.net(input) diff --git a/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py b/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..89e89652feb96314973a050c5a2477b474630abb --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/models/pix2pix4depth_model.py @@ -0,0 +1,155 @@ +import torch +from .base_model import BaseModel +from . import networks + + +class Pix2Pix4DepthModel(BaseModel): + """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. + + The model training requires '--dataset_mode aligned' dataset. + By default, it uses a '--netG unet256' U-Net generator, + a '--netD basic' discriminator (PatchGAN), + and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). + + pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf + """ + @staticmethod + def modify_commandline_options(parser, is_train=True): + """Add new dataset-specific options, and rewrite default values for existing options. + + Parameters: + parser -- original option parser + is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. + + Returns: + the modified parser. + + For pix2pix, we do not use image buffer + The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 + By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. + """ + # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) + parser.set_defaults(input_nc=2,output_nc=1,norm='none', netG='unet_1024', dataset_mode='depthmerge') + if is_train: + parser.set_defaults(pool_size=0, gan_mode='vanilla',) + parser.add_argument('--lambda_L1', type=float, default=1000, help='weight for L1 loss') + return parser + + def __init__(self, opt): + """Initialize the pix2pix class. + + Parameters: + opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions + """ + BaseModel.__init__(self, opt) + # specify the training losses you want to print out. The training/test scripts will call + + self.loss_names = ['G_GAN', 'G_L1', 'D_real', 'D_fake'] + # self.loss_names = ['G_L1'] + + # specify the images you want to save/display. The training/test scripts will call + if self.isTrain: + self.visual_names = ['outer','inner', 'fake_B', 'real_B'] + else: + self.visual_names = ['fake_B'] + + # specify the models you want to save to the disk. The training/test scripts will call and + if self.isTrain: + self.model_names = ['G','D'] + else: # during test time, only load G + self.model_names = ['G'] + + # define networks (both generator and discriminator) + self.netG = networks.define_G(opt.input_nc, opt.output_nc, 64, 'unet_1024', 'none', + False, 'normal', 0.02, self.gpu_ids) + + if self.isTrain: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc + self.netD = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD, + opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids) + + if self.isTrain: + # define loss functions + self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) + self.criterionL1 = torch.nn.L1Loss() + # initialize optimizers; schedulers will be automatically created by function . + self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=1e-4, betas=(opt.beta1, 0.999)) + self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=2e-06, betas=(opt.beta1, 0.999)) + self.optimizers.append(self.optimizer_G) + self.optimizers.append(self.optimizer_D) + + def set_input_train(self, input): + self.outer = input['data_outer'].to(self.device) + self.outer = torch.nn.functional.interpolate(self.outer,(1024,1024),mode='bilinear',align_corners=False) + + self.inner = input['data_inner'].to(self.device) + self.inner = torch.nn.functional.interpolate(self.inner,(1024,1024),mode='bilinear',align_corners=False) + + self.image_paths = input['image_path'] + + if self.isTrain: + self.gtfake = input['data_gtfake'].to(self.device) + self.gtfake = torch.nn.functional.interpolate(self.gtfake, (1024, 1024), mode='bilinear', align_corners=False) + self.real_B = self.gtfake + + self.real_A = torch.cat((self.outer, self.inner), 1) + + def set_input(self, outer, inner): + inner = torch.from_numpy(inner).unsqueeze(0).unsqueeze(0) + outer = torch.from_numpy(outer).unsqueeze(0).unsqueeze(0) + + inner = (inner - torch.min(inner))/(torch.max(inner)-torch.min(inner)) + outer = (outer - torch.min(outer))/(torch.max(outer)-torch.min(outer)) + + inner = self.normalize(inner) + outer = self.normalize(outer) + + self.real_A = torch.cat((outer, inner), 1).to(self.device) + + + def normalize(self, input): + input = input * 2 + input = input - 1 + return input + + def forward(self): + """Run forward pass; called by both functions and .""" + self.fake_B = self.netG(self.real_A) # G(A) + + def backward_D(self): + """Calculate GAN loss for the discriminator""" + # Fake; stop backprop to the generator by detaching fake_B + fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator + pred_fake = self.netD(fake_AB.detach()) + self.loss_D_fake = self.criterionGAN(pred_fake, False) + # Real + real_AB = torch.cat((self.real_A, self.real_B), 1) + pred_real = self.netD(real_AB) + self.loss_D_real = self.criterionGAN(pred_real, True) + # combine loss and calculate gradients + self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 + self.loss_D.backward() + + def backward_G(self): + """Calculate GAN and L1 loss for the generator""" + # First, G(A) should fake the discriminator + fake_AB = torch.cat((self.real_A, self.fake_B), 1) + pred_fake = self.netD(fake_AB) + self.loss_G_GAN = self.criterionGAN(pred_fake, True) + # Second, G(A) = B + self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_L1 + # combine loss and calculate gradients + self.loss_G = self.loss_G_L1 + self.loss_G_GAN + self.loss_G.backward() + + def optimize_parameters(self): + self.forward() # compute fake images: G(A) + # update D + self.set_requires_grad(self.netD, True) # enable backprop for D + self.optimizer_D.zero_grad() # set D's gradients to zero + self.backward_D() # calculate gradients for D + self.optimizer_D.step() # update D's weights + # update G + self.set_requires_grad(self.netD, False) # D requires no gradients when optimizing G + self.optimizer_G.zero_grad() # set G's gradients to zero + self.backward_G() # calculate graidents for G + self.optimizer_G.step() # udpate G's weights \ No newline at end of file diff --git a/custom_controlnet_aux/leres/pix2pix/options/__init__.py b/custom_controlnet_aux/leres/pix2pix/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7eedebe54aa70169fd25951b3034d819e396c90 --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/options/__init__.py @@ -0,0 +1 @@ +"""This package options includes option modules: training options, test options, and basic options (used in both training and test).""" diff --git a/custom_controlnet_aux/leres/pix2pix/options/base_options.py b/custom_controlnet_aux/leres/pix2pix/options/base_options.py new file mode 100644 index 0000000000000000000000000000000000000000..533a1e88a7e8494223f6994e6861c93667754f83 --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/options/base_options.py @@ -0,0 +1,156 @@ +import argparse +import os +from ...pix2pix.util import util +# import torch +from ...pix2pix import models +# import pix2pix.data +import numpy as np + +class BaseOptions(): + """This class defines options used during both training and test time. + + It also implements several helper functions such as parsing, printing, and saving the options. + It also gathers additional options defined in functions in both dataset class and model class. + """ + + def __init__(self): + """Reset the class; indicates the class hasn't been initailized""" + self.initialized = False + + def initialize(self, parser): + """Define the common options that are used in both training and test.""" + # basic parameters + parser.add_argument('--dataroot', help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') + parser.add_argument('--name', type=str, default='void', help='mahdi_unet_new, scaled_unet') + parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') + parser.add_argument('--checkpoints_dir', type=str, default='./pix2pix/checkpoints', help='models are saved here') + # model parameters + parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') + parser.add_argument('--input_nc', type=int, default=2, help='# of input image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--output_nc', type=int, default=1, help='# of output image channels: 3 for RGB and 1 for grayscale') + parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') + parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') + parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') + parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') + parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') + parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') + parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') + parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') + parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') + # dataset parameters + parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') + parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') + parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') + parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') + parser.add_argument('--batch_size', type=int, default=1, help='input batch size') + parser.add_argument('--load_size', type=int, default=672, help='scale images to this size') + parser.add_argument('--crop_size', type=int, default=672, help='then crop to this size') + parser.add_argument('--max_dataset_size', type=int, default=10000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') + parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') + parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') + parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') + # additional parameters + parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') + parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') + parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') + parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') + + parser.add_argument('--data_dir', type=str, required=False, + help='input files directory images can be .png .jpg .tiff') + parser.add_argument('--output_dir', type=str, required=False, + help='result dir. result depth will be png. vides are JMPG as avi') + parser.add_argument('--savecrops', type=int, required=False) + parser.add_argument('--savewholeest', type=int, required=False) + parser.add_argument('--output_resolution', type=int, required=False, + help='0 for no restriction 1 for resize to input size') + parser.add_argument('--net_receptive_field_size', type=int, required=False) + parser.add_argument('--pix2pixsize', type=int, required=False) + parser.add_argument('--generatevideo', type=int, required=False) + parser.add_argument('--depthNet', type=int, required=False, help='0: midas 1:strurturedRL') + parser.add_argument('--R0', action='store_true') + parser.add_argument('--R20', action='store_true') + parser.add_argument('--Final', action='store_true') + parser.add_argument('--colorize_results', action='store_true') + parser.add_argument('--max_res', type=float, default=np.inf) + + self.initialized = True + return parser + + def gather_options(self): + """Initialize our parser with basic options(only once). + Add additional model-specific and dataset-specific options. + These options are defined in the function + in model and dataset classes. + """ + if not self.initialized: # check if it has been initialized + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser = self.initialize(parser) + + # get the basic options + opt, _ = parser.parse_known_args() + + # modify model-related parser options + model_name = opt.model + model_option_setter = models.get_option_setter(model_name) + parser = model_option_setter(parser, self.isTrain) + opt, _ = parser.parse_known_args() # parse again with new defaults + + # modify dataset-related parser options + # dataset_name = opt.dataset_mode + # dataset_option_setter = pix2pix.data.get_option_setter(dataset_name) + # parser = dataset_option_setter(parser, self.isTrain) + + # save and return the parser + self.parser = parser + #return parser.parse_args() #EVIL + return opt + + def print_options(self, opt): + """Print and save options + + It will print both current options and default values(if different). + It will save options into a text file / [checkpoints_dir] / opt.txt + """ + message = '' + message += '----------------- Options ---------------\n' + for k, v in sorted(vars(opt).items()): + comment = '' + default = self.parser.get_default(k) + if v != default: + comment = '\t[default: %s]' % str(default) + message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) + message += '----------------- End -------------------' + print(message) + + # save to the disk + expr_dir = os.path.join(opt.checkpoints_dir, opt.name) + util.mkdirs(expr_dir) + file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) + with open(file_name, 'wt') as opt_file: + opt_file.write(message) + opt_file.write('\n') + + def parse(self): + """Parse our options, create checkpoints directory suffix, and set up gpu device.""" + opt = self.gather_options() + opt.isTrain = self.isTrain # train or test + + # process opt.suffix + if opt.suffix: + suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' + opt.name = opt.name + suffix + + #self.print_options(opt) + + # set gpu ids + str_ids = opt.gpu_ids.split(',') + opt.gpu_ids = [] + for str_id in str_ids: + id = int(str_id) + if id >= 0: + opt.gpu_ids.append(id) + #if len(opt.gpu_ids) > 0: + # torch.cuda.set_device(opt.gpu_ids[0]) + + self.opt = opt + return self.opt diff --git a/custom_controlnet_aux/leres/pix2pix/options/test_options.py b/custom_controlnet_aux/leres/pix2pix/options/test_options.py new file mode 100644 index 0000000000000000000000000000000000000000..a3424b5e3b66d6813f74c8cecad691d7488d121c --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/options/test_options.py @@ -0,0 +1,22 @@ +from .base_options import BaseOptions + + +class TestOptions(BaseOptions): + """This class includes test options. + + It also includes shared options defined in BaseOptions. + """ + + def initialize(self, parser): + parser = BaseOptions.initialize(self, parser) # define shared options + parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') + parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') + # Dropout and Batchnorm has different behavioir during training and test. + parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') + parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') + # rewrite devalue values + parser.set_defaults(model='pix2pix4depth') + # To avoid cropping, the load_size should be the same as crop_size + parser.set_defaults(load_size=parser.get_default('crop_size')) + self.isTrain = False + return parser diff --git a/custom_controlnet_aux/leres/pix2pix/util/__init__.py b/custom_controlnet_aux/leres/pix2pix/util/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae36f63d8859ec0c60dcbfe67c4ac324e751ddf7 --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/util/__init__.py @@ -0,0 +1 @@ +"""This package includes a miscellaneous collection of useful helper functions.""" diff --git a/custom_controlnet_aux/leres/pix2pix/util/util.py b/custom_controlnet_aux/leres/pix2pix/util/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8a7aceaa00681cb76675df7866bf8db58c8d2caf --- /dev/null +++ b/custom_controlnet_aux/leres/pix2pix/util/util.py @@ -0,0 +1,105 @@ +"""This module contains simple helper functions """ +from __future__ import print_function +import torch +import numpy as np +from PIL import Image +import os + + +def tensor2im(input_image, imtype=np.uint16): + """"Converts a Tensor array into a numpy image array. + + Parameters: + input_image (tensor) -- the input image tensor array + imtype (type) -- the desired type of the converted numpy array + """ + if not isinstance(input_image, np.ndarray): + if isinstance(input_image, torch.Tensor): # get the data from a variable + image_tensor = input_image.data + else: + return input_image + image_numpy = torch.squeeze(image_tensor).cpu().numpy() # convert it into a numpy array + image_numpy = (image_numpy + 1) / 2.0 * (2**16-1) # + else: # if it is a numpy array, do nothing + image_numpy = input_image + return image_numpy.astype(imtype) + + +def diagnose_network(net, name='network'): + """Calculate and print the mean of average absolute(gradients) + + Parameters: + net (torch network) -- Torch network + name (str) -- the name of the network + """ + mean = 0.0 + count = 0 + for param in net.parameters(): + if param.grad is not None: + mean += torch.mean(torch.abs(param.grad.data)) + count += 1 + if count > 0: + mean = mean / count + print(name) + print(mean) + + +def save_image(image_numpy, image_path, aspect_ratio=1.0): + """Save a numpy image to the disk + + Parameters: + image_numpy (numpy array) -- input numpy array + image_path (str) -- the path of the image + """ + image_pil = Image.fromarray(image_numpy) + + image_pil = image_pil.convert('I;16') + + # image_pil = Image.fromarray(image_numpy) + # h, w, _ = image_numpy.shape + # + # if aspect_ratio > 1.0: + # image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) + # if aspect_ratio < 1.0: + # image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) + + image_pil.save(image_path) + + +def print_numpy(x, val=True, shp=False): + """Print the mean, min, max, median, std, and size of a numpy array + + Parameters: + val (bool) -- if print the values of the numpy array + shp (bool) -- if print the shape of the numpy array + """ + x = x.astype(np.float64) + if shp: + print('shape,', x.shape) + if val: + x = x.flatten() + print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( + np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) + + +def mkdirs(paths): + """create empty directories if they don't exist + + Parameters: + paths (str list) -- a list of directory paths + """ + if isinstance(paths, list) and not isinstance(paths, str): + for path in paths: + mkdir(path) + else: + mkdir(paths) + + +def mkdir(path): + """create a single empty directory if it didn't exist + + Parameters: + path (str) -- a single directory path + """ + if not os.path.exists(path): + os.makedirs(path) diff --git a/custom_controlnet_aux/lineart/LICENSE b/custom_controlnet_aux/lineart/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/custom_controlnet_aux/lineart/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/lineart/__init__.py b/custom_controlnet_aux/lineart/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d42dc501f67f8fa333b783883373ddb271bfceaf --- /dev/null +++ b/custom_controlnet_aux/lineart/__init__.py @@ -0,0 +1,141 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME + +norm_layer = nn.InstanceNorm2d + + +class ResidualBlock(nn.Module): + def __init__(self, in_features): + super(ResidualBlock, self).__init__() + + conv_block = [ nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features), + nn.ReLU(inplace=True), + nn.ReflectionPad2d(1), + nn.Conv2d(in_features, in_features, 3), + norm_layer(in_features) + ] + + self.conv_block = nn.Sequential(*conv_block) + + def forward(self, x): + return x + self.conv_block(x) + + +class Generator(nn.Module): + def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True): + super(Generator, self).__init__() + + # Initial convolution block + model0 = [ nn.ReflectionPad2d(3), + nn.Conv2d(input_nc, 64, 7), + norm_layer(64), + nn.ReLU(inplace=True) ] + self.model0 = nn.Sequential(*model0) + + # Downsampling + model1 = [] + in_features = 64 + out_features = in_features*2 + for _ in range(2): + model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features*2 + self.model1 = nn.Sequential(*model1) + + model2 = [] + # Residual blocks + for _ in range(n_residual_blocks): + model2 += [ResidualBlock(in_features)] + self.model2 = nn.Sequential(*model2) + + # Upsampling + model3 = [] + out_features = in_features//2 + for _ in range(2): + model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1), + norm_layer(out_features), + nn.ReLU(inplace=True) ] + in_features = out_features + out_features = in_features//2 + self.model3 = nn.Sequential(*model3) + + # Output layer + model4 = [ nn.ReflectionPad2d(3), + nn.Conv2d(64, output_nc, 7)] + if sigmoid: + model4 += [nn.Sigmoid()] + + self.model4 = nn.Sequential(*model4) + + def forward(self, x, cond=None): + out = self.model0(x) + out = self.model1(out) + out = self.model2(out) + out = self.model3(out) + out = self.model4(out) + + return out + + +class LineartDetector: + def __init__(self, model, coarse_model): + self.model = model + self.model_coarse = coarse_model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="sk_model.pth", coarse_filename="sk_model2.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + coarse_model_path = custom_hf_download(pretrained_model_or_path, coarse_filename) + + model = Generator(3, 1, 3) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + model.eval() + + coarse_model = Generator(3, 1, 3) + coarse_model.load_state_dict(torch.load(coarse_model_path, map_location=torch.device('cpu'))) + coarse_model.eval() + + return cls(model, coarse_model) + + def to(self, device): + self.model.to(device) + self.model_coarse.to(device) + self.device = device + return self + + def __call__(self, input_image, coarse=False, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + model = self.model_coarse if coarse else self.model + assert detected_map.ndim == 3 + with torch.no_grad(): + image = torch.from_numpy(detected_map).float().to(self.device) + image = image / 255.0 + image = rearrange(image, 'h w c -> 1 c h w') + line = model(image)[0][0] + + line = line.cpu().numpy() + line = (line * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = HWC3(line) + detected_map = remove_pad(255 - detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/lineart_anime/LICENSE b/custom_controlnet_aux/lineart_anime/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/custom_controlnet_aux/lineart_anime/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/lineart_anime/__init__.py b/custom_controlnet_aux/lineart_anime/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f42dfeb3a5dd504121c4d22aa0bab7a16dc681ed --- /dev/null +++ b/custom_controlnet_aux/lineart_anime/__init__.py @@ -0,0 +1,167 @@ +import functools +import os +import warnings + +import cv2 +import numpy as np +import torch +import torch.nn as nn +from einops import rearrange +from huggingface_hub import hf_hub_download +from PIL import Image + +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME + + +class UnetGenerator(nn.Module): + """Create a Unet-based generator""" + + def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet generator + Parameters: + input_nc (int) -- the number of channels in input images + output_nc (int) -- the number of channels in output images + num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int) -- the number of filters in the last conv layer + norm_layer -- normalization layer + We construct the U-Net from the innermost layer to the outermost layer. + It is a recursive process. + """ + super(UnetGenerator, self).__init__() + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer + for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout) + # gradually reduce the number of filters from ngf * 8 to ngf + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer + + def forward(self, input): + """Standard forward""" + return self.model(input) + + +class UnetSkipConnectionBlock(nn.Module): + """Defines the Unet submodule with skip connection. + X -------------------identity---------------------- + |-- downsampling -- |submodule| -- upsampling --| + """ + + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + """Construct a Unet submodule with skip connections. + Parameters: + outer_nc (int) -- the number of filters in the outer conv layer + inner_nc (int) -- the number of filters in the inner conv layer + input_nc (int) -- the number of channels in input images/features + submodule (UnetSkipConnectionBlock) -- previously defined submodules + outermost (bool) -- if this module is the outermost module + innermost (bool) -- if this module is the innermost module + norm_layer -- normalization layer + use_dropout (bool) -- if use dropout layers. + """ + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: # add skip connections + return torch.cat([x, self.model(x)], 1) + + +class LineartAnimeDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="netG.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) + net = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False) + ckpt = torch.load(model_path) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + + return cls(net) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + H, W, C = input_image.shape + Hn = 256 * int(np.ceil(float(H) / 256.0)) + Wn = 256 * int(np.ceil(float(W) / 256.0)) + input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_CUBIC) + + with torch.no_grad(): + image_feed = torch.from_numpy(input_image).float().to(self.device) + image_feed = image_feed / 127.5 - 1.0 + image_feed = rearrange(image_feed, 'h w c -> 1 c h w') + + line = self.model(image_feed)[0, 0] * 127.5 + 127.5 + line = line.cpu().numpy() + line = line.clip(0, 255).astype(np.uint8) + + #A1111 uses INTER AREA for downscaling so ig that is the best choice + detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_AREA) + detected_map = remove_pad(255 - detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/lineart_standard/__init__.py b/custom_controlnet_aux/lineart_standard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7932e733cf8147a7a500c86ed3eb192a23504626 --- /dev/null +++ b/custom_controlnet_aux/lineart_standard/__init__.py @@ -0,0 +1,21 @@ +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3 + +class LineartStandardDetector: + def __call__(self, input_image=None, guassian_sigma=6.0, intensity_threshold=8, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + x = input_image.astype(np.float32) + g = cv2.GaussianBlur(x, (0, 0), guassian_sigma) + intensity = np.min(g - x, axis=2).clip(0, 255) + intensity /= max(16, np.median(intensity[intensity > intensity_threshold])) + intensity *= 127 + detected_map = intensity.clip(0, 255).astype(np.uint8) + + detected_map = HWC3(remove_pad(detected_map)) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/manga_line/LICENSE b/custom_controlnet_aux/manga_line/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..bdca75a54d05781782e3d939401e93161cdd88f7 --- /dev/null +++ b/custom_controlnet_aux/manga_line/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Miaomiao Li + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/manga_line/__init__.py b/custom_controlnet_aux/manga_line/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..16dea641c70110d920b2bcd72754bb94c4ff087f --- /dev/null +++ b/custom_controlnet_aux/manga_line/__init__.py @@ -0,0 +1,63 @@ +# MangaLineExtraction_PyTorch +# https://github.com/ljsabc/MangaLineExtraction_PyTorch + +#NOTE: This preprocessor is designed to work with lineart_anime ControlNet so the result will be white lines on black canvas + +import torch +import numpy as np +import os +import cv2 +from einops import rearrange +from .model_torch import res_skip +from PIL import Image +import warnings + +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, HF_MODEL_NAME + +class LineartMangaDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="erika.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + net = res_skip() + ckpt = torch.load(model_path) + for key in list(ckpt.keys()): + if 'module.' in key: + ckpt[key.replace('module.', '')] = ckpt[key] + del ckpt[key] + net.load_state_dict(ckpt) + net.eval() + return cls(net) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, 256 * int(np.ceil(float(detect_resolution) / 256.0)), upscale_method) + + img = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY) + with torch.no_grad(): + image_feed = torch.from_numpy(img).float().to(self.device) + image_feed = rearrange(image_feed, 'h w -> 1 1 h w') + + line = self.model(image_feed) + line = line.cpu().numpy()[0,0,:,:] + line[line > 255] = 255 + line[line < 0] = 0 + + line = line.astype(np.uint8) + + detected_map = HWC3(line) + detected_map = remove_pad(255 - detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/manga_line/model_torch.py b/custom_controlnet_aux/manga_line/model_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..de5828ccc486d74490b8da710d644651067bd5f3 --- /dev/null +++ b/custom_controlnet_aux/manga_line/model_torch.py @@ -0,0 +1,196 @@ +import torch.nn as nn +import numpy as np + +#torch.set_printoptions(precision=10) + + +class _bn_relu_conv(nn.Module): + def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): + super(_bn_relu_conv, self).__init__() + self.model = nn.Sequential( + nn.BatchNorm2d(in_filters, eps=1e-3), + nn.LeakyReLU(0.2), + nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2), padding_mode='zeros') + ) + + def forward(self, x): + return self.model(x) + + # the following are for debugs + print("****", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) + for i,layer in enumerate(self.model): + if i != 2: + x = layer(x) + else: + x = layer(x) + #x = nn.functional.pad(x, (1, 1, 1, 1), mode='constant', value=0) + print("____", np.max(x.cpu().numpy()), np.min(x.cpu().numpy()), np.mean(x.cpu().numpy()), np.std(x.cpu().numpy()), x.shape) + print(x[0]) + return x + + +class _u_bn_relu_conv(nn.Module): + def __init__(self, in_filters, nb_filters, fw, fh, subsample=1): + super(_u_bn_relu_conv, self).__init__() + self.model = nn.Sequential( + nn.BatchNorm2d(in_filters, eps=1e-3), + nn.LeakyReLU(0.2), + nn.Conv2d(in_filters, nb_filters, (fw, fh), stride=subsample, padding=(fw//2, fh//2)), + nn.Upsample(scale_factor=2, mode='nearest') + ) + + def forward(self, x): + return self.model(x) + + + +class _shortcut(nn.Module): + def __init__(self, in_filters, nb_filters, subsample=1): + super(_shortcut, self).__init__() + self.process = False + self.model = None + if in_filters != nb_filters or subsample != 1: + self.process = True + self.model = nn.Sequential( + nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample) + ) + + def forward(self, x, y): + #print(x.size(), y.size(), self.process) + if self.process: + y0 = self.model(x) + #print("merge+", torch.max(y0+y), torch.min(y0+y),torch.mean(y0+y), torch.std(y0+y), y0.shape) + return y0 + y + else: + #print("merge", torch.max(x+y), torch.min(x+y),torch.mean(x+y), torch.std(x+y), y.shape) + return x + y + +class _u_shortcut(nn.Module): + def __init__(self, in_filters, nb_filters, subsample): + super(_u_shortcut, self).__init__() + self.process = False + self.model = None + if in_filters != nb_filters: + self.process = True + self.model = nn.Sequential( + nn.Conv2d(in_filters, nb_filters, (1, 1), stride=subsample, padding_mode='zeros'), + nn.Upsample(scale_factor=2, mode='nearest') + ) + + def forward(self, x, y): + if self.process: + return self.model(x) + y + else: + return x + y + + +class basic_block(nn.Module): + def __init__(self, in_filters, nb_filters, init_subsample=1): + super(basic_block, self).__init__() + self.conv1 = _bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) + self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) + self.shortcut = _shortcut(in_filters, nb_filters, subsample=init_subsample) + + def forward(self, x): + x1 = self.conv1(x) + x2 = self.residual(x1) + return self.shortcut(x, x2) + +class _u_basic_block(nn.Module): + def __init__(self, in_filters, nb_filters, init_subsample=1): + super(_u_basic_block, self).__init__() + self.conv1 = _u_bn_relu_conv(in_filters, nb_filters, 3, 3, subsample=init_subsample) + self.residual = _bn_relu_conv(nb_filters, nb_filters, 3, 3) + self.shortcut = _u_shortcut(in_filters, nb_filters, subsample=init_subsample) + + def forward(self, x): + y = self.residual(self.conv1(x)) + return self.shortcut(x, y) + + +class _residual_block(nn.Module): + def __init__(self, in_filters, nb_filters, repetitions, is_first_layer=False): + super(_residual_block, self).__init__() + layers = [] + for i in range(repetitions): + init_subsample = 1 + if i == repetitions - 1 and not is_first_layer: + init_subsample = 2 + if i == 0: + l = basic_block(in_filters=in_filters, nb_filters=nb_filters, init_subsample=init_subsample) + else: + l = basic_block(in_filters=nb_filters, nb_filters=nb_filters, init_subsample=init_subsample) + layers.append(l) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class _upsampling_residual_block(nn.Module): + def __init__(self, in_filters, nb_filters, repetitions): + super(_upsampling_residual_block, self).__init__() + layers = [] + for i in range(repetitions): + l = None + if i == 0: + l = _u_basic_block(in_filters=in_filters, nb_filters=nb_filters)#(input) + else: + l = basic_block(in_filters=nb_filters, nb_filters=nb_filters)#(input) + layers.append(l) + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class res_skip(nn.Module): + + def __init__(self): + super(res_skip, self).__init__() + self.block0 = _residual_block(in_filters=1, nb_filters=24, repetitions=2, is_first_layer=True)#(input) + self.block1 = _residual_block(in_filters=24, nb_filters=48, repetitions=3)#(block0) + self.block2 = _residual_block(in_filters=48, nb_filters=96, repetitions=5)#(block1) + self.block3 = _residual_block(in_filters=96, nb_filters=192, repetitions=7)#(block2) + self.block4 = _residual_block(in_filters=192, nb_filters=384, repetitions=12)#(block3) + + self.block5 = _upsampling_residual_block(in_filters=384, nb_filters=192, repetitions=7)#(block4) + self.res1 = _shortcut(in_filters=192, nb_filters=192)#(block3, block5, subsample=(1,1)) + + self.block6 = _upsampling_residual_block(in_filters=192, nb_filters=96, repetitions=5)#(res1) + self.res2 = _shortcut(in_filters=96, nb_filters=96)#(block2, block6, subsample=(1,1)) + + self.block7 = _upsampling_residual_block(in_filters=96, nb_filters=48, repetitions=3)#(res2) + self.res3 = _shortcut(in_filters=48, nb_filters=48)#(block1, block7, subsample=(1,1)) + + self.block8 = _upsampling_residual_block(in_filters=48, nb_filters=24, repetitions=2)#(res3) + self.res4 = _shortcut(in_filters=24, nb_filters=24)#(block0,block8, subsample=(1,1)) + + self.block9 = _residual_block(in_filters=24, nb_filters=16, repetitions=2, is_first_layer=True)#(res4) + self.conv15 = _bn_relu_conv(in_filters=16, nb_filters=1, fh=1, fw=1, subsample=1)#(block7) + + def forward(self, x): + x0 = self.block0(x) + x1 = self.block1(x0) + x2 = self.block2(x1) + x3 = self.block3(x2) + x4 = self.block4(x3) + + x5 = self.block5(x4) + res1 = self.res1(x3, x5) + + x6 = self.block6(res1) + res2 = self.res2(x2, x6) + + x7 = self.block7(res2) + res3 = self.res3(x1, x7) + + x8 = self.block8(res3) + res4 = self.res4(x0, x8) + + x9 = self.block9(res4) + y = self.conv15(x9) + + return y \ No newline at end of file diff --git a/custom_controlnet_aux/mediapipe_face/__init__.py b/custom_controlnet_aux/mediapipe_face/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..607db69d6d76224f8450c90212e766d21bf0d33b --- /dev/null +++ b/custom_controlnet_aux/mediapipe_face/__init__.py @@ -0,0 +1,31 @@ +import warnings +from typing import Union + +import cv2 +import numpy as np +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad +from .mediapipe_face_common import generate_annotation + + +class MediapipeFaceDetector: + def __call__(self, + input_image: Union[np.ndarray, Image.Image] = None, + max_faces: int = 1, + min_confidence: float = 0.5, + output_type: str = "pil", + detect_resolution: int = 512, + image_resolution: int = 512, + upscale_method="INTER_CUBIC", + **kwargs): + + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + detected_map = generate_annotation(detected_map, max_faces, min_confidence) + detected_map = remove_pad(HWC3(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py b/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py new file mode 100644 index 0000000000000000000000000000000000000000..32eeaf7455df2dd9efa5976def5e617b08757598 --- /dev/null +++ b/custom_controlnet_aux/mediapipe_face/mediapipe_face_common.py @@ -0,0 +1,156 @@ +from typing import Mapping +import warnings + +import mediapipe as mp +import numpy + +if mp: + mp_drawing = mp.solutions.drawing_utils + mp_drawing_styles = mp.solutions.drawing_styles + mp_face_detection = mp.solutions.face_detection # Only for counting faces. + mp_face_mesh = mp.solutions.face_mesh + mp_face_connections = mp.solutions.face_mesh_connections.FACEMESH_TESSELATION + mp_hand_connections = mp.solutions.hands_connections.HAND_CONNECTIONS + mp_body_connections = mp.solutions.pose_connections.POSE_CONNECTIONS + + DrawingSpec = mp.solutions.drawing_styles.DrawingSpec + PoseLandmark = mp.solutions.drawing_styles.PoseLandmark + + min_face_size_pixels: int = 64 + f_thick = 2 + f_rad = 1 + right_iris_draw = DrawingSpec(color=(10, 200, 250), thickness=f_thick, circle_radius=f_rad) + right_eye_draw = DrawingSpec(color=(10, 200, 180), thickness=f_thick, circle_radius=f_rad) + right_eyebrow_draw = DrawingSpec(color=(10, 220, 180), thickness=f_thick, circle_radius=f_rad) + left_iris_draw = DrawingSpec(color=(250, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eye_draw = DrawingSpec(color=(180, 200, 10), thickness=f_thick, circle_radius=f_rad) + left_eyebrow_draw = DrawingSpec(color=(180, 220, 10), thickness=f_thick, circle_radius=f_rad) + mouth_draw = DrawingSpec(color=(10, 180, 10), thickness=f_thick, circle_radius=f_rad) + head_draw = DrawingSpec(color=(10, 200, 10), thickness=f_thick, circle_radius=f_rad) + + # mp_face_mesh.FACEMESH_CONTOURS has all the items we care about. + face_connection_spec = {} + for edge in mp_face_mesh.FACEMESH_FACE_OVAL: + face_connection_spec[edge] = head_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYE: + face_connection_spec[edge] = left_eye_draw + for edge in mp_face_mesh.FACEMESH_LEFT_EYEBROW: + face_connection_spec[edge] = left_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_LEFT_IRIS: + # face_connection_spec[edge] = left_iris_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYE: + face_connection_spec[edge] = right_eye_draw + for edge in mp_face_mesh.FACEMESH_RIGHT_EYEBROW: + face_connection_spec[edge] = right_eyebrow_draw + # for edge in mp_face_mesh.FACEMESH_RIGHT_IRIS: + # face_connection_spec[edge] = right_iris_draw + for edge in mp_face_mesh.FACEMESH_LIPS: + face_connection_spec[edge] = mouth_draw + iris_landmark_spec = {468: right_iris_draw, 473: left_iris_draw} + + +def draw_pupils(image, landmark_list, drawing_spec, halfwidth: int = 2): + """We have a custom function to draw the pupils because the mp.draw_landmarks method requires a parameter for all + landmarks. Until our PR is merged into mediapipe, we need this separate method.""" + if len(image.shape) != 3: + raise ValueError("Input image must be H,W,C.") + image_rows, image_cols, image_channels = image.shape + if image_channels != 3: # BGR channels + raise ValueError('Input image must contain three channel bgr data.') + for idx, landmark in enumerate(landmark_list.landmark): + if ( + (landmark.HasField('visibility') and landmark.visibility < 0.9) or + (landmark.HasField('presence') and landmark.presence < 0.5) + ): + continue + if landmark.x >= 1.0 or landmark.x < 0 or landmark.y >= 1.0 or landmark.y < 0: + continue + image_x = int(image_cols*landmark.x) + image_y = int(image_rows*landmark.y) + draw_color = None + if isinstance(drawing_spec, Mapping): + if drawing_spec.get(idx) is None: + continue + else: + draw_color = drawing_spec[idx].color + elif isinstance(drawing_spec, DrawingSpec): + draw_color = drawing_spec.color + image[image_y-halfwidth:image_y+halfwidth, image_x-halfwidth:image_x+halfwidth, :] = draw_color + + +def reverse_channels(image): + """Given a numpy array in RGB form, convert to BGR. Will also convert from BGR to RGB.""" + # im[:,:,::-1] is a neat hack to convert BGR to RGB by reversing the indexing order. + # im[:,:,::[2,1,0]] would also work but makes a copy of the data. + return image[:, :, ::-1] + + +def generate_annotation( + img_rgb, + max_faces: int, + min_confidence: float +): + """ + Find up to 'max_faces' inside the provided input image. + If min_face_size_pixels is provided and nonzero it will be used to filter faces that occupy less than this many + pixels in the image. + """ + with mp_face_mesh.FaceMesh( + static_image_mode=True, + max_num_faces=max_faces, + refine_landmarks=True, + min_detection_confidence=min_confidence, + ) as facemesh: + img_height, img_width, img_channels = img_rgb.shape + assert(img_channels == 3) + + results = facemesh.process(img_rgb).multi_face_landmarks + + if results is None: + print("No faces detected in controlnet image for Mediapipe face annotator.") + return numpy.zeros_like(img_rgb) + + # Filter faces that are too small + filtered_landmarks = [] + for lm in results: + landmarks = lm.landmark + face_rect = [ + landmarks[0].x, + landmarks[0].y, + landmarks[0].x, + landmarks[0].y, + ] # Left, up, right, down. + for i in range(len(landmarks)): + face_rect[0] = min(face_rect[0], landmarks[i].x) + face_rect[1] = min(face_rect[1], landmarks[i].y) + face_rect[2] = max(face_rect[2], landmarks[i].x) + face_rect[3] = max(face_rect[3], landmarks[i].y) + if min_face_size_pixels > 0: + face_width = abs(face_rect[2] - face_rect[0]) + face_height = abs(face_rect[3] - face_rect[1]) + face_width_pixels = face_width * img_width + face_height_pixels = face_height * img_height + face_size = min(face_width_pixels, face_height_pixels) + if face_size >= min_face_size_pixels: + filtered_landmarks.append(lm) + else: + filtered_landmarks.append(lm) + + # Annotations are drawn in BGR for some reason, but we don't need to flip a zero-filled image at the start. + empty = numpy.zeros_like(img_rgb) + + # Draw detected faces: + for face_landmarks in filtered_landmarks: + mp_drawing.draw_landmarks( + empty, + face_landmarks, + connections=face_connection_spec.keys(), + landmark_drawing_spec=None, + connection_drawing_spec=face_connection_spec + ) + draw_pupils(empty, face_landmarks, iris_landmark_spec, 2) + + # Flip BGR back to RGB. + empty = reverse_channels(empty).copy() + + return empty \ No newline at end of file diff --git a/custom_controlnet_aux/mesh_graphormer/__init__.py b/custom_controlnet_aux/mesh_graphormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c7a35976d1936a2a9356894fadbcede71c6eff25 --- /dev/null +++ b/custom_controlnet_aux/mesh_graphormer/__init__.py @@ -0,0 +1,48 @@ +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3, custom_hf_download, MESH_GRAPHORMER_MODEL_NAME +from custom_controlnet_aux.mesh_graphormer.pipeline import MeshGraphormerMediapipe, args +import random, torch + +def set_seed(seed, n_gpu): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + +class MeshGraphormerDetector: + def __init__(self, pipeline): + self.pipeline = pipeline + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=MESH_GRAPHORMER_MODEL_NAME, filename="graphormer_hand_state_dict.bin", hrnet_filename="hrnetv2_w64_imagenet_pretrained.pth", detect_thr=0.6, presence_thr=0.6): + args.resume_checkpoint = custom_hf_download(pretrained_model_or_path, filename) + args.hrnet_checkpoint = custom_hf_download(pretrained_model_or_path, hrnet_filename) + pipeline = MeshGraphormerMediapipe(args, detect_thr=detect_thr, presence_thr=presence_thr) + return cls(pipeline) + + def to(self, device): + self.pipeline._model.to(device) + self.pipeline.mano_model.to(device) + self.pipeline.mano_model.layer.to(device) + return self + + def __call__(self, input_image=None, mask_bbox_padding=30, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", seed=88, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + set_seed(seed, 0) + depth_map, mask, info = self.pipeline.get_depth(input_image, mask_bbox_padding) + if depth_map is None: + depth_map = np.zeros_like(input_image) + mask = np.zeros_like(input_image) + + #The hand is small + depth_map, mask = HWC3(depth_map), HWC3(mask) + depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method) + depth_map = remove_pad(depth_map) + if output_type == "pil": + depth_map = Image.fromarray(depth_map) + mask = Image.fromarray(mask) + + return depth_map, mask, info diff --git a/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml b/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml new file mode 100644 index 0000000000000000000000000000000000000000..65d710339c15ac3fc94b91bc54b51196feb433b1 --- /dev/null +++ b/custom_controlnet_aux/mesh_graphormer/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml @@ -0,0 +1,92 @@ +GPUS: (0,1,2,3) +LOG_DIR: 'log/' +DATA_DIR: '' +OUTPUT_DIR: 'output/' +WORKERS: 4 +PRINT_FREQ: 1000 + +MODEL: + NAME: cls_hrnet + IMAGE_SIZE: + - 224 + - 224 + EXTRA: + STAGE1: + NUM_MODULES: 1 + NUM_RANCHES: 1 + BLOCK: BOTTLENECK + NUM_BLOCKS: + - 4 + NUM_CHANNELS: + - 64 + FUSE_METHOD: SUM + STAGE2: + NUM_MODULES: 1 + NUM_BRANCHES: 2 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + FUSE_METHOD: SUM + STAGE3: + NUM_MODULES: 4 + NUM_BRANCHES: 3 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + - 256 + FUSE_METHOD: SUM + STAGE4: + NUM_MODULES: 3 + NUM_BRANCHES: 4 + BLOCK: BASIC + NUM_BLOCKS: + - 4 + - 4 + - 4 + - 4 + NUM_CHANNELS: + - 64 + - 128 + - 256 + - 512 + FUSE_METHOD: SUM +CUDNN: + BENCHMARK: true + DETERMINISTIC: false + ENABLED: true +DATASET: + DATASET: 'imagenet' + DATA_FORMAT: 'jpg' + ROOT: 'data/imagenet/' + TEST_SET: 'val' + TRAIN_SET: 'train' +TEST: + BATCH_SIZE_PER_GPU: 32 + MODEL_FILE: '' +TRAIN: + BATCH_SIZE_PER_GPU: 32 + BEGIN_EPOCH: 0 + END_EPOCH: 100 + RESUME: true + LR_FACTOR: 0.1 + LR_STEP: + - 30 + - 60 + - 90 + OPTIMIZER: sgd + LR: 0.05 + WD: 0.0001 + MOMENTUM: 0.9 + NESTEROV: true + SHUFFLE: true +DEBUG: + DEBUG: false diff --git a/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py b/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..496313ad8cff2337bf6b4bee2467082caa838c96 --- /dev/null +++ b/custom_controlnet_aux/mesh_graphormer/depth_preprocessor.py @@ -0,0 +1,6 @@ +class Preprocessor: + def __init__(self) -> None: + pass + + def get_depth(self, input_dir, file_name): + return \ No newline at end of file diff --git a/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task b/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task new file mode 100644 index 0000000000000000000000000000000000000000..5ecab741879892d97c2f90bbf03bf55d7213db7c --- /dev/null +++ b/custom_controlnet_aux/mesh_graphormer/hand_landmarker.task @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbc2a30080c3c557093b5ddfc334698132eb341044ccee322ccf8bcf3607cde1 +size 7819105 diff --git a/custom_controlnet_aux/mesh_graphormer/pipeline.py b/custom_controlnet_aux/mesh_graphormer/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..d6fafd3236243b6982ebd7cc72ab6eb058767332 --- /dev/null +++ b/custom_controlnet_aux/mesh_graphormer/pipeline.py @@ -0,0 +1,472 @@ +import os +import torch +import gc +import numpy as np +from custom_controlnet_aux.mesh_graphormer.depth_preprocessor import Preprocessor + +import torchvision.models as models +from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer +from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from custom_mesh_graphormer.modeling._mano import MANO, Mesh +from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +from custom_mesh_graphormer.utils.miscellaneous import set_seed +from argparse import Namespace +from pathlib import Path +import cv2 +from torchvision import transforms +import numpy as np +import cv2 +from trimesh import Trimesh +from trimesh.ray.ray_triangle import RayMeshIntersector +import mediapipe as mp +from mediapipe.tasks import python +from mediapipe.tasks.python import vision +from torchvision import transforms +from pathlib import Path +from custom_controlnet_aux.util import custom_hf_download +import custom_mesh_graphormer +from comfy.model_management import soft_empty_cache +from packaging import version + +args = Namespace( + num_workers=4, + img_scale_factor=1, + image_file_or_path=os.path.join('', 'MeshGraphormer', 'samples', 'hand'), + model_name_or_path=str(Path(custom_mesh_graphormer.__file__).parent / "modeling/bert/bert-base-uncased"), + resume_checkpoint=None, + output_dir='output/', + config_name='', + a='hrnet-w64', + arch='hrnet-w64', + num_hidden_layers=4, + hidden_size=-1, + num_attention_heads=4, + intermediate_size=-1, + input_feat_dim='2051,512,128', + hidden_feat_dim='1024,256,64', + which_gcn='0,0,1', + mesh_type='hand', + run_eval_only=True, + device="cpu", + seed=88, + hrnet_checkpoint=custom_hf_download("hr16/ControlNet-HandRefiner-pruned", 'hrnetv2_w64_imagenet_pretrained.pth') +) + +#Since mediapipe v0.10.5, the hand category has been correct +if version.parse(mp.__version__) >= version.parse('0.10.5'): + true_hand_category = {"Right": "right", "Left": "left"} +else: + true_hand_category = {"Right": "left", "Left": "right"} + +class MeshGraphormerMediapipe(Preprocessor): + def __init__(self, args=args, detect_thr=0.6, presence_thr=0.6) -> None: + #global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + #mkdir(args.output_dir) + #logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + #logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and MANO utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(args.device) + mesh_sampler = Mesh(device=args.device) + + # Renderer for visualization + # renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + #logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + #logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + #logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + #logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = args.hrnet_checkpoint + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + #logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = Path(__file__).parent / 'cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = args.hrnet_checkpoint + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + #logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + #logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + #logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + #logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + soft_empty_cache() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + self._model = _model + self.mano_model = mano_model + self.mesh_sampler = mesh_sampler + + self.transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + #Fix File loading is not yet supported on Windows + with open(str( Path(__file__).parent / "hand_landmarker.task" ), 'rb') as file: + model_data = file.read() + base_options = python.BaseOptions(model_asset_buffer=model_data) + options = vision.HandLandmarkerOptions(base_options=base_options, + min_hand_detection_confidence=detect_thr, + min_hand_presence_confidence=presence_thr, + min_tracking_confidence=0.6, + num_hands=2) + + self.detector = vision.HandLandmarker.create_from_options(options) + + + def get_rays(self, W, H, fx, fy, cx, cy, c2w_t, center_pixels): # rot = I + + j, i = np.meshgrid(np.arange(H, dtype=np.float32), np.arange(W, dtype=np.float32)) + if center_pixels: + i = i.copy() + 0.5 + j = j.copy() + 0.5 + + directions = np.stack([(i - cx) / fx, (j - cy) / fy, np.ones_like(i)], -1) + directions /= np.linalg.norm(directions, axis=-1, keepdims=True) + + rays_o = np.expand_dims(c2w_t,0).repeat(H*W, 0) + + rays_d = directions # (H, W, 3) + rays_d = (rays_d / np.linalg.norm(rays_d, axis=-1, keepdims=True)).reshape(-1,3) + + return rays_o, rays_d + + def get_mask_bounding_box(self, extrema, H, W, padding=30, dynamic_resize=0.15): + x_min, x_max, y_min, y_max = extrema + bb_xpad = max(int((x_max - x_min + 1) * dynamic_resize), padding) + bb_ypad = max(int((y_max - y_min + 1) * dynamic_resize), padding) + bbx_min = np.max((x_min - bb_xpad, 0)) + bbx_max = np.min((x_max + bb_xpad, W-1)) + bby_min = np.max((y_min - bb_ypad, 0)) + bby_max = np.min((y_max + bb_ypad, H-1)) + return bbx_min, bbx_max, bby_min, bby_max + + def run_inference(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len): + global args + H, W = int(crop_len), int(crop_len) + Graphormer_model.eval() + mano.eval() + device = next(Graphormer_model.parameters()).device + with torch.no_grad(): + img_tensor = self.transform(img) + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + pred_camera = pred_camera.cpu() + pred_vertices = pred_vertices.cpu() + mesh = Trimesh(vertices=pred_vertices[0], faces=mano.face) + res = crop_len + focal_length = 1000 * scale + camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)]) + pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t + z_3d_dist = pred_3d_joints_camera[:,2].clone() + + pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2)) + + rays_o, rays_d = self.get_rays(W, H, focal_length, focal_length, W/2, H/2, camera_t, True) + coords = np.array(list(np.ndindex(H,W))).reshape(H,W,-1).transpose(1,0,2).reshape(-1,2) + intersector = RayMeshIntersector(mesh) + points, index_ray, _ = intersector.intersects_location(rays_o, rays_d, multiple_hits=False) + + tri_index = intersector.intersects_first(rays_o, rays_d) + + tri_index = tri_index[index_ray] + + assert len(index_ray) == len(tri_index) + + discriminator = (np.sum(mesh.face_normals[tri_index]* rays_d[index_ray], axis=-1)<= 0) + points = points[discriminator] # ray intesects in interior faces, discard them + + if len(points) == 0: + return None, None + depth = (points + camera_t)[:,-1] + index_ray = index_ray[discriminator] + pixel_ray = coords[index_ray] + + minval = np.min(depth) + maxval = np.max(depth) + depthmap = np.zeros([H,W]) + + depthmap[pixel_ray[:, 0], pixel_ray[:, 1]] = 1.0 - (0.8 * (depth - minval) / (maxval - minval)) + depthmap *= 255 + return depthmap, pred_2d_joints_img_space + + + def get_depth(self, np_image, padding): + info = {} + + # STEP 3: Load the input image. + #https://stackoverflow.com/a/76407270 + image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np_image.copy()) + + # STEP 4: Detect hand landmarks from the input image. + detection_result = self.detector.detect(image) + + handedness_list = detection_result.handedness + hand_landmarks_list = detection_result.hand_landmarks + + raw_image = image.numpy_view() + H, W, C = raw_image.shape + + + # HANDLANDMARKS CAN BE EMPTY, HANDLE THIS! + if len(hand_landmarks_list) == 0: + return None, None, None + raw_image = raw_image[:, :, :3] + + padded_image = np.zeros((H*2, W*2, 3)) + padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = raw_image + + hand_landmarks_list, handedness_list = zip( + *sorted( + zip(hand_landmarks_list, handedness_list), key=lambda x: x[0][9].z, reverse=True + ) + ) + + padded_depthmap = np.zeros((H*2, W*2)) + mask = np.zeros((H, W)) + crop_boxes = [] + #bboxes = [] + groundtruth_2d_keypoints = [] + hands = [] + depth_failure = False + crop_lens = [] + abs_boxes = [] + + for idx in range(len(hand_landmarks_list)): + hand = true_hand_category[handedness_list[idx][0].category_name] + hands.append(hand) + hand_landmarks = hand_landmarks_list[idx] + handedness = handedness_list[idx] + height, width, _ = raw_image.shape + x_coordinates = [landmark.x for landmark in hand_landmarks] + y_coordinates = [landmark.y for landmark in hand_landmarks] + + # x_min, x_max, y_min, y_max: extrema from mediapipe keypoint detection + x_min = int(min(x_coordinates) * width) + x_max = int(max(x_coordinates) * width) + x_c = (x_min + x_max)//2 + y_min = int(min(y_coordinates) * height) + y_max = int(max(y_coordinates) * height) + y_c = (y_min + y_max)//2 + abs_boxes.append([x_min, x_max, y_min, y_max]) + + #if x_max - x_min < 60 or y_max - y_min < 60: + # continue + + crop_len = (max(x_max - x_min, y_max - y_min) * 1.6) //2 * 2 + + # crop_x_min, crop_x_max, crop_y_min, crop_y_max: bounding box for mesh reconstruction + crop_x_min = int(x_c - (crop_len/2 - 1) + W/2) + crop_x_max = int(x_c + crop_len/2 + W/2) + crop_y_min = int(y_c - (crop_len/2 - 1) + H/2) + crop_y_max = int(y_c + crop_len/2 + H/2) + + cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1] + crop_boxes.append([crop_y_min, crop_y_max, crop_x_min, crop_x_max]) + crop_lens.append(crop_len) + if hand == "left": + cropped = cv2.flip(cropped, 1) + + if crop_len < 224: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC) + else: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA) + scale = crop_len/224 + cropped_depthmap, pred_2d_keypoints = self.run_inference(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, int(crop_len)) + + if cropped_depthmap is None: + depth_failure = True + break + #keypoints_image_space = pred_2d_keypoints * (crop_y_max - crop_y_min + 1)/224 + groundtruth_2d_keypoints.append(pred_2d_keypoints) + + if hand == "left": + cropped_depthmap = cv2.flip(cropped_depthmap, 1) + resized_cropped_depthmap = cv2.resize(cropped_depthmap, (int(crop_len), int(crop_len)), interpolation=cv2.INTER_LINEAR) + nonzero_y, nonzero_x = (resized_cropped_depthmap != 0).nonzero() + if len(nonzero_y) == 0 or len(nonzero_x) == 0: + depth_failure = True + break + padded_depthmap[crop_y_min+nonzero_y, crop_x_min+nonzero_x] = resized_cropped_depthmap[nonzero_y, nonzero_x] + + # nonzero stands for nonzero value on the depth map + # coordinates of nonzero depth pixels in original image space + original_nonzero_x = crop_x_min+nonzero_x - int(W/2) + original_nonzero_y = crop_y_min+nonzero_y - int(H/2) + + nonzerox_min = min(np.min(original_nonzero_x), x_min) + nonzerox_max = max(np.max(original_nonzero_x), x_max) + nonzeroy_min = min(np.min(original_nonzero_y), y_min) + nonzeroy_max = max(np.max(original_nonzero_y), y_max) + + bbx_min, bbx_max, bby_min, bby_max = self.get_mask_bounding_box((nonzerox_min, nonzerox_max, nonzeroy_min, nonzeroy_max), H, W, padding) + mask[bby_min:bby_max+1, bbx_min:bbx_max+1] = 1.0 + #bboxes.append([int(bbx_min), int(bbx_max), int(bby_min), int(bby_max)]) + if depth_failure: + #print("cannot detect normal hands") + return None, None, None + depthmap = padded_depthmap[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)].astype(np.uint8) + mask = (255.0 * mask).astype(np.uint8) + info["groundtruth_2d_keypoints"] = groundtruth_2d_keypoints + info["hands"] = hands + info["crop_boxes"] = crop_boxes + info["crop_lens"] = crop_lens + info["abs_boxes"] = abs_boxes + return depthmap, mask, info + + def get_keypoints(self, img, Graphormer_model, mano, mesh_sampler, scale, crop_len): + global args + H, W = int(crop_len), int(crop_len) + Graphormer_model.eval() + mano.eval() + device = next(Graphormer_model.parameters()).device + with torch.no_grad(): + img_tensor = self.transform(img) + #print(img_tensor) + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + #pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + #pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + pred_camera = pred_camera.cpu() + pred_vertices = pred_vertices.cpu() + # + res = crop_len + focal_length = 1000 * scale + camera_t = np.array([-pred_camera[1], -pred_camera[2], -2*focal_length/(res * pred_camera[0] +1e-9)]) + pred_3d_joints_camera = pred_3d_joints_from_mesh.cpu()[0] - camera_t + z_3d_dist = pred_3d_joints_camera[:,2].clone() + pred_2d_joints_img_space = ((pred_3d_joints_camera/z_3d_dist[:,None]) * np.array((focal_length, focal_length, 1)))[:,:2] + np.array((W/2, H/2)) + + return pred_2d_joints_img_space + + + def eval_mpjpe(self, sample, info): + H, W, C = sample.shape + padded_image = np.zeros((H*2, W*2, 3)) + padded_image[int(1/2 * H):int(3/2 * H), int(1/2 * W):int(3/2 * W)] = sample + crop_boxes = info["crop_boxes"] + hands = info["hands"] + groundtruth_2d_keypoints = info["groundtruth_2d_keypoints"] + crop_lens = info["crop_lens"] + pjpe = 0 + for i in range(len(crop_boxes)):#box in crop_boxes: + crop_y_min, crop_y_max, crop_x_min, crop_x_max = crop_boxes[i] + cropped = padded_image[crop_y_min:crop_y_max+1, crop_x_min:crop_x_max+1] + hand = hands[i] + if hand == "left": + cropped = cv2.flip(cropped, 1) + crop_len = crop_lens[i] + scale = crop_len/224 + if crop_len < 224: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_CUBIC) + else: + graphormer_input = cv2.resize(cropped, (224, 224), interpolation=cv2.INTER_AREA) + generated_keypoint = self.get_keypoints(graphormer_input.astype(np.uint8), self._model, self.mano_model, self.mesh_sampler, scale, crop_len) + #generated_keypoint = generated_keypoint * ((crop_y_max - crop_y_min + 1)/224) + pjpe += np.sum(np.sqrt(np.sum(((generated_keypoint - groundtruth_2d_keypoints[i]) ** 2).numpy(), axis=1))) + pass + mpjpe = pjpe/(len(crop_boxes) * 21) + return mpjpe diff --git a/custom_controlnet_aux/metric3d/__init__.py b/custom_controlnet_aux/metric3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..18cc06e06c01f6c165e9db6f3c7dec03b82b9901 --- /dev/null +++ b/custom_controlnet_aux/metric3d/__init__.py @@ -0,0 +1,124 @@ + +import torch +import os +from pathlib import Path + +CODE_SPACE=Path(os.path.dirname(os.path.abspath(__file__))) + +from custom_mmpkg.custom_mmcv.utils import Config, DictAction +from custom_controlnet_aux.metric3d.mono.model.monodepth_model import get_configured_monodepth_model +from custom_controlnet_aux.metric3d.mono.utils.running import load_ckpt +from custom_controlnet_aux.metric3d.mono.utils.do_test import transform_test_data_scalecano, get_prediction +import numpy as np +from custom_controlnet_aux.metric3d.mono.utils.visualization import vis_surface_normal +from einops import repeat +from PIL import Image +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, METRIC3D_MODEL_NAME +import re +import matplotlib.pyplot as plt + +def load_model(model_selection, model_path): + if model_selection == "vit-small": + cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.small.py') + elif model_selection == "vit-large": + cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.large.py') + elif model_selection == "vit-giant2": + cfg = Config.fromfile(CODE_SPACE / 'mono/configs/HourglassDecoder/vit.raft5.giant2.py') + else: + raise NotImplementedError(f"metric3d model: {model_selection}") + model = get_configured_monodepth_model(cfg, ) + model, _, _, _ = load_ckpt(model_path, model, strict_match=False) + model.eval() + model = model + return model, cfg + +def gray_to_colormap(img, cmap='rainbow'): + """ + Transfer gray map to matplotlib colormap + """ + assert img.ndim == 2 + + img[img<0] = 0 + mask_invalid = img < 1e-10 + img = img / (img.max() + 1e-8) + norm = plt.Normalize(vmin=0, vmax=1.1) # Use plt.Normalize instead of matplotlib.colors.Normalize + cmap_m = plt.get_cmap(cmap) # Access the colormap directly from plt + map = plt.cm.ScalarMappable(norm=norm, cmap=cmap_m) + colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8) + colormap[mask_invalid] = 0 + return colormap + +def predict_depth_normal(model, cfg, np_img, fx=1000.0, fy=1000.0, state_cache={}): + intrinsic = [fx, fy, np_img.shape[1]/2, np_img.shape[0]/2] + rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(np_img, intrinsic, cfg.data_basic, device=next(model.parameters()).device) + + with torch.no_grad(): + pred_depth, confidence, output = get_prediction( + model = model, + input = rgb_input.unsqueeze(0), + cam_model = cam_models_stacks, + pad_info = pad, + scale_info = label_scale_factor, + gt_depth = None, + normalize_scale = cfg.data_basic.depth_range[1], + ori_shape=[np_img.shape[0], np_img.shape[1]], + ) + + pred_normal = output['normal_out_list'][0][:, :3, :, :] + H, W = pred_normal.shape[2:] + pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]] + pred_depth = pred_depth[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3] ] + + pred_depth = pred_depth.squeeze().cpu().numpy() + pred_color = gray_to_colormap(pred_depth, 'Greys') + + pred_normal = torch.nn.functional.interpolate(pred_normal, [np_img.shape[0], np_img.shape[1]], mode='bilinear').squeeze() + pred_normal = pred_normal.permute(1,2,0) + pred_color_normal = vis_surface_normal(pred_normal) + pred_normal = pred_normal.cpu().numpy() + + # Storing depth and normal map in state for potential 3D reconstruction + state_cache['depth'] = pred_depth + state_cache['normal'] = pred_normal + state_cache['img'] = np_img + state_cache['intrinsic'] = intrinsic + state_cache['confidence'] = confidence + + return pred_color, pred_color_normal, state_cache + +class Metric3DDetector: + def __init__(self, model, cfg): + self.model = model + self.cfg = cfg + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=METRIC3D_MODEL_NAME, filename="metric_depth_vit_small_800k.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + backbone = re.findall(r"metric_depth_vit_(\w+)_", model_path)[0] + model, cfg = load_model(f'vit-{backbone}', model_path) + return cls(model, cfg) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, fx=1000, fy=1000, output_type=None, upscale_method="INTER_CUBIC", depth_and_normal=True, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + + depth_map, normal_map, _ = predict_depth_normal(self.model, self.cfg, input_image, fx=fx, fy=fy) + # ControlNet uses inverse depth and normal + depth_map, normal_map = depth_map, 255 - normal_map + depth_map, remove_pad = resize_image_with_pad(depth_map, detect_resolution, upscale_method) + normal_map, _ = resize_image_with_pad(normal_map, detect_resolution, upscale_method) + depth_map, normal_map = remove_pad(depth_map), remove_pad(normal_map) + + if output_type == "pil": + depth_map = Image.fromarray(depth_map) + normal_map = Image.fromarray(normal_map) + + if depth_and_normal: + return depth_map, normal_map + else: + return depth_map diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..37b91c80284d6db3df3017ec636f18198e42dc08 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (544, 1216), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..cdd9156b7f2f0921fb01b1adaf9a2a7447332d6e --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_kitti_convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (512, 1088), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py new file mode 100644 index 0000000000000000000000000000000000000000..6601f5cdfad07c5fad8b89fbf959e67039126dfa --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/test_nyu_convlarge.0.3_150.py @@ -0,0 +1,25 @@ +_base_=[ + '../_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model = dict( + backbone=dict( + pretrained=False, + ) +) + +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + img_size=(512, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.3, 150), + crop_size = (480, 1216), +) + +batchsize_per_gpu = 2 +thread_per_gpu = 4 diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py new file mode 100644 index 0000000000000000000000000000000000000000..a4a1238b6eb7ffcc21a237e748bbd7ed75bdf5aa --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.giant2.py @@ -0,0 +1,32 @@ +_base_=[ + '../_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model=dict( + decode_head=dict( + type='RAFTDepthNormalDPT5', + iters=8, + n_downsample=2, + detach=False, + ) +) + + +max_value = 200 +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + # img_size=(540, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.1, max_value), + crop_size = (616, 1064), # %28 = 0 + clip_depth_range=(0.1, 200), + vit_size=(616,1064) +) + +batchsize_per_gpu = 1 +thread_per_gpu = 1 diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py new file mode 100644 index 0000000000000000000000000000000000000000..4e81c2ea343ea7bcd9980c76ee18e439d68c04e8 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py @@ -0,0 +1,32 @@ +_base_=[ + '../_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model=dict( + decode_head=dict( + type='RAFTDepthNormalDPT5', + iters=8, + n_downsample=2, + detach=False, + ) +) + + +max_value = 200 +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + # img_size=(540, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.1, max_value), + crop_size = (616, 1064), # %28 = 0 + clip_depth_range=(0.1, 200), + vit_size=(616,1064) +) + +batchsize_per_gpu = 1 +thread_per_gpu = 1 diff --git a/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a169d49b38ab72b5f0c6b2eeb6e891431ad279 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py @@ -0,0 +1,32 @@ +_base_=[ + '../_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py', + '../_base_/datasets/_data_base_.py', + '../_base_/default_runtime.py', + ] + +model=dict( + decode_head=dict( + type='RAFTDepthNormalDPT5', + iters=4, + n_downsample=2, + detach=False, + ) +) + + +max_value = 200 +# configs of the canonical space +data_basic=dict( + canonical_space = dict( + # img_size=(540, 960), + focal_length=1000.0, + ), + depth_range=(0, 1), + depth_normalize=(0.1, max_value), + crop_size = (616, 1064), # %28 = 0 + clip_depth_range=(0.1, 200), + vit_size=(616,1064) +) + +batchsize_per_gpu = 1 +thread_per_gpu = 1 diff --git a/custom_controlnet_aux/metric3d/mono/configs/__init__.py b/custom_controlnet_aux/metric3d/mono/configs/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/__init__.py @@ -0,0 +1 @@ + diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py new file mode 100644 index 0000000000000000000000000000000000000000..35f3844f24191b6b9452e136ea3205b7622466d7 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/_data_base_.py @@ -0,0 +1,13 @@ +# canonical camera setting and basic data setting +# we set it same as the E300 camera (crop version) +# +data_basic=dict( + canonical_space = dict( + img_size=(540, 960), + focal_length=1196.0, + ), + depth_range=(0.9, 150), + depth_normalize=(0.006, 1.001), + crop_size = (512, 960), + clip_depth_range=(0.9, 150), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py new file mode 100644 index 0000000000000000000000000000000000000000..b554444e9b75b4519b862e726890dcf7859be0ec --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/datasets/_data_base_.py @@ -0,0 +1,12 @@ +# canonical camera setting and basic data setting +# +data_basic=dict( + canonical_space = dict( + img_size=(540, 960), + focal_length=1196.0, + ), + depth_range=(0.9, 150), + depth_normalize=(0.006, 1.001), + crop_size = (512, 960), + clip_depth_range=(0.9, 150), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..a690b491bf50aad5c2fd7e9ac387609123a4594a --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/default_runtime.py @@ -0,0 +1,4 @@ + +load_from = None +cudnn_benchmark = True +test_metrics = ['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3','rmse_log', 'log10', 'sq_rel'] diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5a22f7e1b53ca154bfae1672e6ee3b52028039b9 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/convnext_large.py @@ -0,0 +1,16 @@ +#_base_ = ['./_model_base_.py',] + +#'https://download.openmmlab.com/mmclassification/v0/convnext/downstream/convnext-large_3rdparty_in21k_20220301-e6e0ea0a.pth' +model = dict( + #type='EncoderDecoderAuxi', + backbone=dict( + type='convnext_large', + pretrained=True, + in_22k=True, + out_indices=[0, 1, 2, 3], + drop_path_rate=0.4, + layer_scale_init_value=1.0, + checkpoint='data/pretrained_weight_repo/convnext/convnext_large_22k_1k_384.pth', + prefix='backbones.', + out_channels=[192, 384, 768, 1536]), + ) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..3c1ebc96ceaa32ad9310d3b84d55d252be843c46 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_giant2_reg.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_giant2_reg', + prefix='backbones.', + out_channels=[1536, 1536, 1536, 1536], + drop_path_rate = 0.0), + ) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py new file mode 100644 index 0000000000000000000000000000000000000000..843178ed6e61d74070b971f01148f87fdf2a62cf --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_large', + prefix='backbones.', + out_channels=[1024, 1024, 1024, 1024], + drop_path_rate = 0.0), + ) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..25e96747d459d42df299f8a6a1e14044a0e56164 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_large_reg.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_large_reg', + prefix='backbones.', + out_channels=[1024, 1024, 1024, 1024], + drop_path_rate = 0.0), + ) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..0c8bd97dccb9cdee7517250f40e01bb3124144e6 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/backbones/dino_vit_small_reg.py @@ -0,0 +1,7 @@ +model = dict( + backbone=dict( + type='vit_small_reg', + prefix='backbones.', + out_channels=[384, 384, 384, 384], + drop_path_rate = 0.0), + ) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f262288c49e7ffccb6174b09b0daf80ff79dd684 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/convnext_large.hourglassdecoder.py @@ -0,0 +1,10 @@ +# model settings +_base_ = ['../backbones/convnext_large.py',] +model = dict( + type='DensePredModel', + decode_head=dict( + type='HourglassDecoder', + in_channels=[192, 384, 768, 1536], + decoder_channel=[128, 128, 256, 512], + prefix='decode_heads.'), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..73702d298c05979bcdf013e9c30ec56f4e36665b --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_giant2_reg.dpt_raft.py @@ -0,0 +1,19 @@ +# model settings +_base_ = ['../backbones/dino_vit_giant2_reg.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[1536, 1536, 1536, 1536], + use_cls_token=True, + feature_channels = [384, 768, 1536, 1536], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [192, 384, 768, 1536, 1536], # [4/7, 2/7, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[192, 192, 192, 192], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] + n_gru_layers=3, + n_downsample=2, + iters=3, + slow_fast_gru=True, + num_register_tokens=4, + prefix='decode_heads.'), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..bd69efefab2c03de435996c6b7b65ff941db1e5d --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large.dpt_raft.py @@ -0,0 +1,20 @@ +# model settings +_base_ = ['../backbones/dino_vit_large.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[1024, 1024, 1024, 1024], + use_cls_token=True, + feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] + n_gru_layers=3, + n_downsample=2, + iters=12, + slow_fast_gru=True, + corr_radius=4, + corr_levels=4, + prefix='decode_heads.'), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..26ab6dc090e9cdb840d84fab10587becb536dbb8 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_large_reg.dpt_raft.py @@ -0,0 +1,19 @@ +# model settings +_base_ = ['../backbones/dino_vit_large_reg.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[1024, 1024, 1024, 1024], + use_cls_token=True, + feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536] + n_gru_layers=3, + n_downsample=2, + iters=3, + slow_fast_gru=True, + num_register_tokens=4, + prefix='decode_heads.'), +) diff --git a/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..19466c191e9f2a83903e55ca4fc0827d9a11bcb9 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/configs/_base_/models/encoder_decoder/dino_vit_small_reg.dpt_raft.py @@ -0,0 +1,19 @@ +# model settings +_base_ = ['../backbones/dino_vit_small_reg.py'] +model = dict( + type='DensePredModel', + decode_head=dict( + type='RAFTDepthDPT', + in_channels=[384, 384, 384, 384], + use_cls_token=True, + feature_channels = [96, 192, 384, 768], # [2/7, 1/7, 1/14, 1/14] + decoder_channels = [48, 96, 192, 384, 384], # [-, 1/4, 1/7, 1/14, 1/14] + up_scale = 7, + hidden_channels=[48, 48, 48, 48], # [x_4, x_8, x_16, x_32] [1/4, 1/7, 1/14, -] + n_gru_layers=3, + n_downsample=2, + iters=3, + slow_fast_gru=True, + num_register_tokens=4, + prefix='decode_heads.'), +) diff --git a/custom_controlnet_aux/metric3d/mono/model/__init__.py b/custom_controlnet_aux/metric3d/mono/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9e1ea3d3e3b880e28ef880083b3c79e3b00cd119 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/__init__.py @@ -0,0 +1,5 @@ +from .monodepth_model import DepthModel +# from .__base_model__ import BaseDepthModel + + +__all__ = ['DepthModel', 'BaseDepthModel'] diff --git a/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py b/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py new file mode 100644 index 0000000000000000000000000000000000000000..f1c4be0e6463ae2b0dda6d20fc273a300afa5ebf --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/backbones/ConvNeXt.py @@ -0,0 +1,271 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_, DropPath +from timm.models.registry import register_model + +class Block(nn.Module): + r""" ConvNeXt Block. There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + We use (2) as we find it slightly faster in PyTorch + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + """ + def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6): + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.pwconv2 = nn.Linear(4 * dim, dim) + self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), + requires_grad=True) if layer_scale_init_value > 0 else None + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + if self.gamma is not None: + x = self.gamma * x + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + +class ConvNeXt(nn.Module): + r""" ConvNeXt + A PyTorch impl of : `A ConvNet for the 2020s` - + https://arxiv.org/pdf/2201.03545.pdf + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + def __init__(self, in_chans=3, num_classes=1000, + depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], drop_path_rate=0., + layer_scale_init_value=1e-6, head_init_scale=1., + **kwargs,): + super().__init__() + + self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first") + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks + dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.Sequential( + *[Block(dim=dims[i], drop_path=dp_rates[cur + j], + layer_scale_init_value=layer_scale_init_value) for j in range(depths[i])] + ) + self.stages.append(stage) + cur += depths[i] + + #self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer + #self.head = nn.Linear(dims[-1], num_classes) + + self.apply(self._init_weights) + #self.head.weight.data.mul_(head_init_scale) + #self.head.bias.data.mul_(head_init_scale) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=.02) + nn.init.constant_(m.bias, 0) + + def forward_features(self, x): + features = [] + for i in range(4): + x = self.downsample_layers[i](x) + x = self.stages[i](x) + features.append(x) + return features # global average pooling, (N, C, H, W) -> (N, C) + + def forward(self, x): + #x = self.forward_features(x) + #x = self.head(x) + features = self.forward_features(x) + return features + +class LayerNorm(nn.Module): + r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape, ) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +model_urls = { + "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", + "convnext_small_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_1k_224_ema.pth", + "convnext_base_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_1k_224_ema.pth", + "convnext_large_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_1k_224_ema.pth", + "convnext_tiny_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_22k_224.pth", + "convnext_small_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_small_22k_224.pth", + "convnext_base_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_base_22k_224.pth", + "convnext_large_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_large_22k_224.pth", + "convnext_xlarge_22k": "https://dl.fbaipublicfiles.com/convnext/convnext_xlarge_22k_224.pth", +} + +def convnext_tiny(pretrained=True,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 9, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_tiny_22k'] if in_22k else model_urls['convnext_tiny_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu", check_hash=True) + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_small(pretrained=True,in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[96, 192, 384, 768], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_small_22k'] if in_22k else model_urls['convnext_small_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_base(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_base_22k'] if in_22k else model_urls['convnext_base_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_large(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536], **kwargs) + if pretrained: + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_large_22k'] if in_22k else model_urls['convnext_large_1k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +def convnext_xlarge(pretrained=True, in_22k=False, **kwargs): + model = ConvNeXt(depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048], **kwargs) + if pretrained: + assert in_22k, "only ImageNet-22K pre-trained ConvNeXt-XL is available; please set in_22k=True" + checkpoint = torch.load(kwargs['checkpoint'], map_location="cpu") + #url = model_urls['convnext_xlarge_22k'] + #checkpoint = torch.hub.load_state_dict_from_url(url=url, map_location="cpu") + model_dict = model.state_dict() + pretrained_dict = {} + unmatched_pretrained_dict = {} + for k, v in checkpoint['model'].items(): + if k in model_dict: + pretrained_dict[k] = v + else: + unmatched_pretrained_dict[k] = v + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + print( + 'Successfully loaded pretrained %d params, and %d paras are unmatched.' + %(len(pretrained_dict.keys()), len(unmatched_pretrained_dict.keys()))) + print('Unmatched pretrained paras are :', unmatched_pretrained_dict.keys()) + return model + +if __name__ == '__main__': + import torch + model = convnext_base(True, in_22k=False).cuda() + + rgb = torch.rand((2, 3, 256, 256)).cuda() + out = model(rgb) + print(len(out)) + for i, ft in enumerate(out): + print(i, ft.shape) diff --git a/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py b/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py new file mode 100644 index 0000000000000000000000000000000000000000..b86c6b71e40817616a0b421b9ac9c9b80dde5f0b --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO.py @@ -0,0 +1,1489 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +class ConvBlock(nn.Module): + def __init__(self, channels): + super(ConvBlock, self).__init__() + + self.act = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm1 = nn.BatchNorm2d(channels) + self.conv2 = nn.Conv2d( + channels, + channels, + kernel_size=3, + stride=1, + padding=1 + ) + self.norm2 = nn.BatchNorm2d(channels) + + def forward(self, x): + + out = self.norm1(x) + out = self.act(out) + out = self.conv1(out) + out = self.norm2(out) + out = self.act(out) + out = self.conv2(out) + return x + out + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=37, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.window_size = window_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + features = [] + for blk in self.blocks: + x = blk(x) + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x) + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +class PosConv(nn.Module): + # PEG from https://arxiv.org/abs/2102.10882 + def __init__(self, in_chans, embed_dim=768, stride=1): + super(PosConv, self).__init__() + self.proj = nn.Sequential( + nn.Conv2d(in_chans, embed_dim, 37, stride, 18, bias=True, groups=embed_dim), + ) + self.stride = stride + + def forward(self, x, size): + B, N, C = x.shape + cnn_feat_token = x.transpose(1, 2).view(B, C, *size) + x = self.proj(cnn_feat_token) + if self.stride == 1: + x += cnn_feat_token + x = x.flatten(2).transpose(1, 2) + return x + + #def no_weight_decay(self): + #return ['proj.%d.weight' % i for i in range(4)] + +class DinoWindowVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + #init_values=None, # for layerscale: None or 0 => no layerscale + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=NestedTensorBlock, + ffn_layer="mlp", + block_chunks=1, + window_size=7, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + #self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + #self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) + + self.pos_conv = PosConv(self.embed_dim, self.embed_dim) + + self.window_size = window_size + #self.conv_block = nn.ModuleList([ConvBlock(embed_dim) for i in range(4)]) + #self.conv_block = nn.ModuleList([nn.Identity() for i in range(4)]) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.nh = -1 + self.nw = -1 + try: + H = cfg.data_basic['crop_size'][0] + W = cfg.data_basic['crop_size'][1] + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + self.nh = (H + pad_h) // self.patch_size + self.nw = (W + pad_w) // self.patch_size + self.prepare_attn_bias((self.nh, self.nw)) + except: + pass + self.init_weights() + + self.total_step = 10000 # For PE -> GPE transfer + self.start_step = 2000 + self.current_step = 20000 + + def init_weights(self): + #trunc_normal_(self.pos_embed, std=0.02) + #nn.init.normal_(self.cls_token, std=1e-6) + named_apply(init_weights_vit_timm, self) + for i in range(4): + try: + nn.init.constant_(self.conv_block[i].conv2.weight, 0.0) + except: + pass + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + #npatch = x.shape[1] - 1 + #N = self.pos_embed.shape[1] - 1 + npatch = x.shape[1] + N = self.pos_embed.shape[1] + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + #class_pos_embed = pos_embed[:, 0] + #patch_pos_embed = pos_embed[:, 1:] + patch_pos_embed = pos_embed + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + 0.1, h0 + 0.1 + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2), + scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), + mode="bicubic", + ) + + assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return patch_pos_embed.to(previous_dtype) + #return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def window_partition(self, x: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + if conv_feature == False: + B, N, C = x.shape + H, W = hw[0], hw[1] + + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size * window_size, C) + else: + B, C, H, W = x.shape + + x = x.view(B, C, H // window_size, window_size, W // window_size, window_size) + + windows = x.permute(0, 2, 4, 3, 5, 1).contiguous().view(-1, window_size * window_size, C) + + #y = torch.cat((x_cls, windows), dim=1) + return windows #, (Hp, Wp) + + + def window_unpartition(self, + windows: torch.Tensor, window_size: int, hw: Tuple[int, int], conv_feature=False + ) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + H, W = hw + + B = windows.shape[0] // (H * W // window_size // window_size) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + + if conv_feature == False: + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp * Wp, -1) + else: + C = windows.shape[-1] + x = x.permute(0, 5, 1, 3, 2, 4).contiguous().view(B, C, H, W) + + # if Hp > H or Wp > W: + # x = x[:, :H, :W, :].contiguous() + return x + + def prepare_tokens_with_masks(self, x, masks=None, step=-1): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + #x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + if step == -1: + step = self.current_step + else: + self.current_step = step + + if step < self.start_step: + coef = 0.0 + elif step < self.total_step: + coef = (step - self.start_step) / (self.total_step - self.start_step) + else: + coef = 1.0 + + x = x + (1 - coef) * self.interpolate_pos_encoding(x, w, h) + coef * self.pos_conv(x, (self.nh, self.nw)) + + return x + + def prepare_attn_bias(self, shape): + window_size = self.window_size + if window_size <= 0: + return + + import xformers.components.attention.attention_patterns as AP + + nh, nw = shape + radius = (window_size-1)//2 + mask_ori = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() + + pad = (8 - (nh * nw) % 8) + if pad == 8: + pad = 0 + mask_pad = nn.functional.pad(mask_ori, (0, pad)).contiguous() + if pad > 0: + mask = mask_pad[:, :-pad].view(nh, nw, nh, nw) + else: + mask = mask_pad[:, :].view(nh, nw, nh, nw) + + # angle + mask[:radius+1, :radius+1, :window_size, :window_size] = True + mask[:radius+1, -radius-1:, :window_size, -window_size:] = True + mask[-radius-1:, :radius+1, -window_size:, :window_size] = True + mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True + + # edge + mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] + mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] + mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] + mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] + + mask = mask.view(nh*nw, nh*nw) + bias_pad = torch.log(mask_pad) + #bias = bias_pad[:, :-pad] + self.register_buffer('attn_bias', bias_pad) + + return bias_pad + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_patchtokens": x_norm[:, 1:], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None, **kwargs): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + nh = (H+pad_h)//self.patch_size + nw = (W+pad_w)//self.patch_size + + if self.window_size > 0: + if nh == self.nh and nw == self.nw: + attn_bias = self.attn_bias + else: + attn_bias = self.prepare_attn_bias(((H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size)) + self.nh = nh + self.nw = nw + attn_bias = attn_bias.unsqueeze(0).repeat(B * self.num_heads, 1, 1) + else: + attn_bias = None + + x = self.prepare_tokens_with_masks(x, masks) + #x = self.patch_embed(x) + + features = [] + #x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + for blk in self.blocks: + x = blk(x, attn_bias) + #x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size)) + + # for idx in range(len(self.blocks[0])): + # x = self.blocks[0][idx](x, attn_bias) + + # if (idx + 1) % (len(self.blocks[0]) // 4) == 0: + # x = self.window_unpartition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # x = self.conv_block[idx // (len(self.blocks[0]) // 4)](x) + # if idx + 1 != len(self.blocks[0]): + # x = self.window_partition(x, self.window_size, (H // self.patch_size, W // self.patch_size), conv_feature=True) + # else: + # b, c, h, w = x.size() + # x = x.permute(0, 2, 3, 1).contiguous().view(b, h, w, c) + #features.append(x) + + #return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + x_norm = self.norm(x) + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_patchtokens": x_norm[:, 1:], + # "x_prenorm": x, + # "masks": masks, + # } + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=14, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=14, **kwargs): + model = DinoWindowVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=14, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + #del model.norm + del model.mask_token + return model + + # model = DinoWindowVisionTransformer( + # img_size = 518, + # patch_size=patch_size, + # embed_dim=1024, + # depth=24, + # num_heads=16, + # mlp_ratio=4, + # block_fn=partial(NestedTensorBlock, attn_class=MemEffAttention), + # window_size=37, + # **kwargs, + # ) + + # if checkpoint is not None: + # with open(checkpoint, "rb") as f: + # state_dict = torch.load(f) + # try: + # model.load_state_dict(state_dict, strict=True) + # except: + # new_state_dict = {} + # for key, value in state_dict.items(): + # if 'blocks' in key: + # key_new = 'blocks.0' + key[len('blocks'):] + # else: + # key_new = key + # if 'pos_embed' in key: + # value = value[:, 1:, :] + # new_state_dict[key_new] = value + + # model.load_state_dict(new_state_dict, strict=False) + # #del model.norm + # del model.mask_token + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + +if __name__ == '__main__': + try: + from custom_mmpkg.custom_mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + cfg = Config.fromfile('mu.hu/monodepth/mono/configs/HourglassDecoder/pub12.convlarge.0.3_150.py') + + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + rgb = torch.zeros(1, 3, 1400, 1680).cuda() + model = vit_large(checkpoint="pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth", kwarg=cfg).cuda() + + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + + out1 = model(rgb) + #out2 = model2(rgb) + temp = 0 + + + +# import time +# window_size = 37 +# def prepare_window_masks(shape): +# if window_size <= 0: +# return None +# import xformers.components.attention.attention_patterns as AP + +# B, nh, nw, _, _ = shape +# radius = (window_size-1)//2 +# #time0 = time.time() +# d = AP.local_nd_distance(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# #mask = AP.local_2d_pattern(nh, nw, distance = radius + 0.1, p=torch.inf).cuda() +# # mask = mask.view(nh, nw, nh, nw) +# # #time1 = time.time() - time0 + +# # # angle +# # mask[:radius+1, :radius+1, :window_size, :window_size] = True +# # mask[:radius+1, -radius-1:, :window_size, -window_size:] = True +# # mask[-radius-1:, :radius+1, -window_size:, :window_size] = True +# # mask[-radius-1:, -radius-1:, -window_size:, -window_size:] = True +# # time2 = time.time() - time0 - time1 + +# # # edge +# # mask[radius+1:-radius-1, :radius+1, :, :] = mask[radius+1:-radius-1, radius:radius+1, :, :] +# # mask[radius+1:-radius-1, -radius-1:, :, :] = mask[radius+1:-radius-1, -radius-1:-radius, :, :] +# # mask[:radius+1, radius+1:-radius-1, :, :] = mask[radius:radius+1, radius+1:-radius-1, :, :] +# # mask[-radius-1:, radius+1:-radius-1, :, :] = mask[-radius-1:-radius, radius+1:-radius-1, :, :] +# # time3 = time.time() - time0 - time2 +# # print(time1, time2, time3) + +# # return mask.view(nw*nw, nh*nw).unsqueeze(0).repeat(B, 1) + +# shape = (1, 55, 55, None, None) +# mask = prepare_window_masks(shape) +# # temp = 1 \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py b/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py new file mode 100644 index 0000000000000000000000000000000000000000..346ed10c4689d5f321f8ada9f26d1929bb956182 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/backbones/ViT_DINO_reg.py @@ -0,0 +1,1303 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, Dict, Any, List + +import torch +import torch.nn as nn +from torch import Tensor +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ +import torch.nn.init +import torch.nn.functional as F + +#from dinov2.layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + +logger = logging.getLogger("dinov2") + +# SSF finetuning originally by dongzelian +def init_ssf_scale_shift(dim): + scale = nn.Parameter(torch.ones(dim)) + shift = nn.Parameter(torch.zeros(dim)) + + nn.init.normal_(scale, mean=1, std=.02) + nn.init.normal_(shift, std=.02) + + return scale, shift + +def ssf_ada(x, scale, shift): + assert scale.shape == shift.shape + if x.shape[-1] == scale.shape[0]: + return x * scale + shift + elif x.shape[1] == scale.shape[0]: + return x * scale.view(1, -1, 1, 1) + shift.view(1, -1, 1, 1) + else: + raise ValueError('the input tensor shape does not match the shape of the scale factor.') + +# LoRA finetuning originally by edwardjhu +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class LoRALinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + #nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode: bool = True): + # def T(w): + # return w.transpose(0, 1) if self.fan_in_fan_out else w + # nn.Linear.train(self, mode) + # if mode: + # if self.merge_weights and self.merged: + # # Make sure that the weights are not merged + # if self.r > 0: + # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # # Merge the weights and mark it + # if self.r > 0: + # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + tuning_mode: Optional[str] = None + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_1, self.ssf_shift_1) + + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.drop(x) + return x + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(2 * hidden_features) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(out_features) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + if self.tuning_mode == 'ssf': + x12 = ssf_ada(x12, self.ssf_scale_1, self.ssf_shift_1) + + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + out = self.w3(hidden) + + if self.tuning_mode == 'ssf': + out = ssf_ada(out, self.ssf_scale_2, self.ssf_scale_2) + + return out + + +try: + from xformers.ops import SwiGLU + #import numpy.bool + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) + + +XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + window_size: int = 0, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.qkv = LoRALinear(dim, dim * 3, bias=qkv_bias, r=8) + else: + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.attn_drop = nn.Dropout(attn_drop) + + if tuning_mode == 'lora': + self.tuning_mode = tuning_mode + self.proj = LoRALinear(dim, dim, bias=proj_bias, r=8) + else: + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim * 3) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + #if not self.training: + # + # self.attn = ScaledDotProduct() + #self.attn = MultiHeadDispatch(dim_model=EMB, residual_dropout=DROPOUT, num_heads=HEADS, attention=attn) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + if attn_bias is not None: + attn = attn + attn_bias[:, :, :N] + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + #if True: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x, attn_bias) + + B, N, C = x.shape + if self.tuning_mode == 'ssf': + qkv = ssf_ada(self.qkv(x), self.ssf_scale_1, self.ssf_shift_1).reshape(B, N, 3, self.num_heads, C // self.num_heads) + else: + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + if attn_bias is not None: + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias[:, :, :N]) + else: + x = memory_efficient_attention(q, k, v) + x = x.reshape([B, N, C]) + + x = self.proj(x) + if self.tuning_mode == 'ssf': + x = ssf_ada(x, self.ssf_scale_2, self.ssf_shift_2) + + x = self.proj_drop(x) + return x + +XFORMERS_AVAILABLE = False + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values = None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + tuning_mode: Optional[int] = None + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + tuning_mode=tuning_mode + ) + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(dim) + self.ssf_scale_2, self.ssf_shift_2 = init_ssf_scale_shift(dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + def attn_residual_func(x: Tensor, attn_bias) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls1(self.attn(ssf_ada(self.norm1(x), self.ssf_scale_1, self.ssf_shift_1), attn_bias)) + else: + return self.ls1(self.attn(self.norm1(x), attn_bias)) + + def ffn_residual_func(x: Tensor) -> Tensor: + if self.tuning_mode == 'ssf': + return self.ls2(self.mlp(ssf_ada(self.norm2(x), self.ssf_scale_2, self.ssf_shift_2))) + else: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + attn_bias=attn_bias + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x, attn_bias)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x, attn_bias) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, attn_bias=None +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset, attn_bias) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list, attn_bias=None): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list, attn_bias) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x, others=None): + for b in self: + if others == None: + x = b(x) + else: + x = b(x, others) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=518, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=1e-5, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + multi_output=False, + tuning_mode=None, + **kwargs + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + if tuning_mode != None: + self.tuning_mode = tuning_mode + if tuning_mode == 'ssf': + self.ssf_scale_1, self.ssf_shift_1 = init_ssf_scale_shift(embed_dim) + else: + pass + #raise NotImplementedError() + else: + self.tuning_mode = None + tuning_mode_list = [tuning_mode] * depth + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, tuning_mode=tuning_mode) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + self.multi_output = multi_output + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + tuning_mode=tuning_mode_list[i] + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + mode="bicubic", + antialias=self.interpolate_antialias, + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + B, C, H, W = x.size() + pad_h = (self.patch_size - H % self.patch_size) + pad_w = (self.patch_size - W % self.patch_size) + if pad_h == self.patch_size: + pad_h = 0 + if pad_w == self.patch_size: + pad_w = 0 + #x = nn.functional.pad(x, (pad_h//2, pad_h-pad_h//2, pad_w//2, pad_w-pad_w//2)) + if pad_h + pad_w > 0: + x = torch.nn.functional.interpolate(x, (H+pad_h, W+pad_w), mode='bilinear') + + x = self.prepare_tokens_with_masks(x, masks) + + #for blk in self.blocks: + #x = blk(x) + + #x_norm = self.norm(x) + #if self.tuning_mode == 'ssf': + #x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) + + # return { + # "x_norm_clstoken": x_norm[:, 0], + # "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + # "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + # "x_prenorm": x, + # "masks": masks, + # } + # features = [] + # features.append(x_norm) + # features.append(x_norm) + # features.append(x_norm) + # features.append(x_norm) + # return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + + if self.multi_output == False: + for blk in self.blocks: + x = blk(x) + x_norm = self.norm(x) + if self.tuning_mode == 'ssf': + x_norm = ssf_ada(x_norm, self.ssf_scale_1, self.ssf_shift_1) + + features = [] + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + features.append(x_norm) + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + else: + features = [] + for blk in self.blocks: + for idx, sub_blk in enumerate(blk): + x = sub_blk(x) + if (idx + 1) % (len(blk) // 4) == 0: + features.append(x) + + return [features, (B, (H+pad_h)//self.patch_size, (W+pad_w)//self.patch_size, H, W, self.num_register_tokens)] + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + # if is_training: + # return ret + # else: + # return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def load_ckpt_dino(checkpoint, model): + if checkpoint is not None: + try: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + except: + print('NO pretrained imagenet ckpt available! Check your path!') + del model.mask_token + return + + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + del model.mask_token + return + else: + return + + +def vit_small(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + try: + model.load_state_dict(state_dict, strict=True) + except: + new_state_dict = {} + for key, value in state_dict.items(): + if 'blocks' in key: + key_new = 'blocks.0' + key[len('blocks'):] + else: + key_new = key + new_state_dict[key_new] = value + + model.load_state_dict(new_state_dict, strict=True) + del model.mask_token + return model + + +def vit_giant2(patch_size=14, num_register_tokens=0, checkpoint=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + **kwargs, + ) + return model + + + +def vit_small_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_base_reg(patch_size=14, num_register_tokens=4, checkpoint=None, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_large_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + model = DinoVisionTransformer( + img_size = 518, + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + tuning_mode=tuning_mode, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + + +def vit_giant2_reg(patch_size=14, num_register_tokens=4, checkpoint=None, tuning_mode=None, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + ffn_layer='swiglu', + tuning_mode=tuning_mode, + multi_output=True, + **kwargs, + ) + + load_ckpt_dino(checkpoint, model) + + return model + +if __name__ == '__main__': + try: + from custom_mmpkg.custom_mmcv.utils import Config + except: + from mmengine import Config + + #rgb = torch.rand((2, 3, 518, 518)).cuda() + + #cfg.data_basic['crop_size']['0'] + #cfg.data_basic['crop_size']['1'] + cfg = Config.fromfile('/mu.hu/projects/monodepth_vit/mono/configs/RAFTDecoder/vit.raft5.large.kitti.py') + + #rgb = torch.arange(0, 2*3*1036*1036, 1).cuda().float().view(2, 3, 1036, 1036) + rgb = torch.zeros(1, 3, 616, 1064).cuda() + cfg['tuning_mode'] = 'ssf' + #model = vit_large_reg(checkpoint="/cpfs02/shared/public/groups/local_map/yvan/pretrained_weight_repo/vit/dinov2_vitl14_reg4_pretrain.pth", kwarg=cfg).cuda() + model = vit_large_reg(tuning_mode='ssf').cuda() + + #import timm + #model2 = timm.models.vision_transformer.vit_large_patch14_dinov2().cuda() + #timm.models.load_checkpoint(model2, '/cpfs02/shared/public/yvan/pretrained_weight_repo/vit/dinov2_vitl14_pretrain.pth', filter_fn=timm.models.vision_transformer.checkpoint_filter_fn) + + out1 = model(rgb) + #out2 = model2(rgb) + temp = 0 + + diff --git a/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py b/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36204128dbb323426b0aa19e52674cbdc3a0f860 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/backbones/__init__.py @@ -0,0 +1,11 @@ +from .ConvNeXt import convnext_xlarge +from .ConvNeXt import convnext_small +from .ConvNeXt import convnext_base +from .ConvNeXt import convnext_large +from .ConvNeXt import convnext_tiny +from .ViT_DINO import vit_large +from .ViT_DINO_reg import vit_small_reg, vit_large_reg, vit_giant2_reg + +__all__ = [ + 'convnext_xlarge', 'convnext_small', 'convnext_base', 'convnext_large', 'convnext_tiny', 'vit_small_reg', 'vit_large_reg', 'vit_giant2_reg' +] diff --git a/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py b/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py new file mode 100644 index 0000000000000000000000000000000000000000..e084382601e21e6ce5144abbd6a65f563905b659 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/decode_heads/HourGlassDecoder.py @@ -0,0 +1,274 @@ +import torch +import torch.nn as nn +import numpy as np +import math +import torch.nn.functional as F + +def compute_depth_expectation(prob, depth_values): + depth_values = depth_values.view(*depth_values.shape, 1, 1) + depth = torch.sum(prob * depth_values, 1) + return depth + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvBlock, self).__init__() + + if kernel_size == 3: + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), + ) + elif kernel_size == 1: + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) + + self.nonlin = nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + return out + + +class ConvBlock_double(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3): + super(ConvBlock_double, self).__init__() + + if kernel_size == 3: + self.conv = nn.Sequential( + nn.ReflectionPad2d(1), + nn.Conv2d(in_channels, out_channels, 3, padding=0, stride=1), + ) + elif kernel_size == 1: + self.conv = nn.Conv2d(int(in_channels), int(out_channels), 1, padding=0, stride=1) + + self.nonlin = nn.ELU(inplace=True) + self.conv_2 = nn.Conv2d(out_channels, out_channels, 1, padding=0, stride=1) + self.nonlin_2 =nn.ELU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.nonlin(out) + out = self.conv_2(out) + out = self.nonlin_2(out) + return out + +class DecoderFeature(nn.Module): + def __init__(self, feat_channels, num_ch_dec=[64, 64, 128, 256]): + super(DecoderFeature, self).__init__() + self.num_ch_dec = num_ch_dec + self.feat_channels = feat_channels + + self.upconv_3_0 = ConvBlock(self.feat_channels[3], self.num_ch_dec[3], kernel_size=1) + self.upconv_3_1 = ConvBlock_double( + self.feat_channels[2] + self.num_ch_dec[3], + self.num_ch_dec[3], + kernel_size=1) + + self.upconv_2_0 = ConvBlock(self.num_ch_dec[3], self.num_ch_dec[2], kernel_size=3) + self.upconv_2_1 = ConvBlock_double( + self.feat_channels[1] + self.num_ch_dec[2], + self.num_ch_dec[2], + kernel_size=3) + + self.upconv_1_0 = ConvBlock(self.num_ch_dec[2], self.num_ch_dec[1], kernel_size=3) + self.upconv_1_1 = ConvBlock_double( + self.feat_channels[0] + self.num_ch_dec[1], + self.num_ch_dec[1], + kernel_size=3) + self.upsample = nn.Upsample(scale_factor=2, mode='nearest') + + def forward(self, ref_feature): + x = ref_feature[3] + + x = self.upconv_3_0(x) + x = torch.cat((self.upsample(x), ref_feature[2]), 1) + x = self.upconv_3_1(x) + + x = self.upconv_2_0(x) + x = torch.cat((self.upsample(x), ref_feature[1]), 1) + x = self.upconv_2_1(x) + + x = self.upconv_1_0(x) + x = torch.cat((self.upsample(x), ref_feature[0]), 1) + x = self.upconv_1_1(x) + return x + + +class UNet(nn.Module): + def __init__(self, inp_ch=32, output_chal=1, down_sample_times=3, channel_mode='v0'): + super(UNet, self).__init__() + basic_block = ConvBnReLU + num_depth = 128 + + self.conv0 = basic_block(inp_ch, num_depth) + if channel_mode == 'v0': + channels = [num_depth, num_depth//2, num_depth//4, num_depth//8, num_depth // 8] + elif channel_mode == 'v1': + channels = [num_depth, num_depth, num_depth, num_depth, num_depth, num_depth] + self.down_sample_times = down_sample_times + for i in range(down_sample_times): + setattr( + self, 'conv_%d' % i, + nn.Sequential( + basic_block(channels[i], channels[i+1], stride=2), + basic_block(channels[i+1], channels[i+1]) + ) + ) + for i in range(down_sample_times-1,-1,-1): + setattr(self, 'deconv_%d' % i, + nn.Sequential( + nn.ConvTranspose2d( + channels[i+1], + channels[i], + kernel_size=3, + padding=1, + output_padding=1, + stride=2, + bias=False), + nn.BatchNorm2d(channels[i]), + nn.ReLU(inplace=True) + ) + ) + self.prob = nn.Conv2d(num_depth, output_chal, 1, stride=1, padding=0) + + def forward(self, x): + features = {} + conv0 = self.conv0(x) + x = conv0 + features[0] = conv0 + for i in range(self.down_sample_times): + x = getattr(self, 'conv_%d' % i)(x) + features[i+1] = x + for i in range(self.down_sample_times-1,-1,-1): + x = features[i] + getattr(self, 'deconv_%d' % i)(x) + x = self.prob(x) + return x + +class ConvBnReLU(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): + super(ConvBnReLU, self).__init__() + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=pad, + bias=False + ) + self.bn = nn.BatchNorm2d(out_channels) + + def forward(self, x): + return F.relu(self.bn(self.conv(x)), inplace=True) + + +class HourglassDecoder(nn.Module): + def __init__(self, cfg): + super(HourglassDecoder, self).__init__() + self.inchannels = cfg.model.decode_head.in_channels # [256, 512, 1024, 2048] + self.decoder_channels = cfg.model.decode_head.decoder_channel # [64, 64, 128, 256] + self.min_val = cfg.data_basic.depth_normalize[0] + self.max_val = cfg.data_basic.depth_normalize[1] + + self.num_ch_dec = self.decoder_channels # [64, 64, 128, 256] + self.num_depth_regressor_anchor = 512 + self.feat_channels = self.inchannels + unet_in_channel = self.num_ch_dec[1] + unet_out_channel = 256 + + self.decoder_mono = DecoderFeature(self.feat_channels, self.num_ch_dec) + self.conv_out_2 = UNet(inp_ch=unet_in_channel, + output_chal=unet_out_channel + 1, + down_sample_times=3, + channel_mode='v0', + ) + + self.depth_regressor_2 = nn.Sequential( + nn.Conv2d(unet_out_channel, + self.num_depth_regressor_anchor, + kernel_size=3, + padding=1, + ), + nn.BatchNorm2d(self.num_depth_regressor_anchor), + nn.ReLU(inplace=True), + nn.Conv2d( + self.num_depth_regressor_anchor, + self.num_depth_regressor_anchor, + kernel_size=1, + ) + ) + self.residual_channel = 16 + self.conv_up_2 = nn.Sequential( + nn.Conv2d(1 + 2 + unet_out_channel, self.residual_channel, 3, padding=1), + nn.BatchNorm2d(self.residual_channel), + nn.ReLU(), + nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), + nn.Upsample(scale_factor=4), + nn.Conv2d(self.residual_channel, self.residual_channel, 3, padding=1), + nn.ReLU(), + nn.Conv2d(self.residual_channel, 1, 1, padding=0), + ) + + def get_bins(self, bins_num): + depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device='cuda') + depth_bins_vec = torch.exp(depth_bins_vec) + return depth_bins_vec + + def register_depth_expectation_anchor(self, bins_num, B): + depth_bins_vec = self.get_bins(bins_num) + depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) + self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) + + def upsample(self, x, scale_factor=2): + return F.interpolate(x, scale_factor=scale_factor, mode='nearest') + + def regress_depth_2(self, feature_map_d): + prob = self.depth_regressor_2(feature_map_d).softmax(dim=1) + B = prob.shape[0] + if "depth_expectation_anchor" not in self._buffers: + self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) + d = compute_depth_expectation( + prob, + self.depth_expectation_anchor[:B, ...] + ).unsqueeze(1) + return d + + def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): + y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), + torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') + meshgrid = torch.stack((x, y)) + meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) + return meshgrid + + def forward(self, features_mono, **kwargs): + ''' + trans_ref2src: list of transformation matrix from the reference view to source view. [B, 4, 4] + inv_intrinsic_pool: list of inverse intrinsic matrix. + features_mono: features of reference and source views. [[ref_f1, ref_f2, ref_f3, ref_f4],[src1_f1, src1_f2, src1_f3, src1_f4], ...]. + ''' + outputs = {} + # get encoder feature of the reference view + ref_feat = features_mono + + feature_map_mono = self.decoder_mono(ref_feat) + feature_map_mono_pred = self.conv_out_2(feature_map_mono) + confidence_map_2 = feature_map_mono_pred[:, -1:, :, :] + feature_map_d_2 = feature_map_mono_pred[:, :-1, :, :] + + depth_pred_2 = self.regress_depth_2(feature_map_d_2) + + B, _, H, W = depth_pred_2.shape + + meshgrid = self.create_mesh_grid(H, W, B) + + depth_pred_mono = self.upsample(depth_pred_2, scale_factor=4) + 1e-1 * \ + self.conv_up_2( + torch.cat((depth_pred_2, meshgrid[:B, ...], feature_map_d_2), 1) + ) + confidence_map_mono = self.upsample(confidence_map_2, scale_factor=4) + + outputs=dict( + prediction=depth_pred_mono, + confidence=confidence_map_mono, + pred_logit=None, + ) + return outputs \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py b/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py new file mode 100644 index 0000000000000000000000000000000000000000..7790d3a03a7573d255238aba399c14f74b50507f --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/decode_heads/RAFTDepthNormalDPTDecoder5.py @@ -0,0 +1,1031 @@ +import torch +import torch.nn as nn +import numpy as np +import math +import torch.nn.functional as F + +# LORA finetuning originally by edwardjhu +class LoRALayer(): + def __init__( + self, + r: int, + lora_alpha: int, + lora_dropout: float, + merge_weights: bool, + ): + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.merge_weights = merge_weights + +class LoRALinear(nn.Linear, LoRALayer): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0., + fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out) + merge_weights: bool = True, + **kwargs + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, + merge_weights=merge_weights) + + self.fan_in_fan_out = fan_in_fan_out + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features))) + self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r))) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + if fan_in_fan_out: + self.weight.data = self.weight.data.transpose(0, 1) + + def reset_parameters(self): + #nn.Linear.reset_parameters(self) + if hasattr(self, 'lora_A'): + # initialize B the same way as the default for nn.Linear and A to zero + # this is different than what is described in the paper but should not affect performance + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode: bool = True): + # def T(w): + # return w.transpose(0, 1) if self.fan_in_fan_out else w + # nn.Linear.train(self, mode) + # if mode: + # if self.merge_weights and self.merged: + # # Make sure that the weights are not merged + # if self.r > 0: + # self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # # Merge the weights and mark it + # if self.r > 0: + # self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling + # self.merged = True + + def forward(self, x: torch.Tensor): + def T(w): + return w.transpose(0, 1) if self.fan_in_fan_out else w + if self.r > 0 and not self.merged: + result = F.linear(x, T(self.weight), bias=self.bias) + result += (self.lora_dropout(x) @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling + return result + else: + return F.linear(x, T(self.weight), bias=self.bias) + +class ConvLoRA(nn.Conv2d, LoRALayer): + def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): + #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) + nn.Conv2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) + ) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + #self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode=True): + # super(ConvLoRA, self).train(mode) + # if mode: + # if self.merge_weights and self.merged: + # if self.r > 0: + # # Make sure that the weights are not merged + # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # if self.r > 0: + # # Merge the weights and mark it + # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + # return self.conv._conv_forward( + # x, + # self.conv.weight + (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling, + # self.conv.bias + # ) + weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + bias = self.bias + + return F.conv2d(x, weight, bias=bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + else: + return F.conv2d(x, self.weight, bias=self.bias, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=self.groups) + +class ConvTransposeLoRA(nn.ConvTranspose2d, LoRALayer): + def __init__(self, in_channels, out_channels, kernel_size, r=0, lora_alpha=1, lora_dropout=0., merge_weights=True, **kwargs): + #self.conv = conv_module(in_channels, out_channels, kernel_size, **kwargs) + nn.ConvTranspose2d.__init__(self, in_channels, out_channels, kernel_size, **kwargs) + LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=merge_weights) + assert isinstance(kernel_size, int) + + # Actual trainable parameters + if r > 0: + self.lora_A = nn.Parameter( + self.weight.new_zeros((r * kernel_size, in_channels * kernel_size)) + ) + self.lora_B = nn.Parameter( + self.weight.new_zeros((out_channels//self.groups*kernel_size, r*kernel_size)) + ) + self.scaling = self.lora_alpha / self.r + # Freezing the pre-trained weight matrix + self.weight.requires_grad = False + self.reset_parameters() + self.merged = False + + def reset_parameters(self): + #self.conv.reset_parameters() + if hasattr(self, 'lora_A'): + # initialize A the same way as the default for nn.Linear and B to zero + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + # def train(self, mode=True): + # super(ConvTransposeLoRA, self).train(mode) + # if mode: + # if self.merge_weights and self.merged: + # if self.r > 0: + # # Make sure that the weights are not merged + # self.conv.weight.data -= (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = False + # else: + # if self.merge_weights and not self.merged: + # if self.r > 0: + # # Merge the weights and mark it + # self.conv.weight.data += (self.lora_B @ self.lora_A).view(self.conv.weight.shape) * self.scaling + # self.merged = True + + def forward(self, x): + if self.r > 0 and not self.merged: + weight = self.weight + (self.lora_B @ self.lora_A).view(self.weight.shape) * self.scaling + bias = self.bias + return F.conv_transpose2d(x, weight, + bias=bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, + groups=self.groups, dilation=self.dilation) + else: + return F.conv_transpose2d(x, self.weight, + bias=self.bias, stride=self.stride, padding=self.padding, output_padding=self.output_padding, + groups=self.groups, dilation=self.dilation) + #return self.conv(x) + +class Conv2dLoRA(ConvLoRA): + def __init__(self, *args, **kwargs): + super(Conv2dLoRA, self).__init__(*args, **kwargs) + +class ConvTranspose2dLoRA(ConvTransposeLoRA): + def __init__(self, *args, **kwargs): + super(ConvTranspose2dLoRA, self).__init__(*args, **kwargs) + + +def compute_depth_expectation(prob, depth_values): + depth_values = depth_values.view(*depth_values.shape, 1, 1) + depth = torch.sum(prob * depth_values, 1) + return depth + +def interpolate_float32(x, size=None, scale_factor=None, mode='nearest', align_corners=None): + return F.interpolate(x.float(), size=size, scale_factor=scale_factor, mode=mode, align_corners=align_corners) + +# def upflow8(flow, mode='bilinear'): +# new_size = (8 * flow.shape[2], 8 * flow.shape[3]) +# return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + +def upflow4(flow, mode='bilinear'): + new_size = (4 * flow.shape[2], 4 * flow.shape[3]) + return F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + +def coords_grid(batch, ht, wd): + # coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = (torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd)), torch.zeros((ht, wd))) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out + +# uncertainty-guided sampling (only used during training) +@torch.no_grad() +def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): + device = init_normal.device + B, _, H, W = init_normal.shape + N = int(sampling_ratio * H * W) + beta = beta + + # uncertainty map + uncertainty_map = -1 * init_normal[:, -1, :, :] # B, H, W + + # gt_invalid_mask (B, H, W) + if gt_norm_mask is not None: + gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') + gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 + uncertainty_map[gt_invalid_mask] = -1e4 + + # (B, H*W) + _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) + + # importance sampling + if int(beta * N) > 0: + importance = idx[:, :int(beta * N)] # B, beta*N + + # remaining + remaining = idx[:, int(beta * N):] # B, H*W - beta*N + + # coverage + num_coverage = N - int(beta * N) + + if num_coverage <= 0: + samples = importance + else: + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = torch.cat((importance, coverage), dim=1) # B, N + + else: + # remaining + remaining = idx[:, :] # B, H*W + + # coverage + num_coverage = N + + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = coverage + + # point coordinates + rows_int = samples // W # 0 for first row, H-1 for last row + rows_float = rows_int / float(H-1) # 0 to 1.0 + rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 + + cols_int = samples % W # 0 for first column, W-1 for last column + cols_float = cols_int / float(W-1) # 0 to 1.0 + cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 + + point_coords = torch.zeros(B, 1, N, 2) + point_coords[:, 0, :, 0] = cols_float # x coord + point_coords[:, 0, :, 1] = rows_float # y coord + point_coords = point_coords.to(device) + return point_coords, rows_int, cols_int + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, output_dim_depth=2, output_dim_norm=4, tuning_mode=None): + super(FlowHead, self).__init__() + self.conv1d = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.conv2d = Conv2dLoRA(hidden_dim // 2, output_dim_depth, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + + self.conv1n = Conv2dLoRA(input_dim, hidden_dim // 2, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.conv2n = Conv2dLoRA(hidden_dim // 2, output_dim_norm, 3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + depth = self.conv2d(self.relu(self.conv1d(x))) + normal = self.conv2n(self.relu(self.conv1n(x))) + return torch.cat((depth, normal), dim=1) + + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim, input_dim, kernel_size=3, tuning_mode=None): + super(ConvGRU, self).__init__() + self.convz = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + self.convr = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + self.convq = Conv2dLoRA(hidden_dim+input_dim, hidden_dim, kernel_size, padding=kernel_size//2, r = 8 if tuning_mode == 'lora' else 0) + + def forward(self, h, cz, cr, cq, *x_list): + x = torch.cat(x_list, dim=1) + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid((self.convz(hx) + cz)) + r = torch.sigmoid((self.convr(hx) + cr)) + q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq)) + + # z = torch.sigmoid((self.convz(hx) + cz).float()) + # r = torch.sigmoid((self.convr(hx) + cr).float()) + # q = torch.tanh((self.convq(torch.cat([r*h, x], dim=1)) + cq).float()) + + h = (1-z) * h + z * q + return h + +def pool2x(x): + return F.avg_pool2d(x, 3, stride=2, padding=1) + +def pool4x(x): + return F.avg_pool2d(x, 5, stride=4, padding=1) + +def interp(x, dest): + interp_args = {'mode': 'bilinear', 'align_corners': True} + return interpolate_float32(x, dest.shape[2:], **interp_args) + +class BasicMultiUpdateBlock(nn.Module): + def __init__(self, args, hidden_dims=[], out_dims=2, tuning_mode=None): + super().__init__() + self.args = args + self.n_gru_layers = args.model.decode_head.n_gru_layers # 3 + self.n_downsample = args.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) + + # self.encoder = BasicMotionEncoder(args) + # encoder_output_dim = 128 # if there is corr volume + encoder_output_dim = 6 # no corr volume + + self.gru08 = ConvGRU(hidden_dims[2], encoder_output_dim + hidden_dims[1] * (self.n_gru_layers > 1), tuning_mode=tuning_mode) + self.gru16 = ConvGRU(hidden_dims[1], hidden_dims[0] * (self.n_gru_layers == 3) + hidden_dims[2], tuning_mode=tuning_mode) + self.gru32 = ConvGRU(hidden_dims[0], hidden_dims[1], tuning_mode=tuning_mode) + self.flow_head = FlowHead(hidden_dims[2], hidden_dim=2*hidden_dims[2], tuning_mode=tuning_mode) + factor = 2**self.n_downsample + + self.mask = nn.Sequential( + Conv2dLoRA(hidden_dims[2], hidden_dims[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0), + nn.ReLU(inplace=True), + Conv2dLoRA(hidden_dims[2], (factor**2)*9, 1, padding=0, r = 8 if tuning_mode == 'lora' else 0)) + + def forward(self, net, inp, corr=None, flow=None, iter08=True, iter16=True, iter32=True, update=True): + + if iter32: + net[2] = self.gru32(net[2], *(inp[2]), pool2x(net[1])) + if iter16: + if self.n_gru_layers > 2: + net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1]), interp(net[2], net[1])) + else: + net[1] = self.gru16(net[1], *(inp[1]), interp(pool2x(net[0]), net[1])) + if iter08: + if corr is not None: + motion_features = self.encoder(flow, corr) + else: + motion_features = flow + if self.n_gru_layers > 1: + net[0] = self.gru08(net[0], *(inp[0]), motion_features, interp(net[1], net[0])) + else: + net[0] = self.gru08(net[0], *(inp[0]), motion_features) + + if not update: + return net + + delta_flow = self.flow_head(net[0]) + + # scale mask to balence gradients + mask = .25 * self.mask(net[0]) + return net, mask, delta_flow + +class LayerNorm2d(nn.LayerNorm): + def __init__(self, dim): + super(LayerNorm2d, self).__init__(dim) + + def forward(self, x): + x = x.permute(0, 2, 3, 1).contiguous() + x = super(LayerNorm2d, self).forward(x) + x = x.permute(0, 3, 1, 2).contiguous() + return x + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1, tuning_mode=None): + super(ResidualBlock, self).__init__() + + self.conv1 = Conv2dLoRA(in_planes, planes, kernel_size=3, padding=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0) + self.conv2 = Conv2dLoRA(planes, planes, kernel_size=3, padding=1, r = 8 if tuning_mode == 'lora' else 0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'layer': + self.norm1 = LayerNorm2d(planes) + self.norm2 = LayerNorm2d(planes) + if not (stride == 1 and in_planes == planes): + self.norm3 = LayerNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not (stride == 1 and in_planes == planes): + self.norm3 = nn.Sequential() + + if stride == 1 and in_planes == planes: + self.downsample = None + + else: + self.downsample = nn.Sequential( + Conv2dLoRA(in_planes, planes, kernel_size=1, stride=stride, r = 8 if tuning_mode == 'lora' else 0), self.norm3) + + def forward(self, x): + y = x + y = self.conv1(y) + y = self.norm1(y) + y = self.relu(y) + y = self.conv2(y) + y = self.norm2(y) + y = self.relu(y) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + +class ContextFeatureEncoder(nn.Module): + ''' + Encoder features are used to: + 1. initialize the hidden state of the update operator + 2. and also injected into the GRU during each iteration of the update operator + ''' + def __init__(self, in_dim, output_dim, tuning_mode=None): + ''' + in_dim = [x4, x8, x16, x32] + output_dim = [hindden_dims, context_dims] + [[x4,x8,x16,x32],[x4,x8,x16,x32]] + ''' + super().__init__() + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[0], dim[0], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[0], dim[0], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs04 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[1], dim[1], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[1], dim[1], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs08 = nn.ModuleList(output_list) + + output_list = [] + for dim in output_dim: + conv_out = nn.Sequential( + ResidualBlock(in_dim[2], dim[2], 'layer', stride=1, tuning_mode=tuning_mode), + Conv2dLoRA(dim[2], dim[2], 3, padding=1, r = 8 if tuning_mode == 'lora' else 0)) + output_list.append(conv_out) + + self.outputs16 = nn.ModuleList(output_list) + + # output_list = [] + # for dim in output_dim: + # conv_out = Conv2dLoRA(in_dim[3], dim[3], 3, padding=1) + # output_list.append(conv_out) + + # self.outputs32 = nn.ModuleList(output_list) + + def forward(self, encoder_features): + x_4, x_8, x_16, x_32 = encoder_features + + outputs04 = [f(x_4) for f in self.outputs04] + outputs08 = [f(x_8) for f in self.outputs08] + outputs16 = [f(x_16)for f in self.outputs16] + # outputs32 = [f(x_32) for f in self.outputs32] + + return (outputs04, outputs08, outputs16) + +class ConvBlock(nn.Module): + # reimplementation of DPT + def __init__(self, channels, tuning_mode=None): + super(ConvBlock, self).__init__() + + self.act = nn.ReLU(inplace=True) + self.conv1 = Conv2dLoRA( + channels, + channels, + kernel_size=3, + stride=1, + padding=1, + r = 8 if tuning_mode == 'lora' else 0 + ) + self.conv2 = Conv2dLoRA( + channels, + channels, + kernel_size=3, + stride=1, + padding=1, + r = 8 if tuning_mode == 'lora' else 0 + ) + + def forward(self, x): + out = self.act(x) + out = self.conv1(out) + out = self.act(out) + out = self.conv2(out) + return x + out + +class FuseBlock(nn.Module): + # reimplementation of DPT + def __init__(self, in_channels, out_channels, fuse=True, upsample=True, scale_factor=2, tuning_mode=None): + super(FuseBlock, self).__init__() + + self.fuse = fuse + self.scale_factor = scale_factor + self.way_trunk = ConvBlock(in_channels, tuning_mode=tuning_mode) + if self.fuse: + self.way_branch = ConvBlock(in_channels, tuning_mode=tuning_mode) + + self.out_conv = Conv2dLoRA( + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + r = 8 if tuning_mode == 'lora' else 0 + ) + self.upsample = upsample + + def forward(self, x1, x2=None): + if x2 is not None: + x2 = self.way_branch(x2) + x1 = x1 + x2 + + out = self.way_trunk(x1) + + if self.upsample: + out = interpolate_float32( + out, scale_factor=self.scale_factor, mode="bilinear", align_corners=True + ) + out = self.out_conv(out) + return out + +class Readout(nn.Module): + # From DPT + def __init__(self, in_features, use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(Readout, self).__init__() + self.use_cls_token = use_cls_token + if self.use_cls_token == True: + self.project_patch = LoRALinear(in_features, in_features, r = 8 if tuning_mode == 'lora' else 0) + self.project_learn = LoRALinear((1 + num_register_tokens) * in_features, in_features, bias=False, r = 8 if tuning_mode == 'lora' else 0) + self.act = nn.GELU() + else: + self.project = nn.Identity() + + def forward(self, x): + + if self.use_cls_token == True: + x_patch = self.project_patch(x[0]) + x_learn = self.project_learn(x[1]) + x_learn = x_learn.expand_as(x_patch).contiguous() + features = x_patch + x_learn + return self.act(features) + else: + return self.project(x) + +class Token2Feature(nn.Module): + # From DPT + def __init__(self, vit_channel, feature_channel, scale_factor, use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(Token2Feature, self).__init__() + self.scale_factor = scale_factor + self.readoper = Readout(in_features=vit_channel, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + if scale_factor > 1 and isinstance(scale_factor, int): + self.sample = ConvTranspose2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=scale_factor, + stride=scale_factor, + padding=0, + ) + + elif scale_factor > 1: + self.sample = nn.Sequential( + # Upsample2(upscale=scale_factor), + # nn.Upsample(scale_factor=scale_factor), + Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=1, + stride=1, + padding=0, + ), + ) + + + elif scale_factor < 1: + scale_factor = int(1.0 / scale_factor) + self.sample = Conv2dLoRA(r = 8 if tuning_mode == 'lora' else 0, + in_channels=vit_channel, + out_channels=feature_channel, + kernel_size=scale_factor+1, + stride=scale_factor, + padding=1, + ) + + else: + self.sample = nn.Identity() + + def forward(self, x): + x = self.readoper(x) + #if use_cls_token == True: + x = x.permute(0, 3, 1, 2).contiguous() + if isinstance(self.scale_factor, float): + x = interpolate_float32(x.float(), scale_factor=self.scale_factor, mode='nearest') + x = self.sample(x) + return x + +class EncoderFeature(nn.Module): + def __init__(self, vit_channel, num_ch_dec=[256, 512, 1024, 1024], use_cls_token=True, num_register_tokens=0, tuning_mode=None): + super(EncoderFeature, self).__init__() + self.vit_channel = vit_channel + self.num_ch_dec = num_ch_dec + + self.read_3 = Token2Feature(self.vit_channel, self.num_ch_dec[3], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_2 = Token2Feature(self.vit_channel, self.num_ch_dec[2], scale_factor=1, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_1 = Token2Feature(self.vit_channel, self.num_ch_dec[1], scale_factor=2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + self.read_0 = Token2Feature(self.vit_channel, self.num_ch_dec[0], scale_factor=7/2, use_cls_token=use_cls_token, num_register_tokens=num_register_tokens, tuning_mode=tuning_mode) + + def forward(self, ref_feature): + x = self.read_3(ref_feature[3]) # 1/14 + x2 = self.read_2(ref_feature[2]) # 1/14 + x1 = self.read_1(ref_feature[1]) # 1/7 + x0 = self.read_0(ref_feature[0]) # 1/4 + + return x, x2, x1, x0 + +class DecoderFeature(nn.Module): + def __init__(self, vit_channel, num_ch_dec=[128, 256, 512, 1024, 1024], use_cls_token=True, tuning_mode=None): + super(DecoderFeature, self).__init__() + self.vit_channel = vit_channel + self.num_ch_dec = num_ch_dec + + self.upconv_3 = FuseBlock( + self.num_ch_dec[4], + self.num_ch_dec[3], + fuse=False, upsample=False, tuning_mode=tuning_mode) + + self.upconv_2 = FuseBlock( + self.num_ch_dec[3], + self.num_ch_dec[2], + tuning_mode=tuning_mode) + + self.upconv_1 = FuseBlock( + self.num_ch_dec[2], + self.num_ch_dec[1] + 2, + scale_factor=7/4, + tuning_mode=tuning_mode) + + # self.upconv_0 = FuseBlock( + # self.num_ch_dec[1], + # self.num_ch_dec[0] + 1, + # ) + + def forward(self, ref_feature): + x, x2, x1, x0 = ref_feature # 1/14 1/14 1/7 1/4 + + x = self.upconv_3(x) # 1/14 + x = self.upconv_2(x, x2) # 1/7 + x = self.upconv_1(x, x1) # 1/4 + # x = self.upconv_0(x, x0) # 4/7 + return x + +class RAFTDepthNormalDPT5(nn.Module): + def __init__(self, cfg): + super().__init__() + self.in_channels = cfg.model.decode_head.in_channels # [1024, 1024, 1024, 1024] + self.feature_channels = cfg.model.decode_head.feature_channels # [256, 512, 1024, 1024] [2/7, 1/7, 1/14, 1/14] + self.decoder_channels = cfg.model.decode_head.decoder_channels # [128, 256, 512, 1024, 1024] [-, 1/4, 1/7, 1/14, 1/14] + self.use_cls_token = cfg.model.decode_head.use_cls_token + self.up_scale = cfg.model.decode_head.up_scale + self.num_register_tokens = cfg.model.decode_head.num_register_tokens + self.min_val = cfg.data_basic.depth_normalize[0] + self.max_val = cfg.data_basic.depth_normalize[1] + self.regress_scale = 100.0\ + + try: + tuning_mode = cfg.model.decode_head.tuning_mode + except: + tuning_mode = None + self.tuning_mode = tuning_mode + + self.hidden_dims = self.context_dims = cfg.model.decode_head.hidden_channels # [128, 128, 128, 128] + self.n_gru_layers = cfg.model.decode_head.n_gru_layers # 3 + self.n_downsample = cfg.model.decode_head.n_downsample # 3, resolution of the disparity field (1/2^K) + self.iters = cfg.model.decode_head.iters # 22 + self.slow_fast_gru = cfg.model.decode_head.slow_fast_gru # True + + self.num_depth_regressor_anchor = 256 # 512 + self.used_res_channel = self.decoder_channels[1] # now, use 2/7 res + self.token2feature = EncoderFeature(self.in_channels[0], self.feature_channels, self.use_cls_token, self.num_register_tokens, tuning_mode=tuning_mode) + self.decoder_mono = DecoderFeature(self.in_channels, self.decoder_channels, tuning_mode=tuning_mode) + self.depth_regressor = nn.Sequential( + Conv2dLoRA(self.used_res_channel, + self.num_depth_regressor_anchor, + kernel_size=3, + padding=1, r = 8 if tuning_mode == 'lora' else 0), + # nn.BatchNorm2d(self.num_depth_regressor_anchor), + nn.ReLU(inplace=True), + Conv2dLoRA(self.num_depth_regressor_anchor, + self.num_depth_regressor_anchor, + kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), + ) + self.normal_predictor = nn.Sequential( + Conv2dLoRA(self.used_res_channel, + 128, + kernel_size=3, + padding=1, r = 8 if tuning_mode == 'lora' else 0,), + # nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), + Conv2dLoRA(128, 128, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), nn.ReLU(inplace=True), + Conv2dLoRA(128, 3, kernel_size=1, r = 8 if tuning_mode == 'lora' else 0), + ) + + self.context_feature_encoder = ContextFeatureEncoder(self.feature_channels, [self.hidden_dims, self.context_dims], tuning_mode=tuning_mode) + self.context_zqr_convs = nn.ModuleList([Conv2dLoRA(self.context_dims[i], self.hidden_dims[i]*3, 3, padding=3//2, r = 8 if tuning_mode == 'lora' else 0) for i in range(self.n_gru_layers)]) + self.update_block = BasicMultiUpdateBlock(cfg, hidden_dims=self.hidden_dims, out_dims=6, tuning_mode=tuning_mode) + + self.relu = nn.ReLU(inplace=True) + + def get_bins(self, bins_num): + depth_bins_vec = torch.linspace(math.log(self.min_val), math.log(self.max_val), bins_num, device=next(self.parameters()).device) + depth_bins_vec = torch.exp(depth_bins_vec) + return depth_bins_vec + + def register_depth_expectation_anchor(self, bins_num, B): + depth_bins_vec = self.get_bins(bins_num) + depth_bins_vec = depth_bins_vec.unsqueeze(0).repeat(B, 1) + self.register_buffer('depth_expectation_anchor', depth_bins_vec, persistent=False) + + def clamp(self, x): + y = self.relu(x - self.min_val) + self.min_val + y = self.max_val - self.relu(self.max_val - y) + return y + + def regress_depth(self, feature_map_d): + prob_feature = self.depth_regressor(feature_map_d) + prob = prob_feature.softmax(dim=1) + #prob = prob_feature.float().softmax(dim=1) + + ## Error logging + if torch.isnan(prob).any(): + print('prob_feat_nan!!!') + if torch.isinf(prob).any(): + print('prob_feat_inf!!!') + + # h = prob[0,:,0,0].cpu().numpy().reshape(-1) + # import matplotlib.pyplot as plt + # plt.bar(range(len(h)), h) + B = prob.shape[0] + if "depth_expectation_anchor" not in self._buffers: + self.register_depth_expectation_anchor(self.num_depth_regressor_anchor, B) + d = compute_depth_expectation( + prob, + self.depth_expectation_anchor[:B, ...]).unsqueeze(1) + + ## Error logging + if torch.isnan(d ).any(): + print('d_nan!!!') + if torch.isinf(d ).any(): + print('d_inf!!!') + + return (self.clamp(d) - self.max_val)/ self.regress_scale, prob_feature + + def pred_normal(self, feature_map, confidence): + normal_out = self.normal_predictor(feature_map) + + ## Error logging + if torch.isnan(normal_out).any(): + print('norm_nan!!!') + if torch.isinf(normal_out).any(): + print('norm_feat_inf!!!') + + return norm_normalize(torch.cat([normal_out, confidence], dim=1)) + #return norm_normalize(torch.cat([normal_out, confidence], dim=1).float()) + + def create_mesh_grid(self, height, width, batch, device="cuda", set_buffer=True): + y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=device), + torch.arange(0, width, dtype=torch.float32, device=device)], indexing='ij') + meshgrid = torch.stack((x, y)) + meshgrid = meshgrid.unsqueeze(0).repeat(batch, 1, 1, 1) + #self.register_buffer('meshgrid', meshgrid, persistent=False) + return meshgrid + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, D, H, W = flow.shape + factor = 2 ** self.n_downsample + mask = mask.view(N, 1, 9, factor, factor, H, W) + mask = torch.softmax(mask, dim=2) + #mask = torch.softmax(mask.float(), dim=2) + + #up_flow = F.unfold(factor * flow, [3,3], padding=1) + up_flow = F.unfold(flow, [3,3], padding=1) + up_flow = up_flow.view(N, D, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, D, factor*H, factor*W) + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, _, H, W = img.shape + + coords0 = coords_grid(N, H, W).to(img.device) + coords1 = coords_grid(N, H, W).to(img.device) + + return coords0, coords1 + + def upsample(self, x, scale_factor=2): + """Upsample input tensor by a factor of 2 + """ + return interpolate_float32(x, scale_factor=scale_factor*self.up_scale/8, mode="nearest") + + def forward(self, vit_features, **kwargs): + ## read vit token to multi-scale features + B, H, W, _, _, num_register_tokens = vit_features[1] + vit_features = vit_features[0] + + ## Error logging + if torch.isnan(vit_features[0]).any(): + print('vit_feature_nan!!!') + if torch.isinf(vit_features[0]).any(): + print('vit_feature_inf!!!') + + if self.use_cls_token == True: + vit_features = [[ft[:, 1+num_register_tokens:, :].view(B, H, W, self.in_channels[0]), \ + ft[:, 0:1+num_register_tokens, :].view(B, 1, 1, self.in_channels[0] * (1+num_register_tokens))] for ft in vit_features] + else: + vit_features = [ft.view(B, H, W, self.in_channels[0]) for ft in vit_features] + encoder_features = self.token2feature(vit_features) # 1/14, 1/14, 1/7, 1/4 + + ## Error logging + for en_ft in encoder_features: + if torch.isnan(en_ft).any(): + print('decoder_feature_nan!!!') + print(en_ft.shape) + if torch.isinf(en_ft).any(): + print('decoder_feature_inf!!!') + print(en_ft.shape) + + ## decode features to init-depth (and confidence) + ref_feat= self.decoder_mono(encoder_features) # now, 1/4 for depth + + ## Error logging + if torch.isnan(ref_feat).any(): + print('ref_feat_nan!!!') + if torch.isinf(ref_feat).any(): + print('ref_feat_inf!!!') + + feature_map = ref_feat[:, :-2, :, :] # feature map share of depth and normal prediction + depth_confidence_map = ref_feat[:, -2:-1, :, :] + normal_confidence_map = ref_feat[:, -1:, :, :] + depth_pred, binmap = self.regress_depth(feature_map) # regress bin for depth + normal_pred = self.pred_normal(feature_map, normal_confidence_map) # mlp for normal + + depth_init = torch.cat((depth_pred, depth_confidence_map, normal_pred), dim=1) # (N, 1+1+4, H, W) + + ## encoder features to context-feature for init-hidden-state and contex-features + cnet_list = self.context_feature_encoder(encoder_features[::-1]) + net_list = [torch.tanh(x[0]) for x in cnet_list] # x_4, x_8, x_16 of hidden state + inp_list = [torch.relu(x[1]) for x in cnet_list] # x_4, x_8, x_16 context features + + # Rather than running the GRU's conv layers on the context features multiple times, we do it once at the beginning + inp_list = [list(conv(i).split(split_size=conv.out_channels//3, dim=1)) for i,conv in zip(inp_list, self.context_zqr_convs)] + + coords0, coords1 = self.initialize_flow(net_list[0]) + if depth_init is not None: + coords1 = coords1 + depth_init + + if self.training: + low_resolution_init = [self.clamp(depth_init[:,:1] * self.regress_scale + self.max_val), depth_init[:,1:2], norm_normalize(depth_init[:,2:].clone())] + init_depth = upflow4(depth_init) + flow_predictions = [self.clamp(init_depth[:,:1] * self.regress_scale + self.max_val)] + conf_predictions = [init_depth[:,1:2]] + normal_outs = [norm_normalize(init_depth[:,2:].clone())] + + else: + flow_predictions = [] + conf_predictions = [] + samples_pred_list = [] + coord_list = [] + normal_outs = [] + low_resolution_init = [] + + for itr in range(self.iters): + # coords1 = coords1.detach() + flow = coords1 - coords0 + if self.n_gru_layers == 3 and self.slow_fast_gru: # Update low-res GRU + net_list = self.update_block(net_list, inp_list, iter32=True, iter16=False, iter08=False, update=False) + if self.n_gru_layers >= 2 and self.slow_fast_gru:# Update low-res GRU and mid-res GRU + net_list = self.update_block(net_list, inp_list, iter32=self.n_gru_layers==3, iter16=True, iter08=False, update=False) + net_list, up_mask, delta_flow = self.update_block(net_list, inp_list, None, flow, iter32=self.n_gru_layers==3, iter16=self.n_gru_layers>=2) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # We do not need to upsample or output intermediate results in test_mode + #if (not self.training) and itr < self.iters-1: + #continue + + # upsample predictions + if up_mask is None: + flow_up = self.upsample(coords1-coords0, 4) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + # flow_up = self.upsample(coords1-coords0, 4) + + flow_predictions.append(self.clamp(flow_up[:,:1] * self.regress_scale + self.max_val)) + conf_predictions.append(flow_up[:,1:2]) + normal_outs.append(norm_normalize(flow_up[:,2:].clone())) + + outputs=dict( + prediction=flow_predictions[-1], + predictions_list=flow_predictions, + confidence=conf_predictions[-1], + confidence_list=conf_predictions, + pred_logit=None, + # samples_pred_list=samples_pred_list, + # coord_list=coord_list, + prediction_normal=normal_outs[-1], + normal_out_list=normal_outs, + low_resolution_init=low_resolution_init, + ) + + return outputs + + +if __name__ == "__main__": + try: + from custom_mmpkg.custom_mmcv.utils import Config + except: + from mmengine import Config + cfg = Config.fromfile('/cpfs01/shared/public/users/mu.hu/monodepth/mono/configs/RAFTDecoder/vit.raft.full2t.py') + cfg.model.decode_head.in_channels = [384, 384, 384, 384] + cfg.model.decode_head.feature_channels = [96, 192, 384, 768] + cfg.model.decode_head.decoder_channels = [48, 96, 192, 384, 384] + cfg.model.decode_head.hidden_channels = [48, 48, 48, 48, 48] + cfg.model.decode_head.up_scale = 7 + + # cfg.model.decode_head.use_cls_token = True + # vit_feature = [[torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()], \ + # [torch.rand((2, 20, 60, 384)).cuda(), torch.rand(2, 384).cuda()]] + + cfg.model.decode_head.use_cls_token = True + cfg.model.decode_head.num_register_tokens = 4 + vit_feature = [[torch.rand((2, (74 * 74) + 5, 384)).cuda(),\ + torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ + torch.rand((2, (74 * 74) + 5, 384)).cuda(), \ + torch.rand((2, (74 * 74) + 5, 384)).cuda()], (2, 74, 74, 1036, 1036, 4)] + + decoder = RAFTDepthNormalDPT5(cfg).cuda() + output = decoder(vit_feature) + temp = 1 + + + + diff --git a/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py b/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92381a5fc3dad0ca8009c1ab0a153ce6b107c634 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/decode_heads/__init__.py @@ -0,0 +1,4 @@ +from .HourGlassDecoder import HourglassDecoder +from .RAFTDepthNormalDPTDecoder5 import RAFTDepthNormalDPT5 + +__all__=['HourglassDecoder', 'RAFTDepthNormalDPT5'] diff --git a/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py new file mode 100644 index 0000000000000000000000000000000000000000..1c9d0464228542274fbf248c7bdc821f8f2a8f62 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__base_model__.py @@ -0,0 +1,20 @@ +import torch +import torch.nn as nn +from custom_controlnet_aux.metric3d.mono.utils.comm import get_func + + +class BaseDepthModel(nn.Module): + def __init__(self, cfg, **kwargs) -> None: + super(BaseDepthModel, self).__init__() + model_type = cfg.model.type + self.depth_model = get_func('custom_controlnet_aux.metric3d.mono.model.model_pipelines.' + model_type)(cfg) + + def forward(self, data): + output = self.depth_model(**data) + + return output['prediction'], output['confidence'], output + + def inference(self, data): + with torch.no_grad(): + pred_depth, confidence, _ = self.forward(data) + return pred_depth, confidence \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b962a3f858573466e429219c4ad70951b545b637 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/__init__.py @@ -0,0 +1,6 @@ + +from .dense_pipeline import DensePredModel +from .__base_model__ import BaseDepthModel +__all__ = [ + 'DensePredModel', 'BaseDepthModel', +] \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9bae9f6de4d734ff935d9fb7eb353db423a6e8a6 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/model_pipelines/dense_pipeline.py @@ -0,0 +1,16 @@ +import torch +import torch.nn as nn +from custom_controlnet_aux.metric3d.mono.utils.comm import get_func + +class DensePredModel(nn.Module): + def __init__(self, cfg) -> None: + super(DensePredModel, self).__init__() + + self.encoder = get_func('custom_controlnet_aux.metric3d.mono.model.' + cfg.model.backbone.prefix + cfg.model.backbone.type)(**cfg.model.backbone) + self.decoder = get_func('custom_controlnet_aux.metric3d.mono.model.' + cfg.model.decode_head.prefix + cfg.model.decode_head.type)(cfg) + + def forward(self, input, **kwargs): + # [f_32, f_16, f_8, f_4] + features = self.encoder(input) + out = self.decoder(features, **kwargs) + return out \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py b/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0b58b7643ee43f84fd4e621e5b3b61b1f3f85564 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/model/monodepth_model.py @@ -0,0 +1,37 @@ +import torch +import torch.nn as nn +from .model_pipelines.__base_model__ import BaseDepthModel + +class DepthModel(BaseDepthModel): + def __init__(self, cfg, **kwards): + super(DepthModel, self).__init__(cfg) + model_type = cfg.model.type + + def inference(self, data): + with torch.no_grad(): + pred_depth, confidence, output_dict = self.forward(data) + return pred_depth, confidence, output_dict + +def get_monodepth_model( + cfg : dict, + **kwargs + ) -> nn.Module: + # config depth model + model = DepthModel(cfg, **kwargs) + #model.init_weights(load_imagenet_model, imagenet_ckpt_fpath) + assert isinstance(model, nn.Module) + return model + +def get_configured_monodepth_model( + cfg: dict, + ) -> nn.Module: + """ + Args: + @ configs: configures for the network. + @ load_imagenet_model: whether to initialize from ImageNet-pretrained model. + @ imagenet_ckpt_fpath: string representing path to file with weights to initialize model with. + Returns: + # model: depth model. + """ + model = get_monodepth_model(cfg) + return model diff --git a/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py b/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py new file mode 100644 index 0000000000000000000000000000000000000000..9788d205ef61bf7fca7500ed03252de0b4d70728 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/tools/test_scale_cano.py @@ -0,0 +1,161 @@ +import os +import os.path as osp +import cv2 +import time +import sys +CODE_SPACE=os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +sys.path.append(CODE_SPACE) +import argparse +import custom_mmpkg.custom_mmcv as mmcv +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +try: + from custom_mmpkg.custom_mmcv.utils import Config, DictAction +except: + from mmengine import Config, DictAction +from datetime import timedelta +import random +import numpy as np +from custom_controlnet_aux.metric3d.mono.utils.logger import setup_logger +import glob +from custom_controlnet_aux.metric3d.mono.utils.comm import init_env +from custom_controlnet_aux.metric3d.mono.model.monodepth_model import get_configured_monodepth_model +from custom_controlnet_aux.metric3d.mono.utils.running import load_ckpt +from custom_controlnet_aux.metric3d.mono.utils.do_test import do_scalecano_test_with_custom_data +from custom_controlnet_aux.metric3d.mono.utils.mldb import load_data_info, reset_ckpt_path +from custom_controlnet_aux.metric3d.mono.utils.custom_data import load_from_annos, load_data + +def parse_args(): + parser = argparse.ArgumentParser(description='Train a segmentor') + parser.add_argument('config', help='train config file path') + parser.add_argument('--show-dir', help='the dir to save logs and visualization results') + parser.add_argument('--load-from', help='the checkpoint file to load weights from') + parser.add_argument('--node_rank', type=int, default=0) + parser.add_argument('--nnodes', type=int, default=1, help='number of nodes') + parser.add_argument('--options', nargs='+', action=DictAction, help='custom options') + parser.add_argument('--launcher', choices=['None', 'pytorch', 'slurm', 'mpi', 'ror'], default='slurm', help='job launcher') + parser.add_argument('--test_data_path', default='None', type=str, help='the path of test data') + parser.add_argument('--batch_size', default=1, type=int, help='the batch size for inference') + args = parser.parse_args() + return args + +def main(args): + os.chdir(CODE_SPACE) + cfg = Config.fromfile(args.config) + + if args.options is not None: + cfg.merge_from_dict(args.options) + + # show_dir is determined in this priority: CLI > segment in file > filename + if args.show_dir is not None: + # update configs according to CLI args if args.show_dir is not None + cfg.show_dir = args.show_dir + else: + # use condig filename + timestamp as default show_dir if args.show_dir is None + cfg.show_dir = osp.join('./show_dirs', + osp.splitext(osp.basename(args.config))[0], + args.timestamp) + + # ckpt path + if args.load_from is None: + raise RuntimeError('Please set model path!') + cfg.load_from = args.load_from + cfg.batch_size = args.batch_size + + # load data info + data_info = {} + load_data_info('data_info', data_info=data_info) + cfg.mldb_info = data_info + # update check point info + reset_ckpt_path(cfg.model, data_info) + + # create show dir + os.makedirs(osp.abspath(cfg.show_dir), exist_ok=True) + + # init the logger before other steps + cfg.log_file = osp.join(cfg.show_dir, f'{args.timestamp}.log') + logger = setup_logger(cfg.log_file) + + # log some basic info + logger.info(f'Config:\n{cfg.pretty_text}') + + # init distributed env dirst, since logger depends on the dist info + if args.launcher == 'None': + cfg.distributed = False + else: + cfg.distributed = True + init_env(args.launcher, cfg) + logger.info(f'Distributed training: {cfg.distributed}') + + # dump config + cfg.dump(osp.join(cfg.show_dir, osp.basename(args.config))) + test_data_path = args.test_data_path + if not os.path.isabs(test_data_path): + test_data_path = osp.join(CODE_SPACE, test_data_path) + + if 'json' in test_data_path: + test_data = load_from_annos(test_data_path) + else: + test_data = load_data(args.test_data_path) + + if not cfg.distributed: + main_worker(0, cfg, args.launcher, test_data) + else: + # distributed training + if args.launcher == 'ror': + local_rank = cfg.dist_params.local_rank + main_worker(local_rank, cfg, args.launcher, test_data) + else: + mp.spawn(main_worker, nprocs=cfg.dist_params.num_gpus_per_node, args=(cfg, args.launcher, test_data)) + +def main_worker(local_rank: int, cfg: dict, launcher: str, test_data: list): + if cfg.distributed: + cfg.dist_params.global_rank = cfg.dist_params.node_rank * cfg.dist_params.num_gpus_per_node + local_rank + cfg.dist_params.local_rank = local_rank + + if launcher == 'ror': + init_torch_process_group(use_hvd=False) + else: + torch.cuda.set_device(local_rank) + default_timeout = timedelta(minutes=30) + dist.init_process_group( + backend=cfg.dist_params.backend, + init_method=cfg.dist_params.dist_url, + world_size=cfg.dist_params.world_size, + rank=cfg.dist_params.global_rank, + timeout=default_timeout) + + logger = setup_logger(cfg.log_file) + # build model + model = get_configured_monodepth_model(cfg, ) + + # config distributed training + if cfg.distributed: + model = torch.nn.parallel.DistributedDataParallel(model.cuda(), + device_ids=[local_rank], + output_device=local_rank, + find_unused_parameters=True) + else: + model = torch.nn.DataParallel(model).cuda() + + # load ckpt + model, _, _, _ = load_ckpt(cfg.load_from, model, strict_match=False) + model.eval() + + do_scalecano_test_with_custom_data( + model, + cfg, + test_data, + logger, + cfg.distributed, + local_rank, + cfg.batch_size, + ) + +if __name__ == '__main__': + args = parse_args() + timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) + args.timestamp = timestamp + main(args) \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/__init__.py b/custom_controlnet_aux/metric3d/mono/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py b/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py new file mode 100644 index 0000000000000000000000000000000000000000..37ed9fffa7aa7be7eea094280102168993912f44 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/avg_meter.py @@ -0,0 +1,475 @@ +import numpy as np +import torch +import torch.distributed as dist +import torch.nn.functional as F +import matplotlib.pyplot as plt + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self) -> None: + self.reset() + + def reset(self) -> None: + self.val = np.longdouble(0.0) + self.avg = np.longdouble(0.0) + self.sum = np.longdouble(0.0) + self.count = np.longdouble(0.0) + + def update(self, val, n: float = 1) -> None: + self.val = val + self.sum += val + self.count += n + self.avg = self.sum / (self.count + 1e-6) + +class MetricAverageMeter(AverageMeter): + """ + An AverageMeter designed specifically for evaluating segmentation results. + """ + def __init__(self, metrics: list) -> None: + """ Initialize object. """ + # average meters for metrics + self.abs_rel = AverageMeter() + self.rmse = AverageMeter() + self.silog = AverageMeter() + self.delta1 = AverageMeter() + self.delta2 = AverageMeter() + self.delta3 = AverageMeter() + + self.metrics = metrics + + self.consistency = AverageMeter() + self.log10 = AverageMeter() + self.rmse_log = AverageMeter() + self.sq_rel = AverageMeter() + + # normal + self.normal_mean = AverageMeter() + self.normal_rmse = AverageMeter() + self.normal_a1 = AverageMeter() + self.normal_a2 = AverageMeter() + + self.normal_median = AverageMeter() + self.normal_a3 = AverageMeter() + self.normal_a4 = AverageMeter() + self.normal_a5 = AverageMeter() + + + def update_metrics_cpu(self, + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor,): + """ + Update metrics on cpu + """ + + assert pred.shape == target.shape + + if len(pred.shape) == 3: + pred = pred[:, None, :, :] + target = target[:, None, :, :] + mask = mask[:, None, :, :] + elif len(pred.shape) == 2: + pred = pred[None, None, :, :] + target = target[None, None, :, :] + mask = mask[None, None, :, :] + + + # Absolute relative error + abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) + abs_rel_sum = abs_rel_sum.numpy() + valid_pics = valid_pics.numpy() + self.abs_rel.update(abs_rel_sum, valid_pics) + + # squared relative error + sqrel_sum, _ = get_sqrel_err(pred, target, mask) + sqrel_sum = sqrel_sum.numpy() + self.sq_rel.update(sqrel_sum, valid_pics) + + # root mean squared error + rmse_sum, _ = get_rmse_err(pred, target, mask) + rmse_sum = rmse_sum.numpy() + self.rmse.update(rmse_sum, valid_pics) + + # log root mean squared error + log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) + log_rmse_sum = log_rmse_sum.numpy() + self.rmse.update(log_rmse_sum, valid_pics) + + # log10 error + log10_sum, _ = get_log10_err(pred, target, mask) + log10_sum = log10_sum.numpy() + self.rmse.update(log10_sum, valid_pics) + + # scale-invariant root mean squared error in log space + silog_sum, _ = get_silog_err(pred, target, mask) + silog_sum = silog_sum.numpy() + self.silog.update(silog_sum, valid_pics) + + # ratio error, delta1, .... + delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_error(pred, target, mask) + delta1_sum = delta1_sum.numpy() + delta2_sum = delta2_sum.numpy() + delta3_sum = delta3_sum.numpy() + + self.delta1.update(delta1_sum, valid_pics) + self.delta2.update(delta1_sum, valid_pics) + self.delta3.update(delta1_sum, valid_pics) + + + def update_metrics_gpu( + self, + pred: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + is_distributed: bool, + pred_next: torch.tensor = None, + pose_f1_to_f2: torch.tensor = None, + intrinsic: torch.tensor = None): + """ + Update metric on GPU. It supports distributed processing. If multiple machines are employed, please + set 'is_distributed' as True. + """ + assert pred.shape == target.shape + + if len(pred.shape) == 3: + pred = pred[:, None, :, :] + target = target[:, None, :, :] + mask = mask[:, None, :, :] + elif len(pred.shape) == 2: + pred = pred[None, None, :, :] + target = target[None, None, :, :] + mask = mask[None, None, :, :] + + + # Absolute relative error + abs_rel_sum, valid_pics = get_absrel_err(pred, target, mask) + if is_distributed: + dist.all_reduce(abs_rel_sum), dist.all_reduce(valid_pics) + abs_rel_sum = abs_rel_sum.cpu().numpy() + valid_pics = int(valid_pics) + self.abs_rel.update(abs_rel_sum, valid_pics) + + # root mean squared error + rmse_sum, _ = get_rmse_err(pred, target, mask) + if is_distributed: + dist.all_reduce(rmse_sum) + rmse_sum = rmse_sum.cpu().numpy() + self.rmse.update(rmse_sum, valid_pics) + + # log root mean squared error + log_rmse_sum, _ = get_rmse_log_err(pred, target, mask) + if is_distributed: + dist.all_reduce(log_rmse_sum) + log_rmse_sum = log_rmse_sum.cpu().numpy() + self.rmse_log.update(log_rmse_sum, valid_pics) + + # log10 error + log10_sum, _ = get_log10_err(pred, target, mask) + if is_distributed: + dist.all_reduce(log10_sum) + log10_sum = log10_sum.cpu().numpy() + self.log10.update(log10_sum, valid_pics) + + # scale-invariant root mean squared error in log space + silog_sum, _ = get_silog_err(pred, target, mask) + if is_distributed: + dist.all_reduce(silog_sum) + silog_sum = silog_sum.cpu().numpy() + self.silog.update(silog_sum, valid_pics) + + # ratio error, delta1, .... + delta1_sum, delta2_sum, delta3_sum, _ = get_ratio_err(pred, target, mask) + if is_distributed: + dist.all_reduce(delta1_sum), dist.all_reduce(delta2_sum), dist.all_reduce(delta3_sum) + delta1_sum = delta1_sum.cpu().numpy() + delta2_sum = delta2_sum.cpu().numpy() + delta3_sum = delta3_sum.cpu().numpy() + + self.delta1.update(delta1_sum, valid_pics) + self.delta2.update(delta2_sum, valid_pics) + self.delta3.update(delta3_sum, valid_pics) + + # video consistency error + # consistency_rel_sum, valid_warps = get_video_consistency_err(pred, pred_next, pose_f1_to_f2, intrinsic) + # if is_distributed: + # dist.all_reduce(consistency_rel_sum), dist.all_reduce(valid_warps) + # consistency_rel_sum = consistency_rel_sum.cpu().numpy() + # valid_warps = int(valid_warps) + # self.consistency.update(consistency_rel_sum, valid_warps) + + ## for surface normal + def update_normal_metrics_gpu( + self, + pred: torch.Tensor, # (B, 3, H, W) + target: torch.Tensor, # (B, 3, H, W) + mask: torch.Tensor, # (B, 1, H, W) + is_distributed: bool, + ): + """ + Update metric on GPU. It supports distributed processing. If multiple machines are employed, please + set 'is_distributed' as True. + """ + assert pred.shape == target.shape + + valid_pics = torch.sum(mask, dtype=torch.float32) + 1e-6 + + if valid_pics < 10: + return + + mean_error = rmse_error = a1_error = a2_error = dist_node_cnt = valid_pics + normal_error = torch.cosine_similarity(pred, target, dim=1) + normal_error = torch.clamp(normal_error, min=-1.0, max=1.0) + angle_error = torch.acos(normal_error) * 180.0 / torch.pi + angle_error = angle_error[:, None, :, :] + angle_error = angle_error[mask] + # Calculation error + mean_error = angle_error.sum() / valid_pics + rmse_error = torch.sqrt( torch.sum(torch.square(angle_error)) / valid_pics ) + median_error = angle_error.median() + a1_error = 100.0 * (torch.sum(angle_error < 5) / valid_pics) + a2_error = 100.0 * (torch.sum(angle_error < 7.5) / valid_pics) + + a3_error = 100.0 * (torch.sum(angle_error < 11.25) / valid_pics) + a4_error = 100.0 * (torch.sum(angle_error < 22.5) / valid_pics) + a5_error = 100.0 * (torch.sum(angle_error < 30) / valid_pics) + + # if valid_pics > 1e-5: + # If the current node gets data with valid normal + dist_node_cnt = (valid_pics - 1e-6) / valid_pics + + if is_distributed: + dist.all_reduce(dist_node_cnt) + dist.all_reduce(mean_error) + dist.all_reduce(rmse_error) + dist.all_reduce(a1_error) + dist.all_reduce(a2_error) + + dist.all_reduce(a3_error) + dist.all_reduce(a4_error) + dist.all_reduce(a5_error) + + dist_node_cnt = dist_node_cnt.cpu().numpy() + self.normal_mean.update(mean_error.cpu().numpy(), dist_node_cnt) + self.normal_rmse.update(rmse_error.cpu().numpy(), dist_node_cnt) + self.normal_a1.update(a1_error.cpu().numpy(), dist_node_cnt) + self.normal_a2.update(a2_error.cpu().numpy(), dist_node_cnt) + + self.normal_median.update(median_error.cpu().numpy(), dist_node_cnt) + self.normal_a3.update(a3_error.cpu().numpy(), dist_node_cnt) + self.normal_a4.update(a4_error.cpu().numpy(), dist_node_cnt) + self.normal_a5.update(a5_error.cpu().numpy(), dist_node_cnt) + + + def get_metrics(self,): + """ + """ + metrics_dict = {} + for metric in self.metrics: + metrics_dict[metric] = self.__getattribute__(metric).avg + return metrics_dict + + + def get_metrics(self,): + """ + """ + metrics_dict = {} + for metric in self.metrics: + metrics_dict[metric] = self.__getattribute__(metric).avg + return metrics_dict + +def get_absrel_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes absolute relative error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + # Mean Absolute Relative Error + rel = torch.abs(t_m - p_m) / (t_m + 1e-10) # compute errors + abs_rel_sum = torch.sum(rel.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + abs_err = abs_rel_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(abs_err), valid_pics + +def get_sqrel_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes squared relative error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + # squared Relative Error + sq_rel = torch.abs(t_m - p_m) ** 2 / (t_m + 1e-10) # compute errors + sq_rel_sum = torch.sum(sq_rel.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + sqrel_err = sq_rel_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(sqrel_err), valid_pics + +def get_log10_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log10 error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + log10_diff = torch.abs(diff_log) + log10_sum = torch.sum(log10_diff.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + log10_err = log10_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + return torch.sum(log10_err), valid_pics + +def get_rmse_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + square = (t_m - p_m) ** 2 + rmse_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + rmse = torch.sqrt(rmse_sum / (num + 1e-10)) + valid_pics = torch.sum(num > 0) + return torch.sum(rmse), valid_pics + +def get_rmse_log_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + square = diff_log ** 2 + rmse_log_sum = torch.sum(square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + rmse_log = torch.sqrt(rmse_log_sum / (num + 1e-10)) + valid_pics = torch.sum(num > 0) + return torch.sum(rmse_log), valid_pics + +def get_silog_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes log rmse error. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred * mask + + diff_log = (torch.log10(p_m+1e-10) - torch.log10(t_m+1e-10)) * mask + diff_log_sum = torch.sum(diff_log.reshape((b, c, -1)), dim=2) # [b, c] + diff_log_square = diff_log ** 2 + diff_log_square_sum = torch.sum(diff_log_square.reshape((b, c, -1)), dim=2) # [b, c] + num = torch.sum(mask.reshape((b, c, -1)), dim=2) # [b, c] + silog = torch.sqrt(diff_log_square_sum / (num + 1e-10) - (diff_log_sum / (num + 1e-10)) ** 2) + valid_pics = torch.sum(num > 0) + return torch.sum(silog), valid_pics + +def get_ratio_err(pred: torch.tensor, + target: torch.tensor, + mask: torch.tensor, + ): + """ + Computes the percentage of pixels for which the ratio of the two depth maps is less than a given threshold. + Tasks preprocessed depths (no nans, infs and non-positive values). + pred, target, and mask should be in the shape of [b, c, h, w] + """ + assert len(pred.shape) == 4, len(target.shape) == 4 + b, c, h, w = pred.shape + mask = mask.to(torch.float) + t_m = target * mask + p_m = pred + + gt_pred = t_m / (p_m + 1e-10) + pred_gt = p_m / (t_m + 1e-10) + gt_pred = gt_pred.reshape((b, c, -1)) + pred_gt = pred_gt.reshape((b, c, -1)) + gt_pred_gt = torch.cat((gt_pred, pred_gt), axis=1) + ratio_max = torch.amax(gt_pred_gt, axis=1) + + delta_1_sum = torch.sum((ratio_max < 1.25), dim=1) # [b, ] + delta_2_sum = torch.sum((ratio_max < 1.25 ** 2), dim=1) # [b, ] + delta_3_sum = torch.sum((ratio_max < 1.25 ** 3), dim=1) # [b, ] + num = torch.sum(mask.reshape((b, -1)), dim=1) # [b, ] + + delta_1 = delta_1_sum / (num + 1e-10) + delta_2 = delta_2_sum / (num + 1e-10) + delta_3 = delta_3_sum / (num + 1e-10) + valid_pics = torch.sum(num > 0) + + return torch.sum(delta_1), torch.sum(delta_2), torch.sum(delta_3), valid_pics + + +if __name__ == '__main__': + cfg = ['abs_rel', 'delta1'] + dam = MetricAverageMeter(cfg) + + pred_depth = np.random.random([2, 480, 640]) + gt_depth = np.random.random([2, 480, 640]) - 0.5 + intrinsic = [[100, 100, 200, 200], [200, 200, 300, 300]] + + pred = torch.from_numpy(pred_depth).cuda() + gt = torch.from_numpy(gt_depth).cuda() + + mask = gt > 0 + dam.update_metrics_gpu(pred, gt, mask, False) + eval_error = dam.get_metrics() + print(eval_error) + \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/comm.py b/custom_controlnet_aux/metric3d/mono/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..99d68a3e95f8b7b99edc1309961b8e5ce40ea9cd --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/comm.py @@ -0,0 +1,322 @@ +import importlib +import torch +import torch.distributed as dist +from .avg_meter import AverageMeter +from collections import defaultdict, OrderedDict +import os +import socket +from custom_mmpkg.custom_mmcv.utils import collect_env as collect_base_env +try: + from custom_mmpkg.custom_mmcv.utils import get_git_hash +except: + from mmengine.utils import get_git_hash +#import mono.mmseg as mmseg +# import mmseg +import time +import datetime +import logging + + +def main_process() -> bool: + return get_rank() == 0 + #return not cfg.distributed or \ + # (cfg.distributed and cfg.local_rank == 0) + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + +def _find_free_port(): + # refer to https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/launch.py # noqa: E501 + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(('', 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + +def _is_free_port(port): + ips = socket.gethostbyname_ex(socket.gethostname())[-1] + ips.append('localhost') + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + return all(s.connect_ex((ip, port)) != 0 for ip in ips) + + +# def collect_env(): +# """Collect the information of the running environments.""" +# env_info = collect_base_env() +# env_info['MMSegmentation'] = f'{mmseg.__version__}+{get_git_hash()[:7]}' + +# return env_info + +def init_env(launcher, cfg): + """Initialize distributed training environment. + If argument ``cfg.dist_params.dist_url`` is specified as 'env://', then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + """ + if launcher == 'slurm': + _init_dist_slurm(cfg) + elif launcher == 'ror': + _init_dist_ror(cfg) + elif launcher == 'None': + _init_none_dist(cfg) + else: + raise RuntimeError(f'{cfg.launcher} has not been supported!') + +def _init_none_dist(cfg): + cfg.dist_params.num_gpus_per_node = 1 + cfg.dist_params.world_size = 1 + cfg.dist_params.nnodes = 1 + cfg.dist_params.node_rank = 0 + cfg.dist_params.global_rank = 0 + cfg.dist_params.local_rank = 0 + os.environ["WORLD_SIZE"] = str(1) + +def _init_dist_ror(cfg): + from ac2.ror.comm import get_local_rank, get_world_rank, get_local_size, get_node_rank, get_world_size + cfg.dist_params.num_gpus_per_node = get_local_size() + cfg.dist_params.world_size = get_world_size() + cfg.dist_params.nnodes = (get_world_size()) // (get_local_size()) + cfg.dist_params.node_rank = get_node_rank() + cfg.dist_params.global_rank = get_world_rank() + cfg.dist_params.local_rank = get_local_rank() + os.environ["WORLD_SIZE"] = str(get_world_size()) + + +def _init_dist_slurm(cfg): + if 'NNODES' not in os.environ: + os.environ['NNODES'] = str(cfg.dist_params.nnodes) + if 'NODE_RANK' not in os.environ: + os.environ['NODE_RANK'] = str(cfg.dist_params.node_rank) + + #cfg.dist_params. + num_gpus = torch.cuda.device_count() + world_size = int(os.environ['NNODES']) * num_gpus + os.environ['WORLD_SIZE'] = str(world_size) + + # config port + if 'MASTER_PORT' in os.environ: + master_port = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable + else: + # if torch.distributed default port(29500) is available + # then use it, else find a free port + if _is_free_port(16500): + master_port = '16500' + else: + master_port = str(_find_free_port()) + os.environ['MASTER_PORT'] = master_port + + # config addr + if 'MASTER_ADDR' in os.environ: + master_addr = str(os.environ['MASTER_PORT']) # use MASTER_PORT in the environment variable + # elif cfg.dist_params.dist_url is not None: + # master_addr = ':'.join(str(cfg.dist_params.dist_url).split(':')[:2]) + else: + master_addr = '127.0.0.1' #'tcp://127.0.0.1' + os.environ['MASTER_ADDR'] = master_addr + + # set dist_url to 'env://' + cfg.dist_params.dist_url = 'env://' #f"{master_addr}:{master_port}" + + cfg.dist_params.num_gpus_per_node = num_gpus + cfg.dist_params.world_size = world_size + cfg.dist_params.nnodes = int(os.environ['NNODES']) + cfg.dist_params.node_rank = int(os.environ['NODE_RANK']) + + # if int(os.environ['NNODES']) > 1 and cfg.dist_params.dist_url.startswith("file://"): + # raise Warning("file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://") + + +def get_func(func_name): + """ + Helper to return a function object by name. func_name must identify + a function in this module or the path to a function relative to the base + module. + @ func_name: function name. + """ + if func_name == '': + return None + try: + parts = func_name.split('.') + # Refers to a function in this module + if len(parts) == 1: + return globals()[parts[0]] + # Otherwise, assume we're referencing a module under modeling + module_name = '.'.join(parts[:-1]) + module = importlib.import_module(module_name) + return getattr(module, parts[-1]) + except: + raise RuntimeError(f'Failed to find function: {func_name}') + +class Timer(object): + """A simple timer.""" + + def __init__(self): + self.reset() + + def tic(self): + # using time.time instead of time.clock because time time.clock + # does not normalize for multithreading + self.start_time = time.time() + + def toc(self, average=True): + self.diff = time.time() - self.start_time + self.total_time += self.diff + self.calls += 1 + self.average_time = self.total_time / self.calls + if average: + return self.average_time + else: + return self.diff + + def reset(self): + self.total_time = 0. + self.calls = 0 + self.start_time = 0. + self.diff = 0. + self.average_time = 0. + +class TrainingStats(object): + """Track vital training statistics.""" + def __init__(self, log_period, tensorboard_logger=None): + self.log_period = log_period + self.tblogger = tensorboard_logger + self.tb_ignored_keys = ['iter', 'eta', 'epoch', 'time'] + self.iter_timer = Timer() + # Window size for smoothing tracked values (with median filtering) + self.filter_size = log_period + def create_smoothed_value(): + return AverageMeter() + self.smoothed_losses = defaultdict(create_smoothed_value) + #self.smoothed_metrics = defaultdict(create_smoothed_value) + #self.smoothed_total_loss = AverageMeter() + + + def IterTic(self): + self.iter_timer.tic() + + def IterToc(self): + return self.iter_timer.toc(average=False) + + def reset_iter_time(self): + self.iter_timer.reset() + + def update_iter_stats(self, losses_dict): + """Update tracked iteration statistics.""" + for k, v in losses_dict.items(): + self.smoothed_losses[k].update(float(v), 1) + + def log_iter_stats(self, cur_iter, optimizer, max_iters, val_err={}): + """Log the tracked statistics.""" + if (cur_iter % self.log_period == 0): + stats = self.get_stats(cur_iter, optimizer, max_iters, val_err) + log_stats(stats) + if self.tblogger: + self.tb_log_stats(stats, cur_iter) + for k, v in self.smoothed_losses.items(): + v.reset() + + def tb_log_stats(self, stats, cur_iter): + """Log the tracked statistics to tensorboard""" + for k in stats: + # ignore some logs + if k not in self.tb_ignored_keys: + v = stats[k] + if isinstance(v, dict): + self.tb_log_stats(v, cur_iter) + else: + self.tblogger.add_scalar(k, v, cur_iter) + + + def get_stats(self, cur_iter, optimizer, max_iters, val_err = {}): + eta_seconds = self.iter_timer.average_time * (max_iters - cur_iter) + + eta = str(datetime.timedelta(seconds=int(eta_seconds))) + stats = OrderedDict( + iter=cur_iter, # 1-indexed + time=self.iter_timer.average_time, + eta=eta, + ) + optimizer_state_dict = optimizer.state_dict() + lr = {} + for i in range(len(optimizer_state_dict['param_groups'])): + lr_name = 'group%d_lr' % i + lr[lr_name] = optimizer_state_dict['param_groups'][i]['lr'] + + stats['lr'] = OrderedDict(lr) + for k, v in self.smoothed_losses.items(): + stats[k] = v.avg + + stats['val_err'] = OrderedDict(val_err) + stats['max_iters'] = max_iters + return stats + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + Args: + @input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + @average (bool): whether to do average or sum + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +def log_stats(stats): + logger = logging.getLogger() + """Log training statistics to terminal""" + lines = "[Step %d/%d]\n" % ( + stats['iter'], stats['max_iters']) + + lines += "\t\tloss: %.3f, time: %.6f, eta: %s\n" % ( + stats['total_loss'], stats['time'], stats['eta']) + + # log loss + lines += "\t\t" + for k, v in stats.items(): + if 'loss' in k.lower() and 'total_loss' not in k.lower(): + lines += "%s: %.3f" % (k, v) + ", " + lines = lines[:-3] + lines += '\n' + + # validate criteria + lines += "\t\tlast val err:" + ", ".join("%s: %.6f" % (k, v) for k, v in stats['val_err'].items()) + ", " + lines += '\n' + + # lr in different groups + lines += "\t\t" + ", ".join("%s: %.8f" % (k, v) for k, v in stats['lr'].items()) + lines += '\n' + logger.info(lines[:-1]) # remove last new linen_pxl + diff --git a/custom_controlnet_aux/metric3d/mono/utils/custom_data.py b/custom_controlnet_aux/metric3d/mono/utils/custom_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d9fab47478bc471c51b5454cc15550079ebec21b --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/custom_data.py @@ -0,0 +1,34 @@ +import glob +import os +import json +import cv2 + +def load_from_annos(anno_path): + with open(anno_path, 'r') as f: + annos = json.load(f)['files'] + + datas = [] + for i, anno in enumerate(annos): + rgb = anno['rgb'] + depth = anno['depth'] if 'depth' in anno else None + depth_scale = anno['depth_scale'] if 'depth_scale' in anno else 1.0 + intrinsic = anno['cam_in'] if 'cam_in' in anno else None + normal = anno['normal'] if 'normal' in anno else None + + data_i = { + 'rgb': rgb, + 'depth': depth, + 'depth_scale': depth_scale, + 'intrinsic': intrinsic, + 'filename': os.path.basename(rgb), + 'folder': rgb.split('/')[-3], + 'normal': normal + } + datas.append(data_i) + return datas + +def load_data(path: str): + rgbs = glob.glob(path + '/*.jpg') + glob.glob(path + '/*.png') + #intrinsic = [835.8179931640625, 835.8179931640625, 961.5419921875, 566.8090209960938] #[721.53769, 721.53769, 609.5593, 172.854] + data = [{'rgb': i, 'depth': None, 'intrinsic': None, 'filename': os.path.basename(i), 'folder': i.split('/')[-3]} for i in rgbs] + return data \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/do_test.py b/custom_controlnet_aux/metric3d/mono/utils/do_test.py new file mode 100644 index 0000000000000000000000000000000000000000..7a6bd5a474912304f4d9be56eb31b633168613fb --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/do_test.py @@ -0,0 +1,380 @@ +import torch +import torch.nn.functional as F +import logging +import os +import os.path as osp +from custom_controlnet_aux.metric3d.mono.utils.avg_meter import MetricAverageMeter +from custom_controlnet_aux.metric3d.mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs +import cv2 +from tqdm import tqdm +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt + +def to_cuda(data: dict): + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.cuda(non_blocking=True) + if isinstance(v, list) and len(v)>=1 and isinstance(v[0], torch.Tensor): + for i, l_i in enumerate(v): + data[k][i] = l_i.cuda(non_blocking=True) + return data + +def align_scale(pred: torch.tensor, target: torch.tensor): + mask = target > 0 + if torch.sum(mask) > 10: + scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) + else: + scale = 1 + pred_scaled = pred * scale + return pred_scaled, scale + +def align_scale_shift(pred: torch.tensor, target: torch.tensor): + mask = target > 0 + target_mask = target[mask].cpu().numpy() + pred_mask = pred[mask].cpu().numpy() + if torch.sum(mask) > 10: + scale, shift = np.polyfit(pred_mask, target_mask, deg=1) + if scale < 0: + scale = torch.median(target[mask]) / (torch.median(pred[mask]) + 1e-8) + shift = 0 + else: + scale = 1 + shift = 0 + pred = pred * scale + shift + return pred, scale + +def align_scale_shift_numpy(pred: np.array, target: np.array): + mask = target > 0 + target_mask = target[mask] + pred_mask = pred[mask] + if np.sum(mask) > 10: + scale, shift = np.polyfit(pred_mask, target_mask, deg=1) + if scale < 0: + scale = np.median(target[mask]) / (np.median(pred[mask]) + 1e-8) + shift = 0 + else: + scale = 1 + shift = 0 + pred = pred * scale + shift + return pred, scale + + +def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: + """ + Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. + """ + fx, fy, u0, v0 = intrinsics + f = (fx + fy) / 2.0 + # principle point location + x_row = np.arange(0, W).astype(np.float32) + x_row_center_norm = (x_row - u0) / W + x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] + + y_col = np.arange(0, H).astype(np.float32) + y_col_center_norm = (y_col - v0) / H + y_center = np.tile(y_col_center_norm, (W, 1)).T # [H, W] + + # FoV + fov_x = np.arctan(x_center / (f / W)) + fov_y = np.arctan(y_center / (f / H)) + + cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) + return cam_model + +def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio): + """ + Resize the input. + Resizing consists of two processed, i.e. 1) to the canonical space (adjust the camera model); 2) resize the image while the camera model holds. Thus the + label will be scaled with the resize factor. + """ + padding = [123.675, 116.28, 103.53] + h, w, _ = image.shape + resize_ratio_h = output_shape[0] / canonical_shape[0] + resize_ratio_w = output_shape[1] / canonical_shape[1] + to_scale_ratio = min(resize_ratio_h, resize_ratio_w) + + resize_ratio = to_canonical_ratio * to_scale_ratio + + reshape_h = int(resize_ratio * h) + reshape_w = int(resize_ratio * w) + + pad_h = max(output_shape[0] - reshape_h, 0) + pad_w = max(output_shape[1] - reshape_w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + + # resize + image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + # padding + image = cv2.copyMakeBorder( + image, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=padding) + + # Resize, adjust principle point + intrinsic[2] = intrinsic[2] * to_scale_ratio + intrinsic[3] = intrinsic[3] * to_scale_ratio + + cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) + cam_model = cv2.copyMakeBorder( + cam_model, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=-1) + + pad=[pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + label_scale_factor=1/to_scale_ratio + return image, cam_model, pad, label_scale_factor + + +def get_prediction( + model: torch.nn.Module, + input: torch.tensor, + cam_model: torch.tensor, + pad_info: torch.tensor, + scale_info: torch.tensor, + gt_depth: torch.tensor, + normalize_scale: float, + ori_shape: list=[], +): + + data = dict( + input=input, + cam_model=cam_model, + ) + pred_depth, confidence, output_dict = model.inference(data) + + return pred_depth, confidence, output_dict + +def transform_test_data_scalecano(rgb, intrinsic, data_basic, device="cuda"): + """ + Pre-process the input for forwarding. Employ `label scale canonical transformation.' + Args: + rgb: input rgb image. [H, W, 3] + intrinsic: camera intrinsic parameter, [fx, fy, u0, v0] + data_basic: predefined canonical space in configs. + """ + canonical_space = data_basic['canonical_space'] + forward_size = data_basic.crop_size + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None] + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None] + + # BGR to RGB + #rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB) + + ori_h, ori_w, _ = rgb.shape + ori_focal = (intrinsic[0] + intrinsic[1]) / 2 + canonical_focal = canonical_space['focal_length'] + + cano_label_scale_ratio = canonical_focal / ori_focal + + canonical_intrinsic = [ + intrinsic[0] * cano_label_scale_ratio, + intrinsic[1] * cano_label_scale_ratio, + intrinsic[2], + intrinsic[3], + ] + + # resize + rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, forward_size, canonical_intrinsic, [ori_h, ori_w], 1.0) + + # label scale factor + label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio + + rgb = torch.from_numpy(rgb.transpose((2, 0, 1))).float() + rgb = torch.div((rgb - mean), std) + rgb = rgb.to(device) + + cam_model = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() + cam_model = cam_model[None, :, :, :].to(device) + cam_model_stacks = [ + torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False) + for i in [2, 4, 8, 16, 32] + ] + return rgb, cam_model_stacks, pad, label_scale_factor + +def do_scalecano_test_with_custom_data( + model: torch.nn.Module, + cfg: dict, + test_data: list, + logger: logging.RootLogger, + is_distributed: bool = True, + local_rank: int = 0, + bs: int = 2, # Batch size parameter +): + + show_dir = cfg.show_dir + save_interval = 1 + save_imgs_dir = show_dir + '/vis' + os.makedirs(save_imgs_dir, exist_ok=True) + save_pcd_dir = show_dir + '/pcd' + os.makedirs(save_pcd_dir, exist_ok=True) + + normalize_scale = cfg.data_basic.depth_range[1] + dam = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + dam_median = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + dam_global = MetricAverageMeter(['abs_rel', 'rmse', 'silog', 'delta1', 'delta2', 'delta3']) + + # Process data in batches + for i in tqdm(range(0, len(test_data), bs)): + batch_data = test_data[i:i + bs] # Extract batch + rgb_inputs, pads, label_scale_factors, gt_depths, rgb_origins = [], [], [], [], [] + + for an in batch_data: + print(an['rgb']) + rgb_origin = cv2.imread(an['rgb'])[:, :, ::-1].copy() + rgb_origins.append(rgb_origin) + gt_depth = None + if an['depth'] is not None: + gt_depth = cv2.imread(an['depth'], -1) + gt_depth_scale = an['depth_scale'] + gt_depth = gt_depth / gt_depth_scale + gt_depths.append(gt_depth) + + intrinsic = an['intrinsic'] + if intrinsic is None: + intrinsic = [1000.0, 1000.0, rgb_origin.shape[1]/2, rgb_origin.shape[0]/2] + + rgb_input, _, pad, label_scale_factor = transform_test_data_scalecano(rgb_origin, intrinsic, cfg.data_basic) + rgb_inputs.append(rgb_input) + pads.append(pad) + label_scale_factors.append(label_scale_factor) + + # Process the batch + pred_depths, outputs = get_prediction( + model=model, + input=torch.stack(rgb_inputs), # Stack inputs for batch processing + cam_model=None, + pad_info=pads, + scale_info=None, + gt_depth=None, + normalize_scale=None, + ) + + for j, gt_depth in enumerate(gt_depths): + normal_out = None + if 'normal_out_list' in outputs.keys(): + normal_out = outputs['normal_out_list'][0][j, :] + + postprocess_per_image( + i*bs+j, + pred_depths[j, :], + gt_depth, + intrinsic, + rgb_origins[j], + normal_out, + pads[j], + batch_data[j], + dam, + dam_median, + dam_global, + is_distributed, + save_imgs_dir, + save_pcd_dir, + normalize_scale, + label_scale_factors[j], + ) + + #if gt_depth_flag: + if False: + eval_error = dam.get_metrics() + print('w/o match :', eval_error) + + eval_error_median = dam_median.get_metrics() + print('median match :', eval_error_median) + + eval_error_global = dam_global.get_metrics() + print('global match :', eval_error_global) + else: + print('missing gt_depth, only save visualizations...') + + +def postprocess_per_image(i, pred_depth, gt_depth, intrinsic, rgb_origin, normal_out, pad, an, dam, dam_median, dam_global, is_distributed, save_imgs_dir, save_pcd_dir, normalize_scale, scale_info): + + pred_depth = pred_depth.squeeze() + pred_depth = pred_depth[pad[0] : pred_depth.shape[0] - pad[1], pad[2] : pred_depth.shape[1] - pad[3]] + pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], [rgb_origin.shape[0], rgb_origin.shape[1]], mode='bilinear').squeeze() # to original size + pred_depth = pred_depth * normalize_scale / scale_info + + pred_depth = (pred_depth > 0) * (pred_depth < 300) * pred_depth + if gt_depth is not None: + + pred_depth = torch.nn.functional.interpolate(pred_depth[None, None, :, :], (gt_depth.shape[0], gt_depth.shape[1]), mode='bilinear').squeeze() # to original size + + gt_depth = torch.from_numpy(gt_depth).cuda() + + pred_depth_median = pred_depth * gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median() + pred_global, _ = align_scale_shift(pred_depth, gt_depth) + + mask = (gt_depth > 1e-8) + dam.update_metrics_gpu(pred_depth, gt_depth, mask, is_distributed) + dam_median.update_metrics_gpu(pred_depth_median, gt_depth, mask, is_distributed) + dam_global.update_metrics_gpu(pred_global, gt_depth, mask, is_distributed) + print(gt_depth[gt_depth != 0].median() / pred_depth[gt_depth != 0].median(), ) + + os.makedirs(osp.join(save_imgs_dir, an['folder']), exist_ok=True) + rgb_torch = torch.from_numpy(rgb_origin).to(pred_depth.device).permute(2, 0, 1) + mean = torch.tensor([123.675, 116.28, 103.53]).float()[:, None, None].to(rgb_torch.device) + std = torch.tensor([58.395, 57.12, 57.375]).float()[:, None, None].to(rgb_torch.device) + rgb_torch = torch.div((rgb_torch - mean), std) + + save_val_imgs( + i, + pred_depth, + gt_depth if gt_depth is not None else torch.ones_like(pred_depth, device=pred_depth.device), + rgb_torch, + osp.join(an['folder'], an['filename']), + save_imgs_dir, + ) + #save_raw_imgs(pred_depth.detach().cpu().numpy(), rgb_torch, osp.join(an['folder'], an['filename']), save_imgs_dir, 1000.0) + + # pcd + pred_depth = pred_depth.detach().cpu().numpy() + #pcd = reconstruct_pcd(pred_depth, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3]) + #os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) + #save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4]+'.ply')) + + if an['intrinsic'] == None: + #for r in [0.9, 1.0, 1.1]: + for r in [1.0]: + #for f in [600, 800, 1000, 1250, 1500]: + for f in [1000]: + pcd = reconstruct_pcd(pred_depth, f * r, f * (2-r), intrinsic[2], intrinsic[3]) + fstr = '_fx_' + str(int(f * r)) + '_fy_' + str(int(f * (2-r))) + os.makedirs(osp.join(save_pcd_dir, an['folder']), exist_ok=True) + save_point_cloud(pcd.reshape((-1, 3)), rgb_origin.reshape(-1, 3), osp.join(save_pcd_dir, an['folder'], an['filename'][:-4] + fstr +'.ply')) + + if normal_out is not None: + pred_normal = normal_out[:3, :, :] # (3, H, W) + H, W = pred_normal.shape[1:] + pred_normal = pred_normal[ :, pad[0]:H-pad[1], pad[2]:W-pad[3]] + + gt_normal = None + #if gt_normal_flag: + if False: + pred_normal = torch.nn.functional.interpolate(pred_normal, size=gt_normal.shape[2:], mode='bilinear', align_corners=True) + gt_normal = cv2.imread(norm_path) + gt_normal = cv2.cvtColor(gt_normal, cv2.COLOR_BGR2RGB) + gt_normal = np.array(gt_normal).astype(np.uint8) + gt_normal = ((gt_normal.astype(np.float32) / 255.0) * 2.0) - 1.0 + norm_valid_mask = (np.linalg.norm(gt_normal, axis=2, keepdims=True) > 0.5) + gt_normal = gt_normal * norm_valid_mask + gt_normal_mask = ~torch.all(gt_normal == 0, dim=1, keepdim=True) + dam.update_normal_metrics_gpu(pred_normal, gt_normal, gt_normal_mask, cfg.distributed)# save valiad normal + + save_normal_val_imgs(iter, + pred_normal, + gt_normal if gt_normal is not None else torch.ones_like(pred_normal, device=pred_normal.device), + rgb_torch, # data['input'], + osp.join(an['folder'], 'normal_'+an['filename']), + save_imgs_dir, + ) + diff --git a/custom_controlnet_aux/metric3d/mono/utils/logger.py b/custom_controlnet_aux/metric3d/mono/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ca48c613b2fdc5352b13ccb7d0bfdc1df5e3b531 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/logger.py @@ -0,0 +1,102 @@ +import atexit +import logging +import os +import sys +import time +import torch +from termcolor import colored + +__all__ = ["setup_logger", ] + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + +def setup_logger( + output=None, distributed_rank=0, *, name='metricdepth', color=True, abbrev_name=None +): + """ + Initialize the detectron2 logger and set its verbosity level to "DEBUG". + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + abbrev_name (str): an abbreviation of the module, to avoid log names in logs. + Set to "" not log the root module in logs. + By default, will abbreviate "detectron2" to "d2" and leave other + modules unchanged. + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger() + logger.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + logger.propagate = False + + if abbrev_name is None: + abbrev_name = "d2" + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s %(message)s ", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + ".rank{}".format(distributed_rank) + os.makedirs(os.path.dirname(filename), exist_ok=True) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.INFO) # NOTE: if more detailed, change it to logging.DEBUG + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + + return logger + +from iopath.common.file_io import PathManager as PathManagerBase + + +PathManager = PathManagerBase() + +# cache the opened file object, so that different calls to 'setup_logger +# with the same file name can safely write to the same file. +def _cached_log_stream(filename): + # use 1K buffer if writting to cloud storage + io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1) + atexit.register(io.close) + return io + \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/mldb.py b/custom_controlnet_aux/metric3d/mono/utils/mldb.py new file mode 100644 index 0000000000000000000000000000000000000000..d74ac53fd0302e2e954105bade52e6de4c18e2f6 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/mldb.py @@ -0,0 +1,34 @@ +from types import ModuleType +import data_info + +def load_data_info(module_name, data_info={}, mldb_type='mldb_info', module=None): + if module is None: + module = globals().get(module_name, None) + if module: + for key, value in module.__dict__.items(): + if not (key.startswith('__')) and not (key.startswith('_')): + if key == 'mldb_info': + data_info.update(value) + elif isinstance(value, ModuleType): + load_data_info(module_name + '.' + key, data_info, module=value) + else: + raise RuntimeError(f'Try to access "mldb_info", but cannot find {module_name} module.') + +def reset_ckpt_path(cfg, data_info): + if isinstance(cfg, dict): + for key in cfg.keys(): + if key == 'backbone': + new_ckpt_path = data_info['checkpoint']['mldb_root'] + '/' + data_info['checkpoint'][cfg.backbone.type] + cfg.backbone.update(checkpoint=new_ckpt_path) + continue + elif isinstance(cfg.get(key), dict): + reset_ckpt_path(cfg.get(key), data_info) + else: + continue + else: + return + +if __name__ == '__main__': + mldb_info_tmp = {} + load_data_info('mldb_data_info', mldb_info_tmp) + print('results', mldb_info_tmp.keys()) \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py b/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..2d26314d806ea961f6bf09d1fb195bf5e364f181 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/pcd_filter.py @@ -0,0 +1,24 @@ +import open3d as o3d +import numpy as np + +def downsample_and_filter(pcd_file): + pcd = o3d.io.read_point_cloud(pcd_file, max_bound_div = 750, neighbor_num = 8) + point_num = len(pcd.points) + if (point_num > 10000000): + voxel_down_pcd = o3d.geometry.PointCloud.uniform_down_sample(pcd, int(point_num / 10000000)+1) + else: + voxel_down_pcd = pcd + max_bound = voxel_down_pcd.get_max_bound() + ball_radius = np.linalg.norm(max_bound) / max_bound_div + pcd_filter, _ = voxel_down_pcd.remove_radius_outlier(neighbor_num, ball_radius) + print('filtered size', len(pcd_filter.points), 'pre size:', len(pcd.points)) + o3d.io.write_point_cloud(pcd_file[:-4] + '_filtered.ply', pcd_filter) + + +if __name__ == "__main__": + import os + dir_path = './data/demo_pcd' + for pcd_file in os.listdir(dir_path): + #if 'jonathan' in pcd_file: set max_bound_div to 300 and neighbot_num to 8 + downsample_and_filter(os.path.join(dir_path, pcd_file)) + \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/running.py b/custom_controlnet_aux/metric3d/mono/utils/running.py new file mode 100644 index 0000000000000000000000000000000000000000..55efa56b69d1a7c728c9d9980ec9abda37757222 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/running.py @@ -0,0 +1,77 @@ +import os +import torch +import torch.nn as nn +from custom_controlnet_aux.metric3d.mono.utils.comm import main_process +import copy +import inspect +import logging +import glob + + +def load_ckpt(load_path, model, optimizer=None, scheduler=None, strict_match=True, loss_scaler=None): + """ + Load the check point for resuming training or finetuning. + """ + logger = logging.getLogger() + if os.path.isfile(load_path): + if main_process(): + logger.info(f"Loading weight '{load_path}'") + checkpoint = torch.load(load_path, map_location="cpu", weights_only=True) + ckpt_state_dict = checkpoint['model_state_dict'] + model.load_state_dict(ckpt_state_dict, strict=strict_match) + + if optimizer is not None: + optimizer.load_state_dict(checkpoint['optimizer']) + if scheduler is not None: + scheduler.load_state_dict(checkpoint['scheduler']) + if loss_scaler is not None and 'scaler' in checkpoint: + scheduler.load_state_dict(checkpoint['scaler']) + del ckpt_state_dict + del checkpoint + if main_process(): + logger.info(f"Successfully loaded weight: '{load_path}'") + if scheduler is not None and optimizer is not None: + logger.info(f"Resume training from: '{load_path}'") + else: + if main_process(): + raise RuntimeError(f"No weight found at '{load_path}'") + return model, optimizer, scheduler, loss_scaler + + +def save_ckpt(cfg, model, optimizer, scheduler, curr_iter=0, curr_epoch=None, loss_scaler=None): + """ + Save the model, optimizer, lr scheduler. + """ + logger = logging.getLogger() + + if 'IterBasedRunner' in cfg.runner.type: + max_iters = cfg.runner.max_iters + elif 'EpochBasedRunner' in cfg.runner.type: + max_iters = cfg.runner.max_epochs + else: + raise TypeError(f'{cfg.runner.type} is not supported') + + ckpt = dict( + model_state_dict=model.module.state_dict(), + optimizer=optimizer.state_dict(), + max_iter=cfg.runner.max_iters if 'max_iters' in cfg.runner \ + else cfg.runner.max_epochs, + scheduler=scheduler.state_dict(), + ) + + if loss_scaler is not None: + ckpt.update(dict(scaler=loss_scaler.state_dict())) + + ckpt_dir = os.path.join(cfg.work_dir, 'ckpt') + os.makedirs(ckpt_dir, exist_ok=True) + + save_name = os.path.join(ckpt_dir, 'step%08d.pth' %curr_iter) + saved_ckpts = glob.glob(ckpt_dir + '/step*.pth') + torch.save(ckpt, save_name) + + # keep the last 8 ckpts + if len(saved_ckpts) > 20: + saved_ckpts.sort() + os.remove(saved_ckpts.pop(0)) + + logger.info(f'Save model: {save_name}') diff --git a/custom_controlnet_aux/metric3d/mono/utils/transform.py b/custom_controlnet_aux/metric3d/mono/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..2af94efe754d6f72325db6fdc170f30fbfb8c2fe --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/transform.py @@ -0,0 +1,408 @@ +import collections +import cv2 +import math +import numpy as np +import numbers +import random +import torch + +import matplotlib +import matplotlib.cm + + +""" +Provides a set of Pytorch transforms that use OpenCV instead of PIL (Pytorch default) +for image manipulation. +""" + +class Compose(object): + # Composes transforms: transforms.Compose([transforms.RandScale([0.5, 2.0]), transforms.ToTensor()]) + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + for t in self.transforms: + images, labels, intrinsics, cam_models, other_labels, transform_paras = t(images, labels, intrinsics, cam_models, other_labels, transform_paras) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class ToTensor(object): + # Converts numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W). + def __init__(self, **kwargs): + return + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + if not isinstance(images, list) or not isinstance(labels, list) or not isinstance(intrinsics, list): + raise (RuntimeError("transform.ToTensor() only handle inputs/labels/intrinsics lists.")) + if len(images) != len(intrinsics): + raise (RuntimeError("Numbers of images and intrinsics are not matched.")) + if not isinstance(images[0], np.ndarray) or not isinstance(labels[0], np.ndarray): + raise (RuntimeError("transform.ToTensor() only handle np.ndarray for the input and label." + "[eg: data readed by cv2.imread()].\n")) + if not isinstance(intrinsics[0], list): + raise (RuntimeError("transform.ToTensor() only handle list for the camera intrinsics")) + + if len(images[0].shape) > 3 or len(images[0].shape) < 2: + raise (RuntimeError("transform.ToTensor() only handle image(np.ndarray) with 3 dims or 2 dims.\n")) + if len(labels[0].shape) > 3 or len(labels[0].shape) < 2: + raise (RuntimeError("transform.ToTensor() only handle label(np.ndarray) with 3 dims or 2 dims.\n")) + + if len(intrinsics[0]) >4 or len(intrinsics[0]) < 3: + raise (RuntimeError("transform.ToTensor() only handle intrinsic(list) with 3 sizes or 4 sizes.\n")) + + for i, img in enumerate(images): + if len(img.shape) == 2: + img = np.expand_dims(img, axis=2) + images[i] = torch.from_numpy(img.transpose((2, 0, 1))).float() + for i, lab in enumerate(labels): + if len(lab.shape) == 2: + lab = np.expand_dims(lab, axis=0) + labels[i] = torch.from_numpy(lab).float() + for i, intrinsic in enumerate(intrinsics): + if len(intrinsic) == 3: + intrinsic = [intrinsic[0],] + intrinsic + intrinsics[i] = torch.tensor(intrinsic, dtype=torch.float) + if cam_models is not None: + for i, cam_model in enumerate(cam_models): + cam_models[i] = torch.from_numpy(cam_model.transpose((2, 0, 1))).float() if cam_model is not None else None + if other_labels is not None: + for i, lab in enumerate(other_labels): + if len(lab.shape) == 2: + lab = np.expand_dims(lab, axis=0) + other_labels[i] = torch.from_numpy(lab).float() + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class Normalize(object): + # Normalize tensor with mean and standard deviation along channel: channel = (channel - mean) / std + def __init__(self, mean, std=None, **kwargs): + if std is None: + assert len(mean) > 0 + else: + assert len(mean) == len(std) + self.mean = torch.tensor(mean).float()[:, None, None] + self.std = torch.tensor(std).float()[:, None, None] if std is not None \ + else torch.tensor([1.0, 1.0, 1.0]).float()[:, None, None] + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + # if self.std is None: + # # for t, m in zip(image, self.mean): + # # t.sub(m) + # image = image - self.mean + # if ref_images is not None: + # for i, ref_i in enumerate(ref_images): + # ref_images[i] = ref_i - self.mean + # else: + # # for t, m, s in zip(image, self.mean, self.std): + # # t.sub(m).div(s) + # image = (image - self.mean) / self.std + # if ref_images is not None: + # for i, ref_i in enumerate(ref_images): + # ref_images[i] = (ref_i - self.mean) / self.std + for i, img in enumerate(images): + img = torch.div((img - self.mean), self.std) + images[i] = img + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class LableScaleCanonical(object): + """ + To solve the ambiguity observation for the mono branch, i.e. different focal length (object size) with the same depth, cameras are + mapped to a canonical space. To mimic this, we set the focal length to a canonical one and scale the depth value. NOTE: resize the image based on the ratio can also solve + Args: + images: list of RGB images. + labels: list of depth/disparity labels. + other labels: other labels, such as instance segmentations, semantic segmentations... + """ + def __init__(self, **kwargs): + self.canonical_focal = kwargs['focal_length'] + + def _get_scale_ratio(self, intrinsic): + target_focal_x = intrinsic[0] + label_scale_ratio = self.canonical_focal / target_focal_x + pose_scale_ratio = 1.0 + return label_scale_ratio, pose_scale_ratio + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + assert len(images[0].shape) == 3 and len(labels[0].shape) == 2 + assert labels[0].dtype == np.float32 + + label_scale_ratio = None + pose_scale_ratio = None + + for i in range(len(intrinsics)): + img_i = images[i] + label_i = labels[i] if i < len(labels) else None + intrinsic_i = intrinsics[i].copy() + cam_model_i = cam_models[i] if cam_models is not None and i < len(cam_models) else None + + label_scale_ratio, pose_scale_ratio = self._get_scale_ratio(intrinsic_i) + + # adjust the focal length, map the current camera to the canonical space + intrinsics[i] = [intrinsic_i[0] * label_scale_ratio, intrinsic_i[1] * label_scale_ratio, intrinsic_i[2], intrinsic_i[3]] + + # scale the label to the canonical space + if label_i is not None: + labels[i] = label_i * label_scale_ratio + + if cam_model_i is not None: + # As the focal length is adjusted (canonical focal length), the camera model should be re-built + ori_h, ori_w, _ = img_i.shape + cam_models[i] = build_camera_model(ori_h, ori_w, intrinsics[i]) + + + if transform_paras is not None: + transform_paras.update(label_scale_factor=label_scale_ratio, focal_scale_factor=label_scale_ratio) + + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class ResizeKeepRatio(object): + """ + Resize and pad to a given size. Hold the aspect ratio. + This resizing assumes that the camera model remains unchanged. + Args: + resize_size: predefined output size. + """ + def __init__(self, resize_size, padding=None, ignore_label=-1, **kwargs): + if isinstance(resize_size, int): + self.resize_h = resize_size + self.resize_w = resize_size + elif isinstance(resize_size, collections.Iterable) and len(resize_size) == 2 \ + and isinstance(resize_size[0], int) and isinstance(resize_size[1], int) \ + and resize_size[0] > 0 and resize_size[1] > 0: + self.resize_h = resize_size[0] + self.resize_w = resize_size[1] + else: + raise (RuntimeError("crop size error.\n")) + if padding is None: + self.padding = padding + elif isinstance(padding, list): + if all(isinstance(i, numbers.Number) for i in padding): + self.padding = padding + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if len(padding) != 3: + raise (RuntimeError("padding channel is not equal with 3\n")) + else: + raise (RuntimeError("padding in Crop() should be a number list\n")) + if isinstance(ignore_label, int): + self.ignore_label = ignore_label + else: + raise (RuntimeError("ignore_label should be an integer number\n")) + # self.crop_size = kwargs['crop_size'] + self.canonical_focal = kwargs['focal_length'] + + def main_data_transform(self, image, label, intrinsic, cam_model, resize_ratio, padding, to_scale_ratio): + """ + Resize data first and then do the padding. + 'label' will be scaled. + """ + h, w, _ = image.shape + reshape_h = int(resize_ratio * h) + reshape_w = int(resize_ratio * w) + + pad_h, pad_w, pad_h_half, pad_w_half = padding + + # resize + image = cv2.resize(image, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + # padding + image = cv2.copyMakeBorder( + image, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.padding) + + if label is not None: + # label = cv2.resize(label, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) + label = resize_depth_preserve(label, (reshape_h, reshape_w)) + label = cv2.copyMakeBorder( + label, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + # scale the label + label = label / to_scale_ratio + + # Resize, adjust principle point + if intrinsic is not None: + intrinsic[0] = intrinsic[0] * resize_ratio / to_scale_ratio + intrinsic[1] = intrinsic[1] * resize_ratio / to_scale_ratio + intrinsic[2] = intrinsic[2] * resize_ratio + intrinsic[3] = intrinsic[3] * resize_ratio + + if cam_model is not None: + #cam_model = cv2.resize(cam_model, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_LINEAR) + cam_model = build_camera_model(reshape_h, reshape_w, intrinsic) + cam_model = cv2.copyMakeBorder( + cam_model, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + + # Pad, adjust the principle point + if intrinsic is not None: + intrinsic[2] = intrinsic[2] + pad_w_half + intrinsic[3] = intrinsic[3] + pad_h_half + return image, label, intrinsic, cam_model + + def get_label_scale_factor(self, image, intrinsic, resize_ratio): + ori_h, ori_w, _ = image.shape + # crop_h, crop_w = self.crop_size + ori_focal = intrinsic[0] + + to_canonical_ratio = self.canonical_focal / ori_focal + to_scale_ratio = resize_ratio / to_canonical_ratio + return to_scale_ratio + + def __call__(self, images, labels, intrinsics, cam_models=None, other_labels=None, transform_paras=None): + target_h, target_w, _ = images[0].shape + resize_ratio_h = self.resize_h / target_h + resize_ratio_w = self.resize_w / target_w + resize_ratio = min(resize_ratio_h, resize_ratio_w) + reshape_h = int(resize_ratio * target_h) + reshape_w = int(resize_ratio * target_w) + pad_h = max(self.resize_h - reshape_h, 0) + pad_w = max(self.resize_w - reshape_w, 0) + pad_h_half = int(pad_h / 2) + pad_w_half = int(pad_w / 2) + + pad_info = [pad_h, pad_w, pad_h_half, pad_w_half] + to_scale_ratio = self.get_label_scale_factor(images[0], intrinsics[0], resize_ratio) + + for i in range(len(images)): + img = images[i] + label = labels[i] if i < len(labels) else None + intrinsic = intrinsics[i] if i < len(intrinsics) else None + cam_model = cam_models[i] if cam_models is not None and i < len(cam_models) else None + img, label, intrinsic, cam_model = self.main_data_transform( + img, label, intrinsic, cam_model, resize_ratio, pad_info, to_scale_ratio) + images[i] = img + if label is not None: + labels[i] = label + if intrinsic is not None: + intrinsics[i] = intrinsic + if cam_model is not None: + cam_models[i] = cam_model + + if other_labels is not None: + + for i, other_lab in enumerate(other_labels): + # resize + other_lab = cv2.resize(other_lab, dsize=(reshape_w, reshape_h), interpolation=cv2.INTER_NEAREST) + # pad + other_labels[i] = cv2.copyMakeBorder( + other_lab, + pad_h_half, + pad_h - pad_h_half, + pad_w_half, + pad_w - pad_w_half, + cv2.BORDER_CONSTANT, + value=self.ignore_label) + + pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half] + if transform_paras is not None: + pad_old = transform_paras['pad'] if 'pad' in transform_paras else [0,0,0,0] + new_pad = [pad_old[0] + pad[0], pad_old[1] + pad[1], pad_old[2] + pad[2], pad_old[3] + pad[3]] + transform_paras.update(dict(pad=new_pad)) + if 'label_scale_factor' in transform_paras: + transform_paras['label_scale_factor'] = transform_paras['label_scale_factor'] * 1.0 / to_scale_ratio + else: + transform_paras.update(label_scale_factor=1.0/to_scale_ratio) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +class BGR2RGB(object): + # Converts image from BGR order to RGB order, for model initialized from Pytorch + def __init__(self, **kwargs): + return + def __call__(self, images, labels, intrinsics, cam_models=None,other_labels=None, transform_paras=None): + for i, img in enumerate(images): + images[i] = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return images, labels, intrinsics, cam_models, other_labels, transform_paras + + +def resize_depth_preserve(depth, shape): + """ + Resizes depth map preserving all valid depth pixels + Multiple downsampled points can be assigned to the same pixel. + + Parameters + ---------- + depth : np.array [h,w] + Depth map + shape : tuple (H,W) + Output shape + + Returns + ------- + depth : np.array [H,W,1] + Resized depth map + """ + # Store dimensions and reshapes to single column + depth = np.squeeze(depth) + h, w = depth.shape + x = depth.reshape(-1) + # Create coordinate grid + uv = np.mgrid[:h, :w].transpose(1, 2, 0).reshape(-1, 2) + # Filters valid points + idx = x > 0 + crd, val = uv[idx], x[idx] + # Downsamples coordinates + crd[:, 0] = (crd[:, 0] * (shape[0] / h) + 0.5).astype(np.int32) + crd[:, 1] = (crd[:, 1] * (shape[1] / w) + 0.5).astype(np.int32) + # Filters points inside image + idx = (crd[:, 0] < shape[0]) & (crd[:, 1] < shape[1]) + crd, val = crd[idx], val[idx] + # Creates downsampled depth image and assigns points + depth = np.zeros(shape) + depth[crd[:, 0], crd[:, 1]] = val + # Return resized depth map + return depth + + +def build_camera_model(H : int, W : int, intrinsics : list) -> np.array: + """ + Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map. + """ + fx, fy, u0, v0 = intrinsics + f = (fx + fy) / 2.0 + # principle point location + x_row = np.arange(0, W).astype(np.float32) + x_row_center_norm = (x_row - u0) / W + x_center = np.tile(x_row_center_norm, (H, 1)) # [H, W] + + y_col = np.arange(0, H).astype(np.float32) + y_col_center_norm = (y_col - v0) / H + y_center = np.tile(y_col_center_norm, (W, 1)).T + + # FoV + fov_x = np.arctan(x_center / (f / W)) + fov_y = np.arctan(y_center/ (f / H)) + + cam_model = np.stack([x_center, y_center, fov_x, fov_y], axis=2) + return cam_model + +def gray_to_colormap(img, cmap='rainbow'): + """ + Transfer gray map to matplotlib colormap + """ + assert img.ndim == 2 + + img[img<0] = 0 + mask_invalid = img < 1e-10 + img = img / (img.max() + 1e-8) + norm = matplotlib.colors.Normalize(vmin=0, vmax=1.1) + cmap_m = matplotlib.cm.get_cmap(cmap) + map = matplotlib.cm.ScalarMappable(norm=norm, cmap=cmap_m) + colormap = (map.to_rgba(img)[:, :, :3] * 255).astype(np.uint8) + colormap[mask_invalid] = 0 + return colormap \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py b/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py new file mode 100644 index 0000000000000000000000000000000000000000..a0986d482a2ec68be1dd65719adec662272b833c --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/unproj_pcd.py @@ -0,0 +1,88 @@ +import numpy as np +import torch +from plyfile import PlyData, PlyElement +import cv2 + + +def get_pcd_base(H, W, u0, v0, fx, fy): + x_row = np.arange(0, W) + x = np.tile(x_row, (H, 1)) + x = x.astype(np.float32) + u_m_u0 = x - u0 + + y_col = np.arange(0, H) # y_col = np.arange(0, height) + y = np.tile(y_col, (W, 1)).T + y = y.astype(np.float32) + v_m_v0 = y - v0 + + x = u_m_u0 / fx + y = v_m_v0 / fy + z = np.ones_like(x) + pw = np.stack([x, y, z], axis=2) # [h, w, c] + return pw + + +def reconstruct_pcd(depth, fx, fy, u0, v0, pcd_base=None, mask=None): + if type(depth) == torch.__name__: + depth = depth.cpu().numpy().squeeze() + depth = cv2.medianBlur(depth, 5) + if pcd_base is None: + H, W = depth.shape + pcd_base = get_pcd_base(H, W, u0, v0, fx, fy) + pcd = depth[:, :, None] * pcd_base + if mask: + pcd[mask] = 0 + return pcd + + +def save_point_cloud(pcd, rgb, filename, binary=True): + """Save an RGB point cloud as a PLY file. + :paras + @pcd: Nx3 matrix, the XYZ coordinates + @rgb: Nx3 matrix, the rgb colors for each 3D point + """ + assert pcd.shape[0] == rgb.shape[0] + + if rgb is None: + gray_concat = np.tile(np.array([128], dtype=np.uint8), + (pcd.shape[0], 3)) + points_3d = np.hstack((pcd, gray_concat)) + else: + points_3d = np.hstack((pcd, rgb)) + python_types = (float, float, float, int, int, int) + npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), + ('green', 'u1'), ('blue', 'u1')] + if binary is True: + # Format into Numpy structured array + vertices = [] + for row_idx in range(points_3d.shape[0]): + cur_point = points_3d[row_idx] + vertices.append( + tuple( + dtype(point) + for dtype, point in zip(python_types, cur_point))) + vertices_array = np.array(vertices, dtype=npy_types) + el = PlyElement.describe(vertices_array, 'vertex') + + # write + PlyData([el]).write(filename) + else: + x = np.squeeze(points_3d[:, 0]) + y = np.squeeze(points_3d[:, 1]) + z = np.squeeze(points_3d[:, 2]) + r = np.squeeze(points_3d[:, 3]) + g = np.squeeze(points_3d[:, 4]) + b = np.squeeze(points_3d[:, 5]) + + ply_head = 'ply\n' \ + 'format ascii 1.0\n' \ + 'element vertex %d\n' \ + 'property float x\n' \ + 'property float y\n' \ + 'property float z\n' \ + 'property uchar red\n' \ + 'property uchar green\n' \ + 'property uchar blue\n' \ + 'end_header' % r.shape[0] + # ---- Save ply data to disk + np.savetxt(filename, np.column_stack[x, y, z, r, g, b], fmt='%f %f %f %d %d %d', header=ply_head, comments='') \ No newline at end of file diff --git a/custom_controlnet_aux/metric3d/mono/utils/visualization.py b/custom_controlnet_aux/metric3d/mono/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..27e5620f433fb07ba4645d9ae6d20d9d9aa66ad1 --- /dev/null +++ b/custom_controlnet_aux/metric3d/mono/utils/visualization.py @@ -0,0 +1,139 @@ +import matplotlib.pyplot as plt +import os, cv2 +import numpy as np +from custom_controlnet_aux.metric3d.mono.utils.transform import gray_to_colormap +import shutil +import glob +from custom_controlnet_aux.metric3d.mono.utils.running import main_process +import torch + +def save_raw_imgs( + pred: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + scale: float=200.0, + target: torch.tensor=None, + ): + """ + Save raw GT, predictions, RGB in the same file. + """ + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_rgb.jpg'), rgb) + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_d.png'), (pred*scale).astype(np.uint16)) + if target is not None: + cv2.imwrite(os.path.join(save_dir, filename[:-4]+'_gt.png'), (target*scale).astype(np.uint16)) + + +def save_val_imgs( + iter: int, + pred: torch.tensor, + target: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + tb_logger=None + ): + """ + Save GT, predictions, RGB in the same file. + """ + rgb, pred_scale, target_scale, pred_color, target_color = get_data_for_log(pred, target, rgb) + rgb = rgb.transpose((1, 2, 0)) + cat_img = np.concatenate([rgb, pred_color, target_color], axis=0) + plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) + + # save to tensorboard + if tb_logger is not None: + tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) + +def save_normal_val_imgs( + iter: int, + pred: torch.tensor, + targ: torch.tensor, + rgb: torch.tensor, + filename: str, + save_dir: str, + tb_logger=None, + mask=None, + ): + """ + Save GT, predictions, RGB in the same file. + """ + mean = np.array([123.675, 116.28, 103.53])[np.newaxis, np.newaxis, :] + std= np.array([58.395, 57.12, 57.375])[np.newaxis, np.newaxis, :] + pred = pred.squeeze() + targ = targ.squeeze() + rgb = rgb.squeeze() + + if pred.size(0) == 3: + pred = pred.permute(1,2,0) + if targ.size(0) == 3: + targ = targ.permute(1,2,0) + if rgb.size(0) == 3: + rgb = rgb.permute(1,2,0) + + pred_color = vis_surface_normal(pred, mask) + targ_color = vis_surface_normal(targ, mask) + rgb_color = ((rgb.cpu().numpy() * std) + mean).astype(np.uint8) + + try: + cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) + except: + pred_color = cv2.resize(pred_color, (rgb.shape[1], rgb.shape[0])) + targ_color = cv2.resize(targ_color, (rgb.shape[1], rgb.shape[0])) + cat_img = np.concatenate([rgb_color, pred_color, targ_color], axis=0) + + plt.imsave(os.path.join(save_dir, filename[:-4]+'_merge.jpg'), cat_img) + # cv2.imwrite(os.path.join(save_dir, filename[:-4]+'.jpg'), pred_color) + # save to tensorboard + if tb_logger is not None: + tb_logger.add_image(f'{filename[:-4]}_merge.jpg', cat_img.transpose((2, 0, 1)), iter) + +def get_data_for_log(pred: torch.tensor, target: torch.tensor, rgb: torch.tensor): + mean = np.array([123.675, 116.28, 103.53])[:, np.newaxis, np.newaxis] + std= np.array([58.395, 57.12, 57.375])[:, np.newaxis, np.newaxis] + + pred = pred.squeeze().cpu().numpy() + target = target.squeeze().cpu().numpy() + rgb = rgb.squeeze().cpu().numpy() + + pred[pred<0] = 0 + target[target<0] = 0 + max_scale = max(pred.max(), target.max()) + pred_scale = (pred/max_scale * 10000).astype(np.uint16) + target_scale = (target/max_scale * 10000).astype(np.uint16) + pred_color = gray_to_colormap(pred) + target_color = gray_to_colormap(target) + pred_color = cv2.resize(pred_color, (rgb.shape[2], rgb.shape[1])) + target_color = cv2.resize(target_color, (rgb.shape[2], rgb.shape[1])) + + rgb = ((rgb * std) + mean).astype(np.uint8) + return rgb, pred_scale, target_scale, pred_color, target_color + + +def create_html(name2path, save_path='index.html', size=(256, 384)): + # table description + cols = [] + for k, v in name2path.items(): + col_i = Col('img', k, v) # specify image content for column + cols.append(col_i) + # html table generation + imagetable(cols, out_file=save_path, imsize=size) + +def vis_surface_normal(normal: torch.tensor, mask: torch.tensor=None) -> np.array: + """ + Visualize surface normal. Transfer surface normal value from [-1, 1] to [0, 255] + Aargs: + normal (torch.tensor, [h, w, 3]): surface normal + mask (torch.tensor, [h, w]): valid masks + """ + normal = normal.cpu().numpy().squeeze() + n_img_L2 = np.sqrt(np.sum(normal ** 2, axis=2, keepdims=True)) + n_img_norm = normal / (n_img_L2 + 1e-8) + normal_vis = n_img_norm * 127 + normal_vis += 128 + normal_vis = normal_vis.astype(np.uint8) + if mask is not None: + mask = mask.cpu().numpy().squeeze() + normal_vis[~mask] = 0 + return normal_vis + diff --git a/custom_controlnet_aux/midas/LICENSE b/custom_controlnet_aux/midas/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..277b5c11be103f028a8d10985139f1da10c2f08e --- /dev/null +++ b/custom_controlnet_aux/midas/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/custom_controlnet_aux/midas/__init__.py b/custom_controlnet_aux/midas/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6eab6928c3a0c58bc17e4a9e2996590645f0401a --- /dev/null +++ b/custom_controlnet_aux/midas/__init__.py @@ -0,0 +1,76 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +from .api import MiDaSInference + + +class MidasDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, model_type="dpt_hybrid", filename="dpt_hybrid-midas-501f0c75.pt"): + subfolder = "annotator/ckpts" if pretrained_model_or_path == "lllyasviel/ControlNet" else '' + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder) + model = MiDaSInference(model_type=model_type, model_path=model_path) + return cls(model) + + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, a=np.pi * 2.0, bg_th=0.1, depth_and_normal=False, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + image_depth = detected_map + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float() + image_depth = image_depth.to(self.device) + image_depth = image_depth / 127.5 - 1.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model(image_depth)[0] + + depth_pt = depth.clone() + depth_pt -= torch.min(depth_pt) + depth_pt /= torch.max(depth_pt) + depth_pt = depth_pt.cpu().numpy() + depth_image = (depth_pt * 255.0).clip(0, 255).astype(np.uint8) + + if depth_and_normal: + depth_np = depth.cpu().numpy() + x = cv2.Sobel(depth_np, cv2.CV_32F, 1, 0, ksize=3) + y = cv2.Sobel(depth_np, cv2.CV_32F, 0, 1, ksize=3) + z = np.ones_like(x) * a + x[depth_pt < bg_th] = 0 + y[depth_pt < bg_th] = 0 + normal = np.stack([x, y, z], axis=2) + normal /= np.sum(normal ** 2.0, axis=2, keepdims=True) ** 0.5 + normal_image = (normal * 127.5 + 127.5).clip(0, 255).astype(np.uint8)[:, :, ::-1] + + depth_image = HWC3(depth_image) + if depth_and_normal: + normal_image = HWC3(normal_image) + + + depth_image = remove_pad(depth_image) + if depth_and_normal: + normal_image = remove_pad(normal_image) + + if output_type == "pil": + depth_image = Image.fromarray(depth_image) + if depth_and_normal: + normal_image = Image.fromarray(normal_image) + + if depth_and_normal: + return depth_image, normal_image + else: + return depth_image diff --git a/custom_controlnet_aux/midas/api.py b/custom_controlnet_aux/midas/api.py new file mode 100644 index 0000000000000000000000000000000000000000..d59b4dbb6c1f3fe7dae49698f414d4895353a90c --- /dev/null +++ b/custom_controlnet_aux/midas/api.py @@ -0,0 +1,169 @@ +# based on https://github.com/isl-org/MiDaS + +import cv2 +import os +import torch +import torch.nn as nn +from torchvision.transforms import Compose + +from custom_midas_repo.midas.dpt_depth import DPTDepthModel +from custom_midas_repo.midas.midas_net import MidasNet +from custom_midas_repo.midas.midas_net_custom import MidasNet_small +from custom_midas_repo.midas.transforms import Resize, NormalizeImage, PrepareForNet +from custom_controlnet_aux.util import annotator_ckpts_path + + +ISL_PATHS = { + "dpt_large": os.path.join(annotator_ckpts_path, "dpt_large-midas-2f21e586.pt"), + "dpt_hybrid": os.path.join(annotator_ckpts_path, "dpt_hybrid-midas-501f0c75.pt"), + "midas_v21": "", + "midas_v21_small": "", +} + +remote_model_path = "https://huggingface.co/lllyasviel/ControlNet/resolve/main/annotator/ckpts/dpt_hybrid-midas-501f0c75.pt" + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def load_midas_transform(model_type): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load transform only + if model_type == "dpt_large": # DPT-Large + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + elif model_type == "midas_v21_small": + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + + else: + assert False, f"model_type '{model_type}' not implemented, use: --model_type large" + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return transform + + +def load_model(model_type, model_path=None): + # https://github.com/isl-org/MiDaS/blob/master/run.py + # load network + model_path = model_path or ISL_PATHS[model_type] + if model_type == "dpt_large": # DPT-Large + model = DPTDepthModel( + path=model_path, + backbone="vitl16_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "dpt_hybrid": # DPT-Hybrid + if not os.path.exists(model_path): + from basicsr.utils.download_util import load_file_from_url + load_file_from_url(remote_model_path, model_dir=annotator_ckpts_path) + + model = DPTDepthModel( + path=model_path, + backbone="vitb_rn50_384", + non_negative=True, + ) + net_w, net_h = 384, 384 + resize_mode = "minimal" + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + elif model_type == "midas_v21": + model = MidasNet(model_path, non_negative=True) + net_w, net_h = 384, 384 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + elif model_type == "midas_v21_small": + model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True, + non_negative=True, blocks={'expand': True}) + net_w, net_h = 256, 256 + resize_mode = "upper_bound" + normalization = NormalizeImage( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ) + + else: + print(f"model_type '{model_type}' not implemented, use: --model_type large") + assert False + + transform = Compose( + [ + Resize( + net_w, + net_h, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method=resize_mode, + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet(), + ] + ) + + return model.eval(), transform + + +class MiDaSInference(nn.Module): + MODEL_TYPES_TORCH_HUB = [ + "DPT_Large", + "DPT_Hybrid", + "MiDaS_small" + ] + MODEL_TYPES_ISL = [ + "dpt_large", + "dpt_hybrid", + "midas_v21", + "midas_v21_small", + ] + + def __init__(self, model_type, model_path): + super().__init__() + assert (model_type in self.MODEL_TYPES_ISL) + model, _ = load_model(model_type, model_path) + self.model = model + self.model.train = disabled_train + + def forward(self, x): + with torch.no_grad(): + prediction = self.model(x) + return prediction + diff --git a/custom_controlnet_aux/midas/utils.py b/custom_controlnet_aux/midas/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9d3b5b66370fa98da9e067ba53ead848ea9a59 --- /dev/null +++ b/custom_controlnet_aux/midas/utils.py @@ -0,0 +1,189 @@ +"""Utils for monoDepth.""" +import sys +import re +import numpy as np +import cv2 +import torch + + +def read_pfm(path): + """Read pfm file. + + Args: + path (str): path to file + + Returns: + tuple: (data, scale) + """ + with open(path, "rb") as file: + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == "PF": + color = True + elif header.decode("ascii") == "Pf": + color = False + else: + raise Exception("Not a PFM file: " + path) + + dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception("Malformed PFM header.") + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: + # little-endian + endian = "<" + scale = -scale + else: + # big-endian + endian = ">" + + data = np.fromfile(file, endian + "f") + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + + return data, scale + + +def write_pfm(path, image, scale=1): + """Write pfm file. + + Args: + path (str): pathto file + image (array): data + scale (int, optional): Scale. Defaults to 1. + """ + + with open(path, "wb") as file: + color = None + + if image.dtype.name != "float32": + raise Exception("Image dtype must be float32.") + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif ( + len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 + ): # greyscale + color = False + else: + raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") + + file.write("PF\n" if color else "Pf\n".encode()) + file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == "<" or endian == "=" and sys.byteorder == "little": + scale = -scale + + file.write("%f\n".encode() % scale) + + image.tofile(file) + + +def read_image(path): + """Read image and output RGB image (0-1). + + Args: + path (str): path to file + + Returns: + array: RGB image (0-1) + """ + img = cv2.imread(path) + + if img.ndim == 2: + img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 + + return img + + +def resize_image(img): + """Resize image and make it fit for network. + + Args: + img (array): image + + Returns: + tensor: data ready for network + """ + height_orig = img.shape[0] + width_orig = img.shape[1] + + if width_orig > height_orig: + scale = width_orig / 384 + else: + scale = height_orig / 384 + + height = (np.ceil(height_orig / scale / 32) * 32).astype(int) + width = (np.ceil(width_orig / scale / 32) * 32).astype(int) + + img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) + + img_resized = ( + torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() + ) + img_resized = img_resized.unsqueeze(0) + + return img_resized + + +def resize_depth(depth, width, height): + """Resize depth map and bring to CPU (numpy). + + Args: + depth (tensor): depth + width (int): image width + height (int): image height + + Returns: + array: processed depth + """ + depth = torch.squeeze(depth[0, :, :, :]).to("cpu") + + depth_resized = cv2.resize( + depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC + ) + + return depth_resized + +def write_depth(path, depth, bits=1): + """Write depth map to pfm and png file. + + Args: + path (str): filepath without extension + depth (array): depth + """ + write_pfm(path + ".pfm", depth.astype(np.float32)) + + depth_min = depth.min() + depth_max = depth.max() + + max_val = (2**(8*bits))-1 + + if depth_max - depth_min > np.finfo("float").eps: + out = max_val * (depth - depth_min) / (depth_max - depth_min) + else: + out = np.zeros(depth.shape, dtype=depth.type) + + if bits == 1: + cv2.imwrite(path + ".png", out.astype("uint8")) + elif bits == 2: + cv2.imwrite(path + ".png", out.astype("uint16")) + + return diff --git a/custom_controlnet_aux/mlsd/LICENSE b/custom_controlnet_aux/mlsd/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..d855c6db44b4e873eedd750d34fa2eaf22e22363 --- /dev/null +++ b/custom_controlnet_aux/mlsd/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2021-present NAVER Corp. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/custom_controlnet_aux/mlsd/__init__.py b/custom_controlnet_aux/mlsd/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fe3b0a972c953894e9e4e60fc62847559d7c8022 --- /dev/null +++ b/custom_controlnet_aux/mlsd/__init__.py @@ -0,0 +1,51 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +from .models.mbv2_mlsd_large import MobileV2_MLSD_Large +from .utils import pred_lines + + +class MLSDdetector: + def __init__(self, model): + self.model = model + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="mlsd_large_512_fp32.pth"): + subfolder = "annotator/ckpts" if pretrained_model_or_path == "lllyasviel/ControlNet" else '' + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder) + model = MobileV2_MLSD_Large() + model.load_state_dict(torch.load(model_path), strict=True) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + return self + + def __call__(self, input_image, thr_v=0.1, thr_d=0.1, detect_resolution=512, output_type="pil", upscale_method="INTER_AREA", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + img = detected_map + img_output = np.zeros_like(img) + try: + with torch.no_grad(): + lines = pred_lines(img, self.model, [img.shape[0], img.shape[1]], thr_v, thr_d) + for line in lines: + x_start, y_start, x_end, y_end = [int(val) for val in line] + cv2.line(img_output, (x_start, y_start), (x_end, y_end), [255, 255, 255], 1) + except Exception as e: + pass + + detected_map = remove_pad(HWC3(img_output[:, :, 0])) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/mlsd/models/__init__.py b/custom_controlnet_aux/mlsd/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py b/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py new file mode 100644 index 0000000000000000000000000000000000000000..5b9799e7573ca41549b3c3b13ac47b906b369603 --- /dev/null +++ b/custom_controlnet_aux/mlsd/models/mbv2_mlsd_large.py @@ -0,0 +1,292 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + if self.upscale: + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + [6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + + self.features = nn.Sequential(*features) + self.fpn_selected = [1, 3, 6, 10, 13] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + if pretrained: + self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c1, c2, c3, c4, c5 = fpn_features + return c1, c2, c3, c4, c5 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Large(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Large, self).__init__() + + self.backbone = MobileNetV2(pretrained=False) + ## A, B + self.block15 = BlockTypeA(in_c1= 64, in_c2= 96, + out_c1= 64, out_c2=64, + upscale=False) + self.block16 = BlockTypeB(128, 64) + + ## A, B + self.block17 = BlockTypeA(in_c1 = 32, in_c2 = 64, + out_c1= 64, out_c2= 64) + self.block18 = BlockTypeB(128, 64) + + ## A, B + self.block19 = BlockTypeA(in_c1=24, in_c2=64, + out_c1=64, out_c2=64) + self.block20 = BlockTypeB(128, 64) + + ## A, B, C + self.block21 = BlockTypeA(in_c1=16, in_c2=64, + out_c1=64, out_c2=64) + self.block22 = BlockTypeB(128, 64) + + self.block23 = BlockTypeC(64, 16) + + def forward(self, x): + c1, c2, c3, c4, c5 = self.backbone(x) + + x = self.block15(c4, c5) + x = self.block16(x) + + x = self.block17(c3, x) + x = self.block18(x) + + x = self.block19(c2, x) + x = self.block20(x) + + x = self.block21(c1, x) + x = self.block22(x) + x = self.block23(x) + x = x[:, 7:, :, :] + + return x \ No newline at end of file diff --git a/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py b/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py new file mode 100644 index 0000000000000000000000000000000000000000..e3ed633f2cc23ea1829a627fdb879ab39f641f83 --- /dev/null +++ b/custom_controlnet_aux/mlsd/models/mbv2_mlsd_tiny.py @@ -0,0 +1,275 @@ +import os +import sys +import torch +import torch.nn as nn +import torch.utils.model_zoo as model_zoo +from torch.nn import functional as F + + +class BlockTypeA(nn.Module): + def __init__(self, in_c1, in_c2, out_c1, out_c2, upscale = True): + super(BlockTypeA, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c2, out_c2, kernel_size=1), + nn.BatchNorm2d(out_c2), + nn.ReLU(inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c1, out_c1, kernel_size=1), + nn.BatchNorm2d(out_c1), + nn.ReLU(inplace=True) + ) + self.upscale = upscale + + def forward(self, a, b): + b = self.conv1(b) + a = self.conv2(a) + b = F.interpolate(b, scale_factor=2.0, mode='bilinear', align_corners=True) + return torch.cat((a, b), dim=1) + + +class BlockTypeB(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeB, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, out_c, kernel_size=3, padding=1), + nn.BatchNorm2d(out_c), + nn.ReLU() + ) + + def forward(self, x): + x = self.conv1(x) + x + x = self.conv2(x) + return x + +class BlockTypeC(nn.Module): + def __init__(self, in_c, out_c): + super(BlockTypeC, self).__init__() + self.conv1 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=5, dilation=5), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv2 = nn.Sequential( + nn.Conv2d(in_c, in_c, kernel_size=3, padding=1), + nn.BatchNorm2d(in_c), + nn.ReLU() + ) + self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1) + + def forward(self, x): + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + return x + +def _make_divisible(v, divisor, min_value=None): + """ + This function is taken from the original tf repo. + It ensures that all layers have a channel number that is divisible by 8 + It can be seen here: + https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py + :param v: + :param divisor: + :param min_value: + :return: + """ + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class ConvBNReLU(nn.Sequential): + def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, groups=1): + self.channel_pad = out_planes - in_planes + self.stride = stride + #padding = (kernel_size - 1) // 2 + + # TFLite uses slightly different padding than PyTorch + if stride == 2: + padding = 0 + else: + padding = (kernel_size - 1) // 2 + + super(ConvBNReLU, self).__init__( + nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, groups=groups, bias=False), + nn.BatchNorm2d(out_planes), + nn.ReLU6(inplace=True) + ) + self.max_pool = nn.MaxPool2d(kernel_size=stride, stride=stride) + + + def forward(self, x): + # TFLite uses different padding + if self.stride == 2: + x = F.pad(x, (0, 1, 0, 1), "constant", 0) + #print(x.shape) + + for module in self: + if not isinstance(module, nn.MaxPool2d): + x = module(x) + return x + + +class InvertedResidual(nn.Module): + def __init__(self, inp, oup, stride, expand_ratio): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2] + + hidden_dim = int(round(inp * expand_ratio)) + self.use_res_connect = self.stride == 1 and inp == oup + + layers = [] + if expand_ratio != 1: + # pw + layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1)) + layers.extend([ + # dw + ConvBNReLU(hidden_dim, hidden_dim, stride=stride, groups=hidden_dim), + # pw-linear + nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False), + nn.BatchNorm2d(oup), + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + +class MobileNetV2(nn.Module): + def __init__(self, pretrained=True): + """ + MobileNet V2 main class + Args: + num_classes (int): Number of classes + width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount + inverted_residual_setting: Network structure + round_nearest (int): Round the number of channels in each layer to be a multiple of this number + Set to 1 to turn off rounding + block: Module specifying inverted residual building block for mobilenet + """ + super(MobileNetV2, self).__init__() + + block = InvertedResidual + input_channel = 32 + last_channel = 1280 + width_mult = 1.0 + round_nearest = 8 + + inverted_residual_setting = [ + # t, c, n, s + [1, 16, 1, 1], + [6, 24, 2, 2], + [6, 32, 3, 2], + [6, 64, 4, 2], + #[6, 96, 3, 1], + #[6, 160, 3, 2], + #[6, 320, 1, 1], + ] + + # only check the first element, assuming user knows t,c,n,s are required + if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4: + raise ValueError("inverted_residual_setting should be non-empty " + "or a 4-element list, got {}".format(inverted_residual_setting)) + + # building first layer + input_channel = _make_divisible(input_channel * width_mult, round_nearest) + self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest) + features = [ConvBNReLU(4, input_channel, stride=2)] + # building inverted residual blocks + for t, c, n, s in inverted_residual_setting: + output_channel = _make_divisible(c * width_mult, round_nearest) + for i in range(n): + stride = s if i == 0 else 1 + features.append(block(input_channel, output_channel, stride, expand_ratio=t)) + input_channel = output_channel + self.features = nn.Sequential(*features) + + self.fpn_selected = [3, 6, 10] + # weight initialization + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, nn.BatchNorm2d): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + #if pretrained: + # self._load_pretrained_model() + + def _forward_impl(self, x): + # This exists since TorchScript doesn't support inheritance, so the superclass method + # (this one) needs to have a name other than `forward` that can be accessed in a subclass + fpn_features = [] + for i, f in enumerate(self.features): + if i > self.fpn_selected[-1]: + break + x = f(x) + if i in self.fpn_selected: + fpn_features.append(x) + + c2, c3, c4 = fpn_features + return c2, c3, c4 + + + def forward(self, x): + return self._forward_impl(x) + + def _load_pretrained_model(self): + pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/mobilenet_v2-b0353104.pth') + model_dict = {} + state_dict = self.state_dict() + for k, v in pretrain_dict.items(): + if k in state_dict: + model_dict[k] = v + state_dict.update(model_dict) + self.load_state_dict(state_dict) + + +class MobileV2_MLSD_Tiny(nn.Module): + def __init__(self): + super(MobileV2_MLSD_Tiny, self).__init__() + + self.backbone = MobileNetV2(pretrained=True) + + self.block12 = BlockTypeA(in_c1= 32, in_c2= 64, + out_c1= 64, out_c2=64) + self.block13 = BlockTypeB(128, 64) + + self.block14 = BlockTypeA(in_c1 = 24, in_c2 = 64, + out_c1= 32, out_c2= 32) + self.block15 = BlockTypeB(64, 64) + + self.block16 = BlockTypeC(64, 16) + + def forward(self, x): + c2, c3, c4 = self.backbone(x) + + x = self.block12(c3, c4) + x = self.block13(x) + x = self.block14(c2, x) + x = self.block15(x) + x = self.block16(x) + x = x[:, 7:, :, :] + #print(x.shape) + x = F.interpolate(x, scale_factor=2.0, mode='bilinear', align_corners=True) + + return x \ No newline at end of file diff --git a/custom_controlnet_aux/mlsd/utils.py b/custom_controlnet_aux/mlsd/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..28071cbf129a2bedb21a44f95d565aef7974e583 --- /dev/null +++ b/custom_controlnet_aux/mlsd/utils.py @@ -0,0 +1,584 @@ +''' +modified by lihaoweicv +pytorch version +''' + +''' +M-LSD +Copyright 2021-present NAVER Corp. +Apache License v2.0 +''' + +import os +import numpy as np +import cv2 +import torch +from torch.nn import functional as F + + +def deccode_output_score_and_ptss(tpMap, topk_n = 200, ksize = 5): + ''' + tpMap: + center: tpMap[1, 0, :, :] + displacement: tpMap[1, 1:5, :, :] + ''' + b, c, h, w = tpMap.shape + assert b==1, 'only support bsize==1' + displacement = tpMap[:, 1:5, :, :][0] + center = tpMap[:, 0, :, :] + heat = torch.sigmoid(center) + hmax = F.max_pool2d( heat, (ksize, ksize), stride=1, padding=(ksize-1)//2) + keep = (hmax == heat).float() + heat = heat * keep + heat = heat.reshape(-1, ) + + scores, indices = torch.topk(heat, topk_n, dim=-1, largest=True) + yy = torch.floor_divide(indices, w).unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + ptss = torch.cat((yy, xx),dim=-1) + + ptss = ptss.detach().cpu().numpy() + scores = scores.detach().cpu().numpy() + displacement = displacement.detach().cpu().numpy() + displacement = displacement.transpose((1,2,0)) + return ptss, scores, displacement + + +def pred_lines(image, model, + input_shape=[512, 512], + score_thr=0.10, + dist_thr=20.0): + h, w, _ = image.shape + + device = next(iter(model.parameters())).device + h_ratio, w_ratio = [h / input_shape[0], w / input_shape[1]] + + resized_image = np.concatenate([cv2.resize(image, (input_shape[1], input_shape[0]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + + resized_image = resized_image.transpose((2,0,1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float() + batch_image = batch_image.to(device) + outputs = model(batch_image) + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] + end = vmap[:, :, 2:] + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + segments_list = [] + for center, score in zip(pts, pts_score): + y, x = center + distance = dist_map[y, x] + if score > score_thr and distance > dist_thr: + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + x_start = x + disp_x_start + y_start = y + disp_y_start + x_end = x + disp_x_end + y_end = y + disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + lines = 2 * np.array(segments_list) # 256 > 512 + lines[:, 0] = lines[:, 0] * w_ratio + lines[:, 1] = lines[:, 1] * h_ratio + lines[:, 2] = lines[:, 2] * w_ratio + lines[:, 3] = lines[:, 3] * h_ratio + + return lines + + +def pred_squares(image, + model, + input_shape=[512, 512], + params={'score': 0.06, + 'outside_ratio': 0.28, + 'inside_ratio': 0.45, + 'w_overlap': 0.0, + 'w_degree': 1.95, + 'w_length': 0.0, + 'w_area': 1.86, + 'w_center': 0.14}): + ''' + shape = [height, width] + ''' + h, w, _ = image.shape + original_shape = [h, w] + device = next(iter(model.parameters())).device + + resized_image = np.concatenate([cv2.resize(image, (input_shape[0], input_shape[1]), interpolation=cv2.INTER_AREA), + np.ones([input_shape[0], input_shape[1], 1])], axis=-1) + resized_image = resized_image.transpose((2, 0, 1)) + batch_image = np.expand_dims(resized_image, axis=0).astype('float32') + batch_image = (batch_image / 127.5) - 1.0 + + batch_image = torch.from_numpy(batch_image).float().to(device) + outputs = model(batch_image) + + pts, pts_score, vmap = deccode_output_score_and_ptss(outputs, 200, 3) + start = vmap[:, :, :2] # (x, y) + end = vmap[:, :, 2:] # (x, y) + dist_map = np.sqrt(np.sum((start - end) ** 2, axis=-1)) + + junc_list = [] + segments_list = [] + for junc, score in zip(pts, pts_score): + y, x = junc + distance = dist_map[y, x] + if score > params['score'] and distance > 20.0: + junc_list.append([x, y]) + disp_x_start, disp_y_start, disp_x_end, disp_y_end = vmap[y, x, :] + d_arrow = 1.0 + x_start = x + d_arrow * disp_x_start + y_start = y + d_arrow * disp_y_start + x_end = x + d_arrow * disp_x_end + y_end = y + d_arrow * disp_y_end + segments_list.append([x_start, y_start, x_end, y_end]) + + segments = np.array(segments_list) + + ####### post processing for squares + # 1. get unique lines + point = np.array([[0, 0]]) + point = point[0] + start = segments[:, :2] + end = segments[:, 2:] + diff = start - end + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + + d = np.abs(a * point[0] + b * point[1] - c) / np.sqrt(a ** 2 + b ** 2 + 1e-10) + theta = np.arctan2(diff[:, 0], diff[:, 1]) * 180 / np.pi + theta[theta < 0.0] += 180 + hough = np.concatenate([d[:, None], theta[:, None]], axis=-1) + + d_quant = 1 + theta_quant = 2 + hough[:, 0] //= d_quant + hough[:, 1] //= theta_quant + _, indices, counts = np.unique(hough, axis=0, return_index=True, return_counts=True) + + acc_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='float32') + idx_map = np.zeros([512 // d_quant + 1, 360 // theta_quant + 1], dtype='int32') - 1 + yx_indices = hough[indices, :].astype('int32') + acc_map[yx_indices[:, 0], yx_indices[:, 1]] = counts + idx_map[yx_indices[:, 0], yx_indices[:, 1]] = indices + + acc_map_np = acc_map + # acc_map = acc_map[None, :, :, None] + # + # ### fast suppression using tensorflow op + # acc_map = tf.constant(acc_map, dtype=tf.float32) + # max_acc_map = tf.keras.layers.MaxPool2D(pool_size=(5, 5), strides=1, padding='same')(acc_map) + # acc_map = acc_map * tf.cast(tf.math.equal(acc_map, max_acc_map), tf.float32) + # flatten_acc_map = tf.reshape(acc_map, [1, -1]) + # topk_values, topk_indices = tf.math.top_k(flatten_acc_map, k=len(pts)) + # _, h, w, _ = acc_map.shape + # y = tf.expand_dims(topk_indices // w, axis=-1) + # x = tf.expand_dims(topk_indices % w, axis=-1) + # yx = tf.concat([y, x], axis=-1) + + ### fast suppression using pytorch op + acc_map = torch.from_numpy(acc_map_np).unsqueeze(0).unsqueeze(0) + _,_, h, w = acc_map.shape + max_acc_map = F.max_pool2d(acc_map,kernel_size=5, stride=1, padding=2) + acc_map = acc_map * ( (acc_map == max_acc_map).float() ) + flatten_acc_map = acc_map.reshape([-1, ]) + + scores, indices = torch.topk(flatten_acc_map, len(pts), dim=-1, largest=True) + yy = torch.div(indices, w, rounding_mode='floor').unsqueeze(-1) + xx = torch.fmod(indices, w).unsqueeze(-1) + yx = torch.cat((yy, xx), dim=-1) + + yx = yx.detach().cpu().numpy() + + topk_values = scores.detach().cpu().numpy() + indices = idx_map[yx[:, 0], yx[:, 1]] + basis = 5 // 2 + + merged_segments = [] + for yx_pt, max_indice, value in zip(yx, indices, topk_values): + y, x = yx_pt + if max_indice == -1 or value == 0: + continue + segment_list = [] + for y_offset in range(-basis, basis + 1): + for x_offset in range(-basis, basis + 1): + indice = idx_map[y + y_offset, x + x_offset] + cnt = int(acc_map_np[y + y_offset, x + x_offset]) + if indice != -1: + segment_list.append(segments[indice]) + if cnt > 1: + check_cnt = 1 + current_hough = hough[indice] + for new_indice, new_hough in enumerate(hough): + if (current_hough == new_hough).all() and indice != new_indice: + segment_list.append(segments[new_indice]) + check_cnt += 1 + if check_cnt == cnt: + break + group_segments = np.array(segment_list).reshape([-1, 2]) + sorted_group_segments = np.sort(group_segments, axis=0) + x_min, y_min = sorted_group_segments[0, :] + x_max, y_max = sorted_group_segments[-1, :] + + deg = theta[max_indice] + if deg >= 90: + merged_segments.append([x_min, y_max, x_max, y_min]) + else: + merged_segments.append([x_min, y_min, x_max, y_max]) + + # 2. get intersections + new_segments = np.array(merged_segments) # (x1, y1, x2, y2) + start = new_segments[:, :2] # (x1, y1) + end = new_segments[:, 2:] # (x2, y2) + new_centers = (start + end) / 2.0 + diff = start - end + dist_segments = np.sqrt(np.sum(diff ** 2, axis=-1)) + + # ax + by = c + a = diff[:, 1] + b = -diff[:, 0] + c = a * start[:, 0] + b * start[:, 1] + pre_det = a[:, None] * b[None, :] + det = pre_det - np.transpose(pre_det) + + pre_inter_y = a[:, None] * c[None, :] + inter_y = (pre_inter_y - np.transpose(pre_inter_y)) / (det + 1e-10) + pre_inter_x = c[:, None] * b[None, :] + inter_x = (pre_inter_x - np.transpose(pre_inter_x)) / (det + 1e-10) + inter_pts = np.concatenate([inter_x[:, :, None], inter_y[:, :, None]], axis=-1).astype('int32') + + # 3. get corner information + # 3.1 get distance + ''' + dist_segments: + | dist(0), dist(1), dist(2), ...| + dist_inter_to_segment1: + | dist(inter,0), dist(inter,0), dist(inter,0), ... | + | dist(inter,1), dist(inter,1), dist(inter,1), ... | + ... + dist_inter_to_semgnet2: + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + | dist(inter,0), dist(inter,1), dist(inter,2), ... | + ... + ''' + + dist_inter_to_segment1_start = np.sqrt( + np.sum(((inter_pts - start[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment1_end = np.sqrt( + np.sum(((inter_pts - end[:, None, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_start = np.sqrt( + np.sum(((inter_pts - start[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + dist_inter_to_segment2_end = np.sqrt( + np.sum(((inter_pts - end[None, :, :]) ** 2), axis=-1, keepdims=True)) # [n_batch, n_batch, 1] + + # sort ascending + dist_inter_to_segment1 = np.sort( + np.concatenate([dist_inter_to_segment1_start, dist_inter_to_segment1_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + dist_inter_to_segment2 = np.sort( + np.concatenate([dist_inter_to_segment2_start, dist_inter_to_segment2_end], axis=-1), + axis=-1) # [n_batch, n_batch, 2] + + # 3.2 get degree + inter_to_start = new_centers[:, None, :] - inter_pts + deg_inter_to_start = np.arctan2(inter_to_start[:, :, 1], inter_to_start[:, :, 0]) * 180 / np.pi + deg_inter_to_start[deg_inter_to_start < 0.0] += 360 + inter_to_end = new_centers[None, :, :] - inter_pts + deg_inter_to_end = np.arctan2(inter_to_end[:, :, 1], inter_to_end[:, :, 0]) * 180 / np.pi + deg_inter_to_end[deg_inter_to_end < 0.0] += 360 + + ''' + B -- G + | | + C -- R + B : blue / G: green / C: cyan / R: red + + 0 -- 1 + | | + 3 -- 2 + ''' + # rename variables + deg1_map, deg2_map = deg_inter_to_start, deg_inter_to_end + # sort deg ascending + deg_sort = np.sort(np.concatenate([deg1_map[:, :, None], deg2_map[:, :, None]], axis=-1), axis=-1) + + deg_diff_map = np.abs(deg1_map - deg2_map) + # we only consider the smallest degree of intersect + deg_diff_map[deg_diff_map > 180] = 360 - deg_diff_map[deg_diff_map > 180] + + # define available degree range + deg_range = [60, 120] + + corner_dict = {corner_info: [] for corner_info in range(4)} + inter_points = [] + for i in range(inter_pts.shape[0]): + for j in range(i + 1, inter_pts.shape[1]): + # i, j > line index, always i < j + x, y = inter_pts[i, j, :] + deg1, deg2 = deg_sort[i, j, :] + deg_diff = deg_diff_map[i, j] + + check_degree = deg_diff > deg_range[0] and deg_diff < deg_range[1] + + outside_ratio = params['outside_ratio'] # over ratio >>> drop it! + inside_ratio = params['inside_ratio'] # over ratio >>> drop it! + check_distance = ((dist_inter_to_segment1[i, j, 1] >= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * outside_ratio) or \ + (dist_inter_to_segment1[i, j, 1] <= dist_segments[i] and \ + dist_inter_to_segment1[i, j, 0] <= dist_segments[i] * inside_ratio)) and \ + ((dist_inter_to_segment2[i, j, 1] >= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * outside_ratio) or \ + (dist_inter_to_segment2[i, j, 1] <= dist_segments[j] and \ + dist_inter_to_segment2[i, j, 0] <= dist_segments[j] * inside_ratio)) + + if check_degree and check_distance: + corner_info = None + + if (deg1 >= 0 and deg1 <= 45 and deg2 >= 45 and deg2 <= 120) or \ + (deg2 >= 315 and deg1 >= 45 and deg1 <= 120): + corner_info, color_info = 0, 'blue' + elif (deg1 >= 45 and deg1 <= 125 and deg2 >= 125 and deg2 <= 225): + corner_info, color_info = 1, 'green' + elif (deg1 >= 125 and deg1 <= 225 and deg2 >= 225 and deg2 <= 315): + corner_info, color_info = 2, 'black' + elif (deg1 >= 0 and deg1 <= 45 and deg2 >= 225 and deg2 <= 315) or \ + (deg2 >= 315 and deg1 >= 225 and deg1 <= 315): + corner_info, color_info = 3, 'cyan' + else: + corner_info, color_info = 4, 'red' # we don't use it + continue + + corner_dict[corner_info].append([x, y, i, j]) + inter_points.append([x, y]) + + square_list = [] + connect_list = [] + segments_list = [] + for corner0 in corner_dict[0]: + for corner1 in corner_dict[1]: + connect01 = False + for corner0_line in corner0[2:]: + if corner0_line in corner1[2:]: + connect01 = True + break + if connect01: + for corner2 in corner_dict[2]: + connect12 = False + for corner1_line in corner1[2:]: + if corner1_line in corner2[2:]: + connect12 = True + break + if connect12: + for corner3 in corner_dict[3]: + connect23 = False + for corner2_line in corner2[2:]: + if corner2_line in corner3[2:]: + connect23 = True + break + if connect23: + for corner3_line in corner3[2:]: + if corner3_line in corner0[2:]: + # SQUARE!!! + ''' + 0 -- 1 + | | + 3 -- 2 + square_list: + order: 0 > 1 > 2 > 3 + | x0, y0, x1, y1, x2, y2, x3, y3 | + | x0, y0, x1, y1, x2, y2, x3, y3 | + ... + connect_list: + order: 01 > 12 > 23 > 30 + | line_idx01, line_idx12, line_idx23, line_idx30 | + | line_idx01, line_idx12, line_idx23, line_idx30 | + ... + segments_list: + order: 0 > 1 > 2 > 3 + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + | line_idx0_i, line_idx0_j, line_idx1_i, line_idx1_j, line_idx2_i, line_idx2_j, line_idx3_i, line_idx3_j | + ... + ''' + square_list.append(corner0[:2] + corner1[:2] + corner2[:2] + corner3[:2]) + connect_list.append([corner0_line, corner1_line, corner2_line, corner3_line]) + segments_list.append(corner0[2:] + corner1[2:] + corner2[2:] + corner3[2:]) + + def check_outside_inside(segments_info, connect_idx): + # return 'outside or inside', min distance, cover_param, peri_param + if connect_idx == segments_info[0]: + check_dist_mat = dist_inter_to_segment1 + else: + check_dist_mat = dist_inter_to_segment2 + + i, j = segments_info + min_dist, max_dist = check_dist_mat[i, j, :] + connect_dist = dist_segments[connect_idx] + if max_dist > connect_dist: + return 'outside', min_dist, 0, 1 + else: + return 'inside', min_dist, -1, -1 + + top_square = None + + try: + map_size = input_shape[0] / 2 + squares = np.array(square_list).reshape([-1, 4, 2]) + score_array = [] + connect_array = np.array(connect_list) + segments_array = np.array(segments_list).reshape([-1, 4, 2]) + + # get degree of corners: + squares_rollup = np.roll(squares, 1, axis=1) + squares_rolldown = np.roll(squares, -1, axis=1) + vec1 = squares_rollup - squares + normalized_vec1 = vec1 / (np.linalg.norm(vec1, axis=-1, keepdims=True) + 1e-10) + vec2 = squares_rolldown - squares + normalized_vec2 = vec2 / (np.linalg.norm(vec2, axis=-1, keepdims=True) + 1e-10) + inner_products = np.sum(normalized_vec1 * normalized_vec2, axis=-1) # [n_squares, 4] + squares_degree = np.arccos(inner_products) * 180 / np.pi # [n_squares, 4] + + # get square score + overlap_scores = [] + degree_scores = [] + length_scores = [] + + for connects, segments, square, degree in zip(connect_array, segments_array, squares, squares_degree): + ''' + 0 -- 1 + | | + 3 -- 2 + + # segments: [4, 2] + # connects: [4] + ''' + + ###################################### OVERLAP SCORES + cover = 0 + perimeter = 0 + # check 0 > 1 > 2 > 3 + square_length = [] + + for start_idx in range(4): + end_idx = (start_idx + 1) % 4 + + connect_idx = connects[start_idx] # segment idx of segment01 + start_segments = segments[start_idx] + end_segments = segments[end_idx] + + start_point = square[start_idx] + end_point = square[end_idx] + + # check whether outside or inside + start_position, start_min, start_cover_param, start_peri_param = check_outside_inside(start_segments, + connect_idx) + end_position, end_min, end_cover_param, end_peri_param = check_outside_inside(end_segments, connect_idx) + + cover += dist_segments[connect_idx] + start_cover_param * start_min + end_cover_param * end_min + perimeter += dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min + + square_length.append( + dist_segments[connect_idx] + start_peri_param * start_min + end_peri_param * end_min) + + overlap_scores.append(cover / perimeter) + ###################################### + ###################################### DEGREE SCORES + ''' + deg0 vs deg2 + deg1 vs deg3 + ''' + deg0, deg1, deg2, deg3 = degree + deg_ratio1 = deg0 / deg2 + if deg_ratio1 > 1.0: + deg_ratio1 = 1 / deg_ratio1 + deg_ratio2 = deg1 / deg3 + if deg_ratio2 > 1.0: + deg_ratio2 = 1 / deg_ratio2 + degree_scores.append((deg_ratio1 + deg_ratio2) / 2) + ###################################### + ###################################### LENGTH SCORES + ''' + len0 vs len2 + len1 vs len3 + ''' + len0, len1, len2, len3 = square_length + len_ratio1 = len0 / len2 if len2 > len0 else len2 / len0 + len_ratio2 = len1 / len3 if len3 > len1 else len3 / len1 + length_scores.append((len_ratio1 + len_ratio2) / 2) + + ###################################### + + overlap_scores = np.array(overlap_scores) + overlap_scores /= np.max(overlap_scores) + + degree_scores = np.array(degree_scores) + # degree_scores /= np.max(degree_scores) + + length_scores = np.array(length_scores) + + ###################################### AREA SCORES + area_scores = np.reshape(squares, [-1, 4, 2]) + area_x = area_scores[:, :, 0] + area_y = area_scores[:, :, 1] + correction = area_x[:, -1] * area_y[:, 0] - area_y[:, -1] * area_x[:, 0] + area_scores = np.sum(area_x[:, :-1] * area_y[:, 1:], axis=-1) - np.sum(area_y[:, :-1] * area_x[:, 1:], axis=-1) + area_scores = 0.5 * np.abs(area_scores + correction) + area_scores /= (map_size * map_size) # np.max(area_scores) + ###################################### + + ###################################### CENTER SCORES + centers = np.array([[256 // 2, 256 // 2]], dtype='float32') # [1, 2] + # squares: [n, 4, 2] + square_centers = np.mean(squares, axis=1) # [n, 2] + center2center = np.sqrt(np.sum((centers - square_centers) ** 2)) + center_scores = center2center / (map_size / np.sqrt(2.0)) + + ''' + score_w = [overlap, degree, area, center, length] + ''' + score_w = [0.0, 1.0, 10.0, 0.5, 1.0] + score_array = params['w_overlap'] * overlap_scores \ + + params['w_degree'] * degree_scores \ + + params['w_area'] * area_scores \ + - params['w_center'] * center_scores \ + + params['w_length'] * length_scores + + best_square = [] + + sorted_idx = np.argsort(score_array)[::-1] + score_array = score_array[sorted_idx] + squares = squares[sorted_idx] + + except Exception as e: + pass + + '''return list + merged_lines, squares, scores + ''' + + try: + new_segments[:, 0] = new_segments[:, 0] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 1] = new_segments[:, 1] * 2 / input_shape[0] * original_shape[0] + new_segments[:, 2] = new_segments[:, 2] * 2 / input_shape[1] * original_shape[1] + new_segments[:, 3] = new_segments[:, 3] * 2 / input_shape[0] * original_shape[0] + except: + new_segments = [] + + try: + squares[:, :, 0] = squares[:, :, 0] * 2 / input_shape[1] * original_shape[1] + squares[:, :, 1] = squares[:, :, 1] * 2 / input_shape[0] * original_shape[0] + except: + squares = [] + score_array = [] + + try: + inter_points = np.array(inter_points) + inter_points[:, 0] = inter_points[:, 0] * 2 / input_shape[1] * original_shape[1] + inter_points[:, 1] = inter_points[:, 1] * 2 / input_shape[0] * original_shape[0] + except: + inter_points = [] + + return new_segments, squares, score_array, inter_points diff --git a/custom_controlnet_aux/normalbae/LICENSE b/custom_controlnet_aux/normalbae/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..16a9d56a3d4c15e4f34ac5426459c58487b01520 --- /dev/null +++ b/custom_controlnet_aux/normalbae/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Caroline Chan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/normalbae/__init__.py b/custom_controlnet_aux/normalbae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae2f54abbdad2a75e4870fff21346ad2265a53ad --- /dev/null +++ b/custom_controlnet_aux/normalbae/__init__.py @@ -0,0 +1,85 @@ +import os +import types +import warnings + +import cv2 +import numpy as np +import torch +import torchvision.transforms as transforms +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +from .nets.NNET import NNET + + +# load model +def load_checkpoint(fpath, model): + ckpt = torch.load(fpath, map_location='cpu')['model'] + + load_dict = {} + for k, v in ckpt.items(): + if k.startswith('module.'): + k_ = k.replace('module.', '') + load_dict[k_] = v + else: + load_dict[k] = v + + model.load_state_dict(load_dict) + return model + +class NormalBaeDetector: + def __init__(self, model): + self.model = model + self.norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="scannet.pt"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + args = types.SimpleNamespace() + args.mode = 'client' + args.architecture = 'BN' + args.pretrained = 'scannet' + args.sampling_ratio = 0.4 + args.importance_ratio = 0.7 + model = NNET(args) + model = load_checkpoint(model_path, model) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + + def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + image_normal = detected_map + with torch.no_grad(): + image_normal = torch.from_numpy(image_normal).float().to(self.device) + image_normal = image_normal / 255.0 + image_normal = rearrange(image_normal, 'h w c -> 1 c h w') + image_normal = self.norm(image_normal) + + normal = self.model(image_normal) + normal = normal[0][-1][:, :3] + # d = torch.sum(normal ** 2.0, dim=1, keepdim=True) ** 0.5 + # d = torch.maximum(d, torch.ones_like(d) * 1e-5) + # normal /= d + normal = ((normal + 1) * 0.5).clip(0, 1) + + normal = rearrange(normal[0], 'c h w -> h w c').cpu().numpy() + normal_image = (normal * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = remove_pad(HWC3(normal_image)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + \ No newline at end of file diff --git a/custom_controlnet_aux/normalbae/nets/NNET.py b/custom_controlnet_aux/normalbae/nets/NNET.py new file mode 100644 index 0000000000000000000000000000000000000000..3ddbc50c3ac18aa4b7f16779fe3c0133981ecc7a --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/NNET.py @@ -0,0 +1,22 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules.encoder import Encoder +from .submodules.decoder import Decoder + + +class NNET(nn.Module): + def __init__(self, args): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(args) + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + return self.decoder.parameters() + + def forward(self, img, **kwargs): + return self.decoder(self.encoder(img), **kwargs) \ No newline at end of file diff --git a/custom_controlnet_aux/normalbae/nets/__init__.py b/custom_controlnet_aux/normalbae/nets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/normalbae/nets/baseline.py b/custom_controlnet_aux/normalbae/nets/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..602d0fbdac1acc9ede9bc1f2e10a5df78831ce9d --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/baseline.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .submodules.submodules import UpSampleBN, norm_normalize + + +# This is the baseline encoder-decoder we used in the ablation study +class NNET(nn.Module): + def __init__(self, args=None): + super(NNET, self).__init__() + self.encoder = Encoder() + self.decoder = Decoder(num_classes=4) + + def forward(self, x, **kwargs): + out = self.decoder(self.encoder(x), **kwargs) + + # Bilinearly upsample the output to match the input resolution + up_out = F.interpolate(out, size=[x.size(2), x.size(3)], mode='bilinear', align_corners=False) + + # L2-normalize the first three channels / ensure positive value for concentration parameters (kappa) + up_out = norm_normalize(up_out) + return up_out + + def get_1x_lr_params(self): # lr/10 learning rate + return self.encoder.parameters() + + def get_10x_lr_params(self): # lr learning rate + modules = [self.decoder] + for m in modules: + yield from m.parameters() + + +# Encoder +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + basemodel = torch.hub.load('rwightman/gen-efficientnet-pytorch', basemodel_name, pretrained=True) + + # Remove last layer + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + +# Decoder (no pixel-wise MLP, no uncertainty-guided sampling) +class Decoder(nn.Module): + def __init__(self, num_classes=4): + super(Decoder, self).__init__() + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + self.conv3 = nn.Conv2d(128, num_classes, kernel_size=3, stride=1, padding=1) + + def forward(self, features): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + x_d0 = self.conv2(x_block4) + x_d1 = self.up1(x_d0, x_block3) + x_d2 = self.up2(x_d1, x_block2) + x_d3 = self.up3(x_d2, x_block1) + x_d4 = self.up4(x_d3, x_block0) + out = self.conv3(x_d4) + return out + + +if __name__ == '__main__': + model = Baseline() + x = torch.rand(2, 3, 480, 640) + out = model(x) + print(out.shape) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/__init__.py b/custom_controlnet_aux/normalbae/nets/submodules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/normalbae/nets/submodules/decoder.py b/custom_controlnet_aux/normalbae/nets/submodules/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..993203d1792311f1c492091eaea3c1ac9088187f --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/decoder.py @@ -0,0 +1,202 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from .submodules import UpSampleBN, UpSampleGN, norm_normalize, sample_points + + +class Decoder(nn.Module): + def __init__(self, args): + super(Decoder, self).__init__() + + # hyper-parameter for sampling + self.sampling_ratio = args.sampling_ratio + self.importance_ratio = args.importance_ratio + + # feature-map + self.conv2 = nn.Conv2d(2048, 2048, kernel_size=1, stride=1, padding=0) + if args.architecture == 'BN': + self.up1 = UpSampleBN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleBN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleBN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleBN(skip_input=256 + 24, output_features=128) + + elif args.architecture == 'GN': + self.up1 = UpSampleGN(skip_input=2048 + 176, output_features=1024) + self.up2 = UpSampleGN(skip_input=1024 + 64, output_features=512) + self.up3 = UpSampleGN(skip_input=512 + 40, output_features=256) + self.up4 = UpSampleGN(skip_input=256 + 24, output_features=128) + + else: + raise Exception('invalid architecture') + + # produces 1/8 res output + self.out_conv_res8 = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1) + + # produces 1/4 res output + self.out_conv_res4 = nn.Sequential( + nn.Conv1d(512 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/2 res output + self.out_conv_res2 = nn.Sequential( + nn.Conv1d(256 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + # produces 1/1 res output + self.out_conv_res1 = nn.Sequential( + nn.Conv1d(128 + 4, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 128, kernel_size=1), nn.ReLU(), + nn.Conv1d(128, 4, kernel_size=1), + ) + + def forward(self, features, gt_norm_mask=None, mode='test'): + x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], features[8], features[11] + + # generate feature-map + + x_d0 = self.conv2(x_block4) # x_d0 : [2, 2048, 15, 20] 1/32 res + x_d1 = self.up1(x_d0, x_block3) # x_d1 : [2, 1024, 30, 40] 1/16 res + x_d2 = self.up2(x_d1, x_block2) # x_d2 : [2, 512, 60, 80] 1/8 res + x_d3 = self.up3(x_d2, x_block1) # x_d3: [2, 256, 120, 160] 1/4 res + x_d4 = self.up4(x_d3, x_block0) # x_d4: [2, 128, 240, 320] 1/2 res + + # 1/8 res output + out_res8 = self.out_conv_res8(x_d2) # out_res8: [2, 4, 60, 80] 1/8 res output + out_res8 = norm_normalize(out_res8) # out_res8: [2, 4, 60, 80] 1/8 res output + + ################################################################################################################ + # out_res4 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res8: [2, 4, 60, 80] -> out_res8_res4: [2, 4, 120, 160] + out_res8_res4 = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res8_res4.shape + + # samples: [B, 1, N, 2] + point_coords_res4, rows_int, cols_int = sample_points(out_res8_res4.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res4 = out_res8_res4 + + # grid_sample feature-map + feat_res4 = F.grid_sample(x_d2, point_coords_res4, mode='bilinear', align_corners=True) # (B, 512, 1, N) + init_pred = F.grid_sample(out_res8, point_coords_res4, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res4 = torch.cat([feat_res4, init_pred], dim=1) # (B, 512+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res4 = self.out_conv_res4(feat_res4[:, :, 0, :]) # (B, 4, N) + samples_pred_res4 = norm_normalize(samples_pred_res4) # (B, 4, N) - normalized + + for i in range(B): + out_res4[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res4[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d2, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res8, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + # try all pixels + out_res4 = self.out_conv_res4(feat_map.view(B, 512 + 4, -1)) # (B, 4, N) + out_res4 = norm_normalize(out_res4) # (B, 4, N) - normalized + out_res4 = out_res4.view(B, 4, H, W) + samples_pred_res4 = point_coords_res4 = None + + ################################################################################################################ + # out_res2 + ################################################################################################################ + + if mode == 'train': + + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res4_res2 = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res4_res2.shape + + # samples: [B, 1, N, 2] + point_coords_res2, rows_int, cols_int = sample_points(out_res4_res2.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res2 = out_res4_res2 + + # grid_sample feature-map + feat_res2 = F.grid_sample(x_d3, point_coords_res2, mode='bilinear', align_corners=True) # (B, 256, 1, N) + init_pred = F.grid_sample(out_res4, point_coords_res2, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res2 = torch.cat([feat_res2, init_pred], dim=1) # (B, 256+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res2 = self.out_conv_res2(feat_res2[:, :, 0, :]) # (B, 4, N) + samples_pred_res2 = norm_normalize(samples_pred_res2) # (B, 4, N) - normalized + + for i in range(B): + out_res2[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res2[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d3, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res4, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res2 = self.out_conv_res2(feat_map.view(B, 256 + 4, -1)) # (B, 4, N) + out_res2 = norm_normalize(out_res2) # (B, 4, N) - normalized + out_res2 = out_res2.view(B, 4, H, W) + samples_pred_res2 = point_coords_res2 = None + + ################################################################################################################ + # out_res1 + ################################################################################################################ + + if mode == 'train': + # upsampling ... out_res4: [2, 4, 120, 160] -> out_res4_res2: [2, 4, 240, 320] + out_res2_res1 = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + B, _, H, W = out_res2_res1.shape + + # samples: [B, 1, N, 2] + point_coords_res1, rows_int, cols_int = sample_points(out_res2_res1.detach(), gt_norm_mask, + sampling_ratio=self.sampling_ratio, + beta=self.importance_ratio) + + # output (needed for evaluation / visualization) + out_res1 = out_res2_res1 + + # grid_sample feature-map + feat_res1 = F.grid_sample(x_d4, point_coords_res1, mode='bilinear', align_corners=True) # (B, 128, 1, N) + init_pred = F.grid_sample(out_res2, point_coords_res1, mode='bilinear', align_corners=True) # (B, 4, 1, N) + feat_res1 = torch.cat([feat_res1, init_pred], dim=1) # (B, 128+4, 1, N) + + # prediction (needed to compute loss) + samples_pred_res1 = self.out_conv_res1(feat_res1[:, :, 0, :]) # (B, 4, N) + samples_pred_res1 = norm_normalize(samples_pred_res1) # (B, 4, N) - normalized + + for i in range(B): + out_res1[i, :, rows_int[i, :], cols_int[i, :]] = samples_pred_res1[i, :, :] + + else: + # grid_sample feature-map + feat_map = F.interpolate(x_d4, scale_factor=2, mode='bilinear', align_corners=True) + init_pred = F.interpolate(out_res2, scale_factor=2, mode='bilinear', align_corners=True) + feat_map = torch.cat([feat_map, init_pred], dim=1) # (B, 512+4, H, W) + B, _, H, W = feat_map.shape + + out_res1 = self.out_conv_res1(feat_map.view(B, 128 + 4, -1)) # (B, 4, N) + out_res1 = norm_normalize(out_res1) # (B, 4, N) - normalized + out_res1 = out_res1.view(B, 4, H, W) + samples_pred_res1 = point_coords_res1 = None + + return [out_res8, out_res4, out_res2, out_res1], \ + [out_res8, samples_pred_res4, samples_pred_res2, samples_pred_res1], \ + [None, point_coords_res4, point_coords_res2, point_coords_res1] + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md new file mode 100644 index 0000000000000000000000000000000000000000..6ead7171ce5a5bbd2702f6b5c825dc9808ba5658 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/BENCHMARK.md @@ -0,0 +1,555 @@ +# Model Performance Benchmarks + +All benchmarks run as per: + +``` +python onnx_export.py --model mobilenetv3_100 ./mobilenetv3_100.onnx +python onnx_optimize.py ./mobilenetv3_100.onnx --output mobilenetv3_100-opt.onnx +python onnx_to_caffe.py ./mobilenetv3_100.onnx --c2-prefix mobilenetv3 +python onnx_to_caffe.py ./mobilenetv3_100-opt.onnx --c2-prefix mobilenetv3-opt +python caffe2_benchmark.py --c2-init ./mobilenetv3.init.pb --c2-predict ./mobilenetv3.predict.pb +python caffe2_benchmark.py --c2-init ./mobilenetv3-opt.init.pb --c2-predict ./mobilenetv3-opt.predict.pb +``` + +## EfficientNet-B0 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 49.2862. Iters per second: 20.2897 +Time per operator type: + 29.7378 ms. 60.5145%. Conv + 12.1785 ms. 24.7824%. Sigmoid + 3.62811 ms. 7.38297%. SpatialBN + 2.98444 ms. 6.07314%. Mul + 0.326902 ms. 0.665225%. AveragePool + 0.197317 ms. 0.401528%. FC + 0.0852877 ms. 0.173555%. Add + 0.0032607 ms. 0.00663532%. Squeeze + 49.1416 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 95.2696%. Conv + 0.0269508 GFLOP. 3.33857%. SpatialBN + 0.00846444 GFLOP. 1.04855%. Mul + 0.002561 GFLOP. 0.317248%. FC + 0.000210112 GFLOP. 0.0260279%. Add + 0.807256 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 43.0891%. Mul + 43.2015 MB. 31.807%. Conv + 27.2869 MB. 20.0899%. SpatialBN + 5.12912 MB. 3.77631%. FC + 1.6809 MB. 1.23756%. Add + 135.824 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 38.1965%. Mul + 26.9881 MB. 30.4465%. Conv + 26.9508 MB. 30.4044%. SpatialBN + 0.840448 MB. 0.948147%. Add + 0.004 MB. 0.00451258%. FC + 88.6412 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 74.9391%. Conv + 5.124 MB. 24.265%. FC + 0.168064 MB. 0.795877%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 21.1168 MB in Total +``` +### Optimized +``` +Main run finished. Milliseconds per iter: 46.0838. Iters per second: 21.6996 +Time per operator type: + 29.776 ms. 65.002%. Conv + 12.2803 ms. 26.8084%. Sigmoid + 3.15073 ms. 6.87815%. Mul + 0.328651 ms. 0.717456%. AveragePool + 0.186237 ms. 0.406563%. FC + 0.0832429 ms. 0.181722%. Add + 0.0026184 ms. 0.00571606%. Squeeze + 45.8078 ms in Total +FLOP per operator type: + 0.76907 GFLOP. 98.5601%. Conv + 0.00846444 GFLOP. 1.08476%. Mul + 0.002561 GFLOP. 0.328205%. FC + 0.000210112 GFLOP. 0.0269269%. Add + 0.780305 GFLOP in Total +Feature Memory Read per operator type: + 58.5253 MB. 53.8803%. Mul + 43.2855 MB. 39.8501%. Conv + 5.12912 MB. 4.72204%. FC + 1.6809 MB. 1.54749%. Add + 108.621 MB in Total +Feature Memory Written per operator type: + 33.8578 MB. 54.8834%. Mul + 26.9881 MB. 43.7477%. Conv + 0.840448 MB. 1.36237%. Add + 0.004 MB. 0.00648399%. FC + 61.6904 MB in Total +Parameter Memory per operator type: + 15.8248 MB. 75.5403%. Conv + 5.124 MB. 24.4597%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 20.9488 MB in Total +``` + +## EfficientNet-B1 +### Optimized +``` +Main run finished. Milliseconds per iter: 71.8102. Iters per second: 13.9256 +Time per operator type: + 45.7915 ms. 66.3206%. Conv + 17.8718 ms. 25.8841%. Sigmoid + 4.44132 ms. 6.43244%. Mul + 0.51001 ms. 0.738658%. AveragePool + 0.233283 ms. 0.337868%. Add + 0.194986 ms. 0.282402%. FC + 0.00268255 ms. 0.00388519%. Squeeze + 69.0456 ms in Total +FLOP per operator type: + 1.37105 GFLOP. 98.7673%. Conv + 0.0138759 GFLOP. 0.99959%. Mul + 0.002561 GFLOP. 0.184489%. FC + 0.000674432 GFLOP. 0.0485847%. Add + 1.38816 GFLOP in Total +Feature Memory Read per operator type: + 94.624 MB. 54.0789%. Mul + 69.8255 MB. 39.9062%. Conv + 5.39546 MB. 3.08357%. Add + 5.12912 MB. 2.93136%. FC + 174.974 MB in Total +Feature Memory Written per operator type: + 55.5035 MB. 54.555%. Mul + 43.5333 MB. 42.7894%. Conv + 2.69773 MB. 2.65163%. Add + 0.004 MB. 0.00393165%. FC + 101.739 MB in Total +Parameter Memory per operator type: + 25.7479 MB. 83.4024%. Conv + 5.124 MB. 16.5976%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 30.8719 MB in Total +``` + +## EfficientNet-B2 +### Optimized +``` +Main run finished. Milliseconds per iter: 92.28. Iters per second: 10.8366 +Time per operator type: + 61.4627 ms. 67.5845%. Conv + 22.7458 ms. 25.0113%. Sigmoid + 5.59931 ms. 6.15701%. Mul + 0.642567 ms. 0.706568%. AveragePool + 0.272795 ms. 0.299965%. Add + 0.216178 ms. 0.237709%. FC + 0.00268895 ms. 0.00295677%. Squeeze + 90.942 ms in Total +FLOP per operator type: + 1.98431 GFLOP. 98.9343%. Conv + 0.0177039 GFLOP. 0.882686%. Mul + 0.002817 GFLOP. 0.140451%. FC + 0.000853984 GFLOP. 0.0425782%. Add + 2.00568 GFLOP in Total +Feature Memory Read per operator type: + 120.609 MB. 54.9637%. Mul + 86.3512 MB. 39.3519%. Conv + 6.83187 MB. 3.11341%. Add + 5.64163 MB. 2.571%. FC + 219.433 MB in Total +Feature Memory Written per operator type: + 70.8155 MB. 54.6573%. Mul + 55.3273 MB. 42.7031%. Conv + 3.41594 MB. 2.63651%. Add + 0.004 MB. 0.00308731%. FC + 129.563 MB in Total +Parameter Memory per operator type: + 30.4721 MB. 84.3913%. Conv + 5.636 MB. 15.6087%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 36.1081 MB in Total +``` + +## MixNet-M +### Optimized +``` +Main run finished. Milliseconds per iter: 63.1122. Iters per second: 15.8448 +Time per operator type: + 48.1139 ms. 75.2052%. Conv + 7.1341 ms. 11.1511%. Sigmoid + 2.63706 ms. 4.12189%. SpatialBN + 1.73186 ms. 2.70701%. Mul + 1.38707 ms. 2.16809%. Split + 1.29322 ms. 2.02139%. Concat + 1.00093 ms. 1.56452%. Relu + 0.235309 ms. 0.367803%. Add + 0.221579 ms. 0.346343%. FC + 0.219315 ms. 0.342803%. AveragePool + 0.00250145 ms. 0.00390993%. Squeeze + 63.9768 ms in Total +FLOP per operator type: + 0.675273 GFLOP. 95.5827%. Conv + 0.0221072 GFLOP. 3.12921%. SpatialBN + 0.00538445 GFLOP. 0.762152%. Mul + 0.003073 GFLOP. 0.434973%. FC + 0.000642488 GFLOP. 0.0909421%. Add + 0 GFLOP. 0%. Concat + 0 GFLOP. 0%. Relu + 0.70648 GFLOP in Total +Feature Memory Read per operator type: + 46.8424 MB. 30.502%. Conv + 36.8626 MB. 24.0036%. Mul + 22.3152 MB. 14.5309%. SpatialBN + 22.1074 MB. 14.3955%. Concat + 14.1496 MB. 9.21372%. Relu + 6.15414 MB. 4.00735%. FC + 5.1399 MB. 3.34692%. Add + 153.571 MB in Total +Feature Memory Written per operator type: + 32.7672 MB. 28.4331%. Conv + 22.1072 MB. 19.1831%. Concat + 22.1072 MB. 19.1831%. SpatialBN + 21.5378 MB. 18.689%. Mul + 14.1496 MB. 12.2781%. Relu + 2.56995 MB. 2.23003%. Add + 0.004 MB. 0.00347092%. FC + 115.243 MB in Total +Parameter Memory per operator type: + 13.7059 MB. 68.674%. Conv + 6.148 MB. 30.8049%. FC + 0.104 MB. 0.521097%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Concat + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 19.9579 MB in Total +``` + +## TF MobileNet-V3 Large 1.0 + +### Optimized +``` +Main run finished. Milliseconds per iter: 22.0495. Iters per second: 45.3525 +Time per operator type: + 17.437 ms. 80.0087%. Conv + 1.27662 ms. 5.8577%. Add + 1.12759 ms. 5.17387%. Div + 0.701155 ms. 3.21721%. Mul + 0.562654 ms. 2.58171%. Relu + 0.431144 ms. 1.97828%. Clip + 0.156902 ms. 0.719936%. FC + 0.0996858 ms. 0.457402%. AveragePool + 0.00112455 ms. 0.00515993%. Flatten + 21.7939 ms in Total +FLOP per operator type: + 0.43062 GFLOP. 98.1484%. Conv + 0.002561 GFLOP. 0.583713%. FC + 0.00210867 GFLOP. 0.480616%. Mul + 0.00193868 GFLOP. 0.441871%. Add + 0.00151532 GFLOP. 0.345377%. Div + 0 GFLOP. 0%. Relu + 0.438743 GFLOP in Total +Feature Memory Read per operator type: + 34.7967 MB. 43.9391%. Conv + 14.496 MB. 18.3046%. Mul + 9.44828 MB. 11.9307%. Add + 9.26157 MB. 11.6949%. Relu + 6.0614 MB. 7.65395%. Div + 5.12912 MB. 6.47673%. FC + 79.193 MB in Total +Feature Memory Written per operator type: + 17.6247 MB. 35.8656%. Conv + 9.26157 MB. 18.847%. Relu + 8.43469 MB. 17.1643%. Mul + 7.75472 MB. 15.7806%. Add + 6.06128 MB. 12.3345%. Div + 0.004 MB. 0.00813985%. FC + 49.1409 MB in Total +Parameter Memory per operator type: + 16.6851 MB. 76.5052%. Conv + 5.124 MB. 23.4948%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8091 MB in Total +``` + +## MobileNet-V3 (RW) + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 24.8316. Iters per second: 40.2712 +Time per operator type: + 15.9266 ms. 69.2624%. Conv + 2.36551 ms. 10.2873%. SpatialBN + 1.39102 ms. 6.04936%. Add + 1.30327 ms. 5.66773%. Div + 0.737014 ms. 3.20517%. Mul + 0.639697 ms. 2.78195%. Relu + 0.375681 ms. 1.63378%. Clip + 0.153126 ms. 0.665921%. FC + 0.0993787 ms. 0.432184%. AveragePool + 0.0032632 ms. 0.0141912%. Squeeze + 22.9946 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 94.4041%. Conv + 0.0175992 GFLOP. 3.85829%. SpatialBN + 0.002561 GFLOP. 0.561449%. FC + 0.00210961 GFLOP. 0.46249%. Mul + 0.00173891 GFLOP. 0.381223%. Add + 0.00151626 GFLOP. 0.33241%. Div + 0 GFLOP. 0%. Relu + 0.456141 GFLOP in Total +Feature Memory Read per operator type: + 34.7354 MB. 36.4363%. Conv + 17.7944 MB. 18.6658%. SpatialBN + 14.5035 MB. 15.2137%. Mul + 9.25778 MB. 9.71113%. Relu + 7.84641 MB. 8.23064%. Add + 6.06516 MB. 6.36216%. Div + 5.12912 MB. 5.38029%. FC + 95.3317 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 26.7264%. Conv + 17.5992 MB. 26.6878%. SpatialBN + 9.25778 MB. 14.0387%. Relu + 8.43843 MB. 12.7962%. Mul + 6.95565 MB. 10.5477%. Add + 6.06502 MB. 9.19713%. Div + 0.004 MB. 0.00606568%. FC + 65.9447 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.1564%. Conv + 5.124 MB. 23.3979%. FC + 0.0976 MB. 0.445674%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8994 MB in Total + +``` +### Optimized + +``` +Main run finished. Milliseconds per iter: 22.0981. Iters per second: 45.2527 +Time per operator type: + 17.146 ms. 78.8965%. Conv + 1.38453 ms. 6.37084%. Add + 1.30991 ms. 6.02749%. Div + 0.685417 ms. 3.15391%. Mul + 0.532589 ms. 2.45068%. Relu + 0.418263 ms. 1.92461%. Clip + 0.15128 ms. 0.696106%. FC + 0.102065 ms. 0.469648%. AveragePool + 0.0022143 ms. 0.010189%. Squeeze + 21.7323 ms in Total +FLOP per operator type: + 0.430616 GFLOP. 98.1927%. Conv + 0.002561 GFLOP. 0.583981%. FC + 0.00210961 GFLOP. 0.481051%. Mul + 0.00173891 GFLOP. 0.396522%. Add + 0.00151626 GFLOP. 0.34575%. Div + 0 GFLOP. 0%. Relu + 0.438542 GFLOP in Total +Feature Memory Read per operator type: + 34.7842 MB. 44.833%. Conv + 14.5035 MB. 18.6934%. Mul + 9.25778 MB. 11.9323%. Relu + 7.84641 MB. 10.1132%. Add + 6.06516 MB. 7.81733%. Div + 5.12912 MB. 6.61087%. FC + 77.5861 MB in Total +Feature Memory Written per operator type: + 17.6246 MB. 36.4556%. Conv + 9.25778 MB. 19.1492%. Relu + 8.43843 MB. 17.4544%. Mul + 6.95565 MB. 14.3874%. Add + 6.06502 MB. 12.5452%. Div + 0.004 MB. 0.00827378%. FC + 48.3455 MB in Total +Parameter Memory per operator type: + 16.6778 MB. 76.4973%. Conv + 5.124 MB. 23.5027%. FC + 0 MB. 0%. Add + 0 MB. 0%. Div + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 21.8018 MB in Total + +``` + +## MnasNet-A1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 30.0892. Iters per second: 33.2345 +Time per operator type: + 24.4656 ms. 79.0905%. Conv + 4.14958 ms. 13.4144%. SpatialBN + 1.60598 ms. 5.19169%. Relu + 0.295219 ms. 0.95436%. Mul + 0.187609 ms. 0.606486%. FC + 0.120556 ms. 0.389724%. AveragePool + 0.09036 ms. 0.292109%. Add + 0.015727 ms. 0.050841%. Sigmoid + 0.00306205 ms. 0.00989875%. Squeeze + 30.9337 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 95.6434%. Conv + 0.0248873 GFLOP. 3.8355%. SpatialBN + 0.002561 GFLOP. 0.394688%. FC + 0.000597408 GFLOP. 0.0920695%. Mul + 0.000222656 GFLOP. 0.0343146%. Add + 0 GFLOP. 0%. Relu + 0.648867 GFLOP in Total +Feature Memory Read per operator type: + 35.5457 MB. 38.4109%. Conv + 25.1552 MB. 27.1829%. SpatialBN + 22.5235 MB. 24.339%. Relu + 5.12912 MB. 5.54256%. FC + 2.40586 MB. 2.59978%. Mul + 1.78125 MB. 1.92483%. Add + 92.5406 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 32.9424%. Conv + 24.8873 MB. 32.92%. SpatialBN + 22.5235 MB. 29.7932%. Relu + 2.38963 MB. 3.16092%. Mul + 0.890624 MB. 1.17809%. Add + 0.004 MB. 0.00529106%. FC + 75.5993 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.1459%. Conv + 5.124 MB. 32.9917%. FC + 0.133952 MB. 0.86247%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.5312 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 24.2367. Iters per second: 41.2597 +Time per operator type: + 22.0547 ms. 91.1375%. Conv + 1.49096 ms. 6.16116%. Relu + 0.253417 ms. 1.0472%. Mul + 0.18506 ms. 0.76473%. FC + 0.112942 ms. 0.466717%. AveragePool + 0.086769 ms. 0.358559%. Add + 0.0127889 ms. 0.0528479%. Sigmoid + 0.0027346 ms. 0.0113003%. Squeeze + 24.1994 ms in Total +FLOP per operator type: + 0.620598 GFLOP. 99.4581%. Conv + 0.002561 GFLOP. 0.41043%. FC + 0.000597408 GFLOP. 0.0957417%. Mul + 0.000222656 GFLOP. 0.0356832%. Add + 0 GFLOP. 0%. Relu + 0.623979 GFLOP in Total +Feature Memory Read per operator type: + 35.6127 MB. 52.7968%. Conv + 22.5235 MB. 33.3917%. Relu + 5.12912 MB. 7.60406%. FC + 2.40586 MB. 3.56675%. Mul + 1.78125 MB. 2.64075%. Add + 67.4524 MB in Total +Feature Memory Written per operator type: + 24.9042 MB. 49.1092%. Conv + 22.5235 MB. 44.4145%. Relu + 2.38963 MB. 4.71216%. Mul + 0.890624 MB. 1.75624%. Add + 0.004 MB. 0.00788768%. FC + 50.712 MB in Total +Parameter Memory per operator type: + 10.2732 MB. 66.7213%. Conv + 5.124 MB. 33.2787%. FC + 0 MB. 0%. Add + 0 MB. 0%. Mul + 0 MB. 0%. Relu + 15.3972 MB in Total +``` +## MnasNet-B1 + +### Unoptimized +``` +Main run finished. Milliseconds per iter: 28.3109. Iters per second: 35.322 +Time per operator type: + 29.1121 ms. 83.3081%. Conv + 4.14959 ms. 11.8746%. SpatialBN + 1.35823 ms. 3.88675%. Relu + 0.186188 ms. 0.532802%. FC + 0.116244 ms. 0.332647%. Add + 0.018641 ms. 0.0533437%. AveragePool + 0.0040904 ms. 0.0117052%. Squeeze + 34.9451 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 96.2088%. Conv + 0.0218266 GFLOP. 3.35303%. SpatialBN + 0.002561 GFLOP. 0.393424%. FC + 0.000291648 GFLOP. 0.0448034%. Add + 0 GFLOP. 0%. Relu + 0.650951 GFLOP in Total +Feature Memory Read per operator type: + 34.4354 MB. 41.3788%. Conv + 22.1299 MB. 26.5921%. SpatialBN + 19.1923 MB. 23.0622%. Relu + 5.12912 MB. 6.16333%. FC + 2.33318 MB. 2.80364%. Add + 83.2199 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 34.0955%. Conv + 21.8266 MB. 34.0955%. SpatialBN + 19.1923 MB. 29.9805%. Relu + 1.16659 MB. 1.82234%. Add + 0.004 MB. 0.00624844%. FC + 64.016 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 69.9104%. Conv + 5.124 MB. 29.2245%. FC + 0.15168 MB. 0.865099%. SpatialBN + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.5332 MB in Total +``` + +### Optimized +``` +Main run finished. Milliseconds per iter: 26.6364. Iters per second: 37.5426 +Time per operator type: + 24.9888 ms. 94.0962%. Conv + 1.26147 ms. 4.75011%. Relu + 0.176234 ms. 0.663619%. FC + 0.113309 ms. 0.426672%. Add + 0.0138708 ms. 0.0522311%. AveragePool + 0.00295685 ms. 0.0111341%. Squeeze + 26.5566 ms in Total +FLOP per operator type: + 0.626272 GFLOP. 99.5466%. Conv + 0.002561 GFLOP. 0.407074%. FC + 0.000291648 GFLOP. 0.0463578%. Add + 0 GFLOP. 0%. Relu + 0.629124 GFLOP in Total +Feature Memory Read per operator type: + 34.5112 MB. 56.4224%. Conv + 19.1923 MB. 31.3775%. Relu + 5.12912 MB. 8.3856%. FC + 2.33318 MB. 3.81452%. Add + 61.1658 MB in Total +Feature Memory Written per operator type: + 21.8266 MB. 51.7346%. Conv + 19.1923 MB. 45.4908%. Relu + 1.16659 MB. 2.76513%. Add + 0.004 MB. 0.00948104%. FC + 42.1895 MB in Total +Parameter Memory per operator type: + 12.2576 MB. 70.5205%. Conv + 5.124 MB. 29.4795%. FC + 0 MB. 0%. Add + 0 MB. 0%. Relu + 17.3816 MB in Total +``` diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..80e7d15508202f3262a50db27f5198460d7f509f --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020 Ross Wightman + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md new file mode 100644 index 0000000000000000000000000000000000000000..463368280d6a5015060eb73d20fe6512f8e04c50 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/README.md @@ -0,0 +1,323 @@ +# (Generic) EfficientNets for PyTorch + +A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter efficient architectures derived from the MobileNet V1/V2 block sequence, including those found via automated neural architecture search. + +All models are implemented by GenEfficientNet or MobileNetV3 classes, with string based architecture definitions to configure the block layouts (idea from [here](https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py)) + +## What's New + +### Aug 19, 2020 +* Add updated PyTorch trained EfficientNet-B3 weights trained by myself with `timm` (82.1 top-1) +* Add PyTorch trained EfficientNet-Lite0 contributed by [@hal-314](https://github.com/hal-314) (75.5 top-1) +* Update ONNX and Caffe2 export / utility scripts to work with latest PyTorch / ONNX +* ONNX runtime based validation script added +* activations (mostly) brought in sync with `timm` equivalents + + +### April 5, 2020 +* Add some newly trained MobileNet-V2 models trained with latest h-params, rand augment. They compare quite favourably to EfficientNet-Lite + * 3.5M param MobileNet-V2 100 @ 73% + * 4.5M param MobileNet-V2 110d @ 75% + * 6.1M param MobileNet-V2 140 @ 76.5% + * 5.8M param MobileNet-V2 120d @ 77.3% + +### March 23, 2020 + * Add EfficientNet-Lite models w/ weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * Add PyTorch trained MobileNet-V3 Large weights with 75.77% top-1 + * IMPORTANT CHANGE (if training from scratch) - weight init changed to better match Tensorflow impl, set `fix_group_fanout=False` in `initialize_weight_goog` for old behavior + +### Feb 12, 2020 + * Add EfficientNet-L2 and B0-B7 NoisyStudent weights ported from [Tensorflow TPU](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet) + * Port new EfficientNet-B8 (RandAugment) weights from TF TPU, these are different than the B8 AdvProp, different input normalization. + * Add RandAugment PyTorch trained EfficientNet-ES (EdgeTPU-Small) weights with 78.1 top-1. Trained by [Andrew Lavin](https://github.com/andravin) + +### Jan 22, 2020 + * Update weights for EfficientNet B0, B2, B3 and MixNet-XL with latest RandAugment trained weights. Trained with (https://github.com/rwightman/pytorch-image-models) + * Fix torchscript compatibility for PyTorch 1.4, add torchscript support for MixedConv2d using ModuleDict + * Test models, torchscript, onnx export with PyTorch 1.4 -- no issues + +### Nov 22, 2019 + * New top-1 high! Ported official TF EfficientNet AdvProp (https://arxiv.org/abs/1911.09665) weights and B8 model spec. Created a new set of `ap` models since they use a different + preprocessing (Inception mean/std) from the original EfficientNet base/AA/RA weights. + +### Nov 15, 2019 + * Ported official TF MobileNet-V3 float32 large/small/minimalistic weights + * Modifications to MobileNet-V3 model and components to support some additional config needed for differences between TF MobileNet-V3 and mine + +### Oct 30, 2019 + * Many of the models will now work with torch.jit.script, MixNet being the biggest exception + * Improved interface for enabling torchscript or ONNX export compatible modes (via config) + * Add JIT optimized mem-efficient Swish/Mish autograd.fn in addition to memory-efficient autgrad.fn + * Activation factory to select best version of activation by name or override one globally + * Add pretrained checkpoint load helper that handles input conv and classifier changes + +### Oct 27, 2019 + * Add CondConv EfficientNet variants ported from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv + * Add RandAug weights for TF EfficientNet B5 and B7 from https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet + * Bring over MixNet-XL model and depth scaling algo from my pytorch-image-models code base + * Switch activations and global pooling to modules + * Add memory-efficient Swish/Mish impl + * Add as_sequential() method to all models and allow as an argument in entrypoint fns + * Move MobileNetV3 into own file since it has a different head + * Remove ChamNet, MobileNet V2/V1 since they will likely never be used here + +## Models + +Implemented models include: + * EfficientNet NoisyStudent (B0-B7, L2) (https://arxiv.org/abs/1911.04252) + * EfficientNet AdvProp (B0-B8) (https://arxiv.org/abs/1911.09665) + * EfficientNet (B0-B8) (https://arxiv.org/abs/1905.11946) + * EfficientNet-EdgeTPU (S, M, L) (https://ai.googleblog.com/2019/08/efficientnet-edgetpu-creating.html) + * EfficientNet-CondConv (https://arxiv.org/abs/1904.04971) + * EfficientNet-Lite (https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite) + * MixNet (https://arxiv.org/abs/1907.09595) + * MNASNet B1, A1 (Squeeze-Excite), and Small (https://arxiv.org/abs/1807.11626) + * MobileNet-V3 (https://arxiv.org/abs/1905.02244) + * FBNet-C (https://arxiv.org/abs/1812.03443) + * Single-Path NAS (https://arxiv.org/abs/1904.02877) + +I originally implemented and trained some these models with code [here](https://github.com/rwightman/pytorch-image-models), this repository contains just the GenEfficientNet models, validation, and associated ONNX/Caffe2 export code. + +## Pretrained + +I've managed to train several of the models to accuracies close to or above the originating papers and official impl. My training code is here: https://github.com/rwightman/pytorch-image-models + + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param#(M) | MAdds(M) | Image Scaling | Resolution | Crop | +|---|---|---|---|---|---|---|---| +| efficientnet_b3 | 82.240 (17.760) | 96.116 (3.884) | 12.23 | TBD | bicubic | 320 | 1.0 | +| efficientnet_b3 | 82.076 (17.924) | 96.020 (3.980) | 12.23 | TBD | bicubic | 300 | 0.904 | +| mixnet_xl | 81.074 (18.926) | 95.282 (4.718) | 11.90 | TBD | bicubic | 256 | 1.0 | +| efficientnet_b2 | 80.612 (19.388) | 95.318 (4.682) | 9.1 | TBD | bicubic | 288 | 1.0 | +| mixnet_xl | 80.476 (19.524) | 94.936 (5.064) | 11.90 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b2 | 80.288 (19.712) | 95.166 (4.834) | 9.1 | 1003 | bicubic | 260 | 0.890 | +| mixnet_l | 78.976 (21.024 | 94.184 (5.816) | 7.33 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b1 | 78.692 (21.308) | 94.086 (5.914) | 7.8 | 694 | bicubic | 240 | 0.882 | +| efficientnet_es | 78.066 (21.934) | 93.926 (6.074) | 5.44 | TBD | bicubic | 224 | 0.875 | +| efficientnet_b0 | 77.698 (22.302) | 93.532 (6.468) | 5.3 | 390 | bicubic | 224 | 0.875 | +| mobilenetv2_120d | 77.294 (22.706 | 93.502 (6.498) | 5.8 | TBD | bicubic | 224 | 0.875 | +| mixnet_m | 77.256 (22.744) | 93.418 (6.582) | 5.01 | 353 | bicubic | 224 | 0.875 | +| mobilenetv2_140 | 76.524 (23.476) | 92.990 (7.010) | 6.1 | TBD | bicubic | 224 | 0.875 | +| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5 | TBD | bicubic | 224 | 0.875 | +| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5 | 219 | bicubic | 224 | 0.875 | +| efficientnet_lite0 | 75.472 (24.528) | 92.520 (7.480) | 4.65 | TBD | bicubic | 224 | 0.875 | +| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.9 | 312 | bicubic | 224 | 0.875 | +| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6 | 385 | bilinear | 224 | 0.875 | +| mobilenetv2_110d | 75.052 (24.948) | 92.180 (7.820) | 4.5 | TBD | bicubic | 224 | 0.875 | +| mnasnet_b1 | 74.658 (25.342) | 92.114 (7.886) | 4.4 | 315 | bicubic | 224 | 0.875 | +| spnasnet_100 | 74.084 (25.916) | 91.818 (8.182) | 4.4 | TBD | bilinear | 224 | 0.875 | +| mobilenetv2_100 | 72.978 (27.022) | 91.016 (8.984) | 3.5 | TBD | bicubic | 224 | 0.875 | + + +More pretrained models to come... + + +## Ported Weights + +The weights ported from Tensorflow checkpoints for the EfficientNet models do pretty much match accuracy in Tensorflow once a SAME convolution padding equivalent is added, and the same crop factors, image scaling, etc (see table) are used via cmd line args. + +**IMPORTANT:** +* Tensorflow ported weights for EfficientNet AdvProp (AP), EfficientNet EdgeTPU, EfficientNet-CondConv, EfficientNet-Lite, and MobileNet-V3 models use Inception style (0.5, 0.5, 0.5) for mean and std. +* Enabling the Tensorflow preprocessing pipeline with `--tf-preprocessing` at validation time will improve scores by 0.1-0.5%, very close to original TF impl. + +To run validation for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --crop-pct 0.934 --interpolation bicubic` + +To run validation w/ TF preprocessing for tf_efficientnet_b5: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b5 -b 64 --img-size 456 --tf-preprocessing` + +To run validation for a model with Inception preprocessing, ie EfficientNet-B8 AdvProp: +`python validate.py /path/to/imagenet/validation/ --model tf_efficientnet_b8_ap -b 48 --num-gpu 2 --img-size 672 --crop-pct 0.954 --mean 0.5 --std 0.5` + +|Model | Prec@1 (Err) | Prec@5 (Err) | Param # | Image Scaling | Image Size | Crop | +|---|---|---|---|---|---|---| +| tf_efficientnet_l2_ns *tfp | 88.352 (11.648) | 98.652 (1.348) | 480 | bicubic | 800 | N/A | +| tf_efficientnet_l2_ns | TBD | TBD | 480 | bicubic | 800 | 0.961 | +| tf_efficientnet_l2_ns_475 | 88.234 (11.766) | 98.546 (1.454) | 480 | bicubic | 475 | 0.936 | +| tf_efficientnet_l2_ns_475 *tfp | 88.172 (11.828) | 98.566 (1.434) | 480 | bicubic | 475 | N/A | +| tf_efficientnet_b7_ns *tfp | 86.844 (13.156) | 98.084 (1.916) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ns | 86.840 (13.160) | 98.094 (1.906) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b6_ns | 86.452 (13.548) | 97.882 (2.118) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6_ns *tfp | 86.444 (13.556) | 97.880 (2.120) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ns *tfp | 86.064 (13.936) | 97.746 (2.254) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ns | 86.088 (13.912) | 97.752 (2.248) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b8_ap *tfp | 85.436 (14.564) | 97.272 (2.728) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 *tfp | 85.384 (14.616) | 97.394 (2.606) | 87.4 | bicubic | 672 | N/A | +| tf_efficientnet_b8 | 85.370 (14.630) | 97.390 (2.610) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b8_ap | 85.368 (14.632) | 97.294 (2.706) | 87.4 | bicubic | 672 | 0.954 | +| tf_efficientnet_b4_ns *tfp | 85.298 (14.702) | 97.504 (2.496) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ns | 85.162 (14.838) | 97.470 (2.530) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b7_ap *tfp | 85.154 (14.846) | 97.244 (2.756) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7_ap | 85.118 (14.882) | 97.252 (2.748) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b7 *tfp | 84.940 (15.060) | 97.214 (2.786) | 66.35 | bicubic | 600 | N/A | +| tf_efficientnet_b7 | 84.932 (15.068) | 97.208 (2.792) | 66.35 | bicubic | 600 | 0.949 | +| tf_efficientnet_b6_ap | 84.786 (15.214) | 97.138 (2.862) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b6_ap *tfp | 84.760 (15.240) | 97.124 (2.876) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b5_ap *tfp | 84.276 (15.724) | 96.932 (3.068) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5_ap | 84.254 (15.746) | 96.976 (3.024) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b6 *tfp | 84.140 (15.860) | 96.852 (3.148) | 43.04 | bicubic | 528 | N/A | +| tf_efficientnet_b6 | 84.110 (15.890) | 96.886 (3.114) | 43.04 | bicubic | 528 | 0.942 | +| tf_efficientnet_b3_ns *tfp | 84.054 (15.946) | 96.918 (3.082) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ns | 84.048 (15.952) | 96.910 (3.090) | 12.23 | bicubic | 300 | .904 | +| tf_efficientnet_b5 *tfp | 83.822 (16.178) | 96.756 (3.244) | 30.39 | bicubic | 456 | N/A | +| tf_efficientnet_b5 | 83.812 (16.188) | 96.748 (3.252) | 30.39 | bicubic | 456 | 0.934 | +| tf_efficientnet_b4_ap *tfp | 83.278 (16.722) | 96.376 (3.624) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b4_ap | 83.248 (16.752) | 96.388 (3.612) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 | 83.022 (16.978) | 96.300 (3.700) | 19.34 | bicubic | 380 | 0.922 | +| tf_efficientnet_b4 *tfp | 82.948 (17.052) | 96.308 (3.692) | 19.34 | bicubic | 380 | N/A | +| tf_efficientnet_b2_ns *tfp | 82.436 (17.564) | 96.268 (3.732) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ns | 82.380 (17.620) | 96.248 (3.752) | 9.11 | bicubic | 260 | 0.89 | +| tf_efficientnet_b3_ap *tfp | 81.882 (18.118) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_b3_ap | 81.828 (18.172) | 95.624 (4.376) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 | 81.636 (18.364) | 95.718 (4.282) | 12.23 | bicubic | 300 | 0.904 | +| tf_efficientnet_b3 *tfp | 81.576 (18.424) | 95.662 (4.338) | 12.23 | bicubic | 300 | N/A | +| tf_efficientnet_lite4 | 81.528 (18.472) | 95.668 (4.332) | 13.00 | bilinear | 380 | 0.92 | +| tf_efficientnet_b1_ns *tfp | 81.514 (18.486) | 95.776 (4.224) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_lite4 *tfp | 81.502 (18.498) | 95.676 (4.324) | 13.00 | bilinear | 380 | N/A | +| tf_efficientnet_b1_ns | 81.388 (18.612) | 95.738 (4.262) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_el | 80.534 (19.466) | 95.190 (4.810) | 10.59 | bicubic | 300 | 0.904 | +| tf_efficientnet_el *tfp | 80.476 (19.524) | 95.200 (4.800) | 10.59 | bicubic | 300 | N/A | +| tf_efficientnet_b2_ap *tfp | 80.420 (19.580) | 95.040 (4.960) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2_ap | 80.306 (19.694) | 95.028 (4.972) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_b2 *tfp | 80.188 (19.812) | 94.974 (5.026) | 9.11 | bicubic | 260 | N/A | +| tf_efficientnet_b2 | 80.086 (19.914) | 94.908 (5.092) | 9.11 | bicubic | 260 | 0.890 | +| tf_efficientnet_lite3 | 79.812 (20.188) | 94.914 (5.086) | 8.20 | bilinear | 300 | 0.904 | +| tf_efficientnet_lite3 *tfp | 79.734 (20.266) | 94.838 (5.162) | 8.20 | bilinear | 300 | N/A | +| tf_efficientnet_b1_ap *tfp | 79.532 (20.468) | 94.378 (5.622) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_cc_b1_8e *tfp | 79.464 (20.536)| 94.492 (5.508) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_cc_b1_8e | 79.298 (20.702) | 94.364 (5.636) | 39.7 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1_ap | 79.278 (20.722) | 94.308 (5.692) | 7.79 | bicubic | 240 | 0.88 | +| tf_efficientnet_b1 *tfp | 79.172 (20.828) | 94.450 (5.550) | 7.79 | bicubic | 240 | N/A | +| tf_efficientnet_em *tfp | 78.958 (21.042) | 94.458 (5.542) | 6.90 | bicubic | 240 | N/A | +| tf_efficientnet_b0_ns *tfp | 78.806 (21.194) | 94.496 (5.504) | 5.29 | bicubic | 224 | N/A | +| tf_mixnet_l *tfp | 78.846 (21.154) | 94.212 (5.788) | 7.33 | bilinear | 224 | N/A | +| tf_efficientnet_b1 | 78.826 (21.174) | 94.198 (5.802) | 7.79 | bicubic | 240 | 0.88 | +| tf_mixnet_l | 78.770 (21.230) | 94.004 (5.996) | 7.33 | bicubic | 224 | 0.875 | +| tf_efficientnet_em | 78.742 (21.258) | 94.332 (5.668) | 6.90 | bicubic | 240 | 0.875 | +| tf_efficientnet_b0_ns | 78.658 (21.342) | 94.376 (5.624) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e *tfp | 78.314 (21.686) | 93.790 (6.210) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_8e | 77.908 (22.092) | 93.656 (6.344) | 24.0 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e *tfp | 77.746 (22.254) | 93.552 (6.448) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_cc_b0_4e | 77.304 (22.696) | 93.332 (6.668) | 13.3 | bicubic | 224 | 0.875 | +| tf_efficientnet_es *tfp | 77.616 (22.384) | 93.750 (6.250) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_lite2 *tfp | 77.544 (22.456) | 93.800 (6.200) | 6.09 | bilinear | 260 | N/A | +| tf_efficientnet_lite2 | 77.460 (22.540) | 93.746 (6.254) | 6.09 | bicubic | 260 | 0.89 | +| tf_efficientnet_b0_ap *tfp | 77.514 (22.486) | 93.576 (6.424) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_es | 77.264 (22.736) | 93.600 (6.400) | 5.44 | bicubic | 224 | N/A | +| tf_efficientnet_b0 *tfp | 77.258 (22.742) | 93.478 (6.522) | 5.29 | bicubic | 224 | N/A | +| tf_efficientnet_b0_ap | 77.084 (22.916) | 93.254 (6.746) | 5.29 | bicubic | 224 | 0.875 | +| tf_mixnet_m *tfp | 77.072 (22.928) | 93.368 (6.632) | 5.01 | bilinear | 224 | N/A | +| tf_mixnet_m | 76.950 (23.050) | 93.156 (6.844) | 5.01 | bicubic | 224 | 0.875 | +| tf_efficientnet_b0 | 76.848 (23.152) | 93.228 (6.772) | 5.29 | bicubic | 224 | 0.875 | +| tf_efficientnet_lite1 *tfp | 76.764 (23.236) | 93.326 (6.674) | 5.42 | bilinear | 240 | N/A | +| tf_efficientnet_lite1 | 76.638 (23.362) | 93.232 (6.768) | 5.42 | bicubic | 240 | 0.882 | +| tf_mixnet_s *tfp | 75.800 (24.200) | 92.788 (7.212) | 4.13 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_100 *tfp | 75.768 (24.232) | 92.710 (7.290) | 5.48 | bilinear | 224 | N/A | +| tf_mixnet_s | 75.648 (24.352) | 92.636 (7.364) | 4.13 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_100 | 75.516 (24.484) | 92.600 (7.400) | 5.48 | bilinear | 224 | 0.875 | +| tf_efficientnet_lite0 *tfp | 75.074 (24.926) | 92.314 (7.686) | 4.65 | bilinear | 224 | N/A | +| tf_efficientnet_lite0 | 74.842 (25.158) | 92.170 (7.830) | 4.65 | bicubic | 224 | 0.875 | +| tf_mobilenetv3_large_075 *tfp | 73.730 (26.270) | 91.616 (8.384) | 3.99 | bilinear | 224 |N/A | +| tf_mobilenetv3_large_075 | 73.442 (26.558) | 91.352 (8.648) | 3.99 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_large_minimal_100 *tfp | 72.678 (27.322) | 90.860 (9.140) | 3.92 | bilinear | 224 | N/A | +| tf_mobilenetv3_large_minimal_100 | 72.244 (27.756) | 90.636 (9.364) | 3.92 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_100 *tfp | 67.918 (32.082) | 87.958 (12.042 | 2.54 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_100 | 67.918 (32.082) | 87.662 (12.338) | 2.54 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_075 *tfp | 66.142 (33.858) | 86.498 (13.502) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_075 | 65.718 (34.282) | 86.136 (13.864) | 2.04 | bilinear | 224 | 0.875 | +| tf_mobilenetv3_small_minimal_100 *tfp | 63.378 (36.622) | 84.802 (15.198) | 2.04 | bilinear | 224 | N/A | +| tf_mobilenetv3_small_minimal_100 | 62.898 (37.102) | 84.230 (15.770) | 2.04 | bilinear | 224 | 0.875 | + + +*tfp models validated with `tf-preprocessing` pipeline + +Google tf and tflite weights ported from official Tensorflow repositories +* https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet +* https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet +* https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet + +## Usage + +### Environment + +All development and testing has been done in Conda Python 3 environments on Linux x86-64 systems, specifically Python 3.6.x, 3.7.x, 3.8.x. + +Users have reported that a Python 3 Anaconda install in Windows works. I have not verified this myself. + +PyTorch versions 1.4, 1.5, 1.6 have been tested with this code. + +I've tried to keep the dependencies minimal, the setup is as per the PyTorch default install instructions for Conda: +``` +conda create -n torch-env +conda activate torch-env +conda install -c pytorch pytorch torchvision cudatoolkit=10.2 +``` + +### PyTorch Hub + +Models can be accessed via the PyTorch Hub API + +``` +>>> torch.hub.list('rwightman/gen-efficientnet-pytorch') +['efficientnet_b0', ...] +>>> model = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b0', pretrained=True) +>>> model.eval() +>>> output = model(torch.randn(1,3,224,224)) +``` + +### Pip +This package can be installed via pip. + +Install (after conda env/install): +``` +pip install geffnet +``` + +Eval use: +``` +>>> import geffnet +>>> m = geffnet.create_model('mobilenetv3_large_100', pretrained=True) +>>> m.eval() +``` + +Train use: +``` +>>> import geffnet +>>> # models can also be created by using the entrypoint directly +>>> m = geffnet.efficientnet_b2(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2) +>>> m.train() +``` + +Create in a nn.Sequential container, for fast.ai, etc: +``` +>>> import geffnet +>>> m = geffnet.mixnet_l(pretrained=True, drop_rate=0.25, drop_connect_rate=0.2, as_sequential=True) +``` + +### Exporting + +Scripts are included to +* export models to ONNX (`onnx_export.py`) +* optimized ONNX graph (`onnx_optimize.py` or `onnx_validate.py` w/ `--onnx-output-opt` arg) +* validate with ONNX runtime (`onnx_validate.py`) +* convert ONNX model to Caffe2 (`onnx_to_caffe.py`) +* validate in Caffe2 (`caffe2_validate.py`) +* benchmark in Caffe2 w/ FLOPs, parameters output (`caffe2_benchmark.py`) + +As an example, to export the MobileNet-V3 pretrained model and then run an Imagenet validation: +``` +python onnx_export.py --model mobilenetv3_large_100 ./mobilenetv3_100.onnx +python onnx_validate.py /imagenet/validation/ --onnx-input ./mobilenetv3_100.onnx +``` + +These scripts were tested to be working as of PyTorch 1.6 and ONNX 1.7 w/ ONNX runtime 1.4. Caffe2 compatible +export now requires additional args mentioned in the export script (not needed in earlier versions). + +#### Export Notes +1. The TF ported weights with the 'SAME' conv padding activated cannot be exported to ONNX unless `_EXPORTABLE` flag in `config.py` is set to True. Use `config.set_exportable(True)` as in the `onnx_export.py` script. +2. TF ported models with 'SAME' padding will have the padding fixed at export time to the resolution used for export. Even though dynamic padding is supported in opset >= 11, I can't get it working. +3. ONNX optimize facility doesn't work reliably in PyTorch 1.6 / ONNX 1.7. Fortunately, the onnxruntime based inference is working very well now and includes on the fly optimization. +3. ONNX / Caffe2 export/import frequently breaks with different PyTorch and ONNX version releases. Please check their respective issue trackers before filing issues here. + + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..93f28a1e63d9f7287ca02997c7991fe66dd0aeb9 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_benchmark.py @@ -0,0 +1,65 @@ +""" Caffe2 validation script + +This script runs Caffe2 benchmark on exported ONNX model. +It is a useful tool for reporting model FLOPS. + +Copyright 2020 Ross Wightman +""" +import argparse +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 + + +parser = argparse.ArgumentParser(description='Caffe2 Model Benchmark') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=224, type=int, + metavar='N', help='Input image dimension, uses model default if empty') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="le_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + # CUDA performance not impressive + #device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + #model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + #model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + input_blob = model.net.external_inputs[0] + model.param_init_net.GaussianFill( + [], + input_blob.GetUnscopedName(), + shape=(args.batch_size, 3, args.img_size, args.img_size), + mean=0.0, + std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + workspace.BenchmarkNet(model.net.Proto().name, 5, 20, True) + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..7cfaab38c095663fe32e4addbdf06b57bcb53614 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/caffe2_validate.py @@ -0,0 +1,138 @@ +""" Caffe2 validation script + +This script is created to verify exported ONNX models running in Caffe2 +It utilizes the same PyTorch dataloader/processing pipeline for a +fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +from caffe2.python import core, workspace, model_helper +from caffe2.proto import caffe2_pb2 +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--c2-prefix', default='', type=str, metavar='NAME', + help='caffe2 model pb name prefix') +parser.add_argument('--c2-init', default='', type=str, metavar='PATH', + help='caffe2 model init .pb') +parser.add_argument('--c2-predict', default='', type=str, metavar='PATH', + help='caffe2 model predict .pb') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + if args.c2_prefix: + args.c2_init = args.c2_prefix + '.init.pb' + args.c2_predict = args.c2_prefix + '.predict.pb' + + model = model_helper.ModelHelper(name="validation_net", init_params=False) + + # Bring in the init net from init_net.pb + init_net_proto = caffe2_pb2.NetDef() + with open(args.c2_init, "rb") as f: + init_net_proto.ParseFromString(f.read()) + model.param_init_net = core.Net(init_net_proto) + + # bring in the predict net from predict_net.pb + predict_net_proto = caffe2_pb2.NetDef() + with open(args.c2_predict, "rb") as f: + predict_net_proto.ParseFromString(f.read()) + model.net = core.Net(predict_net_proto) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + # this is so obvious, wonderful interface + input_blob = model.net.external_inputs[0] + output_blob = model.net.external_outputs[0] + + if True: + device_opts = None + else: + # CUDA is crashing, no idea why, awesome error message, give it a try for kicks + device_opts = core.DeviceOption(caffe2_pb2.PROTO_CUDA, args.gpu_id) + model.net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + model.param_init_net.RunAllOnGPU(gpu_id=args.gpu_id, use_cudnn=True) + + model.param_init_net.GaussianFill( + [], input_blob.GetUnscopedName(), + shape=(1,) + data_config['input_size'], mean=0.0, std=1.0) + workspace.RunNetOnce(model.param_init_net) + workspace.CreateNet(model.net, overwrite=True) + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + caffe2_in = input.data.numpy() + workspace.FeedBlob(input_blob, caffe2_in, device_opts) + workspace.RunNet(model.net, num_iter=1) + output = workspace.FetchBlob(output_blob) + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output.data, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e441a5838d1e972823b9668ac8d459445f6f6ce --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/__init__.py @@ -0,0 +1,5 @@ +from .gen_efficientnet import * +from .mobilenetv3 import * +from .model_factory import create_model +from .config import is_exportable, is_scriptable, set_exportable, set_scriptable +from .activations import * \ No newline at end of file diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..813421a743ffc33b8eb53ebf62dd4a03d831b654 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/__init__.py @@ -0,0 +1,137 @@ +from geffnet import config +from geffnet.activations.activations_me import * +from geffnet.activations.activations_jit import * +from geffnet.activations.activations import * +import torch + +_has_silu = 'silu' in dir(torch.nn.functional) + +_ACT_FN_DEFAULT = dict( + silu=F.silu if _has_silu else swish, + swish=F.silu if _has_silu else swish, + mish=mish, + relu=F.relu, + relu6=F.relu6, + sigmoid=sigmoid, + tanh=tanh, + hard_sigmoid=hard_sigmoid, + hard_swish=hard_swish, +) + +_ACT_FN_JIT = dict( + silu=F.silu if _has_silu else swish_jit, + swish=F.silu if _has_silu else swish_jit, + mish=mish_jit, +) + +_ACT_FN_ME = dict( + silu=F.silu if _has_silu else swish_me, + swish=F.silu if _has_silu else swish_me, + mish=mish_me, + hard_swish=hard_swish_me, + hard_sigmoid_jit=hard_sigmoid_me, +) + +_ACT_LAYER_DEFAULT = dict( + silu=nn.SiLU if _has_silu else Swish, + swish=nn.SiLU if _has_silu else Swish, + mish=Mish, + relu=nn.ReLU, + relu6=nn.ReLU6, + sigmoid=Sigmoid, + tanh=Tanh, + hard_sigmoid=HardSigmoid, + hard_swish=HardSwish, +) + +_ACT_LAYER_JIT = dict( + silu=nn.SiLU if _has_silu else SwishJit, + swish=nn.SiLU if _has_silu else SwishJit, + mish=MishJit, +) + +_ACT_LAYER_ME = dict( + silu=nn.SiLU if _has_silu else SwishMe, + swish=nn.SiLU if _has_silu else SwishMe, + mish=MishMe, + hard_swish=HardSwishMe, + hard_sigmoid=HardSigmoidMe +) + +_OVERRIDE_FN = dict() +_OVERRIDE_LAYER = dict() + + +def add_override_act_fn(name, fn): + global _OVERRIDE_FN + _OVERRIDE_FN[name] = fn + + +def update_override_act_fn(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_FN + _OVERRIDE_FN.update(overrides) + + +def clear_override_act_fn(): + global _OVERRIDE_FN + _OVERRIDE_FN = dict() + + +def add_override_act_layer(name, fn): + _OVERRIDE_LAYER[name] = fn + + +def update_override_act_layer(overrides): + assert isinstance(overrides, dict) + global _OVERRIDE_LAYER + _OVERRIDE_LAYER.update(overrides) + + +def clear_override_act_layer(): + global _OVERRIDE_LAYER + _OVERRIDE_LAYER = dict() + + +def get_act_fn(name='relu'): + """ Activation Function Factory + Fetching activation fns by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_FN: + return _OVERRIDE_FN[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_FN_ME: + # If not exporting or scripting the model, first look for a memory optimized version + # activation with custom autograd, then fallback to jit scripted, then a Python or Torch builtin + return _ACT_FN_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_FN_JIT[name] + return _ACT_FN_DEFAULT[name] + + +def get_act_layer(name='relu'): + """ Activation Layer Factory + Fetching activation layers by name with this function allows export or torch script friendly + functions to be returned dynamically based on current config. + """ + if name in _OVERRIDE_LAYER: + return _OVERRIDE_LAYER[name] + use_me = not (config.is_exportable() or config.is_scriptable() or config.is_no_jit()) + if use_me and name in _ACT_LAYER_ME: + return _ACT_LAYER_ME[name] + if config.is_exportable() and name in ('silu', 'swish'): + # FIXME PyTorch SiLU doesn't ONNX export, this is a temp hack + return Swish + use_jit = not (config.is_exportable() or config.is_no_jit()) + # NOTE: export tracing should work with jit scripted components, but I keep running into issues + if use_jit and name in _ACT_FN_JIT: # jit scripted models should be okay for export/scripting + return _ACT_LAYER_JIT[name] + return _ACT_LAYER_DEFAULT[name] + + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py new file mode 100644 index 0000000000000000000000000000000000000000..bdea692d1397673b2513d898c33edbcb37d94240 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations.py @@ -0,0 +1,102 @@ +""" Activations + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +Copyright 2020 Ross Wightman +""" +from torch import nn as nn +from torch.nn import functional as F + + +def swish(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul_(x.sigmoid()) if inplace else x.mul(x.sigmoid()) + + +class Swish(nn.Module): + def __init__(self, inplace: bool = False): + super(Swish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return swish(x, self.inplace) + + +def mish(x, inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class Mish(nn.Module): + def __init__(self, inplace: bool = False): + super(Mish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return mish(x, self.inplace) + + +def sigmoid(x, inplace: bool = False): + return x.sigmoid_() if inplace else x.sigmoid() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Sigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(Sigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.sigmoid_() if self.inplace else x.sigmoid() + + +def tanh(x, inplace: bool = False): + return x.tanh_() if inplace else x.tanh() + + +# PyTorch has this, but not with a consistent inplace argmument interface +class Tanh(nn.Module): + def __init__(self, inplace: bool = False): + super(Tanh, self).__init__() + self.inplace = inplace + + def forward(self, x): + return x.tanh_() if self.inplace else x.tanh() + + +def hard_swish(x, inplace: bool = False): + inner = F.relu6(x + 3.).div_(6.) + return x.mul_(inner) if inplace else x.mul(inner) + + +class HardSwish(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwish, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_swish(x, self.inplace) + + +def hard_sigmoid(x, inplace: bool = False): + if inplace: + return x.add_(3.).clamp_(0., 6.).div_(6.) + else: + return F.relu6(x + 3.) / 6. + + +class HardSigmoid(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoid, self).__init__() + self.inplace = inplace + + def forward(self, x): + return hard_sigmoid(x, self.inplace) + + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py new file mode 100644 index 0000000000000000000000000000000000000000..7176b05e779787528a47f20d55d64d4a0f219360 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_jit.py @@ -0,0 +1,79 @@ +""" Activations (jit) + +A collection of jit-scripted activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +All jit scripted activations are lacking in-place variations on purpose, scripted kernel fusion does not +currently work across in-place op boundaries, thus performance is equal to or less than the non-scripted +versions if they contain in-place ops. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + +__all__ = ['swish_jit', 'SwishJit', 'mish_jit', 'MishJit', + 'hard_sigmoid_jit', 'HardSigmoidJit', 'hard_swish_jit', 'HardSwishJit'] + + +@torch.jit.script +def swish_jit(x, inplace: bool = False): + """Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + return x.mul(x.sigmoid()) + + +@torch.jit.script +def mish_jit(x, _inplace: bool = False): + """Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + """ + return x.mul(F.softplus(x).tanh()) + + +class SwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishJit, self).__init__() + + def forward(self, x): + return swish_jit(x) + + +class MishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(MishJit, self).__init__() + + def forward(self, x): + return mish_jit(x) + + +@torch.jit.script +def hard_sigmoid_jit(x, inplace: bool = False): + # return F.relu6(x + 3.) / 6. + return (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSigmoidJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidJit, self).__init__() + + def forward(self, x): + return hard_sigmoid_jit(x) + + +@torch.jit.script +def hard_swish_jit(x, inplace: bool = False): + # return x * (F.relu6(x + 3.) / 6) + return x * (x + 3).clamp(min=0, max=6).div(6.) # clamp seems ever so slightly faster? + + +class HardSwishJit(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishJit, self).__init__() + + def forward(self, x): + return hard_swish_jit(x) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py new file mode 100644 index 0000000000000000000000000000000000000000..e91df5a50fdbe40bc386e2541a4fda743ad95e9a --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/activations/activations_me.py @@ -0,0 +1,174 @@ +""" Activations (memory-efficient w/ custom autograd) + +A collection of activations fn and modules with a common interface so that they can +easily be swapped. All have an `inplace` arg even if not used. + +These activations are not compatible with jit scripting or ONNX export of the model, please use either +the JIT or basic versions of the activations. + +Copyright 2020 Ross Wightman +""" + +import torch +from torch import nn as nn +from torch.nn import functional as F + + +__all__ = ['swish_me', 'SwishMe', 'mish_me', 'MishMe', + 'hard_sigmoid_me', 'HardSigmoidMe', 'hard_swish_me', 'HardSwishMe'] + + +@torch.jit.script +def swish_jit_fwd(x): + return x.mul(torch.sigmoid(x)) + + +@torch.jit.script +def swish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + return grad_output * (x_sigmoid * (1 + x * (1 - x_sigmoid))) + + +class SwishJitAutoFn(torch.autograd.Function): + """ torch.jit.script optimised Swish w/ memory-efficient checkpoint + Inspired by conversation btw Jeremy Howard & Adam Pazske + https://twitter.com/jeremyphoward/status/1188251041835315200 + + Swish - Described originally as SiLU (https://arxiv.org/abs/1702.03118v3) + and also as Swish (https://arxiv.org/abs/1710.05941). + + TODO Rename to SiLU with addition to PyTorch + """ + + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return swish_jit_bwd(x, grad_output) + + +def swish_me(x, inplace=False): + return SwishJitAutoFn.apply(x) + + +class SwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(SwishMe, self).__init__() + + def forward(self, x): + return SwishJitAutoFn.apply(x) + + +@torch.jit.script +def mish_jit_fwd(x): + return x.mul(torch.tanh(F.softplus(x))) + + +@torch.jit.script +def mish_jit_bwd(x, grad_output): + x_sigmoid = torch.sigmoid(x) + x_tanh_sp = F.softplus(x).tanh() + return grad_output.mul(x_tanh_sp + x * x_sigmoid * (1 - x_tanh_sp * x_tanh_sp)) + + +class MishJitAutoFn(torch.autograd.Function): + """ Mish: A Self Regularized Non-Monotonic Neural Activation Function - https://arxiv.org/abs/1908.08681 + A memory efficient, jit scripted variant of Mish + """ + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return mish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return mish_jit_bwd(x, grad_output) + + +def mish_me(x, inplace=False): + return MishJitAutoFn.apply(x) + + +class MishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(MishMe, self).__init__() + + def forward(self, x): + return MishJitAutoFn.apply(x) + + +@torch.jit.script +def hard_sigmoid_jit_fwd(x, inplace: bool = False): + return (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_sigmoid_jit_bwd(x, grad_output): + m = torch.ones_like(x) * ((x >= -3.) & (x <= 3.)) / 6. + return grad_output * m + + +class HardSigmoidJitAutoFn(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_sigmoid_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_sigmoid_jit_bwd(x, grad_output) + + +def hard_sigmoid_me(x, inplace: bool = False): + return HardSigmoidJitAutoFn.apply(x) + + +class HardSigmoidMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSigmoidMe, self).__init__() + + def forward(self, x): + return HardSigmoidJitAutoFn.apply(x) + + +@torch.jit.script +def hard_swish_jit_fwd(x): + return x * (x + 3).clamp(min=0, max=6).div(6.) + + +@torch.jit.script +def hard_swish_jit_bwd(x, grad_output): + m = torch.ones_like(x) * (x >= 3.) + m = torch.where((x >= -3.) & (x <= 3.), x / 3. + .5, m) + return grad_output * m + + +class HardSwishJitAutoFn(torch.autograd.Function): + """A memory efficient, jit-scripted HardSwish activation""" + @staticmethod + def forward(ctx, x): + ctx.save_for_backward(x) + return hard_swish_jit_fwd(x) + + @staticmethod + def backward(ctx, grad_output): + x = ctx.saved_tensors[0] + return hard_swish_jit_bwd(x, grad_output) + + +def hard_swish_me(x, inplace=False): + return HardSwishJitAutoFn.apply(x) + + +class HardSwishMe(nn.Module): + def __init__(self, inplace: bool = False): + super(HardSwishMe, self).__init__() + + def forward(self, x): + return HardSwishJitAutoFn.apply(x) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py new file mode 100644 index 0000000000000000000000000000000000000000..27d5307fd9ee0246f1e35f41520f17385d23f1dd --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/config.py @@ -0,0 +1,123 @@ +""" Global layer config state +""" +from typing import Any, Optional + +__all__ = [ + 'is_exportable', 'is_scriptable', 'is_no_jit', 'layer_config_kwargs', + 'set_exportable', 'set_scriptable', 'set_no_jit', 'set_layer_config' +] + +# Set to True if prefer to have layers with no jit optimization (includes activations) +_NO_JIT = False + +# Set to True if prefer to have activation layers with no jit optimization +# NOTE not currently used as no difference between no_jit and no_activation jit as only layers obeying +# the jit flags so far are activations. This will change as more layers are updated and/or added. +_NO_ACTIVATION_JIT = False + +# Set to True if exporting a model with Same padding via ONNX +_EXPORTABLE = False + +# Set to True if wanting to use torch.jit.script on a model +_SCRIPTABLE = False + + +def is_no_jit(): + return _NO_JIT + + +class set_no_jit: + def __init__(self, mode: bool) -> None: + global _NO_JIT + self.prev = _NO_JIT + _NO_JIT = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _NO_JIT + _NO_JIT = self.prev + return False + + +def is_exportable(): + return _EXPORTABLE + + +class set_exportable: + def __init__(self, mode: bool) -> None: + global _EXPORTABLE + self.prev = _EXPORTABLE + _EXPORTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _EXPORTABLE + _EXPORTABLE = self.prev + return False + + +def is_scriptable(): + return _SCRIPTABLE + + +class set_scriptable: + def __init__(self, mode: bool) -> None: + global _SCRIPTABLE + self.prev = _SCRIPTABLE + _SCRIPTABLE = mode + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + _SCRIPTABLE = self.prev + return False + + +class set_layer_config: + """ Layer config context manager that allows setting all layer config flags at once. + If a flag arg is None, it will not change the current value. + """ + def __init__( + self, + scriptable: Optional[bool] = None, + exportable: Optional[bool] = None, + no_jit: Optional[bool] = None, + no_activation_jit: Optional[bool] = None): + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + self.prev = _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT + if scriptable is not None: + _SCRIPTABLE = scriptable + if exportable is not None: + _EXPORTABLE = exportable + if no_jit is not None: + _NO_JIT = no_jit + if no_activation_jit is not None: + _NO_ACTIVATION_JIT = no_activation_jit + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> bool: + global _SCRIPTABLE + global _EXPORTABLE + global _NO_JIT + global _NO_ACTIVATION_JIT + _SCRIPTABLE, _EXPORTABLE, _NO_JIT, _NO_ACTIVATION_JIT = self.prev + return False + + +def layer_config_kwargs(kwargs): + """ Consume config kwargs and return contextmgr obj """ + return set_layer_config( + scriptable=kwargs.pop('scriptable', None), + exportable=kwargs.pop('exportable', None), + no_jit=kwargs.pop('no_jit', None)) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d8467460c4b36e54c83ce2dcd3ebe91d3432cad2 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/conv2d_layers.py @@ -0,0 +1,304 @@ +""" Conv2D w/ SAME padding, CondConv, MixedConv + +A collection of conv layers and padding helpers needed by EfficientNet, MixNet, and +MobileNetV3 models that maintain weight compatibility with original Tensorflow models. + +Copyright 2020 Ross Wightman +""" +import collections.abc +import math +from functools import partial +from itertools import repeat +from typing import Tuple, Optional + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .config import * + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + return parse + + +_single = _ntuple(1) +_pair = _ntuple(2) +_triple = _ntuple(3) +_quadruple = _ntuple(4) + + +def _is_static_pad(kernel_size, stride=1, dilation=1, **_): + return stride == 1 and (dilation * (kernel_size - 1)) % 2 == 0 + + +def _get_padding(kernel_size, stride=1, dilation=1, **_): + padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2 + return padding + + +def _calc_same_pad(i: int, k: int, s: int, d: int): + return max((-(i // -s) - 1) * s + (k - 1) * d + 1 - i, 0) + + +def _same_pad_arg(input_size, kernel_size, stride, dilation): + ih, iw = input_size + kh, kw = kernel_size + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + return [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2] + + +def _split_channels(num_chan, num_groups): + split = [num_chan // num_groups for _ in range(num_groups)] + split[0] += num_chan - sum(split) + return split + + +def conv2d_same( + x, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, stride: Tuple[int, int] = (1, 1), + padding: Tuple[int, int] = (0, 0), dilation: Tuple[int, int] = (1, 1), groups: int = 1): + ih, iw = x.size()[-2:] + kh, kw = weight.size()[-2:] + pad_h = _calc_same_pad(ih, kh, stride[0], dilation[0]) + pad_w = _calc_same_pad(iw, kw, stride[1], dilation[1]) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]) + return F.conv2d(x, weight, bias, stride, (0, 0), dilation, groups) + + +class Conv2dSame(nn.Conv2d): + """ Tensorflow like 'SAME' convolution wrapper for 2D convolutions + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSame, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + + def forward(self, x): + return conv2d_same(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +class Conv2dSameExport(nn.Conv2d): + """ ONNX export friendly Tensorflow like 'SAME' convolution wrapper for 2D convolutions + + NOTE: This does not currently work with torch.jit.script + """ + + # pylint: disable=unused-argument + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2dSameExport, self).__init__( + in_channels, out_channels, kernel_size, stride, 0, dilation, groups, bias) + self.pad = None + self.pad_input_size = (0, 0) + + def forward(self, x): + input_size = x.size()[-2:] + if self.pad is None: + pad_arg = _same_pad_arg(input_size, self.weight.size()[-2:], self.stride, self.dilation) + self.pad = nn.ZeroPad2d(pad_arg) + self.pad_input_size = input_size + + if self.pad is not None: + x = self.pad(x) + return F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + + +def get_padding_value(padding, kernel_size, **kwargs): + dynamic = False + if isinstance(padding, str): + # for any string padding, the padding will be calculated for you, one of three ways + padding = padding.lower() + if padding == 'same': + # TF compatible 'SAME' padding, has a performance and GPU memory allocation impact + if _is_static_pad(kernel_size, **kwargs): + # static case, no extra overhead + padding = _get_padding(kernel_size, **kwargs) + else: + # dynamic padding + padding = 0 + dynamic = True + elif padding == 'valid': + # 'VALID' padding, same as padding=0 + padding = 0 + else: + # Default to PyTorch style 'same'-ish symmetric padding + padding = _get_padding(kernel_size, **kwargs) + return padding, dynamic + + +def create_conv2d_pad(in_chs, out_chs, kernel_size, **kwargs): + padding = kwargs.pop('padding', '') + kwargs.setdefault('bias', False) + padding, is_dynamic = get_padding_value(padding, kernel_size, **kwargs) + if is_dynamic: + if is_exportable(): + assert not is_scriptable() + return Conv2dSameExport(in_chs, out_chs, kernel_size, **kwargs) + else: + return Conv2dSame(in_chs, out_chs, kernel_size, **kwargs) + else: + return nn.Conv2d(in_chs, out_chs, kernel_size, padding=padding, **kwargs) + + +class MixedConv2d(nn.ModuleDict): + """ Mixed Grouped Convolution + Based on MDConv and GroupedConv in MixNet impl: + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mixnet/custom_layers.py + """ + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, depthwise=False, **kwargs): + super(MixedConv2d, self).__init__() + + kernel_size = kernel_size if isinstance(kernel_size, list) else [kernel_size] + num_groups = len(kernel_size) + in_splits = _split_channels(in_channels, num_groups) + out_splits = _split_channels(out_channels, num_groups) + self.in_channels = sum(in_splits) + self.out_channels = sum(out_splits) + for idx, (k, in_ch, out_ch) in enumerate(zip(kernel_size, in_splits, out_splits)): + conv_groups = out_ch if depthwise else 1 + self.add_module( + str(idx), + create_conv2d_pad( + in_ch, out_ch, k, stride=stride, + padding=padding, dilation=dilation, groups=conv_groups, **kwargs) + ) + self.splits = in_splits + + def forward(self, x): + x_split = torch.split(x, self.splits, 1) + x_out = [conv(x_split[i]) for i, conv in enumerate(self.values())] + x = torch.cat(x_out, 1) + return x + + +def get_condconv_initializer(initializer, num_experts, expert_shape): + def condconv_initializer(weight): + """CondConv initializer function.""" + num_params = np.prod(expert_shape) + if (len(weight.shape) != 2 or weight.shape[0] != num_experts or + weight.shape[1] != num_params): + raise (ValueError( + 'CondConv variables must have shape [num_experts, num_params]')) + for i in range(num_experts): + initializer(weight[i].view(expert_shape)) + return condconv_initializer + + +class CondConv2d(nn.Module): + """ Conditional Convolution + Inspired by: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/condconv/condconv_layers.py + + Grouped convolution hackery for parallel execution of the per-sample kernel filters inspired by this discussion: + https://github.com/pytorch/pytorch/issues/17983 + """ + __constants__ = ['bias', 'in_channels', 'out_channels', 'dynamic_padding'] + + def __init__(self, in_channels, out_channels, kernel_size=3, + stride=1, padding='', dilation=1, groups=1, bias=False, num_experts=4): + super(CondConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + padding_val, is_padding_dynamic = get_padding_value( + padding, kernel_size, stride=stride, dilation=dilation) + self.dynamic_padding = is_padding_dynamic # if in forward to work with torchscript + self.padding = _pair(padding_val) + self.dilation = _pair(dilation) + self.groups = groups + self.num_experts = num_experts + + self.weight_shape = (self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight_num_param = 1 + for wd in self.weight_shape: + weight_num_param *= wd + self.weight = torch.nn.Parameter(torch.Tensor(self.num_experts, weight_num_param)) + + if bias: + self.bias_shape = (self.out_channels,) + self.bias = torch.nn.Parameter(torch.Tensor(self.num_experts, self.out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + init_weight = get_condconv_initializer( + partial(nn.init.kaiming_uniform_, a=math.sqrt(5)), self.num_experts, self.weight_shape) + init_weight(self.weight) + if self.bias is not None: + fan_in = np.prod(self.weight_shape[1:]) + bound = 1 / math.sqrt(fan_in) + init_bias = get_condconv_initializer( + partial(nn.init.uniform_, a=-bound, b=bound), self.num_experts, self.bias_shape) + init_bias(self.bias) + + def forward(self, x, routing_weights): + B, C, H, W = x.shape + weight = torch.matmul(routing_weights, self.weight) + new_weight_shape = (B * self.out_channels, self.in_channels // self.groups) + self.kernel_size + weight = weight.view(new_weight_shape) + bias = None + if self.bias is not None: + bias = torch.matmul(routing_weights, self.bias) + bias = bias.view(B * self.out_channels) + # move batch elements with channels so each batch element can be efficiently convolved with separate kernel + x = x.view(1, B * C, H, W) + if self.dynamic_padding: + out = conv2d_same( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + else: + out = F.conv2d( + x, weight, bias, stride=self.stride, padding=self.padding, + dilation=self.dilation, groups=self.groups * B) + out = out.permute([1, 0, 2, 3]).view(B, self.out_channels, out.shape[-2], out.shape[-1]) + + # Literal port (from TF definition) + # x = torch.split(x, 1, 0) + # weight = torch.split(weight, 1, 0) + # if self.bias is not None: + # bias = torch.matmul(routing_weights, self.bias) + # bias = torch.split(bias, 1, 0) + # else: + # bias = [None] * B + # out = [] + # for xi, wi, bi in zip(x, weight, bias): + # wi = wi.view(*self.weight_shape) + # if bi is not None: + # bi = bi.view(*self.bias_shape) + # out.append(self.conv_fn( + # xi, wi, bi, stride=self.stride, padding=self.padding, + # dilation=self.dilation, groups=self.groups)) + # out = torch.cat(out, 0) + return out + + +def select_conv2d(in_chs, out_chs, kernel_size, **kwargs): + assert 'groups' not in kwargs # only use 'depthwise' bool arg + if isinstance(kernel_size, list): + assert 'num_experts' not in kwargs # MixNet + CondConv combo not supported currently + # We're going to use only lists for defining the MixedConv2d kernel groups, + # ints, tuples, other iterables will continue to pass to normal conv and specify h, w. + m = MixedConv2d(in_chs, out_chs, kernel_size, **kwargs) + else: + depthwise = kwargs.pop('depthwise', False) + groups = out_chs if depthwise else 1 + if 'num_experts' in kwargs and kwargs['num_experts'] > 0: + m = CondConv2d(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + else: + m = create_conv2d_pad(in_chs, out_chs, kernel_size, groups=groups, **kwargs) + return m diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..95dd63d400e70d70664c5a433a2772363f865e61 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/efficientnet_builder.py @@ -0,0 +1,683 @@ +""" EfficientNet / MobileNetV3 Blocks and Builder + +Copyright 2020 Ross Wightman +""" +import re +from copy import deepcopy + +from .conv2d_layers import * +from geffnet.activations import * + +__all__ = ['get_bn_args_tf', 'resolve_bn_args', 'resolve_se_args', 'resolve_act_layer', 'make_divisible', + 'round_channels', 'drop_connect', 'SqueezeExcite', 'ConvBnAct', 'DepthwiseSeparableConv', + 'InvertedResidual', 'CondConvResidual', 'EdgeResidual', 'EfficientNetBuilder', 'decode_arch_def', + 'initialize_weight_default', 'initialize_weight_goog', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT' +] + +# Defaults used for Google/Tensorflow training of mobile networks /w RMSprop as per +# papers and TF reference implementations. PT momentum equiv for TF decay is (1 - TF decay) +# NOTE: momentum varies btw .99 and .9997 depending on source +# .99 in official TF TPU impl +# .9997 (/w .999 in search space) for paper +# +# PyTorch defaults are momentum = .1, eps = 1e-5 +# +BN_MOMENTUM_TF_DEFAULT = 1 - 0.99 +BN_EPS_TF_DEFAULT = 1e-3 +_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT) + + +def get_bn_args_tf(): + return _BN_ARGS_TF.copy() + + +def resolve_bn_args(kwargs): + bn_args = get_bn_args_tf() if kwargs.pop('bn_tf', False) else {} + bn_momentum = kwargs.pop('bn_momentum', None) + if bn_momentum is not None: + bn_args['momentum'] = bn_momentum + bn_eps = kwargs.pop('bn_eps', None) + if bn_eps is not None: + bn_args['eps'] = bn_eps + return bn_args + + +_SE_ARGS_DEFAULT = dict( + gate_fn=sigmoid, + act_layer=None, # None == use containing block's activation layer + reduce_mid=False, + divisor=1) + + +def resolve_se_args(kwargs, in_chs, act_layer=None): + se_kwargs = kwargs.copy() if kwargs is not None else {} + # fill in args that aren't specified with the defaults + for k, v in _SE_ARGS_DEFAULT.items(): + se_kwargs.setdefault(k, v) + # some models, like MobilNetV3, calculate SE reduction chs from the containing block's mid_ch instead of in_ch + if not se_kwargs.pop('reduce_mid'): + se_kwargs['reduced_base_chs'] = in_chs + # act_layer override, if it remains None, the containing block's act_layer will be used + if se_kwargs['act_layer'] is None: + assert act_layer is not None + se_kwargs['act_layer'] = act_layer + return se_kwargs + + +def resolve_act_layer(kwargs, default='relu'): + act_layer = kwargs.pop('act_layer', default) + if isinstance(act_layer, str): + act_layer = get_act_layer(act_layer) + return act_layer + + +def make_divisible(v: int, divisor: int = 8, min_value: int = None): + min_value = min_value or divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + if new_v < 0.9 * v: # ensure round down does not go down by more than 10%. + new_v += divisor + return new_v + + +def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None): + """Round number of filters based on depth multiplier.""" + if not multiplier: + return channels + channels *= multiplier + return make_divisible(channels, divisor, channel_min) + + +def drop_connect(inputs, training: bool = False, drop_connect_rate: float = 0.): + """Apply drop connect.""" + if not training: + return inputs + + keep_prob = 1 - drop_connect_rate + random_tensor = keep_prob + torch.rand( + (inputs.size()[0], 1, 1, 1), dtype=inputs.dtype, device=inputs.device) + random_tensor.floor_() # binarize + output = inputs.div(keep_prob) * random_tensor + return output + + +class SqueezeExcite(nn.Module): + + def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None, act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1): + super(SqueezeExcite, self).__init__() + reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor) + self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True) + self.act1 = act_layer(inplace=True) + self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True) + self.gate_fn = gate_fn + + def forward(self, x): + x_se = x.mean((2, 3), keepdim=True) + x_se = self.conv_reduce(x_se) + x_se = self.act1(x_se) + x_se = self.conv_expand(x_se) + x = x * self.gate_fn(x_se) + return x + + +class ConvBnAct(nn.Module): + def __init__(self, in_chs, out_chs, kernel_size, + stride=1, pad_type='', act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, norm_kwargs=None): + super(ConvBnAct, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.conv = select_conv2d(in_chs, out_chs, kernel_size, stride=stride, padding=pad_type) + self.bn1 = norm_layer(out_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn1(x) + x = self.act1(x) + return x + + +class DepthwiseSeparableConv(nn.Module): + """ DepthwiseSeparable block + Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion + factor of 1.0. This is an alternative to having a IR with optional first pw conv. + """ + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + pw_kernel_size=1, pw_act=False, se_ratio=0., se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(DepthwiseSeparableConv, self).__init__() + assert stride in [1, 2] + norm_kwargs = norm_kwargs or {} + self.has_residual = (stride == 1 and in_chs == out_chs) and not noskip + self.drop_connect_rate = drop_connect_rate + + self.conv_dw = select_conv2d( + in_chs, in_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True) + self.bn1 = norm_layer(in_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(in_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + self.conv_pw = select_conv2d(in_chs, out_chs, pw_kernel_size, padding=pad_type) + self.bn2 = norm_layer(out_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) if pw_act else nn.Identity() + + def forward(self, x): + residual = x + + x = self.conv_dw(x) + x = self.bn1(x) + x = self.act1(x) + + x = self.se(x) + + x = self.conv_pw(x) + x = self.bn2(x) + x = self.act2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class InvertedResidual(nn.Module): + """ Inverted residual block w/ optional SE""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + conv_kwargs=None, drop_connect_rate=0.): + super(InvertedResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + conv_kwargs = conv_kwargs or {} + mid_chs: int = make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Point-wise expansion + self.conv_pw = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type, **conv_kwargs) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Depth-wise convolution + self.conv_dw = select_conv2d( + mid_chs, mid_chs, dw_kernel_size, stride=stride, padding=pad_type, depthwise=True, **conv_kwargs) + self.bn2 = norm_layer(mid_chs, **norm_kwargs) + self.act2 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() # for jit.script compat + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, padding=pad_type, **conv_kwargs) + self.bn3 = norm_layer(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Point-wise expansion + x = self.conv_pw(x) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class CondConvResidual(InvertedResidual): + """ Inverted residual block w/ CondConv routing""" + + def __init__(self, in_chs, out_chs, dw_kernel_size=3, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, + exp_ratio=1.0, exp_kernel_size=1, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + num_experts=0, drop_connect_rate=0.): + + self.num_experts = num_experts + conv_kwargs = dict(num_experts=self.num_experts) + + super(CondConvResidual, self).__init__( + in_chs, out_chs, dw_kernel_size=dw_kernel_size, stride=stride, pad_type=pad_type, + act_layer=act_layer, noskip=noskip, exp_ratio=exp_ratio, exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, se_ratio=se_ratio, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, conv_kwargs=conv_kwargs, + drop_connect_rate=drop_connect_rate) + + self.routing_fn = nn.Linear(in_chs, self.num_experts) + + def forward(self, x): + residual = x + + # CondConv routing + pooled_inputs = F.adaptive_avg_pool2d(x, 1).flatten(1) + routing_weights = torch.sigmoid(self.routing_fn(pooled_inputs)) + + # Point-wise expansion + x = self.conv_pw(x, routing_weights) + x = self.bn1(x) + x = self.act1(x) + + # Depth-wise convolution + x = self.conv_dw(x, routing_weights) + x = self.bn2(x) + x = self.act2(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x, routing_weights) + x = self.bn3(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + return x + + +class EdgeResidual(nn.Module): + """ EdgeTPU Residual block with expansion convolution followed by pointwise-linear w/ stride""" + + def __init__(self, in_chs, out_chs, exp_kernel_size=3, exp_ratio=1.0, fake_in_chs=0, + stride=1, pad_type='', act_layer=nn.ReLU, noskip=False, pw_kernel_size=1, + se_ratio=0., se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + super(EdgeResidual, self).__init__() + norm_kwargs = norm_kwargs or {} + mid_chs = make_divisible(fake_in_chs * exp_ratio) if fake_in_chs > 0 else make_divisible(in_chs * exp_ratio) + self.has_residual = (in_chs == out_chs and stride == 1) and not noskip + self.drop_connect_rate = drop_connect_rate + + # Expansion convolution + self.conv_exp = select_conv2d(in_chs, mid_chs, exp_kernel_size, padding=pad_type) + self.bn1 = norm_layer(mid_chs, **norm_kwargs) + self.act1 = act_layer(inplace=True) + + # Squeeze-and-excitation + if se_ratio is not None and se_ratio > 0.: + se_kwargs = resolve_se_args(se_kwargs, in_chs, act_layer) + self.se = SqueezeExcite(mid_chs, se_ratio=se_ratio, **se_kwargs) + else: + self.se = nn.Identity() + + # Point-wise linear projection + self.conv_pwl = select_conv2d(mid_chs, out_chs, pw_kernel_size, stride=stride, padding=pad_type) + self.bn2 = nn.BatchNorm2d(out_chs, **norm_kwargs) + + def forward(self, x): + residual = x + + # Expansion convolution + x = self.conv_exp(x) + x = self.bn1(x) + x = self.act1(x) + + # Squeeze-and-excitation + x = self.se(x) + + # Point-wise linear projection + x = self.conv_pwl(x) + x = self.bn2(x) + + if self.has_residual: + if self.drop_connect_rate > 0.: + x = drop_connect(x, self.training, self.drop_connect_rate) + x += residual + + return x + + +class EfficientNetBuilder: + """ Build Trunk Blocks for Efficient/Mobile Networks + + This ended up being somewhat of a cross between + https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_models.py + and + https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_builder.py + + """ + + def __init__(self, channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=None, se_kwargs=None, + norm_layer=nn.BatchNorm2d, norm_kwargs=None, drop_connect_rate=0.): + self.channel_multiplier = channel_multiplier + self.channel_divisor = channel_divisor + self.channel_min = channel_min + self.pad_type = pad_type + self.act_layer = act_layer + self.se_kwargs = se_kwargs + self.norm_layer = norm_layer + self.norm_kwargs = norm_kwargs + self.drop_connect_rate = drop_connect_rate + + # updated during build + self.in_chs = None + self.block_idx = 0 + self.block_count = 0 + + def _round_channels(self, chs): + return round_channels(chs, self.channel_multiplier, self.channel_divisor, self.channel_min) + + def _make_block(self, ba): + bt = ba.pop('block_type') + ba['in_chs'] = self.in_chs + ba['out_chs'] = self._round_channels(ba['out_chs']) + if 'fake_in_chs' in ba and ba['fake_in_chs']: + # FIXME this is a hack to work around mismatch in origin impl input filters for EdgeTPU + ba['fake_in_chs'] = self._round_channels(ba['fake_in_chs']) + ba['norm_layer'] = self.norm_layer + ba['norm_kwargs'] = self.norm_kwargs + ba['pad_type'] = self.pad_type + # block act fn overrides the model default + ba['act_layer'] = ba['act_layer'] if ba['act_layer'] is not None else self.act_layer + assert ba['act_layer'] is not None + if bt == 'ir': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + if ba.get('num_experts', 0) > 0: + block = CondConvResidual(**ba) + else: + block = InvertedResidual(**ba) + elif bt == 'ds' or bt == 'dsa': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = DepthwiseSeparableConv(**ba) + elif bt == 'er': + ba['drop_connect_rate'] = self.drop_connect_rate * self.block_idx / self.block_count + ba['se_kwargs'] = self.se_kwargs + block = EdgeResidual(**ba) + elif bt == 'cn': + block = ConvBnAct(**ba) + else: + assert False, 'Uknkown block type (%s) while building model.' % bt + self.in_chs = ba['out_chs'] # update in_chs for arg of next block + return block + + def _make_stack(self, stack_args): + blocks = [] + # each stack (stage) contains a list of block arguments + for i, ba in enumerate(stack_args): + if i >= 1: + # only the first block in any stack can have a stride > 1 + ba['stride'] = 1 + block = self._make_block(ba) + blocks.append(block) + self.block_idx += 1 # incr global idx (across all stacks) + return nn.Sequential(*blocks) + + def __call__(self, in_chs, block_args): + """ Build the blocks + Args: + in_chs: Number of input-channels passed to first block + block_args: A list of lists, outer list defines stages, inner + list contains strings defining block configuration(s) + Return: + List of block stacks (each stack wrapped in nn.Sequential) + """ + self.in_chs = in_chs + self.block_count = sum([len(x) for x in block_args]) + self.block_idx = 0 + blocks = [] + # outer list of block_args defines the stacks ('stages' by some conventions) + for stack_idx, stack in enumerate(block_args): + assert isinstance(stack, list) + stack = self._make_stack(stack) + blocks.append(stack) + return blocks + + +def _parse_ksize(ss): + if ss.isdigit(): + return int(ss) + else: + return [int(k) for k in ss.split('.')] + + +def _decode_block_str(block_str): + """ Decode block definition string + + Gets a list of block arg (dicts) through a string notation of arguments. + E.g. ir_r2_k3_s2_e1_i32_o16_se0.25_noskip + + All args can exist in any order with the exception of the leading string which + is assumed to indicate the block type. + + leading string - block type ( + ir = InvertedResidual, ds = DepthwiseSep, dsa = DeptwhiseSep with pw act, cn = ConvBnAct) + r - number of repeat blocks, + k - kernel size, + s - strides (1-9), + e - expansion ratio, + c - output channels, + se - squeeze/excitation ratio + n - activation fn ('re', 'r6', 'hs', or 'sw') + Args: + block_str: a string representation of block arguments. + Returns: + A list of block args (dicts) + Raises: + ValueError: if the string def not properly specified (TODO) + """ + assert isinstance(block_str, str) + ops = block_str.split('_') + block_type = ops[0] # take the block type off the front + ops = ops[1:] + options = {} + noskip = False + for op in ops: + # string options being checked on individual basis, combine if they grow + if op == 'noskip': + noskip = True + elif op.startswith('n'): + # activation fn + key = op[0] + v = op[1:] + if v == 're': + value = get_act_layer('relu') + elif v == 'r6': + value = get_act_layer('relu6') + elif v == 'hs': + value = get_act_layer('hard_swish') + elif v == 'sw': + value = get_act_layer('swish') + else: + continue + options[key] = value + else: + # all numeric options + splits = re.split(r'(\d.*)', op) + if len(splits) >= 2: + key, value = splits[:2] + options[key] = value + + # if act_layer is None, the model default (passed to model init) will be used + act_layer = options['n'] if 'n' in options else None + exp_kernel_size = _parse_ksize(options['a']) if 'a' in options else 1 + pw_kernel_size = _parse_ksize(options['p']) if 'p' in options else 1 + fake_in_chs = int(options['fc']) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def + + num_repeat = int(options['r']) + # each type of block has different valid arguments, fill accordingly + if block_type == 'ir': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + exp_kernel_size=exp_kernel_size, + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + if 'cc' in options: + block_args['num_experts'] = int(options['cc']) + elif block_type == 'ds' or block_type == 'dsa': + block_args = dict( + block_type=block_type, + dw_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + pw_act=block_type == 'dsa', + noskip=block_type == 'dsa' or noskip, + ) + elif block_type == 'er': + block_args = dict( + block_type=block_type, + exp_kernel_size=_parse_ksize(options['k']), + pw_kernel_size=pw_kernel_size, + out_chs=int(options['c']), + exp_ratio=float(options['e']), + fake_in_chs=fake_in_chs, + se_ratio=float(options['se']) if 'se' in options else None, + stride=int(options['s']), + act_layer=act_layer, + noskip=noskip, + ) + elif block_type == 'cn': + block_args = dict( + block_type=block_type, + kernel_size=int(options['k']), + out_chs=int(options['c']), + stride=int(options['s']), + act_layer=act_layer, + ) + else: + assert False, 'Unknown block type (%s)' % block_type + + return block_args, num_repeat + + +def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'): + """ Per-stage depth scaling + Scales the block repeats in each stage. This depth scaling impl maintains + compatibility with the EfficientNet scaling method, while allowing sensible + scaling for other models that may have multiple block arg definitions in each stage. + """ + + # We scale the total repeat count for each stage, there may be multiple + # block arg defs per stage so we need to sum. + num_repeat = sum(repeats) + if depth_trunc == 'round': + # Truncating to int by rounding allows stages with few repeats to remain + # proportionally smaller for longer. This is a good choice when stage definitions + # include single repeat stages that we'd prefer to keep that way as long as possible + num_repeat_scaled = max(1, round(num_repeat * depth_multiplier)) + else: + # The default for EfficientNet truncates repeats to int via 'ceil'. + # Any multiplier > 1.0 will result in an increased depth for every stage. + num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier)) + + # Proportionally distribute repeat count scaling to each block definition in the stage. + # Allocation is done in reverse as it results in the first block being less likely to be scaled. + # The first block makes less sense to repeat in most of the arch definitions. + repeats_scaled = [] + for r in repeats[::-1]: + rs = max(1, round((r / num_repeat * num_repeat_scaled))) + repeats_scaled.append(rs) + num_repeat -= r + num_repeat_scaled -= rs + repeats_scaled = repeats_scaled[::-1] + + # Apply the calculated scaling to each block arg in the stage + sa_scaled = [] + for ba, rep in zip(stack_args, repeats_scaled): + sa_scaled.extend([deepcopy(ba) for _ in range(rep)]) + return sa_scaled + + +def decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil', experts_multiplier=1, fix_first_last=False): + arch_args = [] + for stack_idx, block_strings in enumerate(arch_def): + assert isinstance(block_strings, list) + stack_args = [] + repeats = [] + for block_str in block_strings: + assert isinstance(block_str, str) + ba, rep = _decode_block_str(block_str) + if ba.get('num_experts', 0) > 0 and experts_multiplier > 1: + ba['num_experts'] *= experts_multiplier + stack_args.append(ba) + repeats.append(rep) + if fix_first_last and (stack_idx == 0 or stack_idx == len(arch_def) - 1): + arch_args.append(_scale_stage_depth(stack_args, repeats, 1.0, depth_trunc)) + else: + arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc)) + return arch_args + + +def initialize_weight_goog(m, n='', fix_group_fanout=True): + # weight init as per Tensorflow Official impl + # https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py + if isinstance(m, CondConv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + init_weight_fn = get_condconv_initializer( + lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape) + init_weight_fn(m.weight) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.Conv2d): + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + if fix_group_fanout: + fan_out //= m.groups + m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + fan_out = m.weight.size(0) # fan-out + fan_in = 0 + if 'routing_fn' in n: + fan_in = m.weight.size(1) + init_range = 1.0 / math.sqrt(fan_in + fan_out) + m.weight.data.uniform_(-init_range, init_range) + m.bias.data.zero_() + + +def initialize_weight_default(m, n=''): + if isinstance(m, CondConv2d): + init_fn = get_condconv_initializer(partial( + nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape) + init_fn(m.weight) + elif isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear') diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py new file mode 100644 index 0000000000000000000000000000000000000000..cd170d4cc5bed6ca82b61539902b470d3320c691 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/gen_efficientnet.py @@ -0,0 +1,1450 @@ +""" Generic Efficient Networks + +A generic MobileNet class with building blocks to support a variety of models: + +* EfficientNet (B0-B8, L2 + Tensorflow pretrained AutoAug/RandAug/AdvProp/NoisyStudent ports) + - EfficientNet: Rethinking Model Scaling for CNNs - https://arxiv.org/abs/1905.11946 + - CondConv: Conditionally Parameterized Convolutions for Efficient Inference - https://arxiv.org/abs/1904.04971 + - Adversarial Examples Improve Image Recognition - https://arxiv.org/abs/1911.09665 + - Self-training with Noisy Student improves ImageNet classification - https://arxiv.org/abs/1911.04252 + +* EfficientNet-Lite + +* MixNet (Small, Medium, and Large) + - MixConv: Mixed Depthwise Convolutional Kernels - https://arxiv.org/abs/1907.09595 + +* MNasNet B1, A1 (SE), Small + - MnasNet: Platform-Aware Neural Architecture Search for Mobile - https://arxiv.org/abs/1807.11626 + +* FBNet-C + - FBNet: Hardware-Aware Efficient ConvNet Design via Differentiable NAS - https://arxiv.org/abs/1812.03443 + +* Single-Path NAS Pixel1 + - Single-Path NAS: Designing Hardware-Efficient ConvNets - https://arxiv.org/abs/1904.02877 + +* And likely more... + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .config import layer_config_kwargs, is_scriptable +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['GenEfficientNet', 'mnasnet_050', 'mnasnet_075', 'mnasnet_100', 'mnasnet_b1', 'mnasnet_140', + 'semnasnet_050', 'semnasnet_075', 'semnasnet_100', 'mnasnet_a1', 'semnasnet_140', 'mnasnet_small', + 'mobilenetv2_100', 'mobilenetv2_140', 'mobilenetv2_110d', 'mobilenetv2_120d', + 'fbnetc_100', 'spnasnet_100', 'efficientnet_b0', 'efficientnet_b1', 'efficientnet_b2', 'efficientnet_b3', + 'efficientnet_b4', 'efficientnet_b5', 'efficientnet_b6', 'efficientnet_b7', 'efficientnet_b8', + 'efficientnet_l2', 'efficientnet_es', 'efficientnet_em', 'efficientnet_el', + 'efficientnet_cc_b0_4e', 'efficientnet_cc_b0_8e', 'efficientnet_cc_b1_8e', + 'efficientnet_lite0', 'efficientnet_lite1', 'efficientnet_lite2', 'efficientnet_lite3', 'efficientnet_lite4', + 'tf_efficientnet_b0', 'tf_efficientnet_b1', 'tf_efficientnet_b2', 'tf_efficientnet_b3', + 'tf_efficientnet_b4', 'tf_efficientnet_b5', 'tf_efficientnet_b6', 'tf_efficientnet_b7', 'tf_efficientnet_b8', + 'tf_efficientnet_b0_ap', 'tf_efficientnet_b1_ap', 'tf_efficientnet_b2_ap', 'tf_efficientnet_b3_ap', + 'tf_efficientnet_b4_ap', 'tf_efficientnet_b5_ap', 'tf_efficientnet_b6_ap', 'tf_efficientnet_b7_ap', + 'tf_efficientnet_b8_ap', 'tf_efficientnet_b0_ns', 'tf_efficientnet_b1_ns', 'tf_efficientnet_b2_ns', + 'tf_efficientnet_b3_ns', 'tf_efficientnet_b4_ns', 'tf_efficientnet_b5_ns', 'tf_efficientnet_b6_ns', + 'tf_efficientnet_b7_ns', 'tf_efficientnet_l2_ns', 'tf_efficientnet_l2_ns_475', + 'tf_efficientnet_es', 'tf_efficientnet_em', 'tf_efficientnet_el', + 'tf_efficientnet_cc_b0_4e', 'tf_efficientnet_cc_b0_8e', 'tf_efficientnet_cc_b1_8e', + 'tf_efficientnet_lite0', 'tf_efficientnet_lite1', 'tf_efficientnet_lite2', 'tf_efficientnet_lite3', + 'tf_efficientnet_lite4', + 'mixnet_s', 'mixnet_m', 'mixnet_l', 'mixnet_xl', 'tf_mixnet_s', 'tf_mixnet_m', 'tf_mixnet_l'] + + +model_urls = { + 'mnasnet_050': None, + 'mnasnet_075': None, + 'mnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_b1-74cb7081.pth', + 'mnasnet_140': None, + 'mnasnet_small': None, + + 'semnasnet_050': None, + 'semnasnet_075': None, + 'semnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mnasnet_a1-d9418771.pth', + 'semnasnet_140': None, + + 'mobilenetv2_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_100_ra-b33bc2c4.pth', + 'mobilenetv2_110d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_110d_ra-77090ade.pth', + 'mobilenetv2_120d': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_120d_ra-5987e2ed.pth', + 'mobilenetv2_140': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv2_140_ra-21a4e913.pth', + + 'fbnetc_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/fbnetc_100-c345b898.pth', + 'spnasnet_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/spnasnet_100-048bc3f4.pth', + + 'efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth', + 'efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth', + 'efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth', + 'efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth', + 'efficientnet_b4': None, + 'efficientnet_b5': None, + 'efficientnet_b6': None, + 'efficientnet_b7': None, + 'efficientnet_b8': None, + 'efficientnet_l2': None, + + 'efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_es_ra-f111e99c.pth', + 'efficientnet_em': None, + 'efficientnet_el': None, + + 'efficientnet_cc_b0_4e': None, + 'efficientnet_cc_b0_8e': None, + 'efficientnet_cc_b1_8e': None, + + 'efficientnet_lite0': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth', + 'efficientnet_lite1': None, + 'efficientnet_lite2': None, + 'efficientnet_lite3': None, + 'efficientnet_lite4': None, + + 'tf_efficientnet_b0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth', + 'tf_efficientnet_b1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_aa-ea7a6ee0.pth', + 'tf_efficientnet_b2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_aa-60c94f97.pth', + 'tf_efficientnet_b3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_aa-84b4657e.pth', + 'tf_efficientnet_b4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_aa-818f208c.pth', + 'tf_efficientnet_b5': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ra-9a3e5369.pth', + 'tf_efficientnet_b6': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_aa-80ba17e4.pth', + 'tf_efficientnet_b7': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ra-6c08e654.pth', + 'tf_efficientnet_b8': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ra-572d5dd9.pth', + + 'tf_efficientnet_b0_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ap-f262efe1.pth', + 'tf_efficientnet_b1_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ap-44ef0a3d.pth', + 'tf_efficientnet_b2_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ap-2f8e7636.pth', + 'tf_efficientnet_b3_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ap-aad25bdd.pth', + 'tf_efficientnet_b4_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ap-dedb23e6.pth', + 'tf_efficientnet_b5_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ap-9e82fae8.pth', + 'tf_efficientnet_b6_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ap-4ffb161f.pth', + 'tf_efficientnet_b7_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ap-ddb28fec.pth', + 'tf_efficientnet_b8_ap': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b8_ap-00e169fa.pth', + + 'tf_efficientnet_b0_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_ns-c0e6a31c.pth', + 'tf_efficientnet_b1_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b1_ns-99dd0c41.pth', + 'tf_efficientnet_b2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b2_ns-00306e48.pth', + 'tf_efficientnet_b3_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b3_ns-9d44bf68.pth', + 'tf_efficientnet_b4_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b4_ns-d6313a46.pth', + 'tf_efficientnet_b5_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b5_ns-6f26d0cf.pth', + 'tf_efficientnet_b6_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth', + 'tf_efficientnet_b7_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_ns-1dbc32de.pth', + 'tf_efficientnet_l2_ns_475': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns_475-bebbd00a.pth', + 'tf_efficientnet_l2_ns': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_l2_ns-df73bb44.pth', + + 'tf_efficientnet_es': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth', + 'tf_efficientnet_em': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth', + 'tf_efficientnet_el': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth', + + 'tf_efficientnet_cc_b0_4e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_4e-4362b6b2.pth', + 'tf_efficientnet_cc_b0_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b0_8e-66184a25.pth', + 'tf_efficientnet_cc_b1_8e': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_cc_b1_8e-f7c79ae1.pth', + + 'tf_efficientnet_lite0': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite0-0aa007d2.pth', + 'tf_efficientnet_lite1': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite1-bde8b488.pth', + 'tf_efficientnet_lite2': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite2-dcccb7df.pth', + 'tf_efficientnet_lite3': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite3-b733e338.pth', + 'tf_efficientnet_lite4': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_lite4-741542c3.pth', + + 'mixnet_s': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth', + 'mixnet_m': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth', + 'mixnet_l': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth', + 'mixnet_xl': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_xl_ra-aac3c00c.pth', + + 'tf_mixnet_s': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth', + 'tf_mixnet_m': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_m-0f4d8805.pth', + 'tf_mixnet_l': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_l-6c92e0c8.pth', +} + + +class GenEfficientNet(nn.Module): + """ Generic EfficientNets + + An implementation of mobile optimized networks that covers: + * EfficientNet (B0-B8, L2, CondConv, EdgeTPU) + * MixNet (Small, Medium, and Large, XL) + * MNASNet A1, B1, and small + * FBNet C + * Single-Path NAS Pixel1 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, num_features=1280, stem_size=32, fix_stem=False, + channel_multiplier=1.0, channel_divisor=8, channel_min=None, + pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, + weight_init='goog'): + super(GenEfficientNet, self).__init__() + self.drop_rate = drop_rate + + if not fix_stem: + stem_size = round_channels(stem_size, channel_multiplier, channel_divisor, channel_min) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = norm_layer(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, channel_divisor, channel_min, + pad_type, act_layer, se_kwargs, norm_layer, norm_kwargs, drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type) + self.bn2 = norm_layer(num_features, **norm_kwargs) + self.act2 = act_layer(inplace=True) + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Linear(num_features, num_classes) + + for n, m in self.named_modules(): + if weight_init == 'goog': + initialize_weight_goog(m, n) + else: + initialize_weight_default(m, n) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.conv_head(x) + x = self.bn2(x) + x = self.act2(x) + return x + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.conv_head, self.bn2, self.act2, + self.global_pool, nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def forward(self, x): + x = self.features(x) + x = self.global_pool(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = GenEfficientNet(**model_kwargs) + if pretrained: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mnasnet_a1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-a1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r2_k3_s2_e6_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25'], + # stage 3, 28x28 in + ['ir_r4_k3_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_b1(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r3_k5_s2_e6_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a mnasnet-b1 model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet + Paper: https://arxiv.org/pdf/1807.11626.pdf. + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + ['ds_r1_k3_s1_c8'], + ['ir_r1_k3_s2_e3_c16'], + ['ir_r2_k3_s2_e6_c16'], + ['ir_r4_k5_s2_e6_c32_se0.25'], + ['ir_r3_k3_s1_e6_c32_se0.25'], + ['ir_r3_k5_s2_e6_c88_se0.25'], + ['ir_r1_k3_s1_e6_c144'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=8, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v2( + variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs): + """ Generate MobileNet-V2 network + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py + Paper: https://arxiv.org/abs/1801.04381 + """ + arch_def = [ + ['ds_r1_k3_s1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r3_k3_s2_e6_c32'], + ['ir_r4_k3_s2_e6_c64'], + ['ir_r3_k3_s1_e6_c96'], + ['ir_r3_k3_s2_e6_c160'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head), + num_features=1280 if fix_stem_head else round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + fix_stem=fix_stem_head, + channel_multiplier=channel_multiplier, + norm_kwargs=resolve_bn_args(kwargs), + act_layer=nn.ReLU6, + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_fbnetc(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """ FBNet-C + + Paper: https://arxiv.org/abs/1812.03443 + Ref Impl: https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/maskrcnn_benchmark/modeling/backbone/fbnet_modeldef.py + + NOTE: the impl above does not relate to the 'C' variant here, that was derived from paper, + it was used to confirm some building block details + """ + arch_def = [ + ['ir_r1_k3_s1_e1_c16'], + ['ir_r1_k3_s2_e6_c24', 'ir_r2_k3_s1_e1_c24'], + ['ir_r1_k5_s2_e6_c32', 'ir_r1_k5_s1_e3_c32', 'ir_r1_k5_s1_e6_c32', 'ir_r1_k3_s1_e6_c32'], + ['ir_r1_k5_s2_e6_c64', 'ir_r1_k5_s1_e3_c64', 'ir_r2_k5_s1_e6_c64'], + ['ir_r3_k5_s1_e6_c112', 'ir_r1_k5_s1_e3_c112'], + ['ir_r4_k5_s2_e6_c184'], + ['ir_r1_k3_s1_e6_c352'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=16, + num_features=1984, # paper suggests this, but is not 100% clear + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates the Single-Path NAS model from search targeted for Pixel1 phone. + + Paper: https://arxiv.org/abs/1904.02877 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_c16_noskip'], + # stage 1, 112x112 in + ['ir_r3_k3_s2_e3_c24'], + # stage 2, 56x56 in + ['ir_r1_k5_s2_e6_c40', 'ir_r3_k3_s1_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k5_s2_e6_c80', 'ir_r3_k3_s1_e3_c80'], + # stage 4, 14x14in + ['ir_r1_k5_s1_e6_c96', 'ir_r3_k5_s1_e3_c96'], + # stage 5, 14x14in + ['ir_r4_k5_s2_e6_c192'], + # stage 6, 7x7 in + ['ir_r1_k3_s1_e6_c320_noskip'] + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet model. + + Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), + 'efficientnet-b8': (2.2, 3.6, 672, 0.5), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25'], + ['ir_r4_k5_s2_e6_c192_se0.25'], + ['ir_r1_k3_s1_e6_c320_se0.25'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_edge(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + arch_def = [ + # NOTE `fc` is present to override a mismatch between stem channels and in chs not + # present in other models + ['er_r1_k3_s1_e4_c24_fc24_noskip'], + ['er_r2_k3_s2_e8_c32'], + ['er_r4_k3_s2_e8_c48'], + ['ir_r5_k5_s2_e8_c96'], + ['ir_r4_k5_s1_e8_c144'], + ['ir_r2_k5_s2_e8_c192'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_condconv( + variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs): + """Creates an efficientnet-condconv model.""" + arch_def = [ + ['ds_r1_k3_s1_e1_c16_se0.25'], + ['ir_r2_k3_s2_e6_c24_se0.25'], + ['ir_r2_k5_s2_e6_c40_se0.25'], + ['ir_r3_k3_s2_e6_c80_se0.25'], + ['ir_r3_k5_s1_e6_c112_se0.25_cc4'], + ['ir_r4_k5_s2_e6_c192_se0.25_cc4'], + ['ir_r1_k3_s1_e6_c320_se0.25_cc4'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, experts_multiplier=experts_multiplier), + num_features=round_channels(1280, channel_multiplier, 8, None), + stem_size=32, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'swish'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates an EfficientNet-Lite model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/lite + Paper: https://arxiv.org/abs/1905.11946 + + EfficientNet params + name: (channel_multiplier, depth_multiplier, resolution, dropout_rate) + 'efficientnet-lite0': (1.0, 1.0, 224, 0.2), + 'efficientnet-lite1': (1.0, 1.1, 240, 0.2), + 'efficientnet-lite2': (1.1, 1.2, 260, 0.3), + 'efficientnet-lite3': (1.2, 1.4, 280, 0.3), + 'efficientnet-lite4': (1.4, 1.8, 300, 0.3), + + Args: + channel_multiplier: multiplier to number of channels per layer + depth_multiplier: multiplier to number of repeats per stage + """ + arch_def = [ + ['ds_r1_k3_s1_e1_c16'], + ['ir_r2_k3_s2_e6_c24'], + ['ir_r2_k5_s2_e6_c40'], + ['ir_r3_k3_s2_e6_c80'], + ['ir_r3_k5_s1_e6_c112'], + ['ir_r4_k5_s2_e6_c192'], + ['ir_r1_k3_s1_e6_c320'], + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, fix_first_last=True), + num_features=1280, + stem_size=32, + fix_stem=True, + channel_multiplier=channel_multiplier, + act_layer=nn.ReLU6, + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_s(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Small model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_a1.1_p1.1_s2_e6_c24', 'ir_r1_k3_a1.1_p1.1_s1_e3_c24'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_p1.1_s2_e6_c80_se0.25_nsw', 'ir_r2_k3.5_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3.5.7_a1.1_p1.1_s1_e6_c120_se0.5_nsw', 'ir_r2_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9.11_s2_e6_c200_se0.5_nsw', 'ir_r2_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=1536, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MixNet Medium-Large model. + + Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet + Paper: https://arxiv.org/abs/1907.09595 + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c24'], # relu + # stage 1, 112x112 in + ['ir_r1_k3.5.7_a1.1_p1.1_s2_e6_c32', 'ir_r1_k3_a1.1_p1.1_s1_e3_c32'], # relu + # stage 2, 56x56 in + ['ir_r1_k3.5.7.9_s2_e6_c40_se0.5_nsw', 'ir_r3_k3.5_a1.1_p1.1_s1_e6_c40_se0.5_nsw'], # swish + # stage 3, 28x28 in + ['ir_r1_k3.5.7_s2_e6_c80_se0.25_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e6_c80_se0.25_nsw'], # swish + # stage 4, 14x14in + ['ir_r1_k3_s1_e6_c120_se0.5_nsw', 'ir_r3_k3.5.7.9_a1.1_p1.1_s1_e3_c120_se0.5_nsw'], # swish + # stage 5, 14x14in + ['ir_r1_k3.5.7.9_s2_e6_c200_se0.5_nsw', 'ir_r3_k3.5.7.9_p1.1_s1_e6_c200_se0.5_nsw'], # swish + # 7x7 + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def, depth_multiplier, depth_trunc='round'), + num_features=1536, + stem_size=24, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'relu'), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mnasnet_050(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.5. """ + model = _gen_mnasnet_b1('mnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_075(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 0.75. """ + model = _gen_mnasnet_b1('mnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_100(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + model = _gen_mnasnet_b1('mnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_b1(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.0. """ + return mnasnet_100(pretrained, **kwargs) + + +def mnasnet_140(pretrained=False, **kwargs): + """ MNASNet B1, depth multiplier of 1.4 """ + model = _gen_mnasnet_b1('mnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_050(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.5 """ + model = _gen_mnasnet_a1('semnasnet_050', 0.5, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_075(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 0.75. """ + model = _gen_mnasnet_a1('semnasnet_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def semnasnet_100(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + model = _gen_mnasnet_a1('semnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_a1(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.0. """ + return semnasnet_100(pretrained, **kwargs) + + +def semnasnet_140(pretrained=False, **kwargs): + """ MNASNet A1 (w/ SE), depth multiplier of 1.4. """ + model = _gen_mnasnet_a1('semnasnet_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mnasnet_small(pretrained=False, **kwargs): + """ MNASNet Small, depth multiplier of 1.0. """ + model = _gen_mnasnet_small('mnasnet_small', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_100(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.0 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_140(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.4 channel multiplier """ + model = _gen_mobilenet_v2('mobilenetv2_140', 1.4, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_110d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.1 channel, 1.2 depth multipliers""" + model = _gen_mobilenet_v2( + 'mobilenetv2_110d', 1.1, depth_multiplier=1.2, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv2_120d(pretrained=False, **kwargs): + """ MobileNet V2 w/ 1.2 channel, 1.4 depth multipliers """ + model = _gen_mobilenet_v2( + 'mobilenetv2_120d', 1.2, depth_multiplier=1.4, fix_stem_head=True, pretrained=pretrained, **kwargs) + return model + + +def fbnetc_100(pretrained=False, **kwargs): + """ FBNet-C """ + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_fbnetc('fbnetc_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def spnasnet_100(pretrained=False, **kwargs): + """ Single-Path NAS Pixel1""" + model = _gen_spnasnet('spnasnet_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 """ + # NOTE for train set drop_rate=0.2, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 """ + # NOTE for train set drop_rate=0.4, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 """ + # NOTE for train set drop_rate=0.5, drop_connect_rate=0.2 + model = _gen_efficientnet( + 'efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_l2(pretrained=False, **kwargs): + """ EfficientNet-L2. """ + # NOTE for train, drop_rate should be 0.5 + model = _gen_efficientnet( + 'efficientnet_l2', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. """ + model = _gen_efficientnet_edge( + 'efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. """ + model = _gen_efficientnet_edge( + 'efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. """ + model = _gen_efficientnet_edge( + 'efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_efficientnet_condconv( + 'efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4 """ + model = _gen_efficientnet_lite( + 'efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0(pretrained=False, **kwargs): + """ EfficientNet-B0 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1(pretrained=False, **kwargs): + """ EfficientNet-B1 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2(pretrained=False, **kwargs): + """ EfficientNet-B2 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3(pretrained=False, **kwargs): + """ EfficientNet-B3 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4(pretrained=False, **kwargs): + """ EfficientNet-B4 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5(pretrained=False, **kwargs): + """ EfficientNet-B5 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6(pretrained=False, **kwargs): + """ EfficientNet-B6 AutoAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7(pretrained=False, **kwargs): + """ EfficientNet-B7 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8(pretrained=False, **kwargs): + """ EfficientNet-B8 RandAug. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ap(pretrained=False, **kwargs): + """ EfficientNet-B0 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ap', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ap(pretrained=False, **kwargs): + """ EfficientNet-B1 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ap', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ap(pretrained=False, **kwargs): + """ EfficientNet-B2 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ap', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ap(pretrained=False, **kwargs): + """ EfficientNet-B3 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ap', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ap(pretrained=False, **kwargs): + """ EfficientNet-B4 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ap', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ap(pretrained=False, **kwargs): + """ EfficientNet-B5 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ap', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ap(pretrained=False, **kwargs): + """ EfficientNet-B6 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ap', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ap(pretrained=False, **kwargs): + """ EfficientNet-B7 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ap', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b8_ap(pretrained=False, **kwargs): + """ EfficientNet-B8 AdvProp. Tensorflow compatible variant + Paper: Adversarial Examples Improve Image Recognition (https://arxiv.org/abs/1911.09665) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b8_ap', channel_multiplier=2.2, depth_multiplier=3.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b0_ns(pretrained=False, **kwargs): + """ EfficientNet-B0 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b0_ns', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b1_ns(pretrained=False, **kwargs): + """ EfficientNet-B1 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b1_ns', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b2_ns(pretrained=False, **kwargs): + """ EfficientNet-B2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b2_ns', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b3_ns(pretrained=False, **kwargs): + """ EfficientNet-B3 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b3_ns', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b4_ns(pretrained=False, **kwargs): + """ EfficientNet-B4 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b4_ns', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b5_ns(pretrained=False, **kwargs): + """ EfficientNet-B5 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b5_ns', channel_multiplier=1.6, depth_multiplier=2.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b6_ns(pretrained=False, **kwargs): + """ EfficientNet-B6 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b6_ns', channel_multiplier=1.8, depth_multiplier=2.6, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_b7_ns(pretrained=False, **kwargs): + """ EfficientNet-B7 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_b7_ns', channel_multiplier=2.0, depth_multiplier=3.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns_475(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent @ 475x475. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns_475', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_l2_ns(pretrained=False, **kwargs): + """ EfficientNet-L2 NoisyStudent. Tensorflow compatible variant + Paper: Self-training with Noisy Student improves ImageNet classification (https://arxiv.org/abs/1911.04252) + """ + # NOTE for train, drop_rate should be 0.5 + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet( + 'tf_efficientnet_l2_ns', channel_multiplier=4.3, depth_multiplier=5.3, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_es(pretrained=False, **kwargs): + """ EfficientNet-Edge Small. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_es', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_em(pretrained=False, **kwargs): + """ EfficientNet-Edge-Medium. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_em', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_el(pretrained=False, **kwargs): + """ EfficientNet-Edge-Large. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_edge( + 'tf_efficientnet_el', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_4e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 4 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_4e', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b0_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B0 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b0_8e', channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_cc_b1_8e(pretrained=False, **kwargs): + """ EfficientNet-CondConv-B1 w/ 8 Experts """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_condconv( + 'tf_efficientnet_cc_b1_8e', channel_multiplier=1.0, depth_multiplier=1.1, experts_multiplier=2, + pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite0(pretrained=False, **kwargs): + """ EfficientNet-Lite0. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite0', channel_multiplier=1.0, depth_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite1(pretrained=False, **kwargs): + """ EfficientNet-Lite1. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite1', channel_multiplier=1.0, depth_multiplier=1.1, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite2(pretrained=False, **kwargs): + """ EfficientNet-Lite2. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite2', channel_multiplier=1.1, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite3(pretrained=False, **kwargs): + """ EfficientNet-Lite3. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite3', channel_multiplier=1.2, depth_multiplier=1.4, pretrained=pretrained, **kwargs) + return model + + +def tf_efficientnet_lite4(pretrained=False, **kwargs): + """ EfficientNet-Lite4. Tensorflow compatible variant """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_efficientnet_lite( + 'tf_efficientnet_lite4', channel_multiplier=1.4, depth_multiplier=1.8, pretrained=pretrained, **kwargs) + return model + + +def mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. + """ + # NOTE for train set drop_rate=0.2 + model = _gen_mixnet_s( + 'mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. + """ + # NOTE for train set drop_rate=0.25 + model = _gen_mixnet_m( + 'mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xl(pretrained=False, **kwargs): + """Creates a MixNet Extra-Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.25, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xl', channel_multiplier=1.6, depth_multiplier=1.2, pretrained=pretrained, **kwargs) + return model + + +def mixnet_xxl(pretrained=False, **kwargs): + """Creates a MixNet Double Extra Large model. + Not a paper spec, experimental def by RW w/ depth scaling. + """ + # NOTE for train set drop_rate=0.3, drop_connect_rate=0.2 + model = _gen_mixnet_m( + 'mixnet_xxl', channel_multiplier=2.4, depth_multiplier=1.3, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_s(pretrained=False, **kwargs): + """Creates a MixNet Small model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_s( + 'tf_mixnet_s', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_m(pretrained=False, **kwargs): + """Creates a MixNet Medium model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_m', channel_multiplier=1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mixnet_l(pretrained=False, **kwargs): + """Creates a MixNet Large model. Tensorflow compatible variant + """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mixnet_m( + 'tf_mixnet_l', channel_multiplier=1.3, pretrained=pretrained, **kwargs) + return model diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..3f83a07d690c7ad681c777c19b1e7a5bb95da007 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/helpers.py @@ -0,0 +1,71 @@ +""" Checkpoint loading / state_dict helpers +Copyright 2020 Ross Wightman +""" +import torch +import os +from collections import OrderedDict +try: + from torch.hub import load_state_dict_from_url +except ImportError: + from torch.utils.model_zoo import load_url as load_state_dict_from_url + + +def load_checkpoint(model, checkpoint_path): + if checkpoint_path and os.path.isfile(checkpoint_path): + print("=> Loading checkpoint '{}'".format(checkpoint_path)) + checkpoint = torch.load(checkpoint_path) + if isinstance(checkpoint, dict) and 'state_dict' in checkpoint: + new_state_dict = OrderedDict() + for k, v in checkpoint['state_dict'].items(): + if k.startswith('module'): + name = k[7:] # remove `module.` + else: + name = k + new_state_dict[name] = v + model.load_state_dict(new_state_dict) + else: + model.load_state_dict(checkpoint) + print("=> Loaded checkpoint '{}'".format(checkpoint_path)) + else: + print("=> Error: No checkpoint found at '{}'".format(checkpoint_path)) + raise FileNotFoundError() + + +def load_pretrained(model, url, filter_fn=None, strict=True): + if not url: + print("=> Warning: Pretrained model URL is empty, using random initialization.") + return + + state_dict = load_state_dict_from_url(url, progress=False, map_location='cpu') + + input_conv = 'conv_stem' + classifier = 'classifier' + in_chans = getattr(model, input_conv).weight.shape[1] + num_classes = getattr(model, classifier).weight.shape[0] + + input_conv_weight = input_conv + '.weight' + pretrained_in_chans = state_dict[input_conv_weight].shape[1] + if in_chans != pretrained_in_chans: + if in_chans == 1: + print('=> Converting pretrained input conv {} from {} to 1 channel'.format( + input_conv_weight, pretrained_in_chans)) + conv1_weight = state_dict[input_conv_weight] + state_dict[input_conv_weight] = conv1_weight.sum(dim=1, keepdim=True) + else: + print('=> Discarding pretrained input conv {} since input channel count != {}'.format( + input_conv_weight, pretrained_in_chans)) + del state_dict[input_conv_weight] + strict = False + + classifier_weight = classifier + '.weight' + pretrained_num_classes = state_dict[classifier_weight].shape[0] + if num_classes != pretrained_num_classes: + print('=> Discarding pretrained classifier since num_classes != {}'.format(pretrained_num_classes)) + del state_dict[classifier_weight] + del state_dict[classifier + '.bias'] + strict = False + + if filter_fn is not None: + state_dict = filter_fn(state_dict) + + model.load_state_dict(state_dict, strict=strict) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py new file mode 100644 index 0000000000000000000000000000000000000000..b5966c28f7207e98ee50745b1bc8f3663c650f9d --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/mobilenetv3.py @@ -0,0 +1,364 @@ +""" MobileNet-V3 + +A PyTorch impl of MobileNet-V3, compatible with TF weights from official impl. + +Paper: Searching for MobileNetV3 - https://arxiv.org/abs/1905.02244 + +Hacked together by / Copyright 2020 Ross Wightman +""" +import torch.nn as nn +import torch.nn.functional as F + +from .activations import get_act_fn, get_act_layer, HardSwish +from .config import layer_config_kwargs +from .conv2d_layers import select_conv2d +from .helpers import load_pretrained +from .efficientnet_builder import * + +__all__ = ['mobilenetv3_rw', 'mobilenetv3_large_075', 'mobilenetv3_large_100', 'mobilenetv3_large_minimal_100', + 'mobilenetv3_small_075', 'mobilenetv3_small_100', 'mobilenetv3_small_minimal_100', + 'tf_mobilenetv3_large_075', 'tf_mobilenetv3_large_100', 'tf_mobilenetv3_large_minimal_100', + 'tf_mobilenetv3_small_075', 'tf_mobilenetv3_small_100', 'tf_mobilenetv3_small_minimal_100'] + +model_urls = { + 'mobilenetv3_rw': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_100-35495452.pth', + 'mobilenetv3_large_075': None, + 'mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth', + 'mobilenetv3_large_minimal_100': None, + 'mobilenetv3_small_075': None, + 'mobilenetv3_small_100': None, + 'mobilenetv3_small_minimal_100': None, + 'tf_mobilenetv3_large_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_075-150ee8b0.pth', + 'tf_mobilenetv3_large_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_100-427764d5.pth', + 'tf_mobilenetv3_large_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_large_minimal_100-8596ae28.pth', + 'tf_mobilenetv3_small_075': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_075-da427f52.pth', + 'tf_mobilenetv3_small_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_100-37f49e2b.pth', + 'tf_mobilenetv3_small_minimal_100': + 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mobilenetv3_small_minimal_100-922a7843.pth', +} + + +class MobileNetV3(nn.Module): + """ MobileNet-V3 + + A this model utilizes the MobileNet-v3 specific 'efficient head', where global pooling is done before the + head convolution without a final batch-norm layer before the classifier. + + Paper: https://arxiv.org/abs/1905.02244 + """ + + def __init__(self, block_args, num_classes=1000, in_chans=3, stem_size=16, num_features=1280, head_bias=True, + channel_multiplier=1.0, pad_type='', act_layer=HardSwish, drop_rate=0., drop_connect_rate=0., + se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'): + super(MobileNetV3, self).__init__() + self.drop_rate = drop_rate + + stem_size = round_channels(stem_size, channel_multiplier) + self.conv_stem = select_conv2d(in_chans, stem_size, 3, stride=2, padding=pad_type) + self.bn1 = nn.BatchNorm2d(stem_size, **norm_kwargs) + self.act1 = act_layer(inplace=True) + in_chs = stem_size + + builder = EfficientNetBuilder( + channel_multiplier, pad_type=pad_type, act_layer=act_layer, se_kwargs=se_kwargs, + norm_layer=norm_layer, norm_kwargs=norm_kwargs, drop_connect_rate=drop_connect_rate) + self.blocks = nn.Sequential(*builder(in_chs, block_args)) + in_chs = builder.in_chs + + self.global_pool = nn.AdaptiveAvgPool2d(1) + self.conv_head = select_conv2d(in_chs, num_features, 1, padding=pad_type, bias=head_bias) + self.act2 = act_layer(inplace=True) + self.classifier = nn.Linear(num_features, num_classes) + + for m in self.modules(): + if weight_init == 'goog': + initialize_weight_goog(m) + else: + initialize_weight_default(m) + + def as_sequential(self): + layers = [self.conv_stem, self.bn1, self.act1] + layers.extend(self.blocks) + layers.extend([ + self.global_pool, self.conv_head, self.act2, + nn.Flatten(), nn.Dropout(self.drop_rate), self.classifier]) + return nn.Sequential(*layers) + + def features(self, x): + x = self.conv_stem(x) + x = self.bn1(x) + x = self.act1(x) + x = self.blocks(x) + x = self.global_pool(x) + x = self.conv_head(x) + x = self.act2(x) + return x + + def forward(self, x): + x = self.features(x) + x = x.flatten(1) + if self.drop_rate > 0.: + x = F.dropout(x, p=self.drop_rate, training=self.training) + return self.classifier(x) + + +def _create_model(model_kwargs, variant, pretrained=False): + as_sequential = model_kwargs.pop('as_sequential', False) + model = MobileNetV3(**model_kwargs) + if pretrained and model_urls[variant]: + load_pretrained(model, model_urls[variant]) + if as_sequential: + model = model.as_sequential() + return model + + +def _gen_mobilenet_v3_rw(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 model (RW variant). + + Paper: https://arxiv.org/abs/1905.02244 + + This was my first attempt at reproducing the MobileNet-V3 from paper alone. It came close to the + eventual Tensorflow reference impl but has a few differences: + 1. This model has no bias on the head convolution + 2. This model forces no residual (noskip) on the first DWS block, this is different than MnasNet + 3. This model always uses ReLU for the SE activation layer, other models in the family inherit their act layer + from their parent block + 4. This model does not enforce divisible by 8 limitation on the SE reduction channel count + + Overall the changes are fairly minor and result in a very small parameter count difference and no + top-1/5 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre_noskip'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + head_bias=False, # one of my mistakes + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, 'hard_swish'), + se_kwargs=dict(gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def _gen_mobilenet_v3(variant, channel_multiplier=1.0, pretrained=False, **kwargs): + """Creates a MobileNet-V3 large/small/minimal models. + + Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v3.py + Paper: https://arxiv.org/abs/1905.02244 + + Args: + channel_multiplier: multiplier to number of channels per layer. + """ + if 'small' in variant: + num_features = 1024 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16'], + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24', 'ir_r1_k3_s1_e3.67_c24'], + # stage 2, 28x28 in + ['ir_r1_k3_s2_e4_c40', 'ir_r2_k3_s1_e6_c40'], + # stage 3, 14x14 in + ['ir_r2_k3_s1_e3_c48'], + # stage 4, 14x14in + ['ir_r3_k3_s2_e6_c96'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s2_e1_c16_se0.25_nre'], # relu + # stage 1, 56x56 in + ['ir_r1_k3_s2_e4.5_c24_nre', 'ir_r1_k3_s1_e3.67_c24_nre'], # relu + # stage 2, 28x28 in + ['ir_r1_k5_s2_e4_c40_se0.25', 'ir_r2_k5_s1_e6_c40_se0.25'], # hard-swish + # stage 3, 14x14 in + ['ir_r2_k5_s1_e3_c48_se0.25'], # hard-swish + # stage 4, 14x14in + ['ir_r3_k5_s2_e6_c96_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c576'], # hard-swish + ] + else: + num_features = 1280 + if 'minimal' in variant: + act_layer = 'relu' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16'], + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24', 'ir_r1_k3_s1_e3_c24'], + # stage 2, 56x56 in + ['ir_r3_k3_s2_e3_c40'], + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112'], + # stage 5, 14x14in + ['ir_r3_k3_s2_e6_c160'], + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], + ] + else: + act_layer = 'hard_swish' + arch_def = [ + # stage 0, 112x112 in + ['ds_r1_k3_s1_e1_c16_nre'], # relu + # stage 1, 112x112 in + ['ir_r1_k3_s2_e4_c24_nre', 'ir_r1_k3_s1_e3_c24_nre'], # relu + # stage 2, 56x56 in + ['ir_r3_k5_s2_e3_c40_se0.25_nre'], # relu + # stage 3, 28x28 in + ['ir_r1_k3_s2_e6_c80', 'ir_r1_k3_s1_e2.5_c80', 'ir_r2_k3_s1_e2.3_c80'], # hard-swish + # stage 4, 14x14in + ['ir_r2_k3_s1_e6_c112_se0.25'], # hard-swish + # stage 5, 14x14in + ['ir_r3_k5_s2_e6_c160_se0.25'], # hard-swish + # stage 6, 7x7 in + ['cn_r1_k1_s1_c960'], # hard-swish + ] + with layer_config_kwargs(kwargs): + model_kwargs = dict( + block_args=decode_arch_def(arch_def), + num_features=num_features, + stem_size=16, + channel_multiplier=channel_multiplier, + act_layer=resolve_act_layer(kwargs, act_layer), + se_kwargs=dict( + act_layer=get_act_layer('relu'), gate_fn=get_act_fn('hard_sigmoid'), reduce_mid=True, divisor=8), + norm_kwargs=resolve_bn_args(kwargs), + **kwargs, + ) + model = _create_model(model_kwargs, variant, pretrained) + return model + + +def mobilenetv3_rw(pretrained=False, **kwargs): + """ MobileNet-V3 RW + Attn: See note in gen function for this variant. + """ + # NOTE for train set drop_rate=0.2 + if pretrained: + # pretrained model trained with non-default BN epsilon + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + model = _gen_mobilenet_v3_rw('mobilenetv3_rw', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75""" + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large (Minimalistic) 1.0 """ + # NOTE for train set drop_rate=0.2 + model = _gen_mobilenet_v3('mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75 """ + model = _gen_mobilenet_v3('mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small (Minimalistic) 1.0 """ + model = _gen_mobilenet_v3('mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_075(pretrained=False, **kwargs): + """ MobileNet V3 Large 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_100(pretrained=False, **kwargs): + """ MobileNet V3 Large 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_large_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Large Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_large_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_075(pretrained=False, **kwargs): + """ MobileNet V3 Small 0.75. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_075', 0.75, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_100(pretrained=False, **kwargs): + """ MobileNet V3 Small 1.0. Tensorflow compat variant.""" + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_100', 1.0, pretrained=pretrained, **kwargs) + return model + + +def tf_mobilenetv3_small_minimal_100(pretrained=False, **kwargs): + """ MobileNet V3 Small Minimalistic 1.0. Tensorflow compat variant. """ + kwargs['bn_eps'] = BN_EPS_TF_DEFAULT + kwargs['pad_type'] = 'same' + model = _gen_mobilenet_v3('tf_mobilenetv3_small_minimal_100', 1.0, pretrained=pretrained, **kwargs) + return model diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4d46ea8baedaf3d787826eb3bb314b4230514647 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/model_factory.py @@ -0,0 +1,27 @@ +from .config import set_layer_config +from .helpers import load_checkpoint + +from .gen_efficientnet import * +from .mobilenetv3 import * + + +def create_model( + model_name='mnasnet_100', + pretrained=None, + num_classes=1000, + in_chans=3, + checkpoint_path='', + **kwargs): + + model_kwargs = dict(num_classes=num_classes, in_chans=in_chans, pretrained=pretrained, **kwargs) + + if model_name in globals(): + create_fn = globals()[model_name] + model = create_fn(**model_kwargs) + else: + raise RuntimeError('Unknown model (%s)' % model_name) + + if checkpoint_path and not pretrained: + load_checkpoint(model, checkpoint_path) + + return model diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py new file mode 100644 index 0000000000000000000000000000000000000000..a6221b3de7b1490c5e712e8b5fcc94c3d9d04295 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/geffnet/version.py @@ -0,0 +1 @@ +__version__ = '1.0.2' diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..45b17b99bbeba34596569e6e50f6e8a2ebc45c54 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/hubconf.py @@ -0,0 +1,84 @@ +dependencies = ['torch', 'math'] + +from geffnet import efficientnet_b0 +from geffnet import efficientnet_b1 +from geffnet import efficientnet_b2 +from geffnet import efficientnet_b3 + +from geffnet import efficientnet_es + +from geffnet import efficientnet_lite0 + +from geffnet import mixnet_s +from geffnet import mixnet_m +from geffnet import mixnet_l +from geffnet import mixnet_xl + +from geffnet import mobilenetv2_100 +from geffnet import mobilenetv2_110d +from geffnet import mobilenetv2_120d +from geffnet import mobilenetv2_140 + +from geffnet import mobilenetv3_large_100 +from geffnet import mobilenetv3_rw +from geffnet import mnasnet_a1 +from geffnet import mnasnet_b1 +from geffnet import fbnetc_100 +from geffnet import spnasnet_100 + +from geffnet import tf_efficientnet_b0 +from geffnet import tf_efficientnet_b1 +from geffnet import tf_efficientnet_b2 +from geffnet import tf_efficientnet_b3 +from geffnet import tf_efficientnet_b4 +from geffnet import tf_efficientnet_b5 +from geffnet import tf_efficientnet_b6 +from geffnet import tf_efficientnet_b7 +from geffnet import tf_efficientnet_b8 + +from geffnet import tf_efficientnet_b0_ap +from geffnet import tf_efficientnet_b1_ap +from geffnet import tf_efficientnet_b2_ap +from geffnet import tf_efficientnet_b3_ap +from geffnet import tf_efficientnet_b4_ap +from geffnet import tf_efficientnet_b5_ap +from geffnet import tf_efficientnet_b6_ap +from geffnet import tf_efficientnet_b7_ap +from geffnet import tf_efficientnet_b8_ap + +from geffnet import tf_efficientnet_b0_ns +from geffnet import tf_efficientnet_b1_ns +from geffnet import tf_efficientnet_b2_ns +from geffnet import tf_efficientnet_b3_ns +from geffnet import tf_efficientnet_b4_ns +from geffnet import tf_efficientnet_b5_ns +from geffnet import tf_efficientnet_b6_ns +from geffnet import tf_efficientnet_b7_ns +from geffnet import tf_efficientnet_l2_ns_475 +from geffnet import tf_efficientnet_l2_ns + +from geffnet import tf_efficientnet_es +from geffnet import tf_efficientnet_em +from geffnet import tf_efficientnet_el + +from geffnet import tf_efficientnet_cc_b0_4e +from geffnet import tf_efficientnet_cc_b0_8e +from geffnet import tf_efficientnet_cc_b1_8e + +from geffnet import tf_efficientnet_lite0 +from geffnet import tf_efficientnet_lite1 +from geffnet import tf_efficientnet_lite2 +from geffnet import tf_efficientnet_lite3 +from geffnet import tf_efficientnet_lite4 + +from geffnet import tf_mixnet_s +from geffnet import tf_mixnet_m +from geffnet import tf_mixnet_l + +from geffnet import tf_mobilenetv3_large_075 +from geffnet import tf_mobilenetv3_large_100 +from geffnet import tf_mobilenetv3_large_minimal_100 +from geffnet import tf_mobilenetv3_small_075 +from geffnet import tf_mobilenetv3_small_100 +from geffnet import tf_mobilenetv3_small_minimal_100 + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5162ce214830df501bdb81edb66c095122f69d --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_export.py @@ -0,0 +1,120 @@ +""" ONNX export script + +Export PyTorch models as ONNX graphs. + +This export script originally started as an adaptation of code snippets found at +https://pytorch.org/tutorials/advanced/super_resolution_with_onnxruntime.html + +The default parameters work with PyTorch 1.6 and ONNX 1.7 and produce an optimal ONNX graph +for hosting in the ONNX runtime (see onnx_validate.py). To export an ONNX model compatible +with caffe2 (see caffe2_benchmark.py and caffe2_validate.py), the --keep-init and --aten-fallback +flags are currently required. + +Older versions of PyTorch/ONNX (tested PyTorch 1.4, ONNX 1.5) do not need extra flags for +caffe2 compatibility, but they produce a model that isn't as fast running on ONNX runtime. + +Most new release of PyTorch and ONNX cause some sort of breakage in the export / usage of ONNX models. +Please do your research and search ONNX and PyTorch issue tracker before asking me. Thanks. + +Copyright 2020 Ross Wightman +""" +import argparse +import torch +import numpy as np + +import onnx +import geffnet + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('output', metavar='ONNX_FILE', + help='output model filename') +parser.add_argument('--model', '-m', metavar='MODEL', default='mobilenetv3_large_100', + help='model architecture (default: mobilenetv3_large_100)') +parser.add_argument('--opset', type=int, default=10, + help='ONNX opset to use (default: 10)') +parser.add_argument('--keep-init', action='store_true', default=False, + help='Keep initializers as input. Needed for Caffe2 compatible export in newer PyTorch/ONNX.') +parser.add_argument('--aten-fallback', action='store_true', default=False, + help='Fallback to ATEN ops. Helps fix AdaptiveAvgPool issue with Caffe2 in newer PyTorch/ONNX.') +parser.add_argument('--dynamic-size', action='store_true', default=False, + help='Export model width dynamic width/height. Not recommended for "tf" models with SAME padding.') +parser.add_argument('-b', '--batch-size', default=1, type=int, + metavar='N', help='mini-batch size (default: 1)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to checkpoint (default: none)') + + +def main(): + args = parser.parse_args() + + args.pretrained = True + if args.checkpoint: + args.pretrained = False + + print("==> Creating PyTorch {} model".format(args.model)) + # NOTE exportable=True flag disables autofn/jit scripted activations and uses Conv2dSameExport layers + # for models using SAME padding + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + exportable=True) + + model.eval() + + example_input = torch.randn((args.batch_size, 3, args.img_size or 224, args.img_size or 224), requires_grad=True) + + # Run model once before export trace, sets padding for models with Conv2dSameExport. This means + # that the padding for models with Conv2dSameExport (most models with tf_ prefix) is fixed for + # the input img_size specified in this script. + # Opset >= 11 should allow for dynamic padding, however I cannot get it to work due to + # issues in the tracing of the dynamic padding or errors attempting to export the model after jit + # scripting it (an approach that should work). Perhaps in a future PyTorch or ONNX versions... + model(example_input) + + print("==> Exporting model to ONNX format at '{}'".format(args.output)) + input_names = ["input0"] + output_names = ["output0"] + dynamic_axes = {'input0': {0: 'batch'}, 'output0': {0: 'batch'}} + if args.dynamic_size: + dynamic_axes['input0'][2] = 'height' + dynamic_axes['input0'][3] = 'width' + if args.aten_fallback: + export_type = torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK + else: + export_type = torch.onnx.OperatorExportTypes.ONNX + + torch_out = torch.onnx._export( + model, example_input, args.output, export_params=True, verbose=True, input_names=input_names, + output_names=output_names, keep_initializers_as_inputs=args.keep_init, dynamic_axes=dynamic_axes, + opset_version=args.opset, operator_export_type=export_type) + + print("==> Loading and checking exported model from '{}'".format(args.output)) + onnx_model = onnx.load(args.output) + onnx.checker.check_model(onnx_model) # assuming throw on error + print("==> Passed") + + if args.keep_init and args.aten_fallback: + import caffe2.python.onnx.backend as onnx_caffe2 + # Caffe2 loading only works properly in newer PyTorch/ONNX combos when + # keep_initializers_as_inputs and aten_fallback are set to True. + print("==> Loading model into Caffe2 backend and comparing forward pass.".format(args.output)) + caffe2_backend = onnx_caffe2.prepare(onnx_model) + B = {onnx_model.graph.input[0].name: x.data.numpy()} + c2_out = caffe2_backend.run(B)[0] + np.testing.assert_almost_equal(torch_out.data.numpy(), c2_out, decimal=5) + print("==> Passed") + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py new file mode 100644 index 0000000000000000000000000000000000000000..ee20bbf9f0f9473370489512eb96ca0b570b5388 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_optimize.py @@ -0,0 +1,84 @@ +""" ONNX optimization script + +Run ONNX models through the optimizer to prune unneeded nodes, fuse batchnorm layers into conv, etc. + +NOTE: This isn't working consistently in recent PyTorch/ONNX combos (ie PyTorch 1.6 and ONNX 1.7), +it seems time to switch to using the onnxruntime online optimizer (can also be saved for offline). + +Copyright 2020 Ross Wightman +""" +import argparse +import warnings + +import onnx +from onnx import optimizer + + +parser = argparse.ArgumentParser(description="Optimize ONNX model") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--output", required=True, help="The optimized model output filename") + + +def traverse_graph(graph, prefix=''): + content = [] + indent = prefix + ' ' + graphs = [] + num_nodes = 0 + for node in graph.node: + pn, gs = onnx.helper.printable_node(node, indent, subgraphs=True) + assert isinstance(gs, list) + content.append(pn) + graphs.extend(gs) + num_nodes += 1 + for g in graphs: + g_count, g_str = traverse_graph(g) + content.append('\n' + g_str) + num_nodes += g_count + return num_nodes, '\n'.join(content) + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + num_original_nodes, original_graph_str = traverse_graph(onnx_model.graph) + + # Optimizer passes to perform + passes = [ + #'eliminate_deadend', + 'eliminate_identity', + 'eliminate_nop_dropout', + 'eliminate_nop_pad', + 'eliminate_nop_transpose', + 'eliminate_unused_initializer', + 'extract_constant_to_initializer', + 'fuse_add_bias_into_conv', + 'fuse_bn_into_conv', + 'fuse_consecutive_concats', + 'fuse_consecutive_reduce_unsqueeze', + 'fuse_consecutive_squeezes', + 'fuse_consecutive_transposes', + #'fuse_matmul_add_bias_into_gemm', + 'fuse_pad_into_conv', + #'fuse_transpose_into_gemm', + #'lift_lexical_references', + ] + + # Apply the optimization on the original serialized model + # WARNING I've had issues with optimizer in recent versions of PyTorch / ONNX causing + # 'duplicate definition of name' errors, see: https://github.com/onnx/onnx/issues/2401 + # It may be better to rely on onnxruntime optimizations, see onnx_validate.py script. + warnings.warn("I've had issues with optimizer in recent versions of PyTorch / ONNX." + "Try onnxruntime optimization if this doesn't work.") + optimized_model = optimizer.optimize(onnx_model, passes) + + num_optimized_nodes, optimzied_graph_str = traverse_graph(optimized_model.graph) + print('==> The model after optimization:\n{}\n'.format(optimzied_graph_str)) + print('==> The optimized model has {} nodes, the original had {}.'.format(num_optimized_nodes, num_original_nodes)) + + # Save the ONNX model + onnx.save(optimized_model, args.output) + + +if __name__ == "__main__": + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py new file mode 100644 index 0000000000000000000000000000000000000000..44399aafababcdf6b84147a0613eb0909730db4b --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_to_caffe.py @@ -0,0 +1,27 @@ +import argparse + +import onnx +from caffe2.python.onnx.backend import Caffe2Backend + + +parser = argparse.ArgumentParser(description="Convert ONNX to Caffe2") + +parser.add_argument("model", help="The ONNX model") +parser.add_argument("--c2-prefix", required=True, + help="The output file prefix for the caffe2 model init and predict file. ") + + +def main(): + args = parser.parse_args() + onnx_model = onnx.load(args.model) + caffe2_init, caffe2_predict = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) + caffe2_init_str = caffe2_init.SerializeToString() + with open(args.c2_prefix + '.init.pb', "wb") as f: + f.write(caffe2_init_str) + caffe2_predict_str = caffe2_predict.SerializeToString() + with open(args.c2_prefix + '.predict.pb', "wb") as f: + f.write(caffe2_predict_str) + + +if __name__ == "__main__": + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py new file mode 100644 index 0000000000000000000000000000000000000000..ab3e4fb141b6ef660dcc5b447fd9f368a2ea19a0 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/onnx_validate.py @@ -0,0 +1,112 @@ +""" ONNX-runtime validation script + +This script was created to verify accuracy and performance of exported ONNX +models running with the onnxruntime. It utilizes the PyTorch dataloader/processing +pipeline for a fair comparison against the originals. + +Copyright 2020 Ross Wightman +""" +import argparse +import numpy as np +import onnxruntime +from data import create_loader, resolve_data_config, Dataset +from utils import AverageMeter +import time + +parser = argparse.ArgumentParser(description='Caffe2 ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--onnx-input', default='', type=str, metavar='PATH', + help='path to onnx model/weights file') +parser.add_argument('--onnx-output-opt', default='', type=str, metavar='PATH', + help='path to output optimized onnx graph') +parser.add_argument('--profile', action='store_true', default=False, + help='Enable profiler output.') +parser.add_argument('-j', '--workers', default=2, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') + + +def main(): + args = parser.parse_args() + args.gpu_id = 0 + + # Set graph optimization level + sess_options = onnxruntime.SessionOptions() + sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + if args.profile: + sess_options.enable_profiling = True + if args.onnx_output_opt: + sess_options.optimized_model_filepath = args.onnx_output_opt + + session = onnxruntime.InferenceSession(args.onnx_input, sess_options) + + data_config = resolve_data_config(None, args) + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=False, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + input_name = session.get_inputs()[0].name + + batch_time = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + end = time.time() + for i, (input, target) in enumerate(loader): + # run the net and return prediction + output = session.run([], {input_name: input.data.numpy()}) + output = output[0] + + # measure accuracy and record loss + prec1, prec5 = accuracy_np(output, target.numpy()) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s, {ms_avg:.3f} ms/sample) \t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, rate_avg=input.size(0) / batch_time.avg, + ms_avg=100 * batch_time.avg / input.size(0), top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +def accuracy_np(output, target): + max_indices = np.argsort(output, axis=1)[:, ::-1] + top5 = 100 * np.equal(max_indices[:, :5], target[:, np.newaxis]).sum(axis=1).mean() + top1 = 100 * np.equal(max_indices[:, 0], target).mean() + return top1, top5 + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ac3ffc13bae15f9b11f7cbe3705760056ecd7f13 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.2.0 +torchvision>=0.4.0 diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..023e4c30f98164595964423e3a83eefaf7ffdad6 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/setup.py @@ -0,0 +1,47 @@ +""" Setup +""" +from setuptools import setup, find_packages +from codecs import open +from os import path + +here = path.abspath(path.dirname(__file__)) + +# Get the long description from the README file +with open(path.join(here, 'README.md'), encoding='utf-8') as f: + long_description = f.read() + +exec(open('geffnet/version.py').read()) +setup( + name='geffnet', + version=__version__, + description='(Generic) EfficientNets for PyTorch', + long_description=long_description, + long_description_content_type='text/markdown', + url='https://github.com/rwightman/gen-efficientnet-pytorch', + author='Ross Wightman', + author_email='hello@rwightman.com', + classifiers=[ + # How mature is this project? Common values are + # 3 - Alpha + # 4 - Beta + # 5 - Production/Stable + 'Development Status :: 3 - Alpha', + 'Intended Audience :: Education', + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Programming Language :: Python :: 3.8', + 'Topic :: Scientific/Engineering', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + 'Topic :: Software Development', + 'Topic :: Software Development :: Libraries', + 'Topic :: Software Development :: Libraries :: Python Modules', + ], + + # Note that this is a string of words separated by whitespace, not a list. + keywords='pytorch pretrained models efficientnet mixnet mobilenetv3 mnasnet', + packages=find_packages(exclude=['data']), + install_requires=['torch >= 1.4', 'torchvision'], + python_requires='>=3.6', +) diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d327e8bd8120c5cd09ae6c15c3991ccbe27f6c1f --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/utils.py @@ -0,0 +1,52 @@ +import os + + +class AverageMeter: + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def accuracy(output, target, topk=(1,)): + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def get_outdir(path, *paths, inc=False): + outdir = os.path.join(path, *paths) + if not os.path.exists(outdir): + os.makedirs(outdir) + elif inc: + count = 1 + outdir_inc = outdir + '-' + str(count) + while os.path.exists(outdir_inc): + count = count + 1 + outdir_inc = outdir + '-' + str(count) + assert count < 100 + outdir = outdir_inc + os.makedirs(outdir) + return outdir + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py new file mode 100644 index 0000000000000000000000000000000000000000..5fd44fbb3165ef81ef81251b6299f6aaa80bf2c2 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/efficientnet_repo/validate.py @@ -0,0 +1,166 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import time +import torch +import torch.nn as nn +import torch.nn.parallel +from contextlib import suppress + +import geffnet +from data import Dataset, create_loader, resolve_data_config +from utils import accuracy, AverageMeter + +has_native_amp = False +try: + if getattr(torch.cuda.amp, 'autocast') is not None: + has_native_amp = True +except AttributeError: + pass + +torch.backends.cudnn.benchmark = True + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation') +parser.add_argument('data', metavar='DIR', + help='path to dataset') +parser.add_argument('--model', '-m', metavar='MODEL', default='spnasnet1_00', + help='model architecture (default: dpn92)') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 2)') +parser.add_argument('-b', '--batch-size', default=256, type=int, + metavar='N', help='mini-batch size (default: 256)') +parser.add_argument('--img-size', default=None, type=int, + metavar='N', help='Input image dimension, uses model default if empty') +parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', + help='Override mean pixel value of dataset') +parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', + help='Override std deviation of of dataset') +parser.add_argument('--crop-pct', type=float, default=None, metavar='PCT', + help='Override default crop pct of 0.875') +parser.add_argument('--interpolation', default='', type=str, metavar='NAME', + help='Image resize interpolation type (overrides model)') +parser.add_argument('--num-classes', type=int, default=1000, + help='Number classes in dataset') +parser.add_argument('--print-freq', '-p', default=10, type=int, + metavar='N', help='print frequency (default: 10)') +parser.add_argument('--checkpoint', default='', type=str, metavar='PATH', + help='path to latest checkpoint (default: none)') +parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='use pre-trained model') +parser.add_argument('--torchscript', dest='torchscript', action='store_true', + help='convert model torchscript for inference') +parser.add_argument('--num-gpu', type=int, default=1, + help='Number of GPUS to use') +parser.add_argument('--tf-preprocessing', dest='tf_preprocessing', action='store_true', + help='use tensorflow mnasnet preporcessing') +parser.add_argument('--no-cuda', dest='no_cuda', action='store_true', + help='') +parser.add_argument('--channels-last', action='store_true', default=False, + help='Use channels_last memory layout') +parser.add_argument('--amp', action='store_true', default=False, + help='Use native Torch AMP mixed precision.') + + +def main(): + args = parser.parse_args() + + if not args.checkpoint and not args.pretrained: + args.pretrained = True + + amp_autocast = suppress # do nothing + if args.amp: + if not has_native_amp: + print("Native Torch AMP is not available (requires torch >= 1.6), using FP32.") + else: + amp_autocast = torch.cuda.amp.autocast + + # create model + model = geffnet.create_model( + args.model, + num_classes=args.num_classes, + in_chans=3, + pretrained=args.pretrained, + checkpoint_path=args.checkpoint, + scriptable=args.torchscript) + + if args.channels_last: + model = model.to(memory_format=torch.channels_last) + + if args.torchscript: + torch.jit.optimized_execution(True) + model = torch.jit.script(model) + + print('Model %s created, param count: %d' % + (args.model, sum([m.numel() for m in model.parameters()]))) + + data_config = resolve_data_config(model, args) + + criterion = nn.CrossEntropyLoss() + + if not args.no_cuda: + if args.num_gpu > 1: + model = torch.nn.DataParallel(model, device_ids=list(range(args.num_gpu))).cuda() + else: + model = model.cuda() + criterion = criterion.cuda() + + loader = create_loader( + Dataset(args.data, load_bytes=args.tf_preprocessing), + input_size=data_config['input_size'], + batch_size=args.batch_size, + use_prefetcher=not args.no_cuda, + interpolation=data_config['interpolation'], + mean=data_config['mean'], + std=data_config['std'], + num_workers=args.workers, + crop_pct=data_config['crop_pct'], + tensorflow_preprocessing=args.tf_preprocessing) + + batch_time = AverageMeter() + losses = AverageMeter() + top1 = AverageMeter() + top5 = AverageMeter() + + model.eval() + end = time.time() + with torch.no_grad(): + for i, (input, target) in enumerate(loader): + if not args.no_cuda: + target = target.cuda() + input = input.cuda() + if args.channels_last: + input = input.contiguous(memory_format=torch.channels_last) + + # compute output + with amp_autocast(): + output = model(input) + loss = criterion(output, target) + + # measure accuracy and record loss + prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) + losses.update(loss.item(), input.size(0)) + top1.update(prec1.item(), input.size(0)) + top5.update(prec5.item(), input.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % args.print_freq == 0: + print('Test: [{0}/{1}]\t' + 'Time {batch_time.val:.3f} ({batch_time.avg:.3f}, {rate_avg:.3f}/s) \t' + 'Loss {loss.val:.4f} ({loss.avg:.4f})\t' + 'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t' + 'Prec@5 {top5.val:.3f} ({top5.avg:.3f})'.format( + i, len(loader), batch_time=batch_time, + rate_avg=input.size(0) / batch_time.avg, + loss=losses, top1=top1, top5=top5)) + + print(' * Prec@1 {top1.avg:.3f} ({top1a:.3f}) Prec@5 {top5.avg:.3f} ({top5a:.3f})'.format( + top1=top1, top1a=100-top1.avg, top5=top5, top5a=100.-top5.avg)) + + +if __name__ == '__main__': + main() diff --git a/custom_controlnet_aux/normalbae/nets/submodules/encoder.py b/custom_controlnet_aux/normalbae/nets/submodules/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7149ca3c0cf2b6e019105af7e645cfbb3eda11 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/encoder.py @@ -0,0 +1,34 @@ +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + + basemodel_name = 'tf_efficientnet_b5_ap' + print('Loading base model ()...'.format(basemodel_name), end='') + repo_path = os.path.join(os.path.dirname(__file__), 'efficientnet_repo') + basemodel = torch.hub.load(repo_path, basemodel_name, pretrained=False, source='local') + print('Done.') + + # Remove last layer + print('Removing last two layers (global_pool & classifier).') + basemodel.global_pool = nn.Identity() + basemodel.classifier = nn.Identity() + + self.original_model = basemodel + + def forward(self, x): + features = [x] + for k, v in self.original_model._modules.items(): + if (k == 'blocks'): + for ki, vi in v._modules.items(): + features.append(vi(features[-1])) + else: + features.append(v(features[-1])) + return features + + diff --git a/custom_controlnet_aux/normalbae/nets/submodules/submodules.py b/custom_controlnet_aux/normalbae/nets/submodules/submodules.py new file mode 100644 index 0000000000000000000000000000000000000000..409733351bd6ab5d191c800aff1bc05bfa4cb6f8 --- /dev/null +++ b/custom_controlnet_aux/normalbae/nets/submodules/submodules.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +######################################################################################################################## + + +# Upsample + BatchNorm +class UpSampleBN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleBN, self).__init__() + + self._net = nn.Sequential(nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU(), + nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Upsample + GroupNorm + Weight Standardization +class UpSampleGN(nn.Module): + def __init__(self, skip_input, output_features): + super(UpSampleGN, self).__init__() + + self._net = nn.Sequential(Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU(), + Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1), + nn.GroupNorm(8, output_features), + nn.LeakyReLU()) + + def forward(self, x, concat_with): + up_x = F.interpolate(x, size=[concat_with.size(2), concat_with.size(3)], mode='bilinear', align_corners=True) + f = torch.cat([up_x, concat_with], dim=1) + return self._net(f) + + +# Conv2d with weight standardization +class Conv2d(nn.Conv2d): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, + padding=0, dilation=1, groups=1, bias=True): + super(Conv2d, self).__init__(in_channels, out_channels, kernel_size, stride, + padding, dilation, groups, bias) + + def forward(self, x): + weight = self.weight + weight_mean = weight.mean(dim=1, keepdim=True).mean(dim=2, + keepdim=True).mean(dim=3, keepdim=True) + weight = weight - weight_mean + std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 + weight = weight / std.expand_as(weight) + return F.conv2d(x, weight, self.bias, self.stride, + self.padding, self.dilation, self.groups) + + +# normalize +def norm_normalize(norm_out): + min_kappa = 0.01 + norm_x, norm_y, norm_z, kappa = torch.split(norm_out, 1, dim=1) + norm = torch.sqrt(norm_x ** 2.0 + norm_y ** 2.0 + norm_z ** 2.0) + 1e-10 + kappa = F.elu(kappa) + 1.0 + min_kappa + final_out = torch.cat([norm_x / norm, norm_y / norm, norm_z / norm, kappa], dim=1) + return final_out + + +# uncertainty-guided sampling (only used during training) +@torch.no_grad() +def sample_points(init_normal, gt_norm_mask, sampling_ratio, beta): + device = init_normal.device + B, _, H, W = init_normal.shape + N = int(sampling_ratio * H * W) + beta = beta + + # uncertainty map + uncertainty_map = -1 * init_normal[:, 3, :, :] # B, H, W + + # gt_invalid_mask (B, H, W) + if gt_norm_mask is not None: + gt_invalid_mask = F.interpolate(gt_norm_mask.float(), size=[H, W], mode='nearest') + gt_invalid_mask = gt_invalid_mask[:, 0, :, :] < 0.5 + uncertainty_map[gt_invalid_mask] = -1e4 + + # (B, H*W) + _, idx = uncertainty_map.view(B, -1).sort(1, descending=True) + + # importance sampling + if int(beta * N) > 0: + importance = idx[:, :int(beta * N)] # B, beta*N + + # remaining + remaining = idx[:, int(beta * N):] # B, H*W - beta*N + + # coverage + num_coverage = N - int(beta * N) + + if num_coverage <= 0: + samples = importance + else: + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = torch.cat((importance, coverage), dim=1) # B, N + + else: + # remaining + remaining = idx[:, :] # B, H*W + + # coverage + num_coverage = N + + coverage_list = [] + for i in range(B): + idx_c = torch.randperm(remaining.size()[1]) # shuffles "H*W - beta*N" + coverage_list.append(remaining[i, :][idx_c[:num_coverage]].view(1, -1)) # 1, N-beta*N + coverage = torch.cat(coverage_list, dim=0) # B, N-beta*N + samples = coverage + + # point coordinates + rows_int = samples // W # 0 for first row, H-1 for last row + rows_float = rows_int / float(H-1) # 0 to 1.0 + rows_float = (rows_float * 2.0) - 1.0 # -1.0 to 1.0 + + cols_int = samples % W # 0 for first column, W-1 for last column + cols_float = cols_int / float(W-1) # 0 to 1.0 + cols_float = (cols_float * 2.0) - 1.0 # -1.0 to 1.0 + + point_coords = torch.zeros(B, 1, N, 2) + point_coords[:, 0, :, 0] = cols_float # x coord + point_coords[:, 0, :, 1] = rows_float # y coord + point_coords = point_coords.to(device) + return point_coords, rows_int, cols_int \ No newline at end of file diff --git a/custom_controlnet_aux/oneformer/__init__.py b/custom_controlnet_aux/oneformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..073d5a5d6bcd7a3df993524946c40f6b5df70fc4 --- /dev/null +++ b/custom_controlnet_aux/oneformer/__init__.py @@ -0,0 +1,48 @@ +import os +from .api import make_detectron2_model, semantic_run +from pathlib import Path +import warnings +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +import numpy as np +import cv2 +from PIL import Image + +DEFAULT_CONFIGS = { + "coco": { + "name": "150_16_swin_l_oneformer_coco_100ep.pth", + "config": os.path.join(os.path.dirname(__file__), 'configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml') + }, + "ade20k": { + "name": "250_16_swin_l_oneformer_ade20k_160k.pth", + "config": os.path.join(os.path.dirname(__file__), 'configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml') + } +} +class OneformerSegmentor: + def __init__(self, model, metadata): + self.model = model + self.metadata = metadata + + def to(self, device): + self.model.model.to(device) + return self + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="250_16_swin_l_oneformer_ade20k_160k.pth", config_path = None): + config_path = config_path or DEFAULT_CONFIGS["ade20k" if "ade20k" in filename else "coco"]["config"] + model_path = custom_hf_download(pretrained_model_or_path, filename) + + model, metadata = make_detectron2_model(config_path, model_path) + + return cls(model, metadata) + + def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + detected_map = semantic_run(input_image, self.model, self.metadata) + detected_map = remove_pad(HWC3(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/oneformer/api.py b/custom_controlnet_aux/oneformer/api.py new file mode 100644 index 0000000000000000000000000000000000000000..3502e9404416867ade0063f00bfc44955435a75c --- /dev/null +++ b/custom_controlnet_aux/oneformer/api.py @@ -0,0 +1,39 @@ +import os +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import torch + +from custom_detectron2.config import get_cfg +from custom_detectron2.projects.deeplab import add_deeplab_config +from custom_detectron2.data import MetadataCatalog + +from custom_oneformer import ( + add_oneformer_config, + add_common_config, + add_swin_config, + add_dinat_config, +) + +from custom_oneformer.demo.defaults import DefaultPredictor +from custom_oneformer.demo.visualizer import Visualizer, ColorMode + + +def make_detectron2_model(config_path, ckpt_path): + cfg = get_cfg() + add_deeplab_config(cfg) + add_common_config(cfg) + add_swin_config(cfg) + add_oneformer_config(cfg) + add_dinat_config(cfg) + cfg.merge_from_file(config_path) + cfg.MODEL.WEIGHTS = ckpt_path + cfg.freeze() + metadata = MetadataCatalog.get(cfg.DATASETS.TEST_PANOPTIC[0] if len(cfg.DATASETS.TEST_PANOPTIC) else "__unused") + return DefaultPredictor(cfg), metadata + + +def semantic_run(img, predictor, metadata): + predictions = predictor(img[:, :, ::-1], "semantic") # Predictor of OneFormer must use BGR image !!! + visualizer_map = Visualizer(img, is_img=False, metadata=metadata, instance_mode=ColorMode.IMAGE) + out_map = visualizer_map.draw_sem_seg(predictions["sem_seg"].argmax(dim=0).cpu(), alpha=1, is_text=False).get_image() + return out_map \ No newline at end of file diff --git a/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml b/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..31eab45b878433fc844a13dbdd54f97c936d9b89 --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/ade20k/Base-ADE20K-UnifiedSegmentation.yaml @@ -0,0 +1,68 @@ +MODEL: + BACKBONE: + FREEZE_AT: 0 + NAME: "build_resnet_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STEM_TYPE: "basic" # not used + STEM_OUT_CHANNELS: 64 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 1, 1] # not used +DATASETS: + TRAIN: ("ade20k_panoptic_train",) + TEST_PANOPTIC: ("ade20k_panoptic_val",) + TEST_INSTANCE: ("ade20k_instance_val",) + TEST_SEMANTIC: ("ade20k_sem_seg_val",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.0001 + MAX_ITER: 160000 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 0 + WEIGHT_DECAY: 0.05 + OPTIMIZER: "ADAMW" + LR_SCHEDULER_NAME: "WarmupPolyLR" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 + AMP: + ENABLED: True +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 512) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 512 + MAX_SIZE_TRAIN: 2048 + MAX_SIZE_TEST: 2048 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (512, 512) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 512 # used in dataset mapper + FORMAT: "RGB" + DATASET_MAPPER_NAME: "oneformer_unified" + MAX_SEQ_LEN: 77 + TASK_SEQ_LEN: 77 + TASK_PROB: + SEMANTIC: 0.33 + INSTANCE: 0.66 +TEST: + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [256, 384, 512, 640, 768, 896] + MAX_SIZE: 3584 + FLIP: True +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: True + NUM_WORKERS: 4 +VERSION: 2 \ No newline at end of file diff --git a/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml b/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..770ffc81907f8d7c7520e079b1c46060707254b8 --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_R50_bs16_160k.yaml @@ -0,0 +1,58 @@ +_BASE_: Base-ADE20K-UnifiedSegmentation.yaml +MODEL: + META_ARCHITECTURE: "OneFormer" + SEM_SEG_HEAD: + NAME: "OneFormerHead" + IGNORE_VALUE: 255 + NUM_CLASSES: 150 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 256 + MASK_DIM: 256 + NORM: "GN" + # pixel decoder + PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"] + COMMON_STRIDE: 4 + TRANSFORMER_ENC_LAYERS: 6 + ONE_FORMER: + TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder" + TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" + DEEP_SUPERVISION: True + NO_OBJECT_WEIGHT: 0.1 + CLASS_WEIGHT: 2.0 + MASK_WEIGHT: 5.0 + DICE_WEIGHT: 5.0 + CONTRASTIVE_WEIGHT: 0.5 + CONTRASTIVE_TEMPERATURE: 0.07 + HIDDEN_DIM: 256 + NUM_OBJECT_QUERIES: 150 + USE_TASK_NORM: True + NHEADS: 8 + DROPOUT: 0.1 + DIM_FEEDFORWARD: 2048 + ENC_LAYERS: 0 + PRE_NORM: False + ENFORCE_INPUT_PROJ: False + SIZE_DIVISIBILITY: 32 + CLASS_DEC_LAYERS: 2 + DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query + TRAIN_NUM_POINTS: 12544 + OVERSAMPLE_RATIO: 3.0 + IMPORTANCE_SAMPLE_RATIO: 0.75 + TEXT_ENCODER: + WIDTH: 256 + CONTEXT_LENGTH: 77 + NUM_LAYERS: 6 + VOCAB_SIZE: 49408 + PROJ_NUM_LAYERS: 2 + N_CTX: 16 + TEST: + SEMANTIC_ON: True + INSTANCE_ON: True + PANOPTIC_ON: True + OVERLAP_THRESHOLD: 0.8 + OBJECT_MASK_THRESHOLD: 0.8 + TASK: "panoptic" +TEST: + DETECTIONS_PER_IMAGE: 150 diff --git a/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml b/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml new file mode 100644 index 0000000000000000000000000000000000000000..69c44ade144e4504077c0fe04fa8bb3491a679ed --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/ade20k/oneformer_swin_large_IN21k_384_bs16_160k.yaml @@ -0,0 +1,40 @@ +_BASE_: oneformer_R50_bs16_160k.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 192 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [6, 12, 24, 48] + WINDOW_SIZE: 12 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PRETRAIN_IMG_SIZE: 384 + WEIGHTS: "swin_large_patch4_window12_384_22k.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 250 +INPUT: + MIN_SIZE_TRAIN: !!python/object/apply:eval ["[int(x * 0.1 * 640) for x in range(5, 21)]"] + MIN_SIZE_TRAIN_SAMPLING: "choice" + MIN_SIZE_TEST: 640 + MAX_SIZE_TRAIN: 2560 + MAX_SIZE_TEST: 2560 + CROP: + ENABLED: True + TYPE: "absolute" + SIZE: (640, 640) + SINGLE_CATEGORY_MAX_AREA: 1.0 + COLOR_AUG_SSD: True + SIZE_DIVISIBILITY: 640 # used in dataset mapper + FORMAT: "RGB" +TEST: + DETECTIONS_PER_IMAGE: 250 + EVAL_PERIOD: 5000 + AUG: + ENABLED: False + MIN_SIZES: [320, 480, 640, 800, 960, 1120] + MAX_SIZE: 4480 + FLIP: True diff --git a/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml b/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ccd24f348f9bc7d60dcdc4b74d887708e57cb8a8 --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/coco/Base-COCO-UnifiedSegmentation.yaml @@ -0,0 +1,54 @@ +MODEL: + BACKBONE: + FREEZE_AT: 0 + NAME: "build_resnet_backbone" + WEIGHTS: "detectron2://ImageNetPretrained/torchvision/R-50.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + RESNETS: + DEPTH: 50 + STEM_TYPE: "basic" # not used + STEM_OUT_CHANNELS: 64 + STRIDE_IN_1X1: False + OUT_FEATURES: ["res2", "res3", "res4", "res5"] + # NORM: "SyncBN" + RES5_MULTI_GRID: [1, 1, 1] # not used +DATASETS: + TRAIN: ("coco_2017_train_panoptic_with_sem_seg",) + TEST_PANOPTIC: ("coco_2017_val_panoptic_with_sem_seg",) # to evaluate instance and semantic performance as well + TEST_INSTANCE: ("coco_2017_val",) + TEST_SEMANTIC: ("coco_2017_val_panoptic_with_sem_seg",) +SOLVER: + IMS_PER_BATCH: 16 + BASE_LR: 0.0001 + STEPS: (327778, 355092) + MAX_ITER: 368750 + WARMUP_FACTOR: 1.0 + WARMUP_ITERS: 10 + WEIGHT_DECAY: 0.05 + OPTIMIZER: "ADAMW" + BACKBONE_MULTIPLIER: 0.1 + CLIP_GRADIENTS: + ENABLED: True + CLIP_TYPE: "full_model" + CLIP_VALUE: 0.01 + NORM_TYPE: 2.0 + AMP: + ENABLED: True +INPUT: + IMAGE_SIZE: 1024 + MIN_SCALE: 0.1 + MAX_SCALE: 2.0 + FORMAT: "RGB" + DATASET_MAPPER_NAME: "coco_unified_lsj" + MAX_SEQ_LEN: 77 + TASK_SEQ_LEN: 77 + TASK_PROB: + SEMANTIC: 0.33 + INSTANCE: 0.66 +TEST: + EVAL_PERIOD: 5000 +DATALOADER: + FILTER_EMPTY_ANNOTATIONS: True + NUM_WORKERS: 4 +VERSION: 2 diff --git a/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml b/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f768c8fa8b5e4fc1121e65e050053e0d8870cd73 --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/coco/oneformer_R50_bs16_50ep.yaml @@ -0,0 +1,59 @@ +_BASE_: Base-COCO-UnifiedSegmentation.yaml +MODEL: + META_ARCHITECTURE: "OneFormer" + SEM_SEG_HEAD: + NAME: "OneFormerHead" + IGNORE_VALUE: 255 + NUM_CLASSES: 133 + LOSS_WEIGHT: 1.0 + CONVS_DIM: 256 + MASK_DIM: 256 + NORM: "GN" + # pixel decoder + PIXEL_DECODER_NAME: "MSDeformAttnPixelDecoder" + IN_FEATURES: ["res2", "res3", "res4", "res5"] + DEFORMABLE_TRANSFORMER_ENCODER_IN_FEATURES: ["res3", "res4", "res5"] + COMMON_STRIDE: 4 + TRANSFORMER_ENC_LAYERS: 6 + ONE_FORMER: + TRANSFORMER_DECODER_NAME: "ContrastiveMultiScaleMaskedTransformerDecoder" + TRANSFORMER_IN_FEATURE: "multi_scale_pixel_decoder" + DEEP_SUPERVISION: True + NO_OBJECT_WEIGHT: 0.1 + CLASS_WEIGHT: 2.0 + MASK_WEIGHT: 5.0 + DICE_WEIGHT: 5.0 + CONTRASTIVE_WEIGHT: 0.5 + CONTRASTIVE_TEMPERATURE: 0.07 + HIDDEN_DIM: 256 + NUM_OBJECT_QUERIES: 150 + USE_TASK_NORM: True + NHEADS: 8 + DROPOUT: 0.1 + DIM_FEEDFORWARD: 2048 + ENC_LAYERS: 0 + PRE_NORM: False + ENFORCE_INPUT_PROJ: False + SIZE_DIVISIBILITY: 32 + CLASS_DEC_LAYERS: 2 + DEC_LAYERS: 10 # 9 decoder layers, add one for the loss on learnable query + TRAIN_NUM_POINTS: 12544 + OVERSAMPLE_RATIO: 3.0 + IMPORTANCE_SAMPLE_RATIO: 0.75 + TEXT_ENCODER: + WIDTH: 256 + CONTEXT_LENGTH: 77 + NUM_LAYERS: 6 + VOCAB_SIZE: 49408 + PROJ_NUM_LAYERS: 2 + N_CTX: 16 + TEST: + SEMANTIC_ON: True + INSTANCE_ON: True + PANOPTIC_ON: True + DETECTION_ON: False + OVERLAP_THRESHOLD: 0.8 + OBJECT_MASK_THRESHOLD: 0.8 + TASK: "panoptic" +TEST: + DETECTIONS_PER_IMAGE: 150 diff --git a/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml b/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml new file mode 100644 index 0000000000000000000000000000000000000000..faae655317c52d90b9f756417f8b1a1adcbe78f2 --- /dev/null +++ b/custom_controlnet_aux/oneformer/configs/coco/oneformer_swin_large_IN21k_384_bs16_100ep.yaml @@ -0,0 +1,25 @@ +_BASE_: oneformer_R50_bs16_50ep.yaml +MODEL: + BACKBONE: + NAME: "D2SwinTransformer" + SWIN: + EMBED_DIM: 192 + DEPTHS: [2, 2, 18, 2] + NUM_HEADS: [6, 12, 24, 48] + WINDOW_SIZE: 12 + APE: False + DROP_PATH_RATE: 0.3 + PATCH_NORM: True + PRETRAIN_IMG_SIZE: 384 + WEIGHTS: "swin_large_patch4_window12_384_22k.pkl" + PIXEL_MEAN: [123.675, 116.280, 103.530] + PIXEL_STD: [58.395, 57.120, 57.375] + ONE_FORMER: + NUM_OBJECT_QUERIES: 150 +SOLVER: + STEPS: (655556, 735184) + MAX_ITER: 737500 + AMP: + ENABLED: False +TEST: + DETECTIONS_PER_IMAGE: 150 diff --git a/custom_controlnet_aux/open_pose/LICENSE b/custom_controlnet_aux/open_pose/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..6f60b76d35fa1012809985780964a5068adce4fd --- /dev/null +++ b/custom_controlnet_aux/open_pose/LICENSE @@ -0,0 +1,108 @@ +OPENPOSE: MULTIPERSON KEYPOINT DETECTION +SOFTWARE LICENSE AGREEMENT +ACADEMIC OR NON-PROFIT ORGANIZATION NONCOMMERCIAL RESEARCH USE ONLY + +BY USING OR DOWNLOADING THE SOFTWARE, YOU ARE AGREEING TO THE TERMS OF THIS LICENSE AGREEMENT. IF YOU DO NOT AGREE WITH THESE TERMS, YOU MAY NOT USE OR DOWNLOAD THE SOFTWARE. + +This is a license agreement ("Agreement") between your academic institution or non-profit organization or self (called "Licensee" or "You" in this Agreement) and Carnegie Mellon University (called "Licensor" in this Agreement). All rights not specifically granted to you in this Agreement are reserved for Licensor. + +RESERVATION OF OWNERSHIP AND GRANT OF LICENSE: +Licensor retains exclusive ownership of any copy of the Software (as defined below) licensed under this Agreement and hereby grants to Licensee a personal, non-exclusive, +non-transferable license to use the Software for noncommercial research purposes, without the right to sublicense, pursuant to the terms and conditions of this Agreement. As used in this Agreement, the term "Software" means (i) the actual copy of all or any portion of code for program routines made accessible to Licensee by Licensor pursuant to this Agreement, inclusive of backups, updates, and/or merged copies permitted hereunder or subsequently supplied by Licensor, including all or any file structures, programming instructions, user interfaces and screen formats and sequences as well as any and all documentation and instructions related to it, and (ii) all or any derivatives and/or modifications created or made by You to any of the items specified in (i). + +CONFIDENTIALITY: Licensee acknowledges that the Software is proprietary to Licensor, and as such, Licensee agrees to receive all such materials in confidence and use the Software only in accordance with the terms of this Agreement. Licensee agrees to use reasonable effort to protect the Software from unauthorized use, reproduction, distribution, or publication. + +COPYRIGHT: The Software is owned by Licensor and is protected by United +States copyright laws and applicable international treaties and/or conventions. + +PERMITTED USES: The Software may be used for your own noncommercial internal research purposes. You understand and agree that Licensor is not obligated to implement any suggestions and/or feedback you might provide regarding the Software, but to the extent Licensor does so, you are not entitled to any compensation related thereto. + +DERIVATIVES: You may create derivatives of or make modifications to the Software, however, You agree that all and any such derivatives and modifications will be owned by Licensor and become a part of the Software licensed to You under this Agreement. You may only use such derivatives and modifications for your own noncommercial internal research purposes, and you may not otherwise use, distribute or copy such derivatives and modifications in violation of this Agreement. + +BACKUPS: If Licensee is an organization, it may make that number of copies of the Software necessary for internal noncommercial use at a single site within its organization provided that all information appearing in or on the original labels, including the copyright and trademark notices are copied onto the labels of the copies. + +USES NOT PERMITTED: You may not distribute, copy or use the Software except as explicitly permitted herein. Licensee has not been granted any trademark license as part of this Agreement and may not use the name or mark “OpenPose", "Carnegie Mellon" or any renditions thereof without the prior written permission of Licensor. + +You may not sell, rent, lease, sublicense, lend, time-share or transfer, in whole or in part, or provide third parties access to prior or present versions (or any parts thereof) of the Software. + +ASSIGNMENT: You may not assign this Agreement or your rights hereunder without the prior written consent of Licensor. Any attempted assignment without such consent shall be null and void. + +TERM: The term of the license granted by this Agreement is from Licensee's acceptance of this Agreement by downloading the Software or by using the Software until terminated as provided below. + +The Agreement automatically terminates without notice if you fail to comply with any provision of this Agreement. Licensee may terminate this Agreement by ceasing using the Software. Upon any termination of this Agreement, Licensee will delete any and all copies of the Software. You agree that all provisions which operate to protect the proprietary rights of Licensor shall remain in force should breach occur and that the obligation of confidentiality described in this Agreement is binding in perpetuity and, as such, survives the term of the Agreement. + +FEE: Provided Licensee abides completely by the terms and conditions of this Agreement, there is no fee due to Licensor for Licensee's use of the Software in accordance with this Agreement. + +DISCLAIMER OF WARRANTIES: THE SOFTWARE IS PROVIDED "AS-IS" WITHOUT WARRANTY OF ANY KIND INCLUDING ANY WARRANTIES OF PERFORMANCE OR MERCHANTABILITY OR FITNESS FOR A PARTICULAR USE OR PURPOSE OR OF NON-INFRINGEMENT. LICENSEE BEARS ALL RISK RELATING TO QUALITY AND PERFORMANCE OF THE SOFTWARE AND RELATED MATERIALS. + +SUPPORT AND MAINTENANCE: No Software support or training by the Licensor is provided as part of this Agreement. + +EXCLUSIVE REMEDY AND LIMITATION OF LIABILITY: To the maximum extent permitted under applicable law, Licensor shall not be liable for direct, indirect, special, incidental, or consequential damages or lost profits related to Licensee's use of and/or inability to use the Software, even if Licensor is advised of the possibility of such damage. + +EXPORT REGULATION: Licensee agrees to comply with any and all applicable +U.S. export control laws, regulations, and/or other laws related to embargoes and sanction programs administered by the Office of Foreign Assets Control. + +SEVERABILITY: If any provision(s) of this Agreement shall be held to be invalid, illegal, or unenforceable by a court or other tribunal of competent jurisdiction, the validity, legality and enforceability of the remaining provisions shall not in any way be affected or impaired thereby. + +NO IMPLIED WAIVERS: No failure or delay by Licensor in enforcing any right or remedy under this Agreement shall be construed as a waiver of any future or other exercise of such right or remedy by Licensor. + +GOVERNING LAW: This Agreement shall be construed and enforced in accordance with the laws of the Commonwealth of Pennsylvania without reference to conflict of laws principles. You consent to the personal jurisdiction of the courts of this County and waive their rights to venue outside of Allegheny County, Pennsylvania. + +ENTIRE AGREEMENT AND AMENDMENTS: This Agreement constitutes the sole and entire agreement between Licensee and Licensor as to the matter set forth herein and supersedes any previous agreements, understandings, and arrangements between the parties relating hereto. + + + +************************************************************************ + +THIRD-PARTY SOFTWARE NOTICES AND INFORMATION + +This project incorporates material from the project(s) listed below (collectively, "Third Party Code"). This Third Party Code is licensed to you under their original license terms set forth below. We reserves all other rights not expressly granted, whether by implication, estoppel or otherwise. + +1. Caffe, version 1.0.0, (https://github.com/BVLC/caffe/) + +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014-2017 The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014-2017, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. + +************END OF THIRD-PARTY SOFTWARE NOTICES AND INFORMATION********** \ No newline at end of file diff --git a/custom_controlnet_aux/open_pose/__init__.py b/custom_controlnet_aux/open_pose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a05efecf45f6db5c9332675394898cb35cb8ff4 --- /dev/null +++ b/custom_controlnet_aux/open_pose/__init__.py @@ -0,0 +1,238 @@ +# Openpose +# Original from CMU https://github.com/CMU-Perceptual-Computing-Lab/openpose +# 2nd Edited by https://github.com/Hzzone/pytorch-openpose +# 3rd Edited by ControlNet +# 4th Edited by ControlNet (added face and correct hands) +# 5th Edited by ControlNet (Improved JSON serialization/deserialization, and lots of bug fixs) +# This preprocessor is licensed by CMU for non-commercial use only. + + +import os + +os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" + +import json +import warnings +from typing import Callable, List, NamedTuple, Tuple, Union + +import cv2 +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +from . import util +from .body import Body, BodyResult, Keypoint +from .face import Face +from .hand import Hand + +HandResult = List[Keypoint] +FaceResult = List[Keypoint] + +class PoseResult(NamedTuple): + body: BodyResult + left_hand: Union[HandResult, None] + right_hand: Union[HandResult, None] + face: Union[FaceResult, None] + +def draw_poses(poses: List[PoseResult], H, W, draw_body=True, draw_hand=True, draw_face=True): + """ + Draw the detected poses on an empty canvas. + + Args: + poses (List[PoseResult]): A list of PoseResult objects containing the detected poses. + H (int): The height of the canvas. + W (int): The width of the canvas. + draw_body (bool, optional): Whether to draw body keypoints. Defaults to True. + draw_hand (bool, optional): Whether to draw hand keypoints. Defaults to True. + draw_face (bool, optional): Whether to draw face keypoints. Defaults to True. + + Returns: + numpy.ndarray: A 3D numpy array representing the canvas with the drawn poses. + """ + canvas = np.zeros(shape=(H, W, 3), dtype=np.uint8) + + for pose in poses: + if draw_body: + canvas = util.draw_bodypose(canvas, pose.body.keypoints) + + if draw_hand: + canvas = util.draw_handpose(canvas, pose.left_hand) + canvas = util.draw_handpose(canvas, pose.right_hand) + + if draw_face: + canvas = util.draw_facepose(canvas, pose.face) + + return canvas + +def encode_poses_as_dict(poses: List[PoseResult], canvas_height: int, canvas_width: int) -> str: + """ Encode the pose as a dict following openpose JSON output format: + https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/02_output.md + """ + def compress_keypoints(keypoints: Union[List[Keypoint], None]) -> Union[List[float], None]: + if not keypoints: + return None + + return [ + value + for keypoint in keypoints + for value in ( + [float(keypoint.x), float(keypoint.y), 1.0] + if keypoint is not None + else [0.0, 0.0, 0.0] + ) + ] + + return { + 'people': [ + { + 'pose_keypoints_2d': compress_keypoints(pose.body.keypoints), + "face_keypoints_2d": compress_keypoints(pose.face), + "hand_left_keypoints_2d": compress_keypoints(pose.left_hand), + "hand_right_keypoints_2d":compress_keypoints(pose.right_hand), + } + for pose in poses + ], + 'canvas_height': canvas_height, + 'canvas_width': canvas_width, + } + +class OpenposeDetector: + """ + A class for detecting human poses in images using the Openpose model. + + Attributes: + model_dir (str): Path to the directory where the pose models are stored. + """ + def __init__(self, body_estimation, hand_estimation=None, face_estimation=None): + self.body_estimation = body_estimation + self.hand_estimation = hand_estimation + self.face_estimation = face_estimation + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="body_pose_model.pth", hand_filename="hand_pose_model.pth", face_filename="facenet.pth"): + if pretrained_model_or_path == "lllyasviel/ControlNet": + subfolder = "annotator/ckpts" + face_pretrained_model_or_path = "lllyasviel/Annotators" + + else: + subfolder = '' + face_pretrained_model_or_path = pretrained_model_or_path + + body_model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder) + hand_model_path = custom_hf_download(pretrained_model_or_path, hand_filename, subfolder=subfolder) + face_model_path = custom_hf_download(face_pretrained_model_or_path, face_filename, subfolder=subfolder) + + body_estimation = Body(body_model_path) + hand_estimation = Hand(hand_model_path) + face_estimation = Face(face_model_path) + + return cls(body_estimation, hand_estimation, face_estimation) + + def to(self, device): + self.body_estimation.to(device) + self.hand_estimation.to(device) + self.face_estimation.to(device) + return self + + def detect_hands(self, body: BodyResult, oriImg) -> Tuple[Union[HandResult, None], Union[HandResult, None]]: + left_hand = None + right_hand = None + H, W, _ = oriImg.shape + for x, y, w, is_left in util.handDetect(body, oriImg): + peaks = self.hand_estimation(oriImg[y:y+w, x:x+w, :]).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + + hand_result = [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + if is_left: + left_hand = hand_result + else: + right_hand = hand_result + + return left_hand, right_hand + + def detect_face(self, body: BodyResult, oriImg) -> Union[FaceResult, None]: + face = util.faceDetect(body, oriImg) + if face is None: + return None + + x, y, w = face + H, W, _ = oriImg.shape + heatmaps = self.face_estimation(oriImg[y:y+w, x:x+w, :]) + peaks = self.face_estimation.compute_peaks_from_heatmaps(heatmaps).astype(np.float32) + if peaks.ndim == 2 and peaks.shape[1] == 2: + peaks[:, 0] = np.where(peaks[:, 0] < 1e-6, -1, peaks[:, 0] + x) / float(W) + peaks[:, 1] = np.where(peaks[:, 1] < 1e-6, -1, peaks[:, 1] + y) / float(H) + return [ + Keypoint(x=peak[0], y=peak[1]) + for peak in peaks + ] + + return None + + def detect_poses(self, oriImg, include_hand=False, include_face=False) -> List[PoseResult]: + """ + Detect poses in the given image. + Args: + oriImg (numpy.ndarray): The input image for pose detection. + include_hand (bool, optional): Whether to include hand detection. Defaults to False. + include_face (bool, optional): Whether to include face detection. Defaults to False. + + Returns: + List[PoseResult]: A list of PoseResult objects containing the detected poses. + """ + oriImg = oriImg[:, :, ::-1].copy() + H, W, C = oriImg.shape + with torch.no_grad(): + candidate, subset = self.body_estimation(oriImg) + bodies = self.body_estimation.format_body_result(candidate, subset) + + results = [] + for body in bodies: + left_hand, right_hand, face = (None,) * 3 + if include_hand: + left_hand, right_hand = self.detect_hands(body, oriImg) + if include_face: + face = self.detect_face(body, oriImg) + + results.append(PoseResult(BodyResult( + keypoints=[ + Keypoint( + x=keypoint.x / float(W), + y=keypoint.y / float(H) + ) if keypoint is not None else None + for keypoint in body.keypoints + ], + total_score=body.total_score, + total_parts=body.total_parts + ), left_hand, right_hand, face)) + + return results + + def __call__(self, input_image, detect_resolution=512, include_body=True, include_hand=False, include_face=False, hand_and_face=None, output_type="pil", image_and_json=False, upscale_method="INTER_CUBIC", **kwargs): + if hand_and_face is not None: + warnings.warn("hand_and_face is deprecated. Use include_hand and include_face instead.", DeprecationWarning) + include_hand = hand_and_face + include_face = hand_and_face + + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + poses = self.detect_poses(input_image, include_hand=include_hand, include_face=include_face) + canvas = draw_poses(poses, input_image.shape[0], input_image.shape[1], draw_body=include_body, draw_hand=include_hand, draw_face=include_face) + detected_map = HWC3(remove_pad(canvas)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + if image_and_json: + return (detected_map, encode_poses_as_dict(poses, detected_map.shape[0], detected_map.shape[1])) + + return detected_map diff --git a/custom_controlnet_aux/open_pose/body.py b/custom_controlnet_aux/open_pose/body.py new file mode 100644 index 0000000000000000000000000000000000000000..edd744869b603c4e544f7a522a375a6ac3794aa4 --- /dev/null +++ b/custom_controlnet_aux/open_pose/body.py @@ -0,0 +1,278 @@ +import math +from typing import List, NamedTuple, Union + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter + +from . import util +from .model import bodypose_model + + +class Keypoint(NamedTuple): + x: float + y: float + score: float = 1.0 + id: int = -1 + + +class BodyResult(NamedTuple): + # Note: Using `Union` instead of `|` operator as the ladder is a Python + # 3.10 feature. + # Annotator code should be Python 3.8 Compatible, as controlnet repo uses + # Python 3.8 environment. + # https://github.com/lllyasviel/ControlNet/blob/d3284fcd0972c510635a4f5abe2eeb71dc0de524/environment.yaml#L6 + keypoints: List[Union[Keypoint, None]] + total_score: float + total_parts: int + + +class Body(object): + def __init__(self, model_path): + self.model = bodypose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + self.device = "cpu" + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, oriImg): + # scale_search = [0.5, 1.0, 1.5, 2.0] + scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre1 = 0.1 + thre2 = 0.05 + multiplier = [x * boxsize / oriImg.shape[0] for x in scale_search] + heatmap_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 19)) + paf_avg = np.zeros((oriImg.shape[0], oriImg.shape[1], 38)) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize_k(oriImg, fx=scale, fy=scale) + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(self.device) + # data = data.permute([2, 0, 1]).unsqueeze(0).float() + with torch.no_grad(): + Mconv7_stage6_L1, Mconv7_stage6_L2 = self.model(data) + Mconv7_stage6_L1 = Mconv7_stage6_L1.cpu().numpy() + Mconv7_stage6_L2 = Mconv7_stage6_L2.cpu().numpy() + + # extract outputs, resize, and remove padding + # heatmap = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[1]].data), (1, 2, 0)) # output 1 is heatmaps + heatmap = np.transpose(np.squeeze(Mconv7_stage6_L2), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (oriImg.shape[0], oriImg.shape[1])) + + # paf = np.transpose(np.squeeze(net.blobs[output_blobs.keys()[0]].data), (1, 2, 0)) # output 0 is PAFs + paf = np.transpose(np.squeeze(Mconv7_stage6_L1), (1, 2, 0)) # output 0 is PAFs + paf = util.smart_resize_k(paf, fx=stride, fy=stride) + paf = paf[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + paf = util.smart_resize(paf, (oriImg.shape[0], oriImg.shape[1])) + + heatmap_avg += heatmap_avg + heatmap / len(multiplier) + paf_avg += + paf / len(multiplier) + + all_peaks = [] + peak_counter = 0 + + for part in range(18): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + + map_left = np.zeros(one_heatmap.shape) + map_left[1:, :] = one_heatmap[:-1, :] + map_right = np.zeros(one_heatmap.shape) + map_right[:-1, :] = one_heatmap[1:, :] + map_up = np.zeros(one_heatmap.shape) + map_up[:, 1:] = one_heatmap[:, :-1] + map_down = np.zeros(one_heatmap.shape) + map_down[:, :-1] = one_heatmap[:, 1:] + + peaks_binary = np.logical_and.reduce( + (one_heatmap >= map_left, one_heatmap >= map_right, one_heatmap >= map_up, one_heatmap >= map_down, one_heatmap > thre1)) + peaks = list(zip(np.nonzero(peaks_binary)[1], np.nonzero(peaks_binary)[0])) # note reverse + peaks_with_score = [x + (map_ori[x[1], x[0]],) for x in peaks] + peak_id = range(peak_counter, peak_counter + len(peaks)) + peaks_with_score_and_id = [peaks_with_score[i] + (peak_id[i],) for i in range(len(peak_id))] + + all_peaks.append(peaks_with_score_and_id) + peak_counter += len(peaks) + + # find connection in the specified sequence, center 29 is in the position 15 + limbSeq = [[2, 3], [2, 6], [3, 4], [4, 5], [6, 7], [7, 8], [2, 9], [9, 10], \ + [10, 11], [2, 12], [12, 13], [13, 14], [2, 1], [1, 15], [15, 17], \ + [1, 16], [16, 18], [3, 17], [6, 18]] + # the middle joints heatmap correpondence + mapIdx = [[31, 32], [39, 40], [33, 34], [35, 36], [41, 42], [43, 44], [19, 20], [21, 22], \ + [23, 24], [25, 26], [27, 28], [29, 30], [47, 48], [49, 50], [53, 54], [51, 52], \ + [55, 56], [37, 38], [45, 46]] + + connection_all = [] + special_k = [] + mid_num = 10 + + for k in range(len(mapIdx)): + score_mid = paf_avg[:, :, [x - 19 for x in mapIdx[k]]] + candA = all_peaks[limbSeq[k][0] - 1] + candB = all_peaks[limbSeq[k][1] - 1] + nA = len(candA) + nB = len(candB) + indexA, indexB = limbSeq[k] + if (nA != 0 and nB != 0): + connection_candidate = [] + for i in range(nA): + for j in range(nB): + vec = np.subtract(candB[j][:2], candA[i][:2]) + norm = math.sqrt(vec[0] * vec[0] + vec[1] * vec[1]) + norm = max(0.001, norm) + vec = np.divide(vec, norm) + + startend = list(zip(np.linspace(candA[i][0], candB[j][0], num=mid_num), \ + np.linspace(candA[i][1], candB[j][1], num=mid_num))) + + vec_x = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 0] \ + for I in range(len(startend))]) + vec_y = np.array([score_mid[int(round(startend[I][1])), int(round(startend[I][0])), 1] \ + for I in range(len(startend))]) + + score_midpts = np.multiply(vec_x, vec[0]) + np.multiply(vec_y, vec[1]) + score_with_dist_prior = sum(score_midpts) / len(score_midpts) + min( + 0.5 * oriImg.shape[0] / norm - 1, 0) + criterion1 = len(np.nonzero(score_midpts > thre2)[0]) > 0.8 * len(score_midpts) + criterion2 = score_with_dist_prior > 0 + if criterion1 and criterion2: + connection_candidate.append( + [i, j, score_with_dist_prior, score_with_dist_prior + candA[i][2] + candB[j][2]]) + + connection_candidate = sorted(connection_candidate, key=lambda x: x[2], reverse=True) + connection = np.zeros((0, 5)) + for c in range(len(connection_candidate)): + i, j, s = connection_candidate[c][0:3] + if (i not in connection[:, 3] and j not in connection[:, 4]): + connection = np.vstack([connection, [candA[i][3], candB[j][3], s, i, j]]) + if (len(connection) >= min(nA, nB)): + break + + connection_all.append(connection) + else: + special_k.append(k) + connection_all.append([]) + + # last number in each row is the total parts number of that person + # the second last number in each row is the score of the overall configuration + subset = -1 * np.ones((0, 20)) + candidate = np.array([item for sublist in all_peaks for item in sublist]) + + for k in range(len(mapIdx)): + if k not in special_k: + partAs = connection_all[k][:, 0] + partBs = connection_all[k][:, 1] + indexA, indexB = np.array(limbSeq[k]) - 1 + + for i in range(len(connection_all[k])): # = 1:size(temp,1) + found = 0 + subset_idx = [-1, -1] + for j in range(len(subset)): # 1:size(subset,1): + if subset[j][indexA] == partAs[i] or subset[j][indexB] == partBs[i]: + subset_idx[found] = j + found += 1 + + if found == 1: + j = subset_idx[0] + if subset[j][indexB] != partBs[i]: + subset[j][indexB] = partBs[i] + subset[j][-1] += 1 + subset[j][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + elif found == 2: # if found 2 and disjoint, merge them + j1, j2 = subset_idx + membership = ((subset[j1] >= 0).astype(int) + (subset[j2] >= 0).astype(int))[:-2] + if len(np.nonzero(membership == 2)[0]) == 0: # merge + subset[j1][:-2] += (subset[j2][:-2] + 1) + subset[j1][-2:] += subset[j2][-2:] + subset[j1][-2] += connection_all[k][i][2] + subset = np.delete(subset, j2, 0) + else: # as like found == 1 + subset[j1][indexB] = partBs[i] + subset[j1][-1] += 1 + subset[j1][-2] += candidate[partBs[i].astype(int), 2] + connection_all[k][i][2] + + # if find no partA in the subset, create a new subset + elif not found and k < 17: + row = -1 * np.ones(20) + row[indexA] = partAs[i] + row[indexB] = partBs[i] + row[-1] = 2 + row[-2] = sum(candidate[connection_all[k][i, :2].astype(int), 2]) + connection_all[k][i][2] + subset = np.vstack([subset, row]) + # delete some rows of subset which has few parts occur + deleteIdx = [] + for i in range(len(subset)): + if subset[i][-1] < 4 or subset[i][-2] / subset[i][-1] < 0.4: + deleteIdx.append(i) + subset = np.delete(subset, deleteIdx, axis=0) + + # subset: n*20 array, 0-17 is the index in candidate, 18 is the total score, 19 is the total parts + # candidate: x, y, score, id + return candidate, subset + + @staticmethod + def format_body_result(candidate: np.ndarray, subset: np.ndarray) -> List[BodyResult]: + """ + Format the body results from the candidate and subset arrays into a list of BodyResult objects. + + Args: + candidate (np.ndarray): An array of candidates containing the x, y coordinates, score, and id + for each body part. + subset (np.ndarray): An array of subsets containing indices to the candidate array for each + person detected. The last two columns of each row hold the total score and total parts + of the person. + + Returns: + List[BodyResult]: A list of BodyResult objects, where each object represents a person with + detected keypoints, total score, and total parts. + """ + return [ + BodyResult( + keypoints=[ + Keypoint( + x=candidate[candidate_index][0], + y=candidate[candidate_index][1], + score=candidate[candidate_index][2], + id=candidate[candidate_index][3] + ) if candidate_index != -1 else None + for candidate_index in person[:18].astype(int) + ], + total_score=person[18], + total_parts=person[19] + ) + for person in subset + ] + + +if __name__ == "__main__": + body_estimation = Body('../model/body_pose_model.pth') + + test_image = '../images/ski.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + candidate, subset = body_estimation(oriImg) + bodies = body_estimation.format_body_result(candidate, subset) + + canvas = oriImg + for body in bodies: + canvas = util.draw_bodypose(canvas, body) + + plt.imshow(canvas[:, :, [2, 1, 0]]) + plt.show() diff --git a/custom_controlnet_aux/open_pose/face.py b/custom_controlnet_aux/open_pose/face.py new file mode 100644 index 0000000000000000000000000000000000000000..44e9ee235931bf67582cde56c6b525cb261060d2 --- /dev/null +++ b/custom_controlnet_aux/open_pose/face.py @@ -0,0 +1,365 @@ +import logging + +import numpy as np +import torch +import torch.nn.functional as F +from torch.nn import Conv2d, MaxPool2d, Module, ReLU, init +from torchvision.transforms import ToPILImage, ToTensor + +from . import util + + +class FaceNet(Module): + """Model the cascading heatmaps. """ + def __init__(self): + super(FaceNet, self).__init__() + # cnn to make feature map + self.relu = ReLU() + self.max_pooling_2d = MaxPool2d(kernel_size=2, stride=2) + self.conv1_1 = Conv2d(in_channels=3, out_channels=64, + kernel_size=3, stride=1, padding=1) + self.conv1_2 = Conv2d( + in_channels=64, out_channels=64, kernel_size=3, stride=1, + padding=1) + self.conv2_1 = Conv2d( + in_channels=64, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv2_2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=3, stride=1, + padding=1) + self.conv3_1 = Conv2d( + in_channels=128, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_2 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_3 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv3_4 = Conv2d( + in_channels=256, out_channels=256, kernel_size=3, stride=1, + padding=1) + self.conv4_1 = Conv2d( + in_channels=256, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_3 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv4_4 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_1 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_2 = Conv2d( + in_channels=512, out_channels=512, kernel_size=3, stride=1, + padding=1) + self.conv5_3_CPM = Conv2d( + in_channels=512, out_channels=128, kernel_size=3, stride=1, + padding=1) + + # stage1 + self.conv6_1_CPM = Conv2d( + in_channels=128, out_channels=512, kernel_size=1, stride=1, + padding=0) + self.conv6_2_CPM = Conv2d( + in_channels=512, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage2 + self.Mconv1_stage2 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage2 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage2 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage3 + self.Mconv1_stage3 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage3 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage3 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage4 + self.Mconv1_stage4 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage4 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage4 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage5 + self.Mconv1_stage5 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage5 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage5 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + # stage6 + self.Mconv1_stage6 = Conv2d( + in_channels=199, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv2_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv3_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv4_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv5_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=7, stride=1, + padding=3) + self.Mconv6_stage6 = Conv2d( + in_channels=128, out_channels=128, kernel_size=1, stride=1, + padding=0) + self.Mconv7_stage6 = Conv2d( + in_channels=128, out_channels=71, kernel_size=1, stride=1, + padding=0) + + for m in self.modules(): + if isinstance(m, Conv2d): + init.constant_(m.bias, 0) + + def forward(self, x): + """Return a list of heatmaps.""" + heatmaps = [] + + h = self.relu(self.conv1_1(x)) + h = self.relu(self.conv1_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv2_1(h)) + h = self.relu(self.conv2_2(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv3_1(h)) + h = self.relu(self.conv3_2(h)) + h = self.relu(self.conv3_3(h)) + h = self.relu(self.conv3_4(h)) + h = self.max_pooling_2d(h) + h = self.relu(self.conv4_1(h)) + h = self.relu(self.conv4_2(h)) + h = self.relu(self.conv4_3(h)) + h = self.relu(self.conv4_4(h)) + h = self.relu(self.conv5_1(h)) + h = self.relu(self.conv5_2(h)) + h = self.relu(self.conv5_3_CPM(h)) + feature_map = h + + # stage1 + h = self.relu(self.conv6_1_CPM(h)) + h = self.conv6_2_CPM(h) + heatmaps.append(h) + + # stage2 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage2(h)) + h = self.relu(self.Mconv2_stage2(h)) + h = self.relu(self.Mconv3_stage2(h)) + h = self.relu(self.Mconv4_stage2(h)) + h = self.relu(self.Mconv5_stage2(h)) + h = self.relu(self.Mconv6_stage2(h)) + h = self.Mconv7_stage2(h) + heatmaps.append(h) + + # stage3 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage3(h)) + h = self.relu(self.Mconv2_stage3(h)) + h = self.relu(self.Mconv3_stage3(h)) + h = self.relu(self.Mconv4_stage3(h)) + h = self.relu(self.Mconv5_stage3(h)) + h = self.relu(self.Mconv6_stage3(h)) + h = self.Mconv7_stage3(h) + heatmaps.append(h) + + # stage4 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage4(h)) + h = self.relu(self.Mconv2_stage4(h)) + h = self.relu(self.Mconv3_stage4(h)) + h = self.relu(self.Mconv4_stage4(h)) + h = self.relu(self.Mconv5_stage4(h)) + h = self.relu(self.Mconv6_stage4(h)) + h = self.Mconv7_stage4(h) + heatmaps.append(h) + + # stage5 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage5(h)) + h = self.relu(self.Mconv2_stage5(h)) + h = self.relu(self.Mconv3_stage5(h)) + h = self.relu(self.Mconv4_stage5(h)) + h = self.relu(self.Mconv5_stage5(h)) + h = self.relu(self.Mconv6_stage5(h)) + h = self.Mconv7_stage5(h) + heatmaps.append(h) + + # stage6 + h = torch.cat([h, feature_map], dim=1) # channel concat + h = self.relu(self.Mconv1_stage6(h)) + h = self.relu(self.Mconv2_stage6(h)) + h = self.relu(self.Mconv3_stage6(h)) + h = self.relu(self.Mconv4_stage6(h)) + h = self.relu(self.Mconv5_stage6(h)) + h = self.relu(self.Mconv6_stage6(h)) + h = self.Mconv7_stage6(h) + heatmaps.append(h) + + return heatmaps + + +LOG = logging.getLogger(__name__) +TOTEN = ToTensor() +TOPIL = ToPILImage() + + +params = { + 'gaussian_sigma': 2.5, + 'inference_img_size': 736, # 368, 736, 1312 + 'heatmap_peak_thresh': 0.1, + 'crop_scale': 1.5, + 'line_indices': [ + [0, 1], [1, 2], [2, 3], [3, 4], [4, 5], [5, 6], + [6, 7], [7, 8], [8, 9], [9, 10], [10, 11], [11, 12], [12, 13], + [13, 14], [14, 15], [15, 16], + [17, 18], [18, 19], [19, 20], [20, 21], + [22, 23], [23, 24], [24, 25], [25, 26], + [27, 28], [28, 29], [29, 30], + [31, 32], [32, 33], [33, 34], [34, 35], + [36, 37], [37, 38], [38, 39], [39, 40], [40, 41], [41, 36], + [42, 43], [43, 44], [44, 45], [45, 46], [46, 47], [47, 42], + [48, 49], [49, 50], [50, 51], [51, 52], [52, 53], [53, 54], + [54, 55], [55, 56], [56, 57], [57, 58], [58, 59], [59, 48], + [60, 61], [61, 62], [62, 63], [63, 64], [64, 65], [65, 66], + [66, 67], [67, 60] + ], +} + + +class Face(object): + """ + The OpenPose face landmark detector model. + + Args: + inference_size: set the size of the inference image size, suggested: + 368, 736, 1312, default 736 + gaussian_sigma: blur the heatmaps, default 2.5 + heatmap_peak_thresh: return landmark if over threshold, default 0.1 + + """ + def __init__(self, face_model_path, + inference_size=None, + gaussian_sigma=None, + heatmap_peak_thresh=None): + self.inference_size = inference_size or params["inference_img_size"] + self.sigma = gaussian_sigma or params['gaussian_sigma'] + self.threshold = heatmap_peak_thresh or params["heatmap_peak_thresh"] + self.model = FaceNet() + self.model.load_state_dict(torch.load(face_model_path)) + self.model.eval() + self.device = "cpu" + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, face_img): + H, W, C = face_img.shape + + w_size = 384 + x_data = torch.from_numpy(util.smart_resize(face_img, (w_size, w_size))).permute([2, 0, 1]) / 256.0 - 0.5 + + x_data = x_data.to(self.device) + + with torch.no_grad(): + hs = self.model(x_data[None, ...]) + heatmaps = F.interpolate( + hs[-1], + (H, W), + mode='bilinear', align_corners=True).cpu().numpy()[0] + return heatmaps + + def compute_peaks_from_heatmaps(self, heatmaps): + all_peaks = [] + for part in range(heatmaps.shape[0]): + map_ori = heatmaps[part].copy() + binary = np.ascontiguousarray(map_ori > 0.05, dtype=np.uint8) + + if np.sum(binary) == 0: + continue + + positions = np.where(binary > 0.5) + intensities = map_ori[positions] + mi = np.argmax(intensities) + y, x = positions[0][mi], positions[1][mi] + all_peaks.append([x, y]) + + return np.array(all_peaks) \ No newline at end of file diff --git a/custom_controlnet_aux/open_pose/hand.py b/custom_controlnet_aux/open_pose/hand.py new file mode 100644 index 0000000000000000000000000000000000000000..ba641875540e54045c599818337a9cfef961b431 --- /dev/null +++ b/custom_controlnet_aux/open_pose/hand.py @@ -0,0 +1,91 @@ +import cv2 +import numpy as np +import torch +from scipy.ndimage.filters import gaussian_filter +from skimage.measure import label + +from . import util +from .model import handpose_model + + +class Hand(object): + def __init__(self, model_path): + self.model = handpose_model() + model_dict = util.transfer(self.model, torch.load(model_path)) + self.model.load_state_dict(model_dict) + self.model.eval() + self.device = "cpu" + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, oriImgRaw): + scale_search = [0.5, 1.0, 1.5, 2.0] + # scale_search = [0.5] + boxsize = 368 + stride = 8 + padValue = 128 + thre = 0.05 + multiplier = [x * boxsize for x in scale_search] + + wsize = 128 + heatmap_avg = np.zeros((wsize, wsize, 22)) + + Hr, Wr, Cr = oriImgRaw.shape + + oriImg = cv2.GaussianBlur(oriImgRaw, (0, 0), 0.8) + + for m in range(len(multiplier)): + scale = multiplier[m] + imageToTest = util.smart_resize(oriImg, (scale, scale)) + + imageToTest_padded, pad = util.padRightDownCorner(imageToTest, stride, padValue) + im = np.transpose(np.float32(imageToTest_padded[:, :, :, np.newaxis]), (3, 2, 0, 1)) / 256 - 0.5 + im = np.ascontiguousarray(im) + + data = torch.from_numpy(im).float() + data = data.to(self.device) + + with torch.no_grad(): + output = self.model(data).cpu().numpy() + + # extract outputs, resize, and remove padding + heatmap = np.transpose(np.squeeze(output), (1, 2, 0)) # output 1 is heatmaps + heatmap = util.smart_resize_k(heatmap, fx=stride, fy=stride) + heatmap = heatmap[:imageToTest_padded.shape[0] - pad[2], :imageToTest_padded.shape[1] - pad[3], :] + heatmap = util.smart_resize(heatmap, (wsize, wsize)) + + heatmap_avg += heatmap / len(multiplier) + + all_peaks = [] + for part in range(21): + map_ori = heatmap_avg[:, :, part] + one_heatmap = gaussian_filter(map_ori, sigma=3) + binary = np.ascontiguousarray(one_heatmap > thre, dtype=np.uint8) + + if np.sum(binary) == 0: + all_peaks.append([0, 0]) + continue + label_img, label_numbers = label(binary, return_num=True, connectivity=binary.ndim) + max_index = np.argmax([np.sum(map_ori[label_img == i]) for i in range(1, label_numbers + 1)]) + 1 + label_img[label_img != max_index] = 0 + map_ori[label_img == 0] = 0 + + y, x = util.npmax(map_ori) + y = int(float(y) * float(Hr) / float(wsize)) + x = int(float(x) * float(Wr) / float(wsize)) + all_peaks.append([x, y]) + return np.array(all_peaks) + +if __name__ == "__main__": + hand_estimation = Hand('../model/hand_pose_model.pth') + + # test_image = '../images/hand.jpg' + test_image = '../images/hand.jpg' + oriImg = cv2.imread(test_image) # B,G,R order + peaks = hand_estimation(oriImg) + canvas = util.draw_handpose(oriImg, peaks, True) + cv2.imshow('', canvas) + cv2.waitKey(0) \ No newline at end of file diff --git a/custom_controlnet_aux/open_pose/model.py b/custom_controlnet_aux/open_pose/model.py new file mode 100644 index 0000000000000000000000000000000000000000..6c3d47268986f8018b2c75307a7725d364b175fe --- /dev/null +++ b/custom_controlnet_aux/open_pose/model.py @@ -0,0 +1,217 @@ +import torch +from collections import OrderedDict + +import torch +import torch.nn as nn + +def make_layers(block, no_relu_layers): + layers = [] + for layer_name, v in block.items(): + if 'pool' in layer_name: + layer = nn.MaxPool2d(kernel_size=v[0], stride=v[1], + padding=v[2]) + layers.append((layer_name, layer)) + else: + conv2d = nn.Conv2d(in_channels=v[0], out_channels=v[1], + kernel_size=v[2], stride=v[3], + padding=v[4]) + layers.append((layer_name, conv2d)) + if layer_name not in no_relu_layers: + layers.append(('relu_'+layer_name, nn.ReLU(inplace=True))) + + return nn.Sequential(OrderedDict(layers)) + +class bodypose_model(nn.Module): + def __init__(self): + super(bodypose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv5_5_CPM_L1', 'conv5_5_CPM_L2', 'Mconv7_stage2_L1',\ + 'Mconv7_stage2_L2', 'Mconv7_stage3_L1', 'Mconv7_stage3_L2',\ + 'Mconv7_stage4_L1', 'Mconv7_stage4_L2', 'Mconv7_stage5_L1',\ + 'Mconv7_stage5_L2', 'Mconv7_stage6_L1', 'Mconv7_stage6_L1'] + blocks = {} + block0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3_CPM', [512, 256, 3, 1, 1]), + ('conv4_4_CPM', [256, 128, 3, 1, 1]) + ]) + + + # Stage 1 + block1_1 = OrderedDict([ + ('conv5_1_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L1', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L1', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L1', [512, 38, 1, 1, 0]) + ]) + + block1_2 = OrderedDict([ + ('conv5_1_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_2_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_3_CPM_L2', [128, 128, 3, 1, 1]), + ('conv5_4_CPM_L2', [128, 512, 1, 1, 0]), + ('conv5_5_CPM_L2', [512, 19, 1, 1, 0]) + ]) + blocks['block1_1'] = block1_1 + blocks['block1_2'] = block1_2 + + self.model0 = make_layers(block0, no_relu_layers) + + # Stages 2 - 6 + for i in range(2, 7): + blocks['block%d_1' % i] = OrderedDict([ + ('Mconv1_stage%d_L1' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L1' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L1' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L1' % i, [128, 38, 1, 1, 0]) + ]) + + blocks['block%d_2' % i] = OrderedDict([ + ('Mconv1_stage%d_L2' % i, [185, 128, 7, 1, 3]), + ('Mconv2_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d_L2' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d_L2' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d_L2' % i, [128, 19, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_1 = blocks['block1_1'] + self.model2_1 = blocks['block2_1'] + self.model3_1 = blocks['block3_1'] + self.model4_1 = blocks['block4_1'] + self.model5_1 = blocks['block5_1'] + self.model6_1 = blocks['block6_1'] + + self.model1_2 = blocks['block1_2'] + self.model2_2 = blocks['block2_2'] + self.model3_2 = blocks['block3_2'] + self.model4_2 = blocks['block4_2'] + self.model5_2 = blocks['block5_2'] + self.model6_2 = blocks['block6_2'] + + + def forward(self, x): + + out1 = self.model0(x) + + out1_1 = self.model1_1(out1) + out1_2 = self.model1_2(out1) + out2 = torch.cat([out1_1, out1_2, out1], 1) + + out2_1 = self.model2_1(out2) + out2_2 = self.model2_2(out2) + out3 = torch.cat([out2_1, out2_2, out1], 1) + + out3_1 = self.model3_1(out3) + out3_2 = self.model3_2(out3) + out4 = torch.cat([out3_1, out3_2, out1], 1) + + out4_1 = self.model4_1(out4) + out4_2 = self.model4_2(out4) + out5 = torch.cat([out4_1, out4_2, out1], 1) + + out5_1 = self.model5_1(out5) + out5_2 = self.model5_2(out5) + out6 = torch.cat([out5_1, out5_2, out1], 1) + + out6_1 = self.model6_1(out6) + out6_2 = self.model6_2(out6) + + return out6_1, out6_2 + +class handpose_model(nn.Module): + def __init__(self): + super(handpose_model, self).__init__() + + # these layers have no relu layer + no_relu_layers = ['conv6_2_CPM', 'Mconv7_stage2', 'Mconv7_stage3',\ + 'Mconv7_stage4', 'Mconv7_stage5', 'Mconv7_stage6'] + # stage 1 + block1_0 = OrderedDict([ + ('conv1_1', [3, 64, 3, 1, 1]), + ('conv1_2', [64, 64, 3, 1, 1]), + ('pool1_stage1', [2, 2, 0]), + ('conv2_1', [64, 128, 3, 1, 1]), + ('conv2_2', [128, 128, 3, 1, 1]), + ('pool2_stage1', [2, 2, 0]), + ('conv3_1', [128, 256, 3, 1, 1]), + ('conv3_2', [256, 256, 3, 1, 1]), + ('conv3_3', [256, 256, 3, 1, 1]), + ('conv3_4', [256, 256, 3, 1, 1]), + ('pool3_stage1', [2, 2, 0]), + ('conv4_1', [256, 512, 3, 1, 1]), + ('conv4_2', [512, 512, 3, 1, 1]), + ('conv4_3', [512, 512, 3, 1, 1]), + ('conv4_4', [512, 512, 3, 1, 1]), + ('conv5_1', [512, 512, 3, 1, 1]), + ('conv5_2', [512, 512, 3, 1, 1]), + ('conv5_3_CPM', [512, 128, 3, 1, 1]) + ]) + + block1_1 = OrderedDict([ + ('conv6_1_CPM', [128, 512, 1, 1, 0]), + ('conv6_2_CPM', [512, 22, 1, 1, 0]) + ]) + + blocks = {} + blocks['block1_0'] = block1_0 + blocks['block1_1'] = block1_1 + + # stage 2-6 + for i in range(2, 7): + blocks['block%d' % i] = OrderedDict([ + ('Mconv1_stage%d' % i, [150, 128, 7, 1, 3]), + ('Mconv2_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv3_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv4_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv5_stage%d' % i, [128, 128, 7, 1, 3]), + ('Mconv6_stage%d' % i, [128, 128, 1, 1, 0]), + ('Mconv7_stage%d' % i, [128, 22, 1, 1, 0]) + ]) + + for k in blocks.keys(): + blocks[k] = make_layers(blocks[k], no_relu_layers) + + self.model1_0 = blocks['block1_0'] + self.model1_1 = blocks['block1_1'] + self.model2 = blocks['block2'] + self.model3 = blocks['block3'] + self.model4 = blocks['block4'] + self.model5 = blocks['block5'] + self.model6 = blocks['block6'] + + def forward(self, x): + out1_0 = self.model1_0(x) + out1_1 = self.model1_1(out1_0) + concat_stage2 = torch.cat([out1_1, out1_0], 1) + out_stage2 = self.model2(concat_stage2) + concat_stage3 = torch.cat([out_stage2, out1_0], 1) + out_stage3 = self.model3(concat_stage3) + concat_stage4 = torch.cat([out_stage3, out1_0], 1) + out_stage4 = self.model4(concat_stage4) + concat_stage5 = torch.cat([out_stage4, out1_0], 1) + out_stage5 = self.model5(concat_stage5) + concat_stage6 = torch.cat([out_stage5, out1_0], 1) + out_stage6 = self.model6(concat_stage6) + return out_stage6 diff --git a/custom_controlnet_aux/open_pose/util.py b/custom_controlnet_aux/open_pose/util.py new file mode 100644 index 0000000000000000000000000000000000000000..a0851ca409863dcee4bf731a47b472992569dd68 --- /dev/null +++ b/custom_controlnet_aux/open_pose/util.py @@ -0,0 +1,383 @@ +import math +import numpy as np +import matplotlib +import cv2 +from typing import List, Tuple, Union + +from .body import BodyResult, Keypoint + +eps = 0.01 + + +def smart_resize(x, s): + Ht, Wt = s + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize(x[:, :, i], s) for i in range(Co)], axis=2) + + +def smart_resize_k(x, fx, fy): + if x.ndim == 2: + Ho, Wo = x.shape + Co = 1 + else: + Ho, Wo, Co = x.shape + Ht, Wt = Ho * fy, Wo * fx + if Co == 3 or Co == 1: + k = float(Ht + Wt) / float(Ho + Wo) + return cv2.resize(x, (int(Wt), int(Ht)), interpolation=cv2.INTER_AREA if k < 1 else cv2.INTER_LANCZOS4) + else: + return np.stack([smart_resize_k(x[:, :, i], fx, fy) for i in range(Co)], axis=2) + + +def padRightDownCorner(img, stride, padValue): + h = img.shape[0] + w = img.shape[1] + + pad = 4 * [None] + pad[0] = 0 # up + pad[1] = 0 # left + pad[2] = 0 if (h % stride == 0) else stride - (h % stride) # down + pad[3] = 0 if (w % stride == 0) else stride - (w % stride) # right + + img_padded = img + pad_up = np.tile(img_padded[0:1, :, :]*0 + padValue, (pad[0], 1, 1)) + img_padded = np.concatenate((pad_up, img_padded), axis=0) + pad_left = np.tile(img_padded[:, 0:1, :]*0 + padValue, (1, pad[1], 1)) + img_padded = np.concatenate((pad_left, img_padded), axis=1) + pad_down = np.tile(img_padded[-2:-1, :, :]*0 + padValue, (pad[2], 1, 1)) + img_padded = np.concatenate((img_padded, pad_down), axis=0) + pad_right = np.tile(img_padded[:, -2:-1, :]*0 + padValue, (1, pad[3], 1)) + img_padded = np.concatenate((img_padded, pad_right), axis=1) + + return img_padded, pad + + +def transfer(model, model_weights): + transfered_model_weights = {} + for weights_name in model.state_dict().keys(): + transfered_model_weights[weights_name] = model_weights['.'.join(weights_name.split('.')[1:])] + return transfered_model_weights + + +def draw_bodypose(canvas: np.ndarray, keypoints: List[Keypoint]) -> np.ndarray: + """ + Draw keypoints and limbs representing body pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the body pose. + keypoints (List[Keypoint]): A list of Keypoint objects representing the body keypoints to be drawn. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn body pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + H, W, C = canvas.shape + stickwidth = 4 + + limbSeq = [ + [2, 3], [2, 6], [3, 4], [4, 5], + [6, 7], [7, 8], [2, 9], [9, 10], + [10, 11], [2, 12], [12, 13], [13, 14], + [2, 1], [1, 15], [15, 17], [1, 16], + [16, 18], + ] + + colors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0], \ + [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255], \ + [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]] + + for (k1_index, k2_index), color in zip(limbSeq, colors): + keypoint1 = keypoints[k1_index - 1] + keypoint2 = keypoints[k2_index - 1] + + if keypoint1 is None or keypoint2 is None: + continue + + Y = np.array([keypoint1.x, keypoint2.x]) * float(W) + X = np.array([keypoint1.y, keypoint2.y]) * float(H) + mX = np.mean(X) + mY = np.mean(Y) + length = ((X[0] - X[1]) ** 2 + (Y[0] - Y[1]) ** 2) ** 0.5 + angle = math.degrees(math.atan2(X[0] - X[1], Y[0] - Y[1])) + polygon = cv2.ellipse2Poly((int(mY), int(mX)), (int(length / 2), stickwidth), int(angle), 0, 360, 1) + cv2.fillConvexPoly(canvas, polygon, [int(float(c) * 0.6) for c in color]) + + for keypoint, color in zip(keypoints, colors): + if keypoint is None: + continue + + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + cv2.circle(canvas, (int(x), int(y)), 4, color, thickness=-1) + + return canvas + + +def draw_handpose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints and connections representing hand pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the hand pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the hand keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn hand pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + + edges = [[0, 1], [1, 2], [2, 3], [3, 4], [0, 5], [5, 6], [6, 7], [7, 8], [0, 9], [9, 10], \ + [10, 11], [11, 12], [0, 13], [13, 14], [14, 15], [15, 16], [0, 17], [17, 18], [18, 19], [19, 20]] + + for ie, (e1, e2) in enumerate(edges): + k1 = keypoints[e1] + k2 = keypoints[e2] + if k1 is None or k2 is None: + continue + + x1 = int(k1.x * W) + y1 = int(k1.y * H) + x2 = int(k2.x * W) + y2 = int(k2.y * H) + if x1 > eps and y1 > eps and x2 > eps and y2 > eps: + cv2.line(canvas, (x1, y1), (x2, y2), matplotlib.colors.hsv_to_rgb([ie / float(len(edges)), 1.0, 1.0]) * 255, thickness=2) + + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 4, (0, 0, 255), thickness=-1) + return canvas + + +def draw_facepose(canvas: np.ndarray, keypoints: Union[List[Keypoint], None]) -> np.ndarray: + """ + Draw keypoints representing face pose on a given canvas. + + Args: + canvas (np.ndarray): A 3D numpy array representing the canvas (image) on which to draw the face pose. + keypoints (List[Keypoint]| None): A list of Keypoint objects representing the face keypoints to be drawn + or None if no keypoints are present. + + Returns: + np.ndarray: A 3D numpy array representing the modified canvas with the drawn face pose. + + Note: + The function expects the x and y coordinates of the keypoints to be normalized between 0 and 1. + """ + if not keypoints: + return canvas + + H, W, C = canvas.shape + for keypoint in keypoints: + x, y = keypoint.x, keypoint.y + x = int(x * W) + y = int(y * H) + if x > eps and y > eps: + cv2.circle(canvas, (x, y), 3, (255, 255, 255), thickness=-1) + return canvas + + +# detect hand according to body pose keypoints +# please refer to https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/src/openpose/hand/handDetector.cpp +def handDetect(body: BodyResult, oriImg) -> List[Tuple[int, int, int, bool]]: + """ + Detect hands in the input body pose keypoints and calculate the bounding box for each hand. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + List[Tuple[int, int, int, bool]]: A list of tuples, each containing the coordinates (x, y) of the top-left + corner of the bounding box, the width (height) of the bounding box, and + a boolean flag indicating whether the hand is a left hand (True) or a + right hand (False). + + Notes: + - The width and height of the bounding boxes are equal since the network requires squared input. + - The minimum bounding box size is 20 pixels. + """ + ratioWristElbow = 0.33 + detect_result = [] + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + # right hand: wrist 4, elbow 3, shoulder 2 + # left hand: wrist 7, elbow 6, shoulder 5 + left_shoulder = keypoints[5] + left_elbow = keypoints[6] + left_wrist = keypoints[7] + right_shoulder = keypoints[2] + right_elbow = keypoints[3] + right_wrist = keypoints[4] + + # if any of three not detected + has_left = all(keypoint is not None for keypoint in (left_shoulder, left_elbow, left_wrist)) + has_right = all(keypoint is not None for keypoint in (right_shoulder, right_elbow, right_wrist)) + if not (has_left or has_right): + return [] + + hands = [] + #left hand + if has_left: + hands.append([ + left_shoulder.x, left_shoulder.y, + left_elbow.x, left_elbow.y, + left_wrist.x, left_wrist.y, + True + ]) + # right hand + if has_right: + hands.append([ + right_shoulder.x, right_shoulder.y, + right_elbow.x, right_elbow.y, + right_wrist.x, right_wrist.y, + False + ]) + + for x1, y1, x2, y2, x3, y3, is_left in hands: + # pos_hand = pos_wrist + ratio * (pos_wrist - pos_elbox) = (1 + ratio) * pos_wrist - ratio * pos_elbox + # handRectangle.x = posePtr[wrist*3] + ratioWristElbow * (posePtr[wrist*3] - posePtr[elbow*3]); + # handRectangle.y = posePtr[wrist*3+1] + ratioWristElbow * (posePtr[wrist*3+1] - posePtr[elbow*3+1]); + # const auto distanceWristElbow = getDistance(poseKeypoints, person, wrist, elbow); + # const auto distanceElbowShoulder = getDistance(poseKeypoints, person, elbow, shoulder); + # handRectangle.width = 1.5f * fastMax(distanceWristElbow, 0.9f * distanceElbowShoulder); + x = x3 + ratioWristElbow * (x3 - x2) + y = y3 + ratioWristElbow * (y3 - y2) + distanceWristElbow = math.sqrt((x3 - x2) ** 2 + (y3 - y2) ** 2) + distanceElbowShoulder = math.sqrt((x2 - x1) ** 2 + (y2 - y1) ** 2) + width = 1.5 * max(distanceWristElbow, 0.9 * distanceElbowShoulder) + # x-y refers to the center --> offset to topLeft point + # handRectangle.x -= handRectangle.width / 2.f; + # handRectangle.y -= handRectangle.height / 2.f; + x -= width / 2 + y -= width / 2 # width = height + # overflow the image + if x < 0: x = 0 + if y < 0: y = 0 + width1 = width + width2 = width + if x + width > image_width: width1 = image_width - x + if y + width > image_height: width2 = image_height - y + width = min(width1, width2) + # the max hand box value is 20 pixels + if width >= 20: + detect_result.append((int(x), int(y), int(width), is_left)) + + ''' + return value: [[x, y, w, True if left hand else False]]. + width=height since the network require squared input. + x, y is the coordinate of top left + ''' + return detect_result + + +# Written by Lvmin +def faceDetect(body: BodyResult, oriImg) -> Union[Tuple[int, int, int], None]: + """ + Detect the face in the input body pose keypoints and calculate the bounding box for the face. + + Args: + body (BodyResult): A BodyResult object containing the detected body pose keypoints. + oriImg (numpy.ndarray): A 3D numpy array representing the original input image. + + Returns: + Tuple[int, int, int] | None: A tuple containing the coordinates (x, y) of the top-left corner of the + bounding box and the width (height) of the bounding box, or None if the + face is not detected or the bounding box width is less than 20 pixels. + + Notes: + - The width and height of the bounding box are equal. + - The minimum bounding box size is 20 pixels. + """ + # left right eye ear 14 15 16 17 + image_height, image_width = oriImg.shape[0:2] + + keypoints = body.keypoints + head = keypoints[0] + left_eye = keypoints[14] + right_eye = keypoints[15] + left_ear = keypoints[16] + right_ear = keypoints[17] + + if head is None or all(keypoint is None for keypoint in (left_eye, right_eye, left_ear, right_ear)): + return None + + width = 0.0 + x0, y0 = head.x, head.y + + if left_eye is not None: + x1, y1 = left_eye.x, left_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if right_eye is not None: + x1, y1 = right_eye.x, right_eye.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 3.0) + + if left_ear is not None: + x1, y1 = left_ear.x, left_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + if right_ear is not None: + x1, y1 = right_ear.x, right_ear.y + d = max(abs(x0 - x1), abs(y0 - y1)) + width = max(width, d * 1.5) + + x, y = x0, y0 + + x -= width + y -= width + + if x < 0: + x = 0 + + if y < 0: + y = 0 + + width1 = width * 2 + width2 = width * 2 + + if x + width > image_width: + width1 = image_width - x + + if y + width > image_height: + width2 = image_height - y + + width = min(width1, width2) + + if width >= 20: + return int(x), int(y), int(width) + else: + return None + + +# get max index of 2d array +def npmax(array): + arrayindex = array.argmax(1) + arrayvalue = array.max(1) + i = arrayvalue.argmax() + j = arrayindex[i] + return i, j \ No newline at end of file diff --git a/custom_controlnet_aux/pidi/LICENSE b/custom_controlnet_aux/pidi/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..913b6cf92c19d37b6ee4f7bc99c65f655e7f840c --- /dev/null +++ b/custom_controlnet_aux/pidi/LICENSE @@ -0,0 +1,21 @@ +It is just for research purpose, and commercial use should be contacted with authors first. + +Copyright (c) 2021 Zhuo Su + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/pidi/__init__.py b/custom_controlnet_aux/pidi/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81fb25ac6bd2a2341c0d7163ab4c08f971820193 --- /dev/null +++ b/custom_controlnet_aux/pidi/__init__.py @@ -0,0 +1,64 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, nms, resize_image_with_pad, safe_step,common_input_validate, custom_hf_download, HF_MODEL_NAME +from .model import pidinet + + +class PidiNetDetector: + def __init__(self, netNetwork): + self.netNetwork = netNetwork + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="table5_pidinet.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + netNetwork = pidinet() + netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) + netNetwork.eval() + + return cls(netNetwork) + + def to(self, device): + self.netNetwork.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, safe=False, output_type="pil", scribble=False, apply_filter=False, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + detected_map = detected_map[:, :, ::-1].copy() + with torch.no_grad(): + image_pidi = torch.from_numpy(detected_map).float().to(self.device) + image_pidi = image_pidi / 255.0 + image_pidi = rearrange(image_pidi, 'h w c -> 1 c h w') + edge = self.netNetwork(image_pidi)[-1] + edge = edge.cpu().numpy() + if apply_filter: + edge = edge > 0.5 + if safe: + edge = safe_step(edge) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = edge[0, 0] + + if scribble: + detected_map = nms(detected_map, 127, 3.0) + detected_map = cv2.GaussianBlur(detected_map, (0, 0), 3.0) + detected_map[detected_map > 4] = 255 + detected_map[detected_map < 255] = 0 + + detected_map = HWC3(remove_pad(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/pidi/model.py b/custom_controlnet_aux/pidi/model.py new file mode 100644 index 0000000000000000000000000000000000000000..16595b35a4f75a6d2b0e832e24b6e11706d77326 --- /dev/null +++ b/custom_controlnet_aux/pidi/model.py @@ -0,0 +1,681 @@ +""" +Author: Zhuo Su, Wenzhe Liu +Date: Feb 18, 2021 +""" + +import math + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + +nets = { + 'baseline': { + 'layer0': 'cv', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'c-v15': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'a-v15': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'r-v15': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cv', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cv', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cv', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cvvv4': { + 'layer0': 'cd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'avvv4': { + 'layer0': 'ad', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'rvvv4': { + 'layer0': 'rd', + 'layer1': 'cv', + 'layer2': 'cv', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'cv', + 'layer6': 'cv', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'cv', + 'layer10': 'cv', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'cv', + 'layer14': 'cv', + 'layer15': 'cv', + }, + 'cccv4': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cv', + }, + 'aaav4': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'cv', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'cv', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'cv', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'cv', + }, + 'rrrv4': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'cv', + }, + 'c16': { + 'layer0': 'cd', + 'layer1': 'cd', + 'layer2': 'cd', + 'layer3': 'cd', + 'layer4': 'cd', + 'layer5': 'cd', + 'layer6': 'cd', + 'layer7': 'cd', + 'layer8': 'cd', + 'layer9': 'cd', + 'layer10': 'cd', + 'layer11': 'cd', + 'layer12': 'cd', + 'layer13': 'cd', + 'layer14': 'cd', + 'layer15': 'cd', + }, + 'a16': { + 'layer0': 'ad', + 'layer1': 'ad', + 'layer2': 'ad', + 'layer3': 'ad', + 'layer4': 'ad', + 'layer5': 'ad', + 'layer6': 'ad', + 'layer7': 'ad', + 'layer8': 'ad', + 'layer9': 'ad', + 'layer10': 'ad', + 'layer11': 'ad', + 'layer12': 'ad', + 'layer13': 'ad', + 'layer14': 'ad', + 'layer15': 'ad', + }, + 'r16': { + 'layer0': 'rd', + 'layer1': 'rd', + 'layer2': 'rd', + 'layer3': 'rd', + 'layer4': 'rd', + 'layer5': 'rd', + 'layer6': 'rd', + 'layer7': 'rd', + 'layer8': 'rd', + 'layer9': 'rd', + 'layer10': 'rd', + 'layer11': 'rd', + 'layer12': 'rd', + 'layer13': 'rd', + 'layer14': 'rd', + 'layer15': 'rd', + }, + 'carv4': { + 'layer0': 'cd', + 'layer1': 'ad', + 'layer2': 'rd', + 'layer3': 'cv', + 'layer4': 'cd', + 'layer5': 'ad', + 'layer6': 'rd', + 'layer7': 'cv', + 'layer8': 'cd', + 'layer9': 'ad', + 'layer10': 'rd', + 'layer11': 'cv', + 'layer12': 'cd', + 'layer13': 'ad', + 'layer14': 'rd', + 'layer15': 'cv', + }, + } + +def createConvFunc(op_type): + assert op_type in ['cv', 'cd', 'ad', 'rd'], 'unknown op type: %s' % str(op_type) + if op_type == 'cv': + return F.conv2d + + if op_type == 'cd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for cd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for cd_conv should be 3x3' + assert padding == dilation, 'padding for cd_conv set wrong' + + weights_c = weights.sum(dim=[2, 3], keepdim=True) + yc = F.conv2d(x, weights_c, stride=stride, padding=0, groups=groups) + y = F.conv2d(x, weights, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y - yc + return func + elif op_type == 'ad': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for ad_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for ad_conv should be 3x3' + assert padding == dilation, 'padding for ad_conv set wrong' + + shape = weights.shape + weights = weights.view(shape[0], shape[1], -1) + weights_conv = (weights - weights[:, :, [3, 0, 1, 6, 4, 2, 7, 8, 5]]).view(shape) # clock-wise + y = F.conv2d(x, weights_conv, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + elif op_type == 'rd': + def func(x, weights, bias=None, stride=1, padding=0, dilation=1, groups=1): + assert dilation in [1, 2], 'dilation for rd_conv should be in 1 or 2' + assert weights.size(2) == 3 and weights.size(3) == 3, 'kernel size for rd_conv should be 3x3' + padding = 2 * dilation + + shape = weights.shape + if weights.is_cuda: + buffer = torch.cuda.FloatTensor(shape[0], shape[1], 5 * 5).fill_(0) + else: + buffer = torch.zeros(shape[0], shape[1], 5 * 5).to(weights.device) + weights = weights.view(shape[0], shape[1], -1) + buffer[:, :, [0, 2, 4, 10, 14, 20, 22, 24]] = weights[:, :, 1:] + buffer[:, :, [6, 7, 8, 11, 13, 16, 17, 18]] = -weights[:, :, 1:] + buffer[:, :, 12] = 0 + buffer = buffer.view(shape[0], shape[1], 5, 5) + y = F.conv2d(x, buffer, bias, stride=stride, padding=padding, dilation=dilation, groups=groups) + return y + return func + else: + print('impossible to be here unless you force that') + return None + +class Conv2d(nn.Module): + def __init__(self, pdc, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=False): + super(Conv2d, self).__init__() + if in_channels % groups != 0: + raise ValueError('in_channels must be divisible by groups') + if out_channels % groups != 0: + raise ValueError('out_channels must be divisible by groups') + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, kernel_size, kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + self.pdc = pdc + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input): + + return self.pdc(input, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups) + +class CSAM(nn.Module): + """ + Compact Spatial Attention Module + """ + def __init__(self, channels): + super(CSAM, self).__init__() + + mid_channels = 4 + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(channels, mid_channels, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(mid_channels, 1, kernel_size=3, padding=1, bias=False) + self.sigmoid = nn.Sigmoid() + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + y = self.relu1(x) + y = self.conv1(y) + y = self.conv2(y) + y = self.sigmoid(y) + + return x * y + +class CDCM(nn.Module): + """ + Compact Dilation Convolution based Module + """ + def __init__(self, in_channels, out_channels): + super(CDCM, self).__init__() + + self.relu1 = nn.ReLU() + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) + self.conv2_1 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=5, padding=5, bias=False) + self.conv2_2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=7, padding=7, bias=False) + self.conv2_3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=9, padding=9, bias=False) + self.conv2_4 = nn.Conv2d(out_channels, out_channels, kernel_size=3, dilation=11, padding=11, bias=False) + nn.init.constant_(self.conv1.bias, 0) + + def forward(self, x): + x = self.relu1(x) + x = self.conv1(x) + x1 = self.conv2_1(x) + x2 = self.conv2_2(x) + x3 = self.conv2_3(x) + x4 = self.conv2_4(x) + return x1 + x2 + x3 + x4 + + +class MapReduce(nn.Module): + """ + Reduce feature maps into a single edge map + """ + def __init__(self, channels): + super(MapReduce, self).__init__() + self.conv = nn.Conv2d(channels, 1, kernel_size=1, padding=0) + nn.init.constant_(self.conv.bias, 0) + + def forward(self, x): + return self.conv(x) + + +class PDCBlock(nn.Module): + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock, self).__init__() + self.stride=stride + + self.stride=stride + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + self.conv1 = Conv2d(pdc, inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PDCBlock_converted(nn.Module): + """ + CPDC, APDC can be converted to vanilla 3x3 convolution + RPDC can be converted to vanilla 5x5 convolution + """ + def __init__(self, pdc, inplane, ouplane, stride=1): + super(PDCBlock_converted, self).__init__() + self.stride=stride + + if self.stride > 1: + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.shortcut = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0) + if pdc == 'rd': + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=5, padding=2, groups=inplane, bias=False) + else: + self.conv1 = nn.Conv2d(inplane, inplane, kernel_size=3, padding=1, groups=inplane, bias=False) + self.relu2 = nn.ReLU() + self.conv2 = nn.Conv2d(inplane, ouplane, kernel_size=1, padding=0, bias=False) + + def forward(self, x): + if self.stride > 1: + x = self.pool(x) + y = self.conv1(x) + y = self.relu2(y) + y = self.conv2(y) + if self.stride > 1: + x = self.shortcut(x) + y = y + x + return y + +class PiDiNet(nn.Module): + def __init__(self, inplane, pdcs, dil=None, sa=False, convert=False): + super(PiDiNet, self).__init__() + self.sa = sa + if dil is not None: + assert isinstance(dil, int), 'dil should be an int' + self.dil = dil + + self.fuseplanes = [] + + self.inplane = inplane + if convert: + if pdcs[0] == 'rd': + init_kernel_size = 5 + init_padding = 2 + else: + init_kernel_size = 3 + init_padding = 1 + self.init_block = nn.Conv2d(3, self.inplane, + kernel_size=init_kernel_size, padding=init_padding, bias=False) + block_class = PDCBlock_converted + else: + self.init_block = Conv2d(pdcs[0], 3, self.inplane, kernel_size=3, padding=1) + block_class = PDCBlock + + self.block1_1 = block_class(pdcs[1], self.inplane, self.inplane) + self.block1_2 = block_class(pdcs[2], self.inplane, self.inplane) + self.block1_3 = block_class(pdcs[3], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block2_1 = block_class(pdcs[4], inplane, self.inplane, stride=2) + self.block2_2 = block_class(pdcs[5], self.inplane, self.inplane) + self.block2_3 = block_class(pdcs[6], self.inplane, self.inplane) + self.block2_4 = block_class(pdcs[7], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 2C + + inplane = self.inplane + self.inplane = self.inplane * 2 + self.block3_1 = block_class(pdcs[8], inplane, self.inplane, stride=2) + self.block3_2 = block_class(pdcs[9], self.inplane, self.inplane) + self.block3_3 = block_class(pdcs[10], self.inplane, self.inplane) + self.block3_4 = block_class(pdcs[11], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.block4_1 = block_class(pdcs[12], self.inplane, self.inplane, stride=2) + self.block4_2 = block_class(pdcs[13], self.inplane, self.inplane) + self.block4_3 = block_class(pdcs[14], self.inplane, self.inplane) + self.block4_4 = block_class(pdcs[15], self.inplane, self.inplane) + self.fuseplanes.append(self.inplane) # 4C + + self.conv_reduces = nn.ModuleList() + if self.sa and self.dil is not None: + self.attentions = nn.ModuleList() + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.attentions.append(CSAM(self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + elif self.sa: + self.attentions = nn.ModuleList() + for i in range(4): + self.attentions.append(CSAM(self.fuseplanes[i])) + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + elif self.dil is not None: + self.dilations = nn.ModuleList() + for i in range(4): + self.dilations.append(CDCM(self.fuseplanes[i], self.dil)) + self.conv_reduces.append(MapReduce(self.dil)) + else: + for i in range(4): + self.conv_reduces.append(MapReduce(self.fuseplanes[i])) + + self.classifier = nn.Conv2d(4, 1, kernel_size=1) # has bias + nn.init.constant_(self.classifier.weight, 0.25) + nn.init.constant_(self.classifier.bias, 0) + + # print('initialization done') + + def get_weights(self): + conv_weights = [] + bn_weights = [] + relu_weights = [] + for pname, p in self.named_parameters(): + if 'bn' in pname: + bn_weights.append(p) + elif 'relu' in pname: + relu_weights.append(p) + else: + conv_weights.append(p) + + return conv_weights, bn_weights, relu_weights + + def forward(self, x): + H, W = x.size()[2:] + + x = self.init_block(x) + + x1 = self.block1_1(x) + x1 = self.block1_2(x1) + x1 = self.block1_3(x1) + + x2 = self.block2_1(x1) + x2 = self.block2_2(x2) + x2 = self.block2_3(x2) + x2 = self.block2_4(x2) + + x3 = self.block3_1(x2) + x3 = self.block3_2(x3) + x3 = self.block3_3(x3) + x3 = self.block3_4(x3) + + x4 = self.block4_1(x3) + x4 = self.block4_2(x4) + x4 = self.block4_3(x4) + x4 = self.block4_4(x4) + + x_fuses = [] + if self.sa and self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](self.dilations[i](xi))) + elif self.sa: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.attentions[i](xi)) + elif self.dil is not None: + for i, xi in enumerate([x1, x2, x3, x4]): + x_fuses.append(self.dilations[i](xi)) + else: + x_fuses = [x1, x2, x3, x4] + + e1 = self.conv_reduces[0](x_fuses[0]) + e1 = F.interpolate(e1, (H, W), mode="bilinear", align_corners=False) + + e2 = self.conv_reduces[1](x_fuses[1]) + e2 = F.interpolate(e2, (H, W), mode="bilinear", align_corners=False) + + e3 = self.conv_reduces[2](x_fuses[2]) + e3 = F.interpolate(e3, (H, W), mode="bilinear", align_corners=False) + + e4 = self.conv_reduces[3](x_fuses[3]) + e4 = F.interpolate(e4, (H, W), mode="bilinear", align_corners=False) + + outputs = [e1, e2, e3, e4] + + output = self.classifier(torch.cat(outputs, dim=1)) + #if not self.training: + # return torch.sigmoid(output) + + outputs.append(output) + outputs = [torch.sigmoid(r) for r in outputs] + return outputs + +def config_model(model): + model_options = list(nets.keys()) + assert model in model_options, \ + 'unrecognized model, please choose from %s' % str(model_options) + + # print(str(nets[model])) + + pdcs = [] + for i in range(16): + layer_name = 'layer%d' % i + op = nets[model][layer_name] + pdcs.append(createConvFunc(op)) + + return pdcs + +def pidinet(): + pdcs = config_model('carv4') + dil = 24 #if args.dil else None + return PiDiNet(60, pdcs, dil=dil, sa=True) + + +if __name__ == '__main__': + model = pidinet() + ckp = torch.load('table5_pidinet.pth')['state_dict'] + model.load_state_dict({k.replace('module.',''):v for k, v in ckp.items()}) + im = cv2.imread('examples/test_my/cat_v4.png') + im = img2tensor(im).unsqueeze(0)/255. + res = model(im)[-1] + res = res>0.5 + res = res.float() + res = (res[0,0].cpu().data.numpy()*255.).astype(np.uint8) + print(res.shape) + cv2.imwrite('edge.png', res) diff --git a/custom_controlnet_aux/processor.py b/custom_controlnet_aux/processor.py new file mode 100644 index 0000000000000000000000000000000000000000..e8e087e2179adc7e03b115744eb9ef1e6f747bdb --- /dev/null +++ b/custom_controlnet_aux/processor.py @@ -0,0 +1,147 @@ +""" +This file contains a Processor that can be used to process images with controlnet aux processors +""" +import io +import logging +from typing import Dict, Optional, Union + +from PIL import Image + +from custom_controlnet_aux import (CannyDetector, ContentShuffleDetector, HEDdetector, + LeresDetector, LineartAnimeDetector, + LineartDetector, MediapipeFaceDetector, + MidasDetector, MLSDdetector, NormalBaeDetector, + OpenposeDetector, PidiNetDetector, ZoeDetector, TileDetector) + +LOGGER = logging.getLogger(__name__) + + +MODELS = { + # checkpoint models + 'scribble_hed': {'class': HEDdetector, 'checkpoint': True}, + 'softedge_hed': {'class': HEDdetector, 'checkpoint': True}, + 'scribble_hedsafe': {'class': HEDdetector, 'checkpoint': True}, + 'softedge_hedsafe': {'class': HEDdetector, 'checkpoint': True}, + 'depth_midas': {'class': MidasDetector, 'checkpoint': True}, + 'mlsd': {'class': MLSDdetector, 'checkpoint': True}, + 'openpose': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_face': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_faceonly': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_full': {'class': OpenposeDetector, 'checkpoint': True}, + 'openpose_hand': {'class': OpenposeDetector, 'checkpoint': True}, + 'scribble_pidinet': {'class': PidiNetDetector, 'checkpoint': True}, + 'softedge_pidinet': {'class': PidiNetDetector, 'checkpoint': True}, + 'scribble_pidsafe': {'class': PidiNetDetector, 'checkpoint': True}, + 'softedge_pidsafe': {'class': PidiNetDetector, 'checkpoint': True}, + 'normal_bae': {'class': NormalBaeDetector, 'checkpoint': True}, + 'lineart_coarse': {'class': LineartDetector, 'checkpoint': True}, + 'lineart_realistic': {'class': LineartDetector, 'checkpoint': True}, + 'lineart_anime': {'class': LineartAnimeDetector, 'checkpoint': True}, + 'depth_zoe': {'class': ZoeDetector, 'checkpoint': True}, + 'depth_leres': {'class': LeresDetector, 'checkpoint': True}, + 'depth_leres++': {'class': LeresDetector, 'checkpoint': True}, + # instantiate + 'shuffle': {'class': ContentShuffleDetector, 'checkpoint': False}, + 'mediapipe_face': {'class': MediapipeFaceDetector, 'checkpoint': False}, + 'canny': {'class': CannyDetector, 'checkpoint': False}, + 'tile': {'class': TileDetector, 'checkpoint': False}, +} + + +MODEL_PARAMS = { + 'scribble_hed': {'scribble': True}, + 'softedge_hed': {'scribble': False}, + 'scribble_hedsafe': {'scribble': True, 'safe': True}, + 'softedge_hedsafe': {'scribble': False, 'safe': True}, + 'depth_midas': {}, + 'mlsd': {}, + 'openpose': {'include_body': True, 'include_hand': False, 'include_face': False}, + 'openpose_face': {'include_body': True, 'include_hand': False, 'include_face': True}, + 'openpose_faceonly': {'include_body': False, 'include_hand': False, 'include_face': True}, + 'openpose_full': {'include_body': True, 'include_hand': True, 'include_face': True}, + 'openpose_hand': {'include_body': False, 'include_hand': True, 'include_face': False}, + 'scribble_pidinet': {'safe': False, 'scribble': True}, + 'softedge_pidinet': {'safe': False, 'scribble': False}, + 'scribble_pidsafe': {'safe': True, 'scribble': True}, + 'softedge_pidsafe': {'safe': True, 'scribble': False}, + 'normal_bae': {}, + 'lineart_realistic': {'coarse': False}, + 'lineart_coarse': {'coarse': True}, + 'lineart_anime': {}, + 'canny': {}, + 'shuffle': {}, + 'depth_zoe': {}, + 'depth_leres': {'boost': False}, + 'depth_leres++': {'boost': True}, + 'mediapipe_face': {}, + 'tile': {}, +} + +CHOICES = f"Choices for the processor are {list(MODELS.keys())}" + + +class Processor: + def __init__(self, processor_id: str, params: Optional[Dict] = None) -> None: + """Processor that can be used to process images with controlnet aux processors + + Args: + processor_id (str): processor name, options are 'hed, midas, mlsd, openpose, + pidinet, normalbae, lineart, lineart_coarse, lineart_anime, + canny, content_shuffle, zoe, mediapipe_face, tile' + params (Optional[Dict]): parameters for the processor + """ + LOGGER.info("Loading %s".format(processor_id)) + + if processor_id not in MODELS: + raise ValueError(f"{processor_id} is not a valid processor id. Please make sure to choose one of {', '.join(MODELS.keys())}") + + self.processor_id = processor_id + self.processor = self.load_processor(self.processor_id) + + # load default params + self.params = MODEL_PARAMS[self.processor_id] + # update with user params + if params: + self.params.update(params) + + def load_processor(self, processor_id: str) -> 'Processor': + """Load controlnet aux processors + + Args: + processor_id (str): processor name + + Returns: + Processor: controlnet aux processor + """ + processor = MODELS[processor_id]['class'] + + # check if the proecssor is a checkpoint model + if MODELS[processor_id]['checkpoint']: + processor = processor.from_pretrained("lllyasviel/Annotators") + else: + processor = processor() + return processor + + def __call__(self, image: Union[Image.Image, bytes], + to_pil: bool = True) -> Union[Image.Image, bytes]: + """processes an image with a controlnet aux processor + + Args: + image (Union[Image.Image, bytes]): input image in bytes or PIL Image + to_pil (bool): whether to return bytes or PIL Image + + Returns: + Union[Image.Image, bytes]: processed image in bytes or PIL Image + """ + # check if bytes or PIL Image + if isinstance(image, bytes): + image = Image.open(io.BytesIO(image)).convert("RGB") + + processed_image = self.processor(image, **self.params) + + if to_pil: + return processed_image + else: + output_bytes = io.BytesIO() + processed_image.save(output_bytes, format='JPEG') + return output_bytes.getvalue() diff --git a/custom_controlnet_aux/recolor/__init__.py b/custom_controlnet_aux/recolor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..274ce3a73358450e4fe203cf9f913ef6f88fdb50 --- /dev/null +++ b/custom_controlnet_aux/recolor/__init__.py @@ -0,0 +1,39 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3 + +#https://github.com/Mikubill/sd-webui-controlnet/blob/416c345072c9c2066101e225964e3986abe6945e/scripts/processor.py#L639 +def recolor_luminance(img, thr_a=1.0): + result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2LAB) + result = result[:, :, 0].astype(np.float32) / 255.0 + result = result ** thr_a + result = (result * 255.0).clip(0, 255).astype(np.uint8) + result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB) + return result + + +def recolor_intensity(img, thr_a=1.0): + result = cv2.cvtColor(HWC3(img), cv2.COLOR_BGR2HSV) + result = result[:, :, 2].astype(np.float32) / 255.0 + result = result ** thr_a + result = (result * 255.0).clip(0, 255).astype(np.uint8) + result = cv2.cvtColor(result, cv2.COLOR_GRAY2RGB) + return result + +recolor_methods = { + "luminance": recolor_luminance, + "intensity": recolor_intensity +} + +class Recolorizer: + def __call__(self, input_image=None, mode="luminance", gamma_correction=1.0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + assert mode in recolor_methods.keys() + detected_map = recolor_methods[mode](input_image, gamma_correction) + detected_map = HWC3(remove_pad(detected_map)) + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + return detected_map diff --git a/custom_controlnet_aux/sam/__init__.py b/custom_controlnet_aux/sam/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4e5c5090a867c3f87c96c04043a0bc1cfe14cecb --- /dev/null +++ b/custom_controlnet_aux/sam/__init__.py @@ -0,0 +1,74 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import os +import warnings +from typing import Union + +import cv2 +import numpy as np +import torch +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, SAM_MODEL_NAME +from .automatic_mask_generator import SamAutomaticMaskGenerator +from .build_sam import sam_model_registry + + +class SamDetector: + def __init__(self, mask_generator: SamAutomaticMaskGenerator): + self.mask_generator = mask_generator + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=SAM_MODEL_NAME, model_type="vit_t", filename="mobile_sam.pt", subfolder=None): + """ + Possible model_type : vit_h, vit_l, vit_b, vit_t + download weights from https://github.com/facebookresearch/segment-anything + """ + model_path = custom_hf_download(pretrained_model_or_path, filename) + + sam = sam_model_registry[model_type](checkpoint=model_path) + mask_generator = SamAutomaticMaskGenerator(sam) + + return cls(mask_generator) + + def to(self, device): + model = self.mask_generator.predictor.model.to(device) + model.train(False) #Update attention_bias in https://github.com/Fannovel16/comfyui_controlnet_aux/blob/main/src/custom_controlnet_aux/segment_anything/modeling/tiny_vit_sam.py#L251 + self.mask_generator = SamAutomaticMaskGenerator(model) + return self + + + def show_anns(self, anns): + if len(anns) == 0: + return + sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True) + h, w = anns[0]['segmentation'].shape + final_img = Image.fromarray(np.zeros((h, w, 3), dtype=np.uint8), mode="RGB") + for ann in sorted_anns: + m = ann['segmentation'] + img = np.empty((m.shape[0], m.shape[1], 3), dtype=np.uint8) + for i in range(3): + img[:,:,i] = np.random.randint(255, dtype=np.uint8) + final_img.paste(Image.fromarray(img, mode="RGB"), (0, 0), Image.fromarray(np.uint8(m*255))) + + return np.array(final_img, dtype=np.uint8) + + def __call__(self, input_image: Union[np.ndarray, Image.Image]=None, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", **kwargs) -> Image.Image: + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + # Generate Masks + masks = self.mask_generator.generate(input_image) + # Create map + map = self.show_anns(masks) + + detected_map = HWC3(remove_pad(map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/sam/automatic_mask_generator.py b/custom_controlnet_aux/sam/automatic_mask_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d7ac3589d81890aca612023c85b46c0c8176195d --- /dev/null +++ b/custom_controlnet_aux/sam/automatic_mask_generator.py @@ -0,0 +1,372 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torchvision.ops.boxes import batched_nms, box_area # type: ignore + +from typing import Any, Dict, List, Optional, Tuple + +from .modeling import Sam +from .predictor import SamPredictor +from .utils.amg import ( + MaskData, + area_from_rle, + batch_iterator, + batched_mask_to_box, + box_xyxy_to_xywh, + build_all_layer_point_grids, + calculate_stability_score, + coco_encode_rle, + generate_crop_boxes, + is_box_near_crop_edge, + mask_to_rle_pytorch, + remove_small_regions, + rle_to_mask, + uncrop_boxes_xyxy, + uncrop_masks, + uncrop_points, +) + + +class SamAutomaticMaskGenerator: + def __init__( + self, + model: Sam, + points_per_side: Optional[int] = 32, + points_per_batch: int = 64, + pred_iou_thresh: float = 0.88, + stability_score_thresh: float = 0.95, + stability_score_offset: float = 1.0, + box_nms_thresh: float = 0.7, + crop_n_layers: int = 0, + crop_nms_thresh: float = 0.7, + crop_overlap_ratio: float = 512 / 1500, + crop_n_points_downscale_factor: int = 1, + point_grids: Optional[List[np.ndarray]] = None, + min_mask_region_area: int = 0, + output_mode: str = "binary_mask", + ) -> None: + """ + Using a SAM model, generates masks for the entire image. + Generates a grid of point prompts over the image, then filters + low quality and duplicate masks. The default settings are chosen + for SAM with a ViT-H backbone. + + Arguments: + model (Sam): The SAM model to use for mask prediction. + points_per_side (int or None): The number of points to be sampled + along one side of the image. The total number of points is + points_per_side**2. If None, 'point_grids' must provide explicit + point sampling. + points_per_batch (int): Sets the number of points run simultaneously + by the model. Higher numbers may be faster but use more GPU memory. + pred_iou_thresh (float): A filtering threshold in [0,1], using the + model's predicted mask quality. + stability_score_thresh (float): A filtering threshold in [0,1], using + the stability of the mask under changes to the cutoff used to binarize + the model's mask predictions. + stability_score_offset (float): The amount to shift the cutoff when + calculated the stability score. + box_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks. + crop_n_layers (int): If >0, mask prediction will be run again on + crops of the image. Sets the number of layers to run, where each + layer has 2**i_layer number of image crops. + crop_nms_thresh (float): The box IoU cutoff used by non-maximal + suppression to filter duplicate masks between different crops. + crop_overlap_ratio (float): Sets the degree to which crops overlap. + In the first crop layer, crops will overlap by this fraction of + the image length. Later layers with more crops scale down this overlap. + crop_n_points_downscale_factor (int): The number of points-per-side + sampled in layer n is scaled down by crop_n_points_downscale_factor**n. + point_grids (list(np.ndarray) or None): A list over explicit grids + of points used for sampling, normalized to [0,1]. The nth grid in the + list is used in the nth crop layer. Exclusive with points_per_side. + min_mask_region_area (int): If >0, postprocessing will be applied + to remove disconnected regions and holes in masks with area smaller + than min_mask_region_area. Requires opencv. + output_mode (str): The form masks are returned in. Can be 'binary_mask', + 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools. + For large resolutions, 'binary_mask' may consume large amounts of + memory. + """ + + assert (points_per_side is None) != ( + point_grids is None + ), "Exactly one of points_per_side or point_grid must be provided." + if points_per_side is not None: + self.point_grids = build_all_layer_point_grids( + points_per_side, + crop_n_layers, + crop_n_points_downscale_factor, + ) + elif point_grids is not None: + self.point_grids = point_grids + else: + raise ValueError("Can't have both points_per_side and point_grid be None.") + + assert output_mode in [ + "binary_mask", + "uncompressed_rle", + "coco_rle", + ], f"Unknown output_mode {output_mode}." + if output_mode == "coco_rle": + from custom_pycocotools import mask as mask_utils # type: ignore # noqa: F401 + + if min_mask_region_area > 0: + import cv2 # type: ignore # noqa: F401 + + self.predictor = SamPredictor(model) + self.points_per_batch = points_per_batch + self.pred_iou_thresh = pred_iou_thresh + self.stability_score_thresh = stability_score_thresh + self.stability_score_offset = stability_score_offset + self.box_nms_thresh = box_nms_thresh + self.crop_n_layers = crop_n_layers + self.crop_nms_thresh = crop_nms_thresh + self.crop_overlap_ratio = crop_overlap_ratio + self.crop_n_points_downscale_factor = crop_n_points_downscale_factor + self.min_mask_region_area = min_mask_region_area + self.output_mode = output_mode + + @torch.no_grad() + def generate(self, image: np.ndarray) -> List[Dict[str, Any]]: + """ + Generates masks for the given image. + + Arguments: + image (np.ndarray): The image to generate masks for, in HWC uint8 format. + + Returns: + list(dict(str, any)): A list over records for masks. Each record is + a dict containing the following keys: + segmentation (dict(str, any) or np.ndarray): The mask. If + output_mode='binary_mask', is an array of shape HW. Otherwise, + is a dictionary containing the RLE. + bbox (list(float)): The box around the mask, in XYWH format. + area (int): The area in pixels of the mask. + predicted_iou (float): The model's own prediction of the mask's + quality. This is filtered by the pred_iou_thresh parameter. + point_coords (list(list(float))): The point coordinates input + to the model to generate this mask. + stability_score (float): A measure of the mask's quality. This + is filtered on using the stability_score_thresh parameter. + crop_box (list(float)): The crop of the image used to generate + the mask, given in XYWH format. + """ + + # Generate masks + mask_data = self._generate_masks(image) + + # Filter small disconnected regions and holes in masks + if self.min_mask_region_area > 0: + mask_data = self.postprocess_small_regions( + mask_data, + self.min_mask_region_area, + max(self.box_nms_thresh, self.crop_nms_thresh), + ) + + # Encode masks + if self.output_mode == "coco_rle": + mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]] + elif self.output_mode == "binary_mask": + mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]] + else: + mask_data["segmentations"] = mask_data["rles"] + + # Write mask records + curr_anns = [] + for idx in range(len(mask_data["segmentations"])): + ann = { + "segmentation": mask_data["segmentations"][idx], + "area": area_from_rle(mask_data["rles"][idx]), + "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(), + "predicted_iou": mask_data["iou_preds"][idx].item(), + "point_coords": [mask_data["points"][idx].tolist()], + "stability_score": mask_data["stability_score"][idx].item(), + "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(), + } + curr_anns.append(ann) + + return curr_anns + + def _generate_masks(self, image: np.ndarray) -> MaskData: + orig_size = image.shape[:2] + crop_boxes, layer_idxs = generate_crop_boxes( + orig_size, self.crop_n_layers, self.crop_overlap_ratio + ) + + # Iterate over image crops + data = MaskData() + for crop_box, layer_idx in zip(crop_boxes, layer_idxs): + crop_data = self._process_crop(image, crop_box, layer_idx, orig_size) + data.cat(crop_data) + + # Remove duplicate masks between crops + if len(crop_boxes) > 1: + # Prefer masks from smaller crops + scores = 1 / box_area(data["crop_boxes"]) + scores = scores.to(data["boxes"].device) + keep_by_nms = batched_nms( + data["boxes"].float(), + scores, + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.crop_nms_thresh, + ) + data.filter(keep_by_nms) + + data.to_numpy() + return data + + def _process_crop( + self, + image: np.ndarray, + crop_box: List[int], + crop_layer_idx: int, + orig_size: Tuple[int, ...], + ) -> MaskData: + # Crop the image and calculate embeddings + x0, y0, x1, y1 = crop_box + cropped_im = image[y0:y1, x0:x1, :] + cropped_im_size = cropped_im.shape[:2] + self.predictor.set_image(cropped_im) + + # Get points for this crop + points_scale = np.array(cropped_im_size)[None, ::-1] + points_for_image = self.point_grids[crop_layer_idx] * points_scale + + # Generate masks for this crop in batches + data = MaskData() + for (points,) in batch_iterator(self.points_per_batch, points_for_image): + batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size) + data.cat(batch_data) + del batch_data + self.predictor.reset_image() + + # Remove duplicates within this crop. + keep_by_nms = batched_nms( + data["boxes"].float(), + data["iou_preds"], + torch.zeros_like(data["boxes"][:, 0]), # categories + iou_threshold=self.box_nms_thresh, + ) + data.filter(keep_by_nms) + + # Return to the original image frame + data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box) + data["points"] = uncrop_points(data["points"], crop_box) + data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))]) + + return data + + def _process_batch( + self, + points: np.ndarray, + im_size: Tuple[int, ...], + crop_box: List[int], + orig_size: Tuple[int, ...], + ) -> MaskData: + orig_h, orig_w = orig_size + + # Run model on this batch + transformed_points = self.predictor.transform.apply_coords(points, im_size) + in_points = torch.as_tensor(transformed_points, device=self.predictor.device) + in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device) + masks, iou_preds, _ = self.predictor.predict_torch( + in_points[:, None, :], + in_labels[:, None], + multimask_output=True, + return_logits=True, + ) + + # Serialize predictions and store in MaskData + data = MaskData( + masks=masks.flatten(0, 1), + iou_preds=iou_preds.flatten(0, 1), + points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)), + ) + del masks + + # Filter by predicted IoU + if self.pred_iou_thresh > 0.0: + keep_mask = data["iou_preds"] > self.pred_iou_thresh + data.filter(keep_mask) + + # Calculate stability score + data["stability_score"] = calculate_stability_score( + data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset + ) + if self.stability_score_thresh > 0.0: + keep_mask = data["stability_score"] >= self.stability_score_thresh + data.filter(keep_mask) + + # Threshold masks and calculate boxes + data["masks"] = data["masks"] > self.predictor.model.mask_threshold + data["boxes"] = batched_mask_to_box(data["masks"]) + + # Filter boxes that touch crop boundaries + keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h]) + if not torch.all(keep_mask): + data.filter(keep_mask) + + # Compress to RLE + data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w) + data["rles"] = mask_to_rle_pytorch(data["masks"]) + del data["masks"] + + return data + + @staticmethod + def postprocess_small_regions( + mask_data: MaskData, min_area: int, nms_thresh: float + ) -> MaskData: + """ + Removes small disconnected regions and holes in masks, then reruns + box NMS to remove any new duplicates. + + Edits mask_data in place. + + Requires open-cv as a dependency. + """ + if len(mask_data["rles"]) == 0: + return mask_data + + # Filter small disconnected regions and holes + new_masks = [] + scores = [] + for rle in mask_data["rles"]: + mask = rle_to_mask(rle) + + mask, changed = remove_small_regions(mask, min_area, mode="holes") + unchanged = not changed + mask, changed = remove_small_regions(mask, min_area, mode="islands") + unchanged = unchanged and not changed + + new_masks.append(torch.as_tensor(mask).unsqueeze(0)) + # Give score=0 to changed masks and score=1 to unchanged masks + # so NMS will prefer ones that didn't need postprocessing + scores.append(float(unchanged)) + + # Recalculate boxes and remove any new duplicates + masks = torch.cat(new_masks, dim=0) + boxes = batched_mask_to_box(masks) + keep_by_nms = batched_nms( + boxes.float(), + torch.as_tensor(scores), + torch.zeros_like(boxes[:, 0]), # categories + iou_threshold=nms_thresh, + ) + + # Only recalculate RLEs for masks that have changed + for i_mask in keep_by_nms: + if scores[i_mask] == 0.0: + mask_torch = masks[i_mask].unsqueeze(0) + mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0] + mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly + mask_data.filter(keep_by_nms) + + return mask_data diff --git a/custom_controlnet_aux/sam/build_sam.py b/custom_controlnet_aux/sam/build_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..9a52c506b69d29ee2356cc0e62274fe6f6ee075b --- /dev/null +++ b/custom_controlnet_aux/sam/build_sam.py @@ -0,0 +1,159 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from functools import partial + +from .modeling import ImageEncoderViT, MaskDecoder, PromptEncoder, Sam, TwoWayTransformer, TinyViT + + +def build_sam_vit_h(checkpoint=None): + return _build_sam( + encoder_embed_dim=1280, + encoder_depth=32, + encoder_num_heads=16, + encoder_global_attn_indexes=[7, 15, 23, 31], + checkpoint=checkpoint, + ) + + +build_sam = build_sam_vit_h + + +def build_sam_vit_l(checkpoint=None): + return _build_sam( + encoder_embed_dim=1024, + encoder_depth=24, + encoder_num_heads=16, + encoder_global_attn_indexes=[5, 11, 17, 23], + checkpoint=checkpoint, + ) + + +def build_sam_vit_b(checkpoint=None): + return _build_sam( + encoder_embed_dim=768, + encoder_depth=12, + encoder_num_heads=12, + encoder_global_attn_indexes=[2, 5, 8, 11], + checkpoint=checkpoint, + ) + + +def build_sam_vit_t(checkpoint=None): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + mobile_sam = Sam( + image_encoder=TinyViT(img_size=1024, in_chans=3, num_classes=1000, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.0, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=0.8 + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + + mobile_sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + mobile_sam.load_state_dict(state_dict) + return mobile_sam + + +sam_model_registry = { + "default": build_sam_vit_h, + "vit_h": build_sam_vit_h, + "vit_l": build_sam_vit_l, + "vit_b": build_sam_vit_b, + "vit_t": build_sam_vit_t, +} + + +def _build_sam( + encoder_embed_dim, + encoder_depth, + encoder_num_heads, + encoder_global_attn_indexes, + checkpoint=None, +): + prompt_embed_dim = 256 + image_size = 1024 + vit_patch_size = 16 + image_embedding_size = image_size // vit_patch_size + sam = Sam( + image_encoder=ImageEncoderViT( + depth=encoder_depth, + embed_dim=encoder_embed_dim, + img_size=image_size, + mlp_ratio=4, + norm_layer=partial(torch.nn.LayerNorm, eps=1e-6), + num_heads=encoder_num_heads, + patch_size=vit_patch_size, + qkv_bias=True, + use_rel_pos=True, + global_attn_indexes=encoder_global_attn_indexes, + window_size=14, + out_chans=prompt_embed_dim, + ), + prompt_encoder=PromptEncoder( + embed_dim=prompt_embed_dim, + image_embedding_size=(image_embedding_size, image_embedding_size), + input_image_size=(image_size, image_size), + mask_in_chans=16, + ), + mask_decoder=MaskDecoder( + num_multimask_outputs=3, + transformer=TwoWayTransformer( + depth=2, + embedding_dim=prompt_embed_dim, + mlp_dim=2048, + num_heads=8, + ), + transformer_dim=prompt_embed_dim, + iou_head_depth=3, + iou_head_hidden_dim=256, + ), + pixel_mean=[123.675, 116.28, 103.53], + pixel_std=[58.395, 57.12, 57.375], + ) + sam.eval() + if checkpoint is not None: + with open(checkpoint, "rb") as f: + state_dict = torch.load(f) + sam.load_state_dict(state_dict) + return sam + + diff --git a/custom_controlnet_aux/sam/modeling/__init__.py b/custom_controlnet_aux/sam/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa261b8356b8c1174139c19782657abca0cfec2 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .sam import Sam +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder +from .transformer import TwoWayTransformer +from .tiny_vit_sam import TinyViT diff --git a/custom_controlnet_aux/sam/modeling/common.py b/custom_controlnet_aux/sam/modeling/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2bf15236a3eb24d8526073bc4fa2b274cccb3f96 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/common.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from typing import Type + + +class MLPBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + mlp_dim: int, + act: Type[nn.Module] = nn.GELU, + ) -> None: + super().__init__() + self.lin1 = nn.Linear(embedding_dim, mlp_dim) + self.lin2 = nn.Linear(mlp_dim, embedding_dim) + self.act = act() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.lin2(self.act(self.lin1(x))) + + +# From https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/batch_norm.py # noqa +# Itself from https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/custom_controlnet_aux/sam/modeling/image_encoder.py b/custom_controlnet_aux/sam/modeling/image_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..66351d9d7c589be693f4b3485901d3bdfed54d4a --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/image_encoder.py @@ -0,0 +1,395 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from typing import Optional, Tuple, Type + +from .common import LayerNorm2d, MLPBlock + + +# This class and its supporting functions below lightly adapted from the ViTDet backbone available at: https://github.com/facebookresearch/detectron2/blob/main/detectron2/modeling/backbone/vit.py # noqa +class ImageEncoderViT(nn.Module): + def __init__( + self, + img_size: int = 1024, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + depth: int = 12, + num_heads: int = 12, + mlp_ratio: float = 4.0, + out_chans: int = 256, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_abs_pos: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + global_attn_indexes: Tuple[int, ...] = (), + ) -> None: + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + global_attn_indexes (list): Indexes for blocks using global attention. + """ + super().__init__() + self.img_size = img_size + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + self.pos_embed: Optional[nn.Parameter] = None + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + self.pos_embed = nn.Parameter( + torch.zeros(1, img_size // patch_size, img_size // patch_size, embed_dim) + ) + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i not in global_attn_indexes else 0, + input_size=(img_size // patch_size, img_size // patch_size), + ) + self.blocks.append(block) + + self.neck = nn.Sequential( + nn.Conv2d( + embed_dim, + out_chans, + kernel_size=1, + bias=False, + ), + LayerNorm2d(out_chans), + nn.Conv2d( + out_chans, + out_chans, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(out_chans), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + self.pos_embed + + for blk in self.blocks: + x = blk(x) + + x = self.neck(x.permute(0, 3, 1, 2)) + + return x + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = True, + norm_layer: Type[nn.Module] = nn.LayerNorm, + act_layer: Type[nn.Module] = nn.GELU, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + window_size: int = 0, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then + use global attention. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + self.norm2 = norm_layer(dim) + self.mlp = MLPBlock(embedding_dim=dim, mlp_dim=int(dim * mlp_ratio), act=act_layer) + + self.window_size = window_size + + def forward(self, x: torch.Tensor) -> torch.Tensor: + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + x + x = x + self.mlp(self.norm2(x)) + + return x + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = True, + use_rel_pos: bool = False, + rel_pos_zero_init: bool = True, + input_size: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (tuple(int, int) or None): Input resolution for calculating the relative + positional parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + assert ( + input_size is not None + ), "Input size must be provided if using relative positional encoding." + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +def window_partition(x: torch.Tensor, window_size: int) -> Tuple[torch.Tensor, Tuple[int, int]]: + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition( + windows: torch.Tensor, window_size: int, pad_hw: Tuple[int, int], hw: Tuple[int, int] +) -> torch.Tensor: + """ + Window unpartition into original sequences and removing padding. + Args: + windows (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor: + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos( + attn: torch.Tensor, + q: torch.Tensor, + rel_pos_h: torch.Tensor, + rel_pos_w: torch.Tensor, + q_size: Tuple[int, int], + k_size: Tuple[int, int], +) -> torch.Tensor: + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, + kernel_size: Tuple[int, int] = (16, 16), + stride: Tuple[int, int] = (16, 16), + padding: Tuple[int, int] = (0, 0), + in_chans: int = 3, + embed_dim: int = 768, + ) -> None: + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/custom_controlnet_aux/sam/modeling/mask_decoder.py b/custom_controlnet_aux/sam/modeling/mask_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..5d2fdb03d535a91fa725d1ec4e92a7a1f217dfe0 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/mask_decoder.py @@ -0,0 +1,176 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import List, Tuple, Type + +from .common import LayerNorm2d + + +class MaskDecoder(nn.Module): + def __init__( + self, + *, + transformer_dim: int, + transformer: nn.Module, + num_multimask_outputs: int = 3, + activation: Type[nn.Module] = nn.GELU, + iou_head_depth: int = 3, + iou_head_hidden_dim: int = 256, + ) -> None: + """ + Predicts masks given an image and prompt embeddings, using a + transformer architecture. + + Arguments: + transformer_dim (int): the channel dimension of the transformer + transformer (nn.Module): the transformer used to predict masks + num_multimask_outputs (int): the number of masks to predict + when disambiguating masks + activation (nn.Module): the type of activation to use when + upscaling masks + iou_head_depth (int): the depth of the MLP used to predict + mask quality + iou_head_hidden_dim (int): the hidden dimension of the MLP + used to predict mask quality + """ + super().__init__() + self.transformer_dim = transformer_dim + self.transformer = transformer + + self.num_multimask_outputs = num_multimask_outputs + + self.iou_token = nn.Embedding(1, transformer_dim) + self.num_mask_tokens = num_multimask_outputs + 1 + self.mask_tokens = nn.Embedding(self.num_mask_tokens, transformer_dim) + + self.output_upscaling = nn.Sequential( + nn.ConvTranspose2d(transformer_dim, transformer_dim // 4, kernel_size=2, stride=2), + LayerNorm2d(transformer_dim // 4), + activation(), + nn.ConvTranspose2d(transformer_dim // 4, transformer_dim // 8, kernel_size=2, stride=2), + activation(), + ) + self.output_hypernetworks_mlps = nn.ModuleList( + [ + MLP(transformer_dim, transformer_dim, transformer_dim // 8, 3) + for i in range(self.num_mask_tokens) + ] + ) + + self.iou_prediction_head = MLP( + transformer_dim, iou_head_hidden_dim, self.num_mask_tokens, iou_head_depth + ) + + def forward( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + multimask_output: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Predict masks given image and prompt embeddings. + + Arguments: + image_embeddings (torch.Tensor): the embeddings from the image encoder + image_pe (torch.Tensor): positional encoding with the shape of image_embeddings + sparse_prompt_embeddings (torch.Tensor): the embeddings of the points and boxes + dense_prompt_embeddings (torch.Tensor): the embeddings of the mask inputs + multimask_output (bool): Whether to return multiple masks or a single + mask. + + Returns: + torch.Tensor: batched predicted masks + torch.Tensor: batched predictions of mask quality + """ + masks, iou_pred = self.predict_masks( + image_embeddings=image_embeddings, + image_pe=image_pe, + sparse_prompt_embeddings=sparse_prompt_embeddings, + dense_prompt_embeddings=dense_prompt_embeddings, + ) + + # Select the correct mask or masks for output + if multimask_output: + mask_slice = slice(1, None) + else: + mask_slice = slice(0, 1) + masks = masks[:, mask_slice, :, :] + iou_pred = iou_pred[:, mask_slice] + + # Prepare output + return masks, iou_pred + + def predict_masks( + self, + image_embeddings: torch.Tensor, + image_pe: torch.Tensor, + sparse_prompt_embeddings: torch.Tensor, + dense_prompt_embeddings: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Predicts masks. See 'forward' for more details.""" + # Concatenate output tokens + output_tokens = torch.cat([self.iou_token.weight, self.mask_tokens.weight], dim=0) + output_tokens = output_tokens.unsqueeze(0).expand(sparse_prompt_embeddings.size(0), -1, -1) + tokens = torch.cat((output_tokens, sparse_prompt_embeddings), dim=1) + + # Expand per-image data in batch direction to be per-mask + src = torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0) + src = src + dense_prompt_embeddings + pos_src = torch.repeat_interleave(image_pe, tokens.shape[0], dim=0) + b, c, h, w = src.shape + + # Run the transformer + hs, src = self.transformer(src, pos_src, tokens) + iou_token_out = hs[:, 0, :] + mask_tokens_out = hs[:, 1 : (1 + self.num_mask_tokens), :] + + # Upscale mask embeddings and predict masks using the mask tokens + src = src.transpose(1, 2).view(b, c, h, w) + upscaled_embedding = self.output_upscaling(src) + hyper_in_list: List[torch.Tensor] = [] + for i in range(self.num_mask_tokens): + hyper_in_list.append(self.output_hypernetworks_mlps[i](mask_tokens_out[:, i, :])) + hyper_in = torch.stack(hyper_in_list, dim=1) + b, c, h, w = upscaled_embedding.shape + masks = (hyper_in @ upscaled_embedding.view(b, c, h * w)).view(b, -1, h, w) + + # Generate mask quality predictions + iou_pred = self.iou_prediction_head(iou_token_out) + + return masks, iou_pred + + +# Lightly adapted from +# https://github.com/facebookresearch/MaskFormer/blob/main/mask_former/modeling/transformer/transformer_predictor.py # noqa +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + hidden_dim: int, + output_dim: int, + num_layers: int, + sigmoid_output: bool = False, + ) -> None: + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]) + ) + self.sigmoid_output = sigmoid_output + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + if self.sigmoid_output: + x = F.sigmoid(x) + return x diff --git a/custom_controlnet_aux/sam/modeling/prompt_encoder.py b/custom_controlnet_aux/sam/modeling/prompt_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c3143f4f8e02ddd7ca8587b40ff5d47c3a6b7ef3 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/prompt_encoder.py @@ -0,0 +1,214 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch import nn + +from typing import Any, Optional, Tuple, Type + +from .common import LayerNorm2d + + +class PromptEncoder(nn.Module): + def __init__( + self, + embed_dim: int, + image_embedding_size: Tuple[int, int], + input_image_size: Tuple[int, int], + mask_in_chans: int, + activation: Type[nn.Module] = nn.GELU, + ) -> None: + """ + Encodes prompts for input to SAM's mask decoder. + + Arguments: + embed_dim (int): The prompts' embedding dimension + image_embedding_size (tuple(int, int)): The spatial size of the + image embedding, as (H, W). + input_image_size (int): The padded size of the image as input + to the image encoder, as (H, W). + mask_in_chans (int): The number of hidden channels used for + encoding input masks. + activation (nn.Module): The activation to use when encoding + input masks. + """ + super().__init__() + self.embed_dim = embed_dim + self.input_image_size = input_image_size + self.image_embedding_size = image_embedding_size + self.pe_layer = PositionEmbeddingRandom(embed_dim // 2) + + self.num_point_embeddings: int = 4 # pos/neg point + 2 box corners + point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] + self.point_embeddings = nn.ModuleList(point_embeddings) + self.not_a_point_embed = nn.Embedding(1, embed_dim) + + self.mask_input_size = (4 * image_embedding_size[0], 4 * image_embedding_size[1]) + self.mask_downscaling = nn.Sequential( + nn.Conv2d(1, mask_in_chans // 4, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans // 4), + activation(), + nn.Conv2d(mask_in_chans // 4, mask_in_chans, kernel_size=2, stride=2), + LayerNorm2d(mask_in_chans), + activation(), + nn.Conv2d(mask_in_chans, embed_dim, kernel_size=1), + ) + self.no_mask_embed = nn.Embedding(1, embed_dim) + + def get_dense_pe(self) -> torch.Tensor: + """ + Returns the positional encoding used to encode point prompts, + applied to a dense set of points the shape of the image encoding. + + Returns: + torch.Tensor: Positional encoding with shape + 1x(embed_dim)x(embedding_h)x(embedding_w) + """ + return self.pe_layer(self.image_embedding_size).unsqueeze(0) + + def _embed_points( + self, + points: torch.Tensor, + labels: torch.Tensor, + pad: bool, + ) -> torch.Tensor: + """Embeds point prompts.""" + points = points + 0.5 # Shift to center of pixel + if pad: + padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device) + padding_label = -torch.ones((labels.shape[0], 1), device=labels.device) + points = torch.cat([points, padding_point], dim=1) + labels = torch.cat([labels, padding_label], dim=1) + point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size) + point_embedding[labels == -1] = 0.0 + point_embedding[labels == -1] += self.not_a_point_embed.weight + point_embedding[labels == 0] += self.point_embeddings[0].weight + point_embedding[labels == 1] += self.point_embeddings[1].weight + return point_embedding + + def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor: + """Embeds box prompts.""" + boxes = boxes + 0.5 # Shift to center of pixel + coords = boxes.reshape(-1, 2, 2) + corner_embedding = self.pe_layer.forward_with_coords(coords, self.input_image_size) + corner_embedding[:, 0, :] += self.point_embeddings[2].weight + corner_embedding[:, 1, :] += self.point_embeddings[3].weight + return corner_embedding + + def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor: + """Embeds mask inputs.""" + mask_embedding = self.mask_downscaling(masks) + return mask_embedding + + def _get_batch_size( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> int: + """ + Gets the batch size of the output given the batch size of the input prompts. + """ + if points is not None: + return points[0].shape[0] + elif boxes is not None: + return boxes.shape[0] + elif masks is not None: + return masks.shape[0] + else: + return 1 + + def _get_device(self) -> torch.device: + return self.point_embeddings[0].weight.device + + def forward( + self, + points: Optional[Tuple[torch.Tensor, torch.Tensor]], + boxes: Optional[torch.Tensor], + masks: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Embeds different types of prompts, returning both sparse and dense + embeddings. + + Arguments: + points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates + and labels to embed. + boxes (torch.Tensor or none): boxes to embed + masks (torch.Tensor or none): masks to embed + + Returns: + torch.Tensor: sparse embeddings for the points and boxes, with shape + BxNx(embed_dim), where N is determined by the number of input points + and boxes. + torch.Tensor: dense embeddings for the masks, in the shape + Bx(embed_dim)x(embed_H)x(embed_W) + """ + bs = self._get_batch_size(points, boxes, masks) + sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device()) + if points is not None: + coords, labels = points + point_embeddings = self._embed_points(coords, labels, pad=(boxes is None)) + sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1) + if boxes is not None: + box_embeddings = self._embed_boxes(boxes) + sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1) + + if masks is not None: + dense_embeddings = self._embed_masks(masks) + else: + dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand( + bs, -1, self.image_embedding_size[0], self.image_embedding_size[1] + ) + + return sparse_embeddings, dense_embeddings + + +class PositionEmbeddingRandom(nn.Module): + """ + Positional encoding using random spatial frequencies. + """ + + def __init__(self, num_pos_feats: int = 64, scale: Optional[float] = None) -> None: + super().__init__() + if scale is None or scale <= 0.0: + scale = 1.0 + self.register_buffer( + "positional_encoding_gaussian_matrix", + scale * torch.randn((2, num_pos_feats)), + ) + + def _pe_encoding(self, coords: torch.Tensor) -> torch.Tensor: + """Positionally encode points that are normalized to [0,1].""" + # assuming coords are in [0, 1]^2 square and have d_1 x ... x d_n x 2 shape + coords = 2 * coords - 1 + coords = coords @ self.positional_encoding_gaussian_matrix + coords = 2 * np.pi * coords + # outputs d_1 x ... x d_n x C shape + return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) + + def forward(self, size: Tuple[int, int]) -> torch.Tensor: + """Generate positional encoding for a grid of the specified size.""" + h, w = size + device: Any = self.positional_encoding_gaussian_matrix.device + grid = torch.ones((h, w), device=device, dtype=torch.float32) + y_embed = grid.cumsum(dim=0) - 0.5 + x_embed = grid.cumsum(dim=1) - 0.5 + y_embed = y_embed / h + x_embed = x_embed / w + + pe = self._pe_encoding(torch.stack([x_embed, y_embed], dim=-1)) + return pe.permute(2, 0, 1) # C x H x W + + def forward_with_coords( + self, coords_input: torch.Tensor, image_size: Tuple[int, int] + ) -> torch.Tensor: + """Positionally encode points that are not normalized to [0,1].""" + coords = coords_input.clone() + coords[:, :, 0] = coords[:, :, 0] / image_size[1] + coords[:, :, 1] = coords[:, :, 1] / image_size[0] + return self._pe_encoding(coords.to(torch.float)) # B x N x C diff --git a/custom_controlnet_aux/sam/modeling/sam.py b/custom_controlnet_aux/sam/modeling/sam.py new file mode 100644 index 0000000000000000000000000000000000000000..45b9e7c56d10cc47e7ed0739e35d850bfccbb257 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/sam.py @@ -0,0 +1,175 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn import functional as F + +from typing import Any, Dict, List, Tuple, Union + +from .tiny_vit_sam import TinyViT +from .image_encoder import ImageEncoderViT +from .mask_decoder import MaskDecoder +from .prompt_encoder import PromptEncoder + + +class Sam(nn.Module): + mask_threshold: float = 0.0 + image_format: str = "RGB" + + def __init__( + self, + image_encoder: Union[ImageEncoderViT, TinyViT], + prompt_encoder: PromptEncoder, + mask_decoder: MaskDecoder, + pixel_mean: List[float] = [123.675, 116.28, 103.53], + pixel_std: List[float] = [58.395, 57.12, 57.375], + ) -> None: + """ + SAM predicts object masks from an image and input prompts. + + Arguments: + image_encoder (ImageEncoderViT): The backbone used to encode the + image into image embeddings that allow for efficient mask prediction. + prompt_encoder (PromptEncoder): Encodes various types of input prompts. + mask_decoder (MaskDecoder): Predicts masks from the image embeddings + and encoded prompts. + pixel_mean (list(float)): Mean values for normalizing pixels in the input image. + pixel_std (list(float)): Std values for normalizing pixels in the input image. + """ + super().__init__() + self.image_encoder = image_encoder + self.prompt_encoder = prompt_encoder + self.mask_decoder = mask_decoder + self.register_buffer("pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self) -> Any: + return self.pixel_mean.device + + @torch.no_grad() + def forward( + self, + batched_input: List[Dict[str, Any]], + multimask_output: bool, + ) -> List[Dict[str, torch.Tensor]]: + """ + Predicts masks end-to-end from provided images and prompts. + If prompts are not known in advance, using SamPredictor is + recommended over calling the model directly. + + Arguments: + batched_input (list(dict)): A list over input images, each a + dictionary with the following keys. A prompt key can be + excluded if it is not present. + 'image': The image as a torch tensor in 3xHxW format, + already transformed for input to the model. + 'original_size': (tuple(int, int)) The original size of + the image before transformation, as (H, W). + 'point_coords': (torch.Tensor) Batched point prompts for + this image, with shape BxNx2. Already transformed to the + input frame of the model. + 'point_labels': (torch.Tensor) Batched labels for point prompts, + with shape BxN. + 'boxes': (torch.Tensor) Batched box inputs, with shape Bx4. + Already transformed to the input frame of the model. + 'mask_inputs': (torch.Tensor) Batched mask inputs to the model, + in the form Bx1xHxW. + multimask_output (bool): Whether the model should predict multiple + disambiguating masks, or return a single mask. + + Returns: + (list(dict)): A list over input images, where each element is + as dictionary with the following keys. + 'masks': (torch.Tensor) Batched binary mask predictions, + with shape BxCxHxW, where B is the number of input prompts, + C is determined by multimask_output, and (H, W) is the + original size of the image. + 'iou_predictions': (torch.Tensor) The model's predictions + of mask quality, in shape BxC. + 'low_res_logits': (torch.Tensor) Low resolution logits with + shape BxCxHxW, where H=W=256. Can be passed as mask input + to subsequent iterations of prediction. + """ + input_images = torch.stack([self.preprocess(x["image"]) for x in batched_input], dim=0) + image_embeddings = self.image_encoder(input_images) + + outputs = [] + for image_record, curr_embedding in zip(batched_input, image_embeddings): + if "point_coords" in image_record: + points = (image_record["point_coords"], image_record["point_labels"]) + else: + points = None + sparse_embeddings, dense_embeddings = self.prompt_encoder( + points=points, + boxes=image_record.get("boxes", None), + masks=image_record.get("mask_inputs", None), + ) + low_res_masks, iou_predictions = self.mask_decoder( + image_embeddings=curr_embedding.unsqueeze(0), + image_pe=self.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + masks = self.postprocess_masks( + low_res_masks, + input_size=image_record["image"].shape[-2:], + original_size=image_record["original_size"], + ) + masks = masks > self.mask_threshold + outputs.append( + { + "masks": masks, + "iou_predictions": iou_predictions, + "low_res_logits": low_res_masks, + } + ) + return outputs + + def postprocess_masks( + self, + masks: torch.Tensor, + input_size: Tuple[int, ...], + original_size: Tuple[int, ...], + ) -> torch.Tensor: + """ + Remove padding and upscale masks to the original image size. + + Arguments: + masks (torch.Tensor): Batched masks from the mask_decoder, + in BxCxHxW format. + input_size (tuple(int, int)): The size of the image input to the + model, in (H, W) format. Used to remove padding. + original_size (tuple(int, int)): The original size of the image + before resizing for input to the model, in (H, W) format. + + Returns: + (torch.Tensor): Batched masks in BxCxHxW format, where (H, W) + is given by original_size. + """ + masks = F.interpolate( + masks, + (self.image_encoder.img_size, self.image_encoder.img_size), + mode="bilinear", + align_corners=False, + ) + masks = masks[..., : input_size[0], : input_size[1]] + masks = F.interpolate(masks, original_size, mode="bilinear", align_corners=False) + return masks + + def preprocess(self, x: torch.Tensor) -> torch.Tensor: + """Normalize pixel values and pad to a square input.""" + # Normalize colors + x = (x - self.pixel_mean) / self.pixel_std + + # Pad + h, w = x.shape[-2:] + padh = self.image_encoder.img_size - h + padw = self.image_encoder.img_size - w + x = F.pad(x, (0, padw, 0, padh)) + return x diff --git a/custom_controlnet_aux/sam/modeling/tiny_vit_sam.py b/custom_controlnet_aux/sam/modeling/tiny_vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1aad263f07652385010194738f8f69254df644 --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/tiny_vit_sam.py @@ -0,0 +1,716 @@ +# -------------------------------------------------------- +# TinyViT Model Architecture +# Copyright (c) 2022 Microsoft +# Adapted from LeViT and Swin Transformer +# LeViT: (https://github.com/facebookresearch/levit) +# Swin: (https://github.com/microsoft/swin-transformer) +# Build the TinyViT Model +# -------------------------------------------------------- + +import itertools +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from custom_timm.models.layers import DropPath as TimmDropPath,\ + to_2tuple, trunc_normal_ +from custom_timm.models.registry import register_model +from typing import Tuple + + +class Conv2d_BN(torch.nn.Sequential): + def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1, + groups=1, bn_weight_init=1): + super().__init__() + self.add_module('c', torch.nn.Conv2d( + a, b, ks, stride, pad, dilation, groups, bias=False)) + bn = torch.nn.BatchNorm2d(b) + torch.nn.init.constant_(bn.weight, bn_weight_init) + torch.nn.init.constant_(bn.bias, 0) + self.add_module('bn', bn) + + @torch.no_grad() + def fuse(self): + c, bn = self._modules.values() + w = bn.weight / (bn.running_var + bn.eps)**0.5 + w = c.weight * w[:, None, None, None] + b = bn.bias - bn.running_mean * bn.weight / \ + (bn.running_var + bn.eps)**0.5 + m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size( + 0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) + m.weight.data.copy_(w) + m.bias.data.copy_(b) + return m + + +class DropPath(TimmDropPath): + def __init__(self, drop_prob=None): + super().__init__(drop_prob=drop_prob) + self.drop_prob = drop_prob + + def __repr__(self): + msg = super().__repr__() + msg += f'(drop_prob={self.drop_prob})' + return msg + + +class PatchEmbed(nn.Module): + def __init__(self, in_chans, embed_dim, resolution, activation): + super().__init__() + img_size: Tuple[int, int] = to_2tuple(resolution) + self.patches_resolution = (img_size[0] // 4, img_size[1] // 4) + self.num_patches = self.patches_resolution[0] * \ + self.patches_resolution[1] + self.in_chans = in_chans + self.embed_dim = embed_dim + n = embed_dim + self.seq = nn.Sequential( + Conv2d_BN(in_chans, n // 2, 3, 2, 1), + activation(), + Conv2d_BN(n // 2, n, 3, 2, 1), + ) + + def forward(self, x): + return self.seq(x) + + +class MBConv(nn.Module): + def __init__(self, in_chans, out_chans, expand_ratio, + activation, drop_path): + super().__init__() + self.in_chans = in_chans + self.hidden_chans = int(in_chans * expand_ratio) + self.out_chans = out_chans + + self.conv1 = Conv2d_BN(in_chans, self.hidden_chans, ks=1) + self.act1 = activation() + + self.conv2 = Conv2d_BN(self.hidden_chans, self.hidden_chans, + ks=3, stride=1, pad=1, groups=self.hidden_chans) + self.act2 = activation() + + self.conv3 = Conv2d_BN( + self.hidden_chans, out_chans, ks=1, bn_weight_init=0.0) + self.act3 = activation() + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + def forward(self, x): + shortcut = x + + x = self.conv1(x) + x = self.act1(x) + + x = self.conv2(x) + x = self.act2(x) + + x = self.conv3(x) + + x = self.drop_path(x) + + x += shortcut + x = self.act3(x) + + return x + + +class PatchMerging(nn.Module): + def __init__(self, input_resolution, dim, out_dim, activation): + super().__init__() + + self.input_resolution = input_resolution + self.dim = dim + self.out_dim = out_dim + self.act = activation() + self.conv1 = Conv2d_BN(dim, out_dim, 1, 1, 0) + stride_c=2 + if(out_dim==320 or out_dim==448 or out_dim==576): + stride_c=1 + self.conv2 = Conv2d_BN(out_dim, out_dim, 3, stride_c, 1, groups=out_dim) + self.conv3 = Conv2d_BN(out_dim, out_dim, 1, 1, 0) + + def forward(self, x): + if x.ndim == 3: + H, W = self.input_resolution + B = len(x) + # (B, C, H, W) + x = x.view(B, H, W, -1).permute(0, 3, 1, 2) + + x = self.conv1(x) + x = self.act(x) + + x = self.conv2(x) + x = self.act(x) + x = self.conv3(x) + x = x.flatten(2).transpose(1, 2) + return x + + +class ConvLayer(nn.Module): + def __init__(self, dim, input_resolution, depth, + activation, + drop_path=0., downsample=None, use_checkpoint=False, + out_dim=None, + conv_expand_ratio=4., + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + MBConv(dim, dim, conv_expand_ratio, activation, + drop_path[i] if isinstance(drop_path, list) else drop_path, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, + out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.norm = nn.LayerNorm(in_features) + self.fc1 = nn.Linear(in_features, hidden_features) + self.fc2 = nn.Linear(hidden_features, out_features) + self.act = act_layer() + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.norm(x) + + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class Attention(torch.nn.Module): + def __init__(self, dim, key_dim, num_heads=8, + attn_ratio=4, + resolution=(14, 14), + ): + super().__init__() + # (h, w) + assert isinstance(resolution, tuple) and len(resolution) == 2 + self.num_heads = num_heads + self.scale = key_dim ** -0.5 + self.key_dim = key_dim + self.nh_kd = nh_kd = key_dim * num_heads + self.d = int(attn_ratio * key_dim) + self.dh = int(attn_ratio * key_dim) * num_heads + self.attn_ratio = attn_ratio + h = self.dh + nh_kd * 2 + + self.norm = nn.LayerNorm(dim) + self.qkv = nn.Linear(dim, h) + self.proj = nn.Linear(self.dh, dim) + + points = list(itertools.product( + range(resolution[0]), range(resolution[1]))) + N = len(points) + attention_offsets = {} + idxs = [] + for p1 in points: + for p2 in points: + offset = (abs(p1[0] - p2[0]), abs(p1[1] - p2[1])) + if offset not in attention_offsets: + attention_offsets[offset] = len(attention_offsets) + idxs.append(attention_offsets[offset]) + self.attention_biases = torch.nn.Parameter( + torch.zeros(num_heads, len(attention_offsets))) + self.register_buffer('attention_bias_idxs', + torch.LongTensor(idxs).view(N, N), + persistent=False) + + @torch.no_grad() + def train(self, mode=True): + super().train(mode) + if mode and hasattr(self, 'ab'): + del self.ab + else: + self.ab = self.attention_biases[:, self.attention_bias_idxs] + + def forward(self, x): # x (B,N,C) + B, N, _ = x.shape + + # Normalization + x = self.norm(x) + + qkv = self.qkv(x) + # (B, N, num_heads, d) + q, k, v = qkv.view(B, N, self.num_heads, - + 1).split([self.key_dim, self.key_dim, self.d], dim=3) + # (B, num_heads, N, d) + q = q.permute(0, 2, 1, 3) + k = k.permute(0, 2, 1, 3) + v = v.permute(0, 2, 1, 3) + + attn = ( + (q @ k.transpose(-2, -1)) * self.scale + + + (self.attention_biases[:, self.attention_bias_idxs] + if self.training else self.ab) + ) + attn = attn.softmax(dim=-1) + x = (attn @ v).transpose(1, 2).reshape(B, N, self.dh) + x = self.proj(x) + return x + + +class TinyViTBlock(nn.Module): + r""" TinyViT Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int, int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + local_conv_size (int): the kernel size of the convolution between + Attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + """ + + def __init__(self, dim, input_resolution, num_heads, window_size=7, + mlp_ratio=4., drop=0., drop_path=0., + local_conv_size=3, + activation=nn.GELU, + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + assert window_size > 0, 'window_size must be greater than 0' + self.window_size = window_size + self.mlp_ratio = mlp_ratio + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + + assert dim % num_heads == 0, 'dim must be divisible by num_heads' + head_dim = dim // num_heads + + window_resolution = (window_size, window_size) + self.attn = Attention(dim, head_dim, num_heads, + attn_ratio=1, resolution=window_resolution) + + mlp_hidden_dim = int(dim * mlp_ratio) + mlp_activation = activation + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, + act_layer=mlp_activation, drop=drop) + + pad = local_conv_size // 2 + self.local_conv = Conv2d_BN( + dim, dim, ks=local_conv_size, stride=1, pad=pad, groups=dim) + + def forward(self, x): + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + res_x = x + if H == self.window_size and W == self.window_size: + x = self.attn(x) + else: + x = x.view(B, H, W, C) + pad_b = (self.window_size - H % + self.window_size) % self.window_size + pad_r = (self.window_size - W % + self.window_size) % self.window_size + padding = pad_b > 0 or pad_r > 0 + + if padding: + x = F.pad(x, (0, 0, 0, pad_r, 0, pad_b)) + + pH, pW = H + pad_b, W + pad_r + nH = pH // self.window_size + nW = pW // self.window_size + # window partition + x = x.view(B, nH, self.window_size, nW, self.window_size, C).transpose(2, 3).reshape( + B * nH * nW, self.window_size * self.window_size, C) + x = self.attn(x) + # window reverse + x = x.view(B, nH, nW, self.window_size, self.window_size, + C).transpose(2, 3).reshape(B, pH, pW, C) + + if padding: + x = x[:, :H, :W].contiguous() + + x = x.view(B, L, C) + + x = res_x + self.drop_path(x) + + x = x.transpose(1, 2).reshape(B, C, H, W) + x = self.local_conv(x) + x = x.view(B, C, L).transpose(1, 2) + + x = x + self.drop_path(self.mlp(x)) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ + f"window_size={self.window_size}, mlp_ratio={self.mlp_ratio}" + + +class BasicLayer(nn.Module): + """ A basic TinyViT layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + drop (float, optional): Dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + local_conv_size: the kernel size of the depthwise convolution between attention and MLP. Default: 3 + activation: the activation function. Default: nn.GELU + out_dim: the output dimension of the layer. Default: dim + """ + + def __init__(self, dim, input_resolution, depth, num_heads, window_size, + mlp_ratio=4., drop=0., + drop_path=0., downsample=None, use_checkpoint=False, + local_conv_size=3, + activation=nn.GELU, + out_dim=None, + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList([ + TinyViTBlock(dim=dim, input_resolution=input_resolution, + num_heads=num_heads, window_size=window_size, + mlp_ratio=mlp_ratio, + drop=drop, + drop_path=drop_path[i] if isinstance( + drop_path, list) else drop_path, + local_conv_size=local_conv_size, + activation=activation, + ) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, out_dim=out_dim, activation=activation) + else: + self.downsample = None + + def forward(self, x): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + +class LayerNorm2d(nn.Module): + def __init__(self, num_channels: int, eps: float = 1e-6) -> None: + super().__init__() + self.weight = nn.Parameter(torch.ones(num_channels)) + self.bias = nn.Parameter(torch.zeros(num_channels)) + self.eps = eps + + def forward(self, x: torch.Tensor) -> torch.Tensor: + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x +class TinyViT(nn.Module): + def __init__(self, img_size=224, in_chans=3, num_classes=1000, + embed_dims=[96, 192, 384, 768], depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_sizes=[7, 7, 14, 7], + mlp_ratio=4., + drop_rate=0., + drop_path_rate=0.1, + use_checkpoint=False, + mbconv_expand_ratio=4.0, + local_conv_size=3, + layer_lr_decay=1.0, + ): + super().__init__() + self.img_size=img_size + self.num_classes = num_classes + self.depths = depths + self.num_layers = len(depths) + self.mlp_ratio = mlp_ratio + + activation = nn.GELU + + self.patch_embed = PatchEmbed(in_chans=in_chans, + embed_dim=embed_dims[0], + resolution=img_size, + activation=activation) + + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, + sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + kwargs = dict(dim=embed_dims[i_layer], + input_resolution=(patches_resolution[0] // (2 ** (i_layer-1 if i_layer == 3 else i_layer)), + patches_resolution[1] // (2 ** (i_layer-1 if i_layer == 3 else i_layer))), + # input_resolution=(patches_resolution[0] // (2 ** i_layer), + # patches_resolution[1] // (2 ** i_layer)), + depth=depths[i_layer], + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + downsample=PatchMerging if ( + i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + out_dim=embed_dims[min( + i_layer + 1, len(embed_dims) - 1)], + activation=activation, + ) + if i_layer == 0: + layer = ConvLayer( + conv_expand_ratio=mbconv_expand_ratio, + **kwargs, + ) + else: + layer = BasicLayer( + num_heads=num_heads[i_layer], + window_size=window_sizes[i_layer], + mlp_ratio=self.mlp_ratio, + drop=drop_rate, + local_conv_size=local_conv_size, + **kwargs) + self.layers.append(layer) + + # Classifier head + self.norm_head = nn.LayerNorm(embed_dims[-1]) + self.head = nn.Linear( + embed_dims[-1], num_classes) if num_classes > 0 else torch.nn.Identity() + + # init weights + self.apply(self._init_weights) + self.set_layer_lr_decay(layer_lr_decay) + self.neck = nn.Sequential( + nn.Conv2d( + embed_dims[-1], + 256, + kernel_size=1, + bias=False, + ), + LayerNorm2d(256), + nn.Conv2d( + 256, + 256, + kernel_size=3, + padding=1, + bias=False, + ), + LayerNorm2d(256), + ) + def set_layer_lr_decay(self, layer_lr_decay): + decay_rate = layer_lr_decay + + # layers -> blocks (depth) + depth = sum(self.depths) + lr_scales = [decay_rate ** (depth - i - 1) for i in range(depth)] + #print("LR SCALES:", lr_scales) + + def _set_lr_scale(m, scale): + for p in m.parameters(): + p.lr_scale = scale + + self.patch_embed.apply(lambda x: _set_lr_scale(x, lr_scales[0])) + i = 0 + for layer in self.layers: + for block in layer.blocks: + block.apply(lambda x: _set_lr_scale(x, lr_scales[i])) + i += 1 + if layer.downsample is not None: + layer.downsample.apply( + lambda x: _set_lr_scale(x, lr_scales[i - 1])) + assert i == depth + for m in [self.norm_head, self.head]: + m.apply(lambda x: _set_lr_scale(x, lr_scales[-1])) + + for k, p in self.named_parameters(): + p.param_name = k + + def _check_lr_scale(m): + for p in m.parameters(): + assert hasattr(p, 'lr_scale'), p.param_name + + self.apply(_check_lr_scale) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'attention_biases'} + + def forward_features(self, x): + # x: (N, C, H, W) + x = self.patch_embed(x) + + x = self.layers[0](x) + start_i = 1 + + for i in range(start_i, len(self.layers)): + layer = self.layers[i] + x = layer(x) + B,_,C=x.size() + x = x.view(B, 64, 64, C) + x=x.permute(0, 3, 1, 2) + x=self.neck(x) + return x + + def forward(self, x): + x = self.forward_features(x) + #x = self.norm_head(x) + #x = self.head(x) + return x + + +_checkpoint_url_format = \ + 'https://github.com/wkcn/TinyViT-model-zoo/releases/download/checkpoints/{}.pth' +_provided_checkpoints = { + 'tiny_vit_5m_224': 'tiny_vit_5m_22kto1k_distill', + 'tiny_vit_11m_224': 'tiny_vit_11m_22kto1k_distill', + 'tiny_vit_21m_224': 'tiny_vit_21m_22kto1k_distill', + 'tiny_vit_21m_384': 'tiny_vit_21m_22kto1k_384_distill', + 'tiny_vit_21m_512': 'tiny_vit_21m_22kto1k_512_distill', +} + + +def register_tiny_vit_model(fn): + '''Register a TinyViT model + It is a wrapper of `register_model` with loading the pretrained checkpoint. + ''' + def fn_wrapper(pretrained=False, **kwargs): + model = fn() + if pretrained: + model_name = fn.__name__ + assert model_name in _provided_checkpoints, \ + f'Sorry that the checkpoint `{model_name}` is not provided yet.' + url = _checkpoint_url_format.format( + _provided_checkpoints[model_name]) + checkpoint = torch.hub.load_state_dict_from_url( + url=url, + map_location='cpu', check_hash=False, + ) + model.load_state_dict(checkpoint['model']) + + return model + + # rename the name of fn_wrapper + fn_wrapper.__name__ = fn.__name__ + return register_model(fn_wrapper) + + +@register_tiny_vit_model +def tiny_vit_5m_224(pretrained=False, num_classes=1000, drop_path_rate=0.0): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 160, 320], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 5, 10], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_11m_224(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + num_classes=num_classes, + embed_dims=[64, 128, 256, 448], + depths=[2, 2, 6, 2], + num_heads=[2, 4, 8, 14], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_224(pretrained=False, num_classes=1000, drop_path_rate=0.2): + return TinyViT( + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[7, 7, 14, 7], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_384(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=384, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[12, 12, 24, 12], + drop_path_rate=drop_path_rate, + ) + + +@register_tiny_vit_model +def tiny_vit_21m_512(pretrained=False, num_classes=1000, drop_path_rate=0.1): + return TinyViT( + img_size=512, + num_classes=num_classes, + embed_dims=[96, 192, 384, 576], + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 18], + window_sizes=[16, 16, 32, 16], + drop_path_rate=drop_path_rate, + ) diff --git a/custom_controlnet_aux/sam/modeling/transformer.py b/custom_controlnet_aux/sam/modeling/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..28fafea52288603fea275f3a100790471825c34a --- /dev/null +++ b/custom_controlnet_aux/sam/modeling/transformer.py @@ -0,0 +1,240 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import Tensor, nn + +import math +from typing import Tuple, Type + +from .common import MLPBlock + + +class TwoWayTransformer(nn.Module): + def __init__( + self, + depth: int, + embedding_dim: int, + num_heads: int, + mlp_dim: int, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + ) -> None: + """ + A transformer decoder that attends to an input image using + queries whose positional embedding is supplied. + + Args: + depth (int): number of layers in the transformer + embedding_dim (int): the channel dimension for the input embeddings + num_heads (int): the number of heads for multihead attention. Must + divide embedding_dim + mlp_dim (int): the channel dimension internal to the MLP block + activation (nn.Module): the activation to use in the MLP block + """ + super().__init__() + self.depth = depth + self.embedding_dim = embedding_dim + self.num_heads = num_heads + self.mlp_dim = mlp_dim + self.layers = nn.ModuleList() + + for i in range(depth): + self.layers.append( + TwoWayAttentionBlock( + embedding_dim=embedding_dim, + num_heads=num_heads, + mlp_dim=mlp_dim, + activation=activation, + attention_downsample_rate=attention_downsample_rate, + skip_first_layer_pe=(i == 0), + ) + ) + + self.final_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm_final_attn = nn.LayerNorm(embedding_dim) + + def forward( + self, + image_embedding: Tensor, + image_pe: Tensor, + point_embedding: Tensor, + ) -> Tuple[Tensor, Tensor]: + """ + Args: + image_embedding (torch.Tensor): image to attend to. Should be shape + B x embedding_dim x h x w for any h and w. + image_pe (torch.Tensor): the positional encoding to add to the image. Must + have the same shape as image_embedding. + point_embedding (torch.Tensor): the embedding to add to the query points. + Must have shape B x N_points x embedding_dim for any N_points. + + Returns: + torch.Tensor: the processed point_embedding + torch.Tensor: the processed image_embedding + """ + # BxCxHxW -> BxHWxC == B x N_image_tokens x C + bs, c, h, w = image_embedding.shape + image_embedding = image_embedding.flatten(2).permute(0, 2, 1) + image_pe = image_pe.flatten(2).permute(0, 2, 1) + + # Prepare queries + queries = point_embedding + keys = image_embedding + + # Apply transformer blocks and final layernorm + for layer in self.layers: + queries, keys = layer( + queries=queries, + keys=keys, + query_pe=point_embedding, + key_pe=image_pe, + ) + + # Apply the final attention layer from the points to the image + q = queries + point_embedding + k = keys + image_pe + attn_out = self.final_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm_final_attn(queries) + + return queries, keys + + +class TwoWayAttentionBlock(nn.Module): + def __init__( + self, + embedding_dim: int, + num_heads: int, + mlp_dim: int = 2048, + activation: Type[nn.Module] = nn.ReLU, + attention_downsample_rate: int = 2, + skip_first_layer_pe: bool = False, + ) -> None: + """ + A transformer block with four layers: (1) self-attention of sparse + inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp + block on sparse inputs, and (4) cross attention of dense inputs to sparse + inputs. + + Arguments: + embedding_dim (int): the channel dimension of the embeddings + num_heads (int): the number of heads in the attention layers + mlp_dim (int): the hidden dimension of the mlp block + activation (nn.Module): the activation of the mlp block + skip_first_layer_pe (bool): skip the PE on the first layer + """ + super().__init__() + self.self_attn = Attention(embedding_dim, num_heads) + self.norm1 = nn.LayerNorm(embedding_dim) + + self.cross_attn_token_to_image = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + self.norm2 = nn.LayerNorm(embedding_dim) + + self.mlp = MLPBlock(embedding_dim, mlp_dim, activation) + self.norm3 = nn.LayerNorm(embedding_dim) + + self.norm4 = nn.LayerNorm(embedding_dim) + self.cross_attn_image_to_token = Attention( + embedding_dim, num_heads, downsample_rate=attention_downsample_rate + ) + + self.skip_first_layer_pe = skip_first_layer_pe + + def forward( + self, queries: Tensor, keys: Tensor, query_pe: Tensor, key_pe: Tensor + ) -> Tuple[Tensor, Tensor]: + # Self attention block + if self.skip_first_layer_pe: + queries = self.self_attn(q=queries, k=queries, v=queries) + else: + q = queries + query_pe + attn_out = self.self_attn(q=q, k=q, v=queries) + queries = queries + attn_out + queries = self.norm1(queries) + + # Cross attention block, tokens attending to image embedding + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_token_to_image(q=q, k=k, v=keys) + queries = queries + attn_out + queries = self.norm2(queries) + + # MLP block + mlp_out = self.mlp(queries) + queries = queries + mlp_out + queries = self.norm3(queries) + + # Cross attention block, image embedding attending to tokens + q = queries + query_pe + k = keys + key_pe + attn_out = self.cross_attn_image_to_token(q=k, k=q, v=queries) + keys = keys + attn_out + keys = self.norm4(keys) + + return queries, keys + + +class Attention(nn.Module): + """ + An attention layer that allows for downscaling the size of the embedding + after projection to queries, keys, and values. + """ + + def __init__( + self, + embedding_dim: int, + num_heads: int, + downsample_rate: int = 1, + ) -> None: + super().__init__() + self.embedding_dim = embedding_dim + self.internal_dim = embedding_dim // downsample_rate + self.num_heads = num_heads + assert self.internal_dim % num_heads == 0, "num_heads must divide embedding_dim." + + self.q_proj = nn.Linear(embedding_dim, self.internal_dim) + self.k_proj = nn.Linear(embedding_dim, self.internal_dim) + self.v_proj = nn.Linear(embedding_dim, self.internal_dim) + self.out_proj = nn.Linear(self.internal_dim, embedding_dim) + + def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor: + b, n, c = x.shape + x = x.reshape(b, n, num_heads, c // num_heads) + return x.transpose(1, 2) # B x N_heads x N_tokens x C_per_head + + def _recombine_heads(self, x: Tensor) -> Tensor: + b, n_heads, n_tokens, c_per_head = x.shape + x = x.transpose(1, 2) + return x.reshape(b, n_tokens, n_heads * c_per_head) # B x N_tokens x C + + def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor: + # Input projections + q = self.q_proj(q) + k = self.k_proj(k) + v = self.v_proj(v) + + # Separate into heads + q = self._separate_heads(q, self.num_heads) + k = self._separate_heads(k, self.num_heads) + v = self._separate_heads(v, self.num_heads) + + # Attention + _, _, _, c_per_head = q.shape + attn = q @ k.permute(0, 1, 3, 2) # B x N_heads x N_tokens x N_tokens + attn = attn / math.sqrt(c_per_head) + attn = torch.softmax(attn, dim=-1) + + # Get output + out = attn @ v + out = self._recombine_heads(out) + out = self.out_proj(out) + + return out diff --git a/custom_controlnet_aux/sam/predictor.py b/custom_controlnet_aux/sam/predictor.py new file mode 100644 index 0000000000000000000000000000000000000000..a3820fb7de8647e5d6adf229debc498b33caad62 --- /dev/null +++ b/custom_controlnet_aux/sam/predictor.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +from .modeling import Sam + +from typing import Optional, Tuple + +from .utils.transforms import ResizeLongestSide + + +class SamPredictor: + def __init__( + self, + sam_model: Sam, + ) -> None: + """ + Uses SAM to calculate the image embedding for an image, and then + allow repeated, efficient mask prediction given prompts. + + Arguments: + sam_model (Sam): The model to use for mask prediction. + """ + super().__init__() + self.model = sam_model + self.transform = ResizeLongestSide(sam_model.image_encoder.img_size) + self.reset_image() + + def set_image( + self, + image: np.ndarray, + image_format: str = "RGB", + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. + + Arguments: + image (np.ndarray): The image for calculating masks. Expects an + image in HWC uint8 format, with pixel values in [0, 255]. + image_format (str): The color format of the image, in ['RGB', 'BGR']. + """ + assert image_format in [ + "RGB", + "BGR", + ], f"image_format must be in ['RGB', 'BGR'], is {image_format}." + if image_format != self.model.image_format: + image = image[..., ::-1] + + # Transform the image to the form expected by the model + input_image = self.transform.apply_image(image) + input_image_torch = torch.as_tensor(input_image, device=self.device) + input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :] + + self.set_torch_image(input_image_torch, image.shape[:2]) + + @torch.no_grad() + def set_torch_image( + self, + transformed_image: torch.Tensor, + original_image_size: Tuple[int, ...], + ) -> None: + """ + Calculates the image embeddings for the provided image, allowing + masks to be predicted with the 'predict' method. Expects the input + image to be already transformed to the format expected by the model. + + Arguments: + transformed_image (torch.Tensor): The input image, with shape + 1x3xHxW, which has been transformed with ResizeLongestSide. + original_image_size (tuple(int, int)): The size of the image + before transformation, in (H, W) format. + """ + assert ( + len(transformed_image.shape) == 4 + and transformed_image.shape[1] == 3 + and max(*transformed_image.shape[2:]) == self.model.image_encoder.img_size + ), f"set_torch_image input must be BCHW with long side {self.model.image_encoder.img_size}." + self.reset_image() + + self.original_size = original_image_size + self.input_size = tuple(transformed_image.shape[-2:]) + input_image = self.model.preprocess(transformed_image) + self.features = self.model.image_encoder(input_image) + self.is_image_set = True + + def predict( + self, + point_coords: Optional[np.ndarray] = None, + point_labels: Optional[np.ndarray] = None, + box: Optional[np.ndarray] = None, + mask_input: Optional[np.ndarray] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Predict masks for the given input prompts, using the currently set image. + + Arguments: + point_coords (np.ndarray or None): A Nx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (np.ndarray or None): A length N array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + box (np.ndarray or None): A length 4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form 1xHxW, where + for SAM, H=W=256. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (np.ndarray): The output masks in CxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (np.ndarray): An array of length C containing the model's + predictions for the quality of each mask. + (np.ndarray): An array of shape CxHxW, where C is the number + of masks and H=W=256. These low resolution logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + # Transform input prompts + coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None + if point_coords is not None: + assert ( + point_labels is not None + ), "point_labels must be supplied if point_coords is supplied." + point_coords = self.transform.apply_coords(point_coords, self.original_size) + coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device) + labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=self.device) + coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] + if box is not None: + box = self.transform.apply_boxes(box, self.original_size) + box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device) + box_torch = box_torch[None, :] + if mask_input is not None: + mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=self.device) + mask_input_torch = mask_input_torch[None, :, :, :] + + masks, iou_predictions, low_res_masks = self.predict_torch( + coords_torch, + labels_torch, + box_torch, + mask_input_torch, + multimask_output, + return_logits=return_logits, + ) + + masks_np = masks[0].detach().cpu().numpy() + iou_predictions_np = iou_predictions[0].detach().cpu().numpy() + low_res_masks_np = low_res_masks[0].detach().cpu().numpy() + return masks_np, iou_predictions_np, low_res_masks_np + + @torch.no_grad() + def predict_torch( + self, + point_coords: Optional[torch.Tensor], + point_labels: Optional[torch.Tensor], + boxes: Optional[torch.Tensor] = None, + mask_input: Optional[torch.Tensor] = None, + multimask_output: bool = True, + return_logits: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Predict masks for the given input prompts, using the currently set image. + Input prompts are batched torch tensors and are expected to already be + transformed to the input frame using ResizeLongestSide. + + Arguments: + point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the + model. Each point is in (X,Y) in pixels. + point_labels (torch.Tensor or None): A BxN array of labels for the + point prompts. 1 indicates a foreground point and 0 indicates a + background point. + boxes (np.ndarray or None): A Bx4 array given a box prompt to the + model, in XYXY format. + mask_input (np.ndarray): A low resolution mask input to the model, typically + coming from a previous prediction iteration. Has form Bx1xHxW, where + for SAM, H=W=256. Masks returned by a previous iteration of the + predict method do not need further transformation. + multimask_output (bool): If true, the model will return three masks. + For ambiguous input prompts (such as a single click), this will often + produce better masks than a single prediction. If only a single + mask is needed, the model's predicted quality score can be used + to select the best mask. For non-ambiguous prompts, such as multiple + input prompts, multimask_output=False can give better results. + return_logits (bool): If true, returns un-thresholded masks logits + instead of a binary mask. + + Returns: + (torch.Tensor): The output masks in BxCxHxW format, where C is the + number of masks, and (H, W) is the original image size. + (torch.Tensor): An array of shape BxC containing the model's + predictions for the quality of each mask. + (torch.Tensor): An array of shape BxCxHxW, where C is the number + of masks and H=W=256. These low res logits can be passed to + a subsequent iteration as mask input. + """ + if not self.is_image_set: + raise RuntimeError("An image must be set with .set_image(...) before mask prediction.") + + if point_coords is not None: + points = (point_coords, point_labels) + else: + points = None + + # Embed prompts + sparse_embeddings, dense_embeddings = self.model.prompt_encoder( + points=points, + boxes=boxes, + masks=mask_input, + ) + + # Predict masks + low_res_masks, iou_predictions = self.model.mask_decoder( + image_embeddings=self.features, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embeddings, + dense_prompt_embeddings=dense_embeddings, + multimask_output=multimask_output, + ) + + # Upscale the masks to the original image resolution + masks = self.model.postprocess_masks(low_res_masks, self.input_size, self.original_size) + + if not return_logits: + masks = masks > self.model.mask_threshold + + return masks, iou_predictions, low_res_masks + + def get_image_embedding(self) -> torch.Tensor: + """ + Returns the image embeddings for the currently set image, with + shape 1xCxHxW, where C is the embedding dimension and (H,W) are + the embedding spatial dimension of SAM (typically C=256, H=W=64). + """ + if not self.is_image_set: + raise RuntimeError( + "An image must be set with .set_image(...) to generate an embedding." + ) + assert self.features is not None, "Features must exist if an image has been set." + return self.features + + @property + def device(self) -> torch.device: + return self.model.device + + def reset_image(self) -> None: + """Resets the currently set image.""" + self.is_image_set = False + self.features = None + self.orig_h = None + self.orig_w = None + self.input_h = None + self.input_w = None diff --git a/custom_controlnet_aux/sam/utils/__init__.py b/custom_controlnet_aux/sam/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5277f46157403e47fd830fc519144b97ef69d4ae --- /dev/null +++ b/custom_controlnet_aux/sam/utils/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. diff --git a/custom_controlnet_aux/sam/utils/amg.py b/custom_controlnet_aux/sam/utils/amg.py new file mode 100644 index 0000000000000000000000000000000000000000..de480030b10c60fa398fbd6019b7ca32b6ea970f --- /dev/null +++ b/custom_controlnet_aux/sam/utils/amg.py @@ -0,0 +1,346 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch + +import math +from copy import deepcopy +from itertools import product +from typing import Any, Dict, Generator, ItemsView, List, Tuple + + +class MaskData: + """ + A structure for storing masks and their related data in batched format. + Implements basic filtering and concatenation. + """ + + def __init__(self, **kwargs) -> None: + for v in kwargs.values(): + assert isinstance( + v, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats = dict(**kwargs) + + def __setitem__(self, key: str, item: Any) -> None: + assert isinstance( + item, (list, np.ndarray, torch.Tensor) + ), "MaskData only supports list, numpy arrays, and torch tensors." + self._stats[key] = item + + def __delitem__(self, key: str) -> None: + del self._stats[key] + + def __getitem__(self, key: str) -> Any: + return self._stats[key] + + def items(self) -> ItemsView[str, Any]: + return self._stats.items() + + def filter(self, keep: torch.Tensor) -> None: + for k, v in self._stats.items(): + if v is None: + self._stats[k] = None + elif isinstance(v, torch.Tensor): + self._stats[k] = v[torch.as_tensor(keep, device=v.device)] + elif isinstance(v, np.ndarray): + self._stats[k] = v[keep.detach().cpu().numpy()] + elif isinstance(v, list) and keep.dtype == torch.bool: + self._stats[k] = [a for i, a in enumerate(v) if keep[i]] + elif isinstance(v, list): + self._stats[k] = [v[i] for i in keep] + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def cat(self, new_stats: "MaskData") -> None: + for k, v in new_stats.items(): + if k not in self._stats or self._stats[k] is None: + self._stats[k] = deepcopy(v) + elif isinstance(v, torch.Tensor): + self._stats[k] = torch.cat([self._stats[k], v], dim=0) + elif isinstance(v, np.ndarray): + self._stats[k] = np.concatenate([self._stats[k], v], axis=0) + elif isinstance(v, list): + self._stats[k] = self._stats[k] + deepcopy(v) + else: + raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.") + + def to_numpy(self) -> None: + for k, v in self._stats.items(): + if isinstance(v, torch.Tensor): + self._stats[k] = v.detach().cpu().numpy() + + +def is_box_near_crop_edge( + boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0 +) -> torch.Tensor: + """Filter masks at the edge of a crop, but not at the edge of the original image.""" + crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device) + orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device) + boxes = uncrop_boxes_xyxy(boxes, crop_box).float() + near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0) + near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0) + near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge) + return torch.any(near_crop_edge, dim=1) + + +def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor: + box_xywh = deepcopy(box_xyxy) + box_xywh[2] = box_xywh[2] - box_xywh[0] + box_xywh[3] = box_xywh[3] - box_xywh[1] + return box_xywh + + +def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]: + assert len(args) > 0 and all( + len(a) == len(args[0]) for a in args + ), "Batched iteration must have inputs of all the same size." + n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0) + for b in range(n_batches): + yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args] + + +def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]: + """ + Encodes masks to an uncompressed RLE, in the format expected by + pycoco tools. + """ + # Put in fortran order and flatten h,w + b, h, w = tensor.shape + tensor = tensor.permute(0, 2, 1).flatten(1) + + # Compute change indices + diff = tensor[:, 1:] ^ tensor[:, :-1] + change_indices = diff.nonzero() + + # Encode run length + out = [] + for i in range(b): + cur_idxs = change_indices[change_indices[:, 0] == i, 1] + cur_idxs = torch.cat( + [ + torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device), + cur_idxs + 1, + torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device), + ] + ) + btw_idxs = cur_idxs[1:] - cur_idxs[:-1] + counts = [] if tensor[i, 0] == 0 else [0] + counts.extend(btw_idxs.detach().cpu().tolist()) + out.append({"size": [h, w], "counts": counts}) + return out + + +def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray: + """Compute a binary mask from an uncompressed RLE.""" + h, w = rle["size"] + mask = np.empty(h * w, dtype=bool) + idx = 0 + parity = False + for count in rle["counts"]: + mask[idx : idx + count] = parity + idx += count + parity ^= True + mask = mask.reshape(w, h) + return mask.transpose() # Put in C order + + +def area_from_rle(rle: Dict[str, Any]) -> int: + return sum(rle["counts"][1::2]) + + +def calculate_stability_score( + masks: torch.Tensor, mask_threshold: float, threshold_offset: float +) -> torch.Tensor: + """ + Computes the stability score for a batch of masks. The stability + score is the IoU between the binary masks obtained by thresholding + the predicted mask logits at high and low values. + """ + # One mask is always contained inside the other. + # Save memory by preventing unnecessary cast to torch.int64 + intersections = ( + (masks > (mask_threshold + threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + unions = ( + (masks > (mask_threshold - threshold_offset)) + .sum(-1, dtype=torch.int16) + .sum(-1, dtype=torch.int32) + ) + return intersections / unions + + +def build_point_grid(n_per_side: int) -> np.ndarray: + """Generates a 2D grid of points evenly spaced in [0,1]x[0,1].""" + offset = 1 / (2 * n_per_side) + points_one_side = np.linspace(offset, 1 - offset, n_per_side) + points_x = np.tile(points_one_side[None, :], (n_per_side, 1)) + points_y = np.tile(points_one_side[:, None], (1, n_per_side)) + points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2) + return points + + +def build_all_layer_point_grids( + n_per_side: int, n_layers: int, scale_per_layer: int +) -> List[np.ndarray]: + """Generates point grids for all crop layers.""" + points_by_layer = [] + for i in range(n_layers + 1): + n_points = int(n_per_side / (scale_per_layer**i)) + points_by_layer.append(build_point_grid(n_points)) + return points_by_layer + + +def generate_crop_boxes( + im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float +) -> Tuple[List[List[int]], List[int]]: + """ + Generates a list of crop boxes of different sizes. Each layer + has (2**i)**2 boxes for the ith layer. + """ + crop_boxes, layer_idxs = [], [] + im_h, im_w = im_size + short_side = min(im_h, im_w) + + # Original image + crop_boxes.append([0, 0, im_w, im_h]) + layer_idxs.append(0) + + def crop_len(orig_len, n_crops, overlap): + return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops)) + + for i_layer in range(n_layers): + n_crops_per_side = 2 ** (i_layer + 1) + overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side)) + + crop_w = crop_len(im_w, n_crops_per_side, overlap) + crop_h = crop_len(im_h, n_crops_per_side, overlap) + + crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)] + crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)] + + # Crops in XYWH format + for x0, y0 in product(crop_box_x0, crop_box_y0): + box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)] + crop_boxes.append(box) + layer_idxs.append(i_layer + 1) + + return crop_boxes, layer_idxs + + +def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device) + # Check if boxes has a channel dimension + if len(boxes.shape) == 3: + offset = offset.unsqueeze(1) + return boxes + offset + + +def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor: + x0, y0, _, _ = crop_box + offset = torch.tensor([[x0, y0]], device=points.device) + # Check if points has a channel dimension + if len(points.shape) == 3: + offset = offset.unsqueeze(1) + return points + offset + + +def uncrop_masks( + masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int +) -> torch.Tensor: + x0, y0, x1, y1 = crop_box + if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h: + return masks + # Coordinate transform masks + pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0) + pad = (x0, pad_x - x0, y0, pad_y - y0) + return torch.nn.functional.pad(masks, pad, value=0) + + +def remove_small_regions( + mask: np.ndarray, area_thresh: float, mode: str +) -> Tuple[np.ndarray, bool]: + """ + Removes small disconnected regions and holes in a mask. Returns the + mask and an indicator of if the mask has been modified. + """ + import cv2 # type: ignore + + assert mode in ["holes", "islands"] + correct_holes = mode == "holes" + working_mask = (correct_holes ^ mask).astype(np.uint8) + n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8) + sizes = stats[:, -1][1:] # Row 0 is background label + small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh] + if len(small_regions) == 0: + return mask, False + fill_labels = [0] + small_regions + if not correct_holes: + fill_labels = [i for i in range(n_labels) if i not in fill_labels] + # If every region is below threshold, keep largest + if len(fill_labels) == 0: + fill_labels = [int(np.argmax(sizes)) + 1] + mask = np.isin(regions, fill_labels) + return mask, True + + +def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]: + from custom_pycocotools import mask as mask_utils # type: ignore + + h, w = uncompressed_rle["size"] + rle = mask_utils.frPyObjects(uncompressed_rle, h, w) + rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json + return rle + + +def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor: + """ + Calculates boxes in XYXY format around masks. Return [0,0,0,0] for + an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4. + """ + # torch.max below raises an error on empty inputs, just skip in this case + if torch.numel(masks) == 0: + return torch.zeros(*masks.shape[:-2], 4, device=masks.device) + + # Normalize shape to CxHxW + shape = masks.shape + h, w = shape[-2:] + if len(shape) > 2: + masks = masks.flatten(0, -3) + else: + masks = masks.unsqueeze(0) + + # Get top and bottom edges + in_height, _ = torch.max(masks, dim=-1) + in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :] + bottom_edges, _ = torch.max(in_height_coords, dim=-1) + in_height_coords = in_height_coords + h * (~in_height) + top_edges, _ = torch.min(in_height_coords, dim=-1) + + # Get left and right edges + in_width, _ = torch.max(masks, dim=-2) + in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :] + right_edges, _ = torch.max(in_width_coords, dim=-1) + in_width_coords = in_width_coords + w * (~in_width) + left_edges, _ = torch.min(in_width_coords, dim=-1) + + # If the mask is empty the right edge will be to the left of the left edge. + # Replace these boxes with [0, 0, 0, 0] + empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges) + out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1) + out = out * (~empty_filter).unsqueeze(-1) + + # Return to original shape + if len(shape) > 2: + out = out.reshape(*shape[:-2], 4) + else: + out = out[0] + + return out diff --git a/custom_controlnet_aux/sam/utils/onnx.py b/custom_controlnet_aux/sam/utils/onnx.py new file mode 100644 index 0000000000000000000000000000000000000000..3196bdf4b782e6eeb3da4ad66ef3c7b1741535fe --- /dev/null +++ b/custom_controlnet_aux/sam/utils/onnx.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn import functional as F + +from typing import Tuple + +from ..modeling import Sam +from .amg import calculate_stability_score + + +class SamOnnxModel(nn.Module): + """ + This model should not be called directly, but is used in ONNX export. + It combines the prompt encoder, mask decoder, and mask postprocessing of Sam, + with some functions modified to enable model tracing. Also supports extra + options controlling what information. See the ONNX export script for details. + """ + + def __init__( + self, + model: Sam, + return_single_mask: bool, + use_stability_score: bool = False, + return_extra_metrics: bool = False, + ) -> None: + super().__init__() + self.mask_decoder = model.mask_decoder + self.model = model + self.img_size = model.image_encoder.img_size + self.return_single_mask = return_single_mask + self.use_stability_score = use_stability_score + self.stability_score_offset = 1.0 + self.return_extra_metrics = return_extra_metrics + + @staticmethod + def resize_longest_image_size( + input_image_size: torch.Tensor, longest_side: int + ) -> torch.Tensor: + input_image_size = input_image_size.to(torch.float32) + scale = longest_side / torch.max(input_image_size) + transformed_size = scale * input_image_size + transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64) + return transformed_size + + def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor: + point_coords = point_coords + 0.5 + point_coords = point_coords / self.img_size + point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords) + point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding) + + point_embedding = point_embedding * (point_labels != -1) + point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * ( + point_labels == -1 + ) + + for i in range(self.model.prompt_encoder.num_point_embeddings): + point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[ + i + ].weight * (point_labels == i) + + return point_embedding + + def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor: + mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask) + mask_embedding = mask_embedding + ( + 1 - has_mask_input + ) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1) + return mask_embedding + + def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor: + masks = F.interpolate( + masks, + size=(self.img_size, self.img_size), + mode="bilinear", + align_corners=False, + ) + + prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64) + masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore + + orig_im_size = orig_im_size.to(torch.int64) + h, w = orig_im_size[0], orig_im_size[1] + masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False) + return masks + + def select_masks( + self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Determine if we should return the multiclick mask or not from the number of points. + # The reweighting is used to avoid control flow. + score_reweight = torch.tensor( + [[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)] + ).to(iou_preds.device) + score = iou_preds + (num_points - 2.5) * score_reweight + best_idx = torch.argmax(score, dim=1) + masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1) + iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1) + + return masks, iou_preds + + @torch.no_grad() + def forward( + self, + image_embeddings: torch.Tensor, + point_coords: torch.Tensor, + point_labels: torch.Tensor, + mask_input: torch.Tensor, + has_mask_input: torch.Tensor, + orig_im_size: torch.Tensor, + ): + sparse_embedding = self._embed_points(point_coords, point_labels) + dense_embedding = self._embed_masks(mask_input, has_mask_input) + + masks, scores = self.model.mask_decoder.predict_masks( + image_embeddings=image_embeddings, + image_pe=self.model.prompt_encoder.get_dense_pe(), + sparse_prompt_embeddings=sparse_embedding, + dense_prompt_embeddings=dense_embedding, + ) + + if self.use_stability_score: + scores = calculate_stability_score( + masks, self.model.mask_threshold, self.stability_score_offset + ) + + if self.return_single_mask: + masks, scores = self.select_masks(masks, scores, point_coords.shape[1]) + + upscaled_masks = self.mask_postprocessing(masks, orig_im_size) + + if self.return_extra_metrics: + stability_scores = calculate_stability_score( + upscaled_masks, self.model.mask_threshold, self.stability_score_offset + ) + areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1) + return upscaled_masks, scores, stability_scores, areas, masks + + return upscaled_masks, scores, masks diff --git a/custom_controlnet_aux/sam/utils/transforms.py b/custom_controlnet_aux/sam/utils/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c08ba1e3db751f3a5483a003be38c69c2cf2df85 --- /dev/null +++ b/custom_controlnet_aux/sam/utils/transforms.py @@ -0,0 +1,102 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import torch +from torch.nn import functional as F +from torchvision.transforms.functional import resize, to_pil_image # type: ignore + +from copy import deepcopy +from typing import Tuple + + +class ResizeLongestSide: + """ + Resizes images to the longest side 'target_length', as well as provides + methods for resizing coordinates and boxes. Provides methods for + transforming both numpy array and batched torch tensors. + """ + + def __init__(self, target_length: int) -> None: + self.target_length = target_length + + def apply_image(self, image: np.ndarray) -> np.ndarray: + """ + Expects a numpy array with shape HxWxC in uint8 format. + """ + target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) + return np.array(resize(to_pil_image(image), target_size)) + + def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array of length 2 in the final dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).astype(float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: + """ + Expects a numpy array shape Bx4. Requires the original image size + in (H, W) format. + """ + boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: + """ + Expects batched images with shape BxCxHxW and float format. This + transformation may not exactly match apply_image. apply_image is + the transformation expected by the model. + """ + # Expects an image in BCHW format. May not exactly match apply_image. + target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) + return F.interpolate( + image, target_size, mode="bilinear", align_corners=False, antialias=True + ) + + def apply_coords_torch( + self, coords: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with length 2 in the last dimension. Requires the + original image size in (H, W) format. + """ + old_h, old_w = original_size + new_h, new_w = self.get_preprocess_shape( + original_size[0], original_size[1], self.target_length + ) + coords = deepcopy(coords).to(torch.float) + coords[..., 0] = coords[..., 0] * (new_w / old_w) + coords[..., 1] = coords[..., 1] * (new_h / old_h) + return coords + + def apply_boxes_torch( + self, boxes: torch.Tensor, original_size: Tuple[int, ...] + ) -> torch.Tensor: + """ + Expects a torch tensor with shape Bx4. Requires the original image + size in (H, W) format. + """ + boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) + return boxes.reshape(-1, 4) + + @staticmethod + def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: + """ + Compute the output size given input size and target long side length. + """ + scale = long_side_length * 1.0 / max(oldh, oldw) + newh, neww = oldh * scale, oldw * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) diff --git a/custom_controlnet_aux/scribble/__init__.py b/custom_controlnet_aux/scribble/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53bb98b5470b4e7d410ab15948a44cdd595fcf9d --- /dev/null +++ b/custom_controlnet_aux/scribble/__init__.py @@ -0,0 +1,41 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, HWC3 + +#Not to be confused with "scribble" from HED. That is "fake scribble" which is more accurate and less picky than this. +class ScribbleDetector: + def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_AREA", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + detected_map = np.zeros_like(input_image, dtype=np.uint8) + detected_map[np.min(input_image, axis=2) < 127] = 255 + detected_map = 255 - detected_map + + detected_map = remove_pad(detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + +class ScribbleXDog_Detector: + def __call__(self, input_image=None, detect_resolution=512, thr_a=32, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + g1 = cv2.GaussianBlur(input_image.astype(np.float32), (0, 0), 0.5) + g2 = cv2.GaussianBlur(input_image.astype(np.float32), (0, 0), 5.0) + dog = (255 - np.min(g2 - g1, axis=2)).clip(0, 255).astype(np.uint8) + result = np.zeros_like(input_image, dtype=np.uint8) + result[2 * (255 - dog) > thr_a] = 255 + #result = 255 - result + + detected_map = HWC3(remove_pad(result)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/shuffle/__init__.py b/custom_controlnet_aux/shuffle/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..278db7973d501161cca19f0364c71c76220f5128 --- /dev/null +++ b/custom_controlnet_aux/shuffle/__init__.py @@ -0,0 +1,87 @@ +import warnings + +import cv2 +import numpy as np +from PIL import Image +import random + +from custom_controlnet_aux.util import HWC3, common_input_validate, img2mask, make_noise_disk, resize_image_with_pad + + +class ContentShuffleDetector: + def __call__(self, input_image, h=None, w=None, f=None, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", seed=-1, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + H, W, C = input_image.shape + if h is None: + h = H + if w is None: + w = W + if f is None: + f = 256 + rng = np.random.default_rng(seed) if seed else None + x = make_noise_disk(h, w, 1, f, rng=rng) * float(W - 1) + y = make_noise_disk(h, w, 1, f, rng=rng) * float(H - 1) + flow = np.concatenate([x, y], axis=2).astype(np.float32) + detected_map = cv2.remap(input_image, flow, None, cv2.INTER_LINEAR) + detected_map = remove_pad(detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + + +class ColorShuffleDetector: + def __call__(self, img): + H, W, C = img.shape + F = np.random.randint(64, 384) + A = make_noise_disk(H, W, 3, F) + B = make_noise_disk(H, W, 3, F) + C = (A + B) / 2.0 + A = (C + (A - C) * 3.0).clip(0, 1) + B = (C + (B - C) * 3.0).clip(0, 1) + L = img.astype(np.float32) / 255.0 + Y = A * L + B * (1 - L) + Y -= np.min(Y, axis=(0, 1), keepdims=True) + Y /= np.maximum(np.max(Y, axis=(0, 1), keepdims=True), 1e-5) + Y *= 255.0 + return Y.clip(0, 255).astype(np.uint8) + + +class GrayDetector: + def __call__(self, img): + eps = 1e-5 + X = img.astype(np.float32) + r, g, b = X[:, :, 0], X[:, :, 1], X[:, :, 2] + kr, kg, kb = [random.random() + eps for _ in range(3)] + ks = kr + kg + kb + kr /= ks + kg /= ks + kb /= ks + Y = r * kr + g * kg + b * kb + Y = np.stack([Y] * 3, axis=2) + return Y.clip(0, 255).astype(np.uint8) + + +class DownSampleDetector: + def __call__(self, img, level=3, k=16.0): + h = img.astype(np.float32) + for _ in range(level): + h += np.random.normal(loc=0.0, scale=k, size=h.shape) + h = cv2.pyrDown(h) + for _ in range(level): + h = cv2.pyrUp(h) + h += np.random.normal(loc=0.0, scale=k, size=h.shape) + return h.clip(0, 255).astype(np.uint8) + + +class Image2MaskShuffleDetector: + def __init__(self, resolution=(640, 512)): + self.H, self.W = resolution + + def __call__(self, img): + m = img2mask(img, self.H, self.W) + m *= 255.0 + return m.clip(0, 255).astype(np.uint8) diff --git a/custom_controlnet_aux/teed/Fmish.py b/custom_controlnet_aux/teed/Fmish.py new file mode 100644 index 0000000000000000000000000000000000000000..40c867a272bdaf948d435a46a6aaa70478036994 --- /dev/null +++ b/custom_controlnet_aux/teed/Fmish.py @@ -0,0 +1,17 @@ +""" +Script provides functional interface for Mish activation function. +""" + +# import pytorch +import torch +import torch.nn.functional as F + + +@torch.jit.script +def mish(input): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + See additional documentation for mish class. + """ + return input * torch.tanh(F.softplus(input)) \ No newline at end of file diff --git a/custom_controlnet_aux/teed/Fsmish.py b/custom_controlnet_aux/teed/Fsmish.py new file mode 100644 index 0000000000000000000000000000000000000000..eb8c55cad89953f202384eee81173e2b2ae10712 --- /dev/null +++ b/custom_controlnet_aux/teed/Fsmish.py @@ -0,0 +1,20 @@ +""" +Script based on: +Wang, Xueliang, Honge Ren, and Achuan Wang. + "Smish: A Novel Activation Function for Deep Learning Methods. + " Electronics 11.4 (2022): 540. +""" + +# import pytorch +import torch +import torch.nn.functional as F + + +@torch.jit.script +def smish(input): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(sigmoid(x)))) + See additional documentation for mish class. + """ + return input * torch.tanh(torch.log(1+torch.sigmoid(input))) \ No newline at end of file diff --git a/custom_controlnet_aux/teed/LICENSE.txt b/custom_controlnet_aux/teed/LICENSE.txt new file mode 100644 index 0000000000000000000000000000000000000000..4a99ffdd7372b1bfa44ea302330343cb7370d0e9 --- /dev/null +++ b/custom_controlnet_aux/teed/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Xavier Soria Poma + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/custom_controlnet_aux/teed/Xmish.py b/custom_controlnet_aux/teed/Xmish.py new file mode 100644 index 0000000000000000000000000000000000000000..15e84ed98d165ab6eb4db672dbfdf50ef0953e31 --- /dev/null +++ b/custom_controlnet_aux/teed/Xmish.py @@ -0,0 +1,43 @@ +""" +Applies the mish function element-wise: +mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) +""" + +# import pytorch +import torch +import torch.nn.functional as F +from torch import nn + +# import activation functions +from .Fmish import mish + + +class Mish(nn.Module): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + Examples: + >>> m = Mish() + >>> input = torch.randn(2) + >>> output = m(input) + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() + + def forward(self, input): + """ + Forward pass of the function. + """ + if torch.__version__ >= "1.9": + return F.mish(input) + else: + return mish(input) \ No newline at end of file diff --git a/custom_controlnet_aux/teed/Xsmish.py b/custom_controlnet_aux/teed/Xsmish.py new file mode 100644 index 0000000000000000000000000000000000000000..df75bee4d3d3585b1b265435a713ef0186b1e701 --- /dev/null +++ b/custom_controlnet_aux/teed/Xsmish.py @@ -0,0 +1,43 @@ +""" +Script based on: +Wang, Xueliang, Honge Ren, and Achuan Wang. + "Smish: A Novel Activation Function for Deep Learning Methods. + " Electronics 11.4 (2022): 540. +smish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + sigmoid(x))) +""" + +# import pytorch +import torch +import torch.nn.functional as F +from torch import nn + +# import activation functions +from .Fsmish import smish + + +class Smish(nn.Module): + """ + Applies the mish function element-wise: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + Shape: + - Input: (N, *) where * means, any number of additional + dimensions + - Output: (N, *), same shape as the input + Examples: + >>> m = Mish() + >>> input = torch.randn(2) + >>> output = m(input) + Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html + """ + + def __init__(self): + """ + Init method. + """ + super().__init__() + + def forward(self, input): + """ + Forward pass of the function. + """ + return smish(input) \ No newline at end of file diff --git a/custom_controlnet_aux/teed/__init__.py b/custom_controlnet_aux/teed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e08bb3f695e92de4f3cb4bd01a106af9cea1cbfb --- /dev/null +++ b/custom_controlnet_aux/teed/__init__.py @@ -0,0 +1,58 @@ +""" +Hello, welcome on board, +""" +from __future__ import print_function + +import os +import cv2 +import numpy as np + +import torch + +from .ted import TED # TEED architecture +from einops import rearrange +from custom_controlnet_aux.util import safe_step, custom_hf_download, BDS_MODEL_NAME, common_input_validate, resize_image_with_pad, HWC3 +from PIL import Image + + +class TEDDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="7_model.pth", subfolder="Annotators"): + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder=subfolder) + model = TED() + model.load_state_dict(torch.load(model_path, map_location='cpu')) + model.eval() + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + + def __call__(self, input_image, detect_resolution=512, safe_steps=2, upscale_method="INTER_CUBIC", output_type="pil", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + H, W, _ = input_image.shape + with torch.no_grad(): + image_teed = torch.from_numpy(input_image.copy()).float().to(self.device) + image_teed = rearrange(image_teed, 'h w c -> 1 c h w') + edges = self.model(image_teed) + edges = [e.detach().cpu().numpy().astype(np.float32)[0, 0] for e in edges] + edges = [cv2.resize(e, (W, H), interpolation=cv2.INTER_LINEAR) for e in edges] + edges = np.stack(edges, axis=2) + edge = 1 / (1 + np.exp(-np.mean(edges, axis=2).astype(np.float64))) + if safe_steps != 0: + edge = safe_step(edge, safe_steps) + edge = (edge * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = remove_pad(HWC3(edge)) + if output_type == "pil": + detected_map = Image.fromarray(detected_map[..., :3]) + + return detected_map diff --git a/custom_controlnet_aux/teed/ted.py b/custom_controlnet_aux/teed/ted.py new file mode 100644 index 0000000000000000000000000000000000000000..ff347d5acf767126cc95022a2c2036c0262db40d --- /dev/null +++ b/custom_controlnet_aux/teed/ted.py @@ -0,0 +1,296 @@ +# TEED: is a Tiny but Efficient Edge Detection, it comes from the LDC-B3 +# with a Slightly modification +# LDC parameters: +# 155665 +# TED > 58K + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .Fsmish import smish as Fsmish +from .Xsmish import Smish + + +def weight_init(m): + if isinstance(m, (nn.Conv2d,)): + torch.nn.init.xavier_normal_(m.weight, gain=1.0) + + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + + # for fusion layer + if isinstance(m, (nn.ConvTranspose2d,)): + torch.nn.init.xavier_normal_(m.weight, gain=1.0) + if m.bias is not None: + torch.nn.init.zeros_(m.bias) + +class CoFusion(nn.Module): + # from LDC + + def __init__(self, in_ch, out_ch): + super(CoFusion, self).__init__() + self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, + stride=1, padding=1) # before 64 + self.conv3= nn.Conv2d(32, out_ch, kernel_size=3, + stride=1, padding=1)# before 64 instead of 32 + self.relu = nn.ReLU() + self.norm_layer1 = nn.GroupNorm(4, 32) # before 64 + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.relu(self.norm_layer1(self.conv1(x))) + attn = F.softmax(self.conv3(attn), dim=1) + return ((x * attn).sum(1)).unsqueeze(1) + + +class CoFusion2(nn.Module): + # TEDv14-3 + def __init__(self, in_ch, out_ch): + super(CoFusion2, self).__init__() + self.conv1 = nn.Conv2d(in_ch, 32, kernel_size=3, + stride=1, padding=1) # before 64 + # self.conv2 = nn.Conv2d(32, 32, kernel_size=3, + # stride=1, padding=1)# before 64 + self.conv3 = nn.Conv2d(32, out_ch, kernel_size=3, + stride=1, padding=1)# before 64 instead of 32 + self.smish= Smish()#nn.ReLU(inplace=True) + + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.conv1(self.smish(x)) + attn = self.conv3(self.smish(attn)) # before , )dim=1) + + # return ((fusecat * attn).sum(1)).unsqueeze(1) + return ((x * attn).sum(1)).unsqueeze(1) + +class DoubleFusion(nn.Module): + # TED fusion before the final edge map prediction + def __init__(self, in_ch, out_ch): + super(DoubleFusion, self).__init__() + self.DWconv1 = nn.Conv2d(in_ch, in_ch*8, kernel_size=3, + stride=1, padding=1, groups=in_ch) # before 64 + self.PSconv1 = nn.PixelShuffle(1) + + self.DWconv2 = nn.Conv2d(24, 24*1, kernel_size=3, + stride=1, padding=1,groups=24)# before 64 instead of 32 + + self.AF= Smish()#XAF() #nn.Tanh()# XAF() # # Smish()# + + + def forward(self, x): + # fusecat = torch.cat(x, dim=1) + attn = self.PSconv1(self.DWconv1(self.AF(x))) # #TEED best res TEDv14 [8, 32, 352, 352] + + attn2 = self.PSconv1(self.DWconv2(self.AF(attn))) # #TEED best res TEDv14[8, 3, 352, 352] + + return Fsmish(((attn2 +attn).sum(1)).unsqueeze(1)) #TED best res + +class _DenseLayer(nn.Sequential): + def __init__(self, input_features, out_features): + super(_DenseLayer, self).__init__() + + self.add_module('conv1', nn.Conv2d(input_features, out_features, + kernel_size=3, stride=1, padding=2, bias=True)), + self.add_module('smish1', Smish()), + self.add_module('conv2', nn.Conv2d(out_features, out_features, + kernel_size=3, stride=1, bias=True)) + def forward(self, x): + x1, x2 = x + + new_features = super(_DenseLayer, self).forward(Fsmish(x1)) # F.relu() + + return 0.5 * (new_features + x2), x2 + + +class _DenseBlock(nn.Sequential): + def __init__(self, num_layers, input_features, out_features): + super(_DenseBlock, self).__init__() + for i in range(num_layers): + layer = _DenseLayer(input_features, out_features) + self.add_module('denselayer%d' % (i + 1), layer) + input_features = out_features + + +class UpConvBlock(nn.Module): + def __init__(self, in_features, up_scale): + super(UpConvBlock, self).__init__() + self.up_factor = 2 + self.constant_features = 16 + + layers = self.make_deconv_layers(in_features, up_scale) + assert layers is not None, layers + self.features = nn.Sequential(*layers) + + def make_deconv_layers(self, in_features, up_scale): + layers = [] + all_pads=[0,0,1,3,7] + for i in range(up_scale): + kernel_size = 2 ** up_scale + pad = all_pads[up_scale] # kernel_size-1 + out_features = self.compute_out_features(i, up_scale) + layers.append(nn.Conv2d(in_features, out_features, 1)) + layers.append(Smish()) + layers.append(nn.ConvTranspose2d( + out_features, out_features, kernel_size, stride=2, padding=pad)) + in_features = out_features + return layers + + def compute_out_features(self, idx, up_scale): + return 1 if idx == up_scale - 1 else self.constant_features + + def forward(self, x): + return self.features(x) + + +class SingleConvBlock(nn.Module): + def __init__(self, in_features, out_features, stride, use_ac=False): + super(SingleConvBlock, self).__init__() + # self.use_bn = use_bs + self.use_ac=use_ac + self.conv = nn.Conv2d(in_features, out_features, 1, stride=stride, + bias=True) + if self.use_ac: + self.smish = Smish() + + def forward(self, x): + x = self.conv(x) + if self.use_ac: + return self.smish(x) + else: + return x + +class DoubleConvBlock(nn.Module): + def __init__(self, in_features, mid_features, + out_features=None, + stride=1, + use_act=True): + super(DoubleConvBlock, self).__init__() + + self.use_act = use_act + if out_features is None: + out_features = mid_features + self.conv1 = nn.Conv2d(in_features, mid_features, + 3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(mid_features, out_features, 3, padding=1) + self.smish= Smish()#nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv1(x) + x = self.smish(x) + x = self.conv2(x) + if self.use_act: + x = self.smish(x) + return x + + +class TED(nn.Module): + """ Definition of Tiny and Efficient Edge Detector + model + """ + + def __init__(self): + super(TED, self).__init__() + self.block_1 = DoubleConvBlock(3, 16, 16, stride=2,) + self.block_2 = DoubleConvBlock(16, 32, use_act=False) + self.dblock_3 = _DenseBlock(1, 32, 48) # [32,48,100,100] before (2, 32, 64) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + # skip1 connection, see fig. 2 + self.side_1 = SingleConvBlock(16, 32, 2) + + # skip2 connection, see fig. 2 + self.pre_dense_3 = SingleConvBlock(32, 48, 1) # before (32, 64, 1) + + # USNet + self.up_block_1 = UpConvBlock(16, 1) + self.up_block_2 = UpConvBlock(32, 1) + self.up_block_3 = UpConvBlock(48, 2) # (32, 64, 1) + + self.block_cat = DoubleFusion(3,3) # TEED: DoubleFusion + + self.apply(weight_init) + + def slice(self, tensor, slice_shape): + t_shape = tensor.shape + img_h, img_w = slice_shape + if img_w!=t_shape[-1] or img_h!=t_shape[2]: + new_tensor = F.interpolate( + tensor, size=(img_h, img_w), mode='bicubic',align_corners=False) + + else: + new_tensor=tensor + # tensor[..., :height, :width] + return new_tensor + def resize_input(self,tensor): + t_shape = tensor.shape + if t_shape[2] % 8 != 0 or t_shape[3] % 8 != 0: + img_w= ((t_shape[3]// 8) + 1) * 8 + img_h = ((t_shape[2] // 8) + 1) * 8 + new_tensor = F.interpolate( + tensor, size=(img_h, img_w), mode='bicubic', align_corners=False) + else: + new_tensor = tensor + return new_tensor + + def crop_bdcn(data1, h, w, crop_h, crop_w): + # Based on BDCN Implementation @ https://github.com/pkuCactus/BDCN + _, _, h1, w1 = data1.size() + assert (h <= h1 and w <= w1) + data = data1[:, :, crop_h:crop_h + h, crop_w:crop_w + w] + return data + + + def forward(self, x, single_test=False): + assert x.ndim == 4, x.shape + # supose the image size is 352x352 + + # Block 1 + block_1 = self.block_1(x) # [8,16,176,176] + block_1_side = self.side_1(block_1) # 16 [8,32,88,88] + + # Block 2 + block_2 = self.block_2(block_1) # 32 # [8,32,176,176] + block_2_down = self.maxpool(block_2) # [8,32,88,88] + block_2_add = block_2_down + block_1_side # [8,32,88,88] + + # Block 3 + block_3_pre_dense = self.pre_dense_3(block_2_down) # [8,64,88,88] block 3 L connection + block_3, _ = self.dblock_3([block_2_add, block_3_pre_dense]) # [8,64,88,88] + + # upsampling blocks + out_1 = self.up_block_1(block_1) + out_2 = self.up_block_2(block_2) + out_3 = self.up_block_3(block_3) + + results = [out_1, out_2, out_3] + + # concatenate multiscale outputs + block_cat = torch.cat(results, dim=1) # Bx6xHxW + block_cat = self.block_cat(block_cat) # Bx1xHxW DoubleFusion + + results.append(block_cat) + return results + + +if __name__ == '__main__': + batch_size = 8 + img_height = 352 + img_width = 352 + + # device = "cuda" if torch.cuda.is_available() else "cpu" + device = "cpu" + input = torch.rand(batch_size, 3, img_height, img_width).to(device) + # target = torch.rand(batch_size, 1, img_height, img_width).to(device) + print(f"input shape: {input.shape}") + model = TED().to(device) + output = model(input) + print(f"output shapes: {[t.shape for t in output]}") + + # for i in range(20000): + # print(i) + # output = model(input) + # loss = nn.MSELoss()(output[-1], target) + # loss.backward() diff --git a/custom_controlnet_aux/tests/requirements.txt b/custom_controlnet_aux/tests/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/tests/test_image.png b/custom_controlnet_aux/tests/test_image.png new file mode 100644 index 0000000000000000000000000000000000000000..e0c185fcecef3045fed44818d62546464b855b16 --- /dev/null +++ b/custom_controlnet_aux/tests/test_image.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8c5524e32ceda656a903e9aee1afb2f53c7554b5e5ab0cf666095fe7acc9edfd +size 337869 diff --git a/custom_controlnet_aux/tests/test_processor.py b/custom_controlnet_aux/tests/test_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..f0b420a9a60fcd6ed0a04b29cd5cab91a553173e --- /dev/null +++ b/custom_controlnet_aux/tests/test_processor.py @@ -0,0 +1,95 @@ +"""Test the Processor class.""" +import unittest +from PIL import Image + +from custom_controlnet_aux.processor import Processor + + +class TestProcessor(unittest.TestCase): + def test_hed(self): + processor = Processor('hed') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_midas(self): + processor = Processor('midas') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_mlsd(self): + processor = Processor('mlsd') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_openpose(self): + processor = Processor('openpose') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_pidinet(self): + processor = Processor('pidinet') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_normalbae(self): + processor = Processor('normalbae') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart(self): + processor = Processor('lineart') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart_coarse(self): + processor = Processor('lineart_coarse') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_lineart_anime(self): + processor = Processor('lineart_anime') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_canny(self): + processor = Processor('canny') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_content_shuffle(self): + processor = Processor('content_shuffle') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_zoe(self): + processor = Processor('zoe') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_mediapipe_face(self): + processor = Processor('mediapipe_face') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + def test_tile(self): + processor = Processor('tile') + image = Image.open('test_image.png') + processed_image = processor(image) + self.assertIsInstance(processed_image, bytes) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/custom_controlnet_aux/tests/test_processor_pytest.py b/custom_controlnet_aux/tests/test_processor_pytest.py new file mode 100644 index 0000000000000000000000000000000000000000..08cb07eb38f8689987b5c6e9978b52a40d58763e --- /dev/null +++ b/custom_controlnet_aux/tests/test_processor_pytest.py @@ -0,0 +1,78 @@ +import io + +import numpy as np +import pytest +from PIL import Image + +from custom_controlnet_aux.processor import MODELS, Processor + + +@pytest.fixture(params=[ + 'scribble_hed', + 'softedge_hed', + 'scribble_hedsafe', + 'softedge_hedsafe', + 'depth_midas', + 'mlsd', + 'openpose', + 'openpose_hand', + 'openpose_face', + 'openpose_faceonly', + 'openpose_full', + 'scribble_pidinet', + 'softedge_pidinet', + 'scribble_pidsafe', + 'softedge_pidsafe', + 'normal_bae', + 'lineart_coarse', + 'lineart_realistic', + 'lineart_anime', + 'canny', + 'shuffle', + 'depth_zoe', + 'depth_leres', + 'depth_leres++', + 'mediapipe_face', + 'tile' +]) +def processor(request): + return Processor(request.param) + + +def test_processor_init(processor): + assert isinstance(processor.processor, MODELS[processor.processor_id]['class']) + assert isinstance(processor.params, dict) + + +def test_processor_call(processor): + # Load test image + with open('test_image.png', 'rb') as f: + image_bytes = f.read() + image = Image.open(io.BytesIO(image_bytes)) + + # Output size + resolution = 512 + W, H = image.size + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + + # Test processing + processed_image = processor(image) + assert isinstance(processed_image, Image.Image) + assert processed_image.size == (W, H) + + +def test_processor_call_bytes(processor): + # Load test image + with open('test_image.png', 'rb') as f: + image_bytes = f.read() + + # Test processing + processed_image_bytes = processor(image_bytes, to_pil=False) + assert isinstance(processed_image_bytes, bytes) + assert len(processed_image_bytes) > 0 \ No newline at end of file diff --git a/custom_controlnet_aux/tile/__init__.py b/custom_controlnet_aux/tile/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c58555bcaa928a3044c50e30bf2021b0e3b1926 --- /dev/null +++ b/custom_controlnet_aux/tile/__init__.py @@ -0,0 +1,82 @@ +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import get_upscale_method, common_input_validate, HWC3 +from .guided_filter import FastGuidedFilter + +class TileDetector: + def __call__(self, input_image=None, pyrUp_iters=3, output_type=None, upscale_method="INTER_AREA", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + H, W, _ = input_image.shape + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + detected_map = cv2.resize(input_image, (W // (2 ** pyrUp_iters), H // (2 ** pyrUp_iters)), + interpolation=get_upscale_method(upscale_method)) + detected_map = HWC3(detected_map) + + for _ in range(pyrUp_iters): + detected_map = cv2.pyrUp(detected_map) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + + +# Source: https://huggingface.co/TTPlanet/TTPLanet_SDXL_Controlnet_Tile_Realistic/blob/main/TTP_tile_preprocessor_v5.py + +def apply_gaussian_blur(image_np, ksize=5, sigmaX=1.0): + if ksize % 2 == 0: + ksize += 1 # ksize must be odd + blurred_image = cv2.GaussianBlur(image_np, (ksize, ksize), sigmaX=sigmaX) + return blurred_image + +def apply_guided_filter(image_np, radius, eps, scale): + filter = FastGuidedFilter(image_np, radius, eps, scale) + return filter.filter(image_np) + +class TTPlanet_Tile_Detector_GF: + def __call__(self, input_image, scale_factor, blur_strength, radius, eps, output_type=None, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + img_np = input_image[:, :, ::-1] # RGB to BGR + + # Apply Gaussian blur + img_np = apply_gaussian_blur(img_np, ksize=int(blur_strength), sigmaX=blur_strength / 2) + + # Apply Guided Filter + img_np = apply_guided_filter(img_np, radius, eps, scale_factor) + + # Resize image + height, width = img_np.shape[:2] + new_width = int(width / scale_factor) + new_height = int(height / scale_factor) + resized_down = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_AREA) + resized_img = cv2.resize(resized_down, (width, height), interpolation=cv2.INTER_CUBIC) + detected_map = HWC3(resized_img[:, :, ::-1]) # BGR to RGB + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + +class TTPLanet_Tile_Detector_Simple: + def __call__(self, input_image, scale_factor, blur_strength, output_type=None, **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + img_np = input_image[:, :, ::-1] # RGB to BGR + + # Resize image first if you want blur to apply after resizing + height, width = img_np.shape[:2] + new_width = int(width / scale_factor) + new_height = int(height / scale_factor) + resized_down = cv2.resize(img_np, (new_width, new_height), interpolation=cv2.INTER_AREA) + resized_img = cv2.resize(resized_down, (width, height), interpolation=cv2.INTER_LANCZOS4) + + # Apply Gaussian blur after resizing + img_np = apply_gaussian_blur(resized_img, ksize=int(blur_strength), sigmaX=blur_strength / 2) + detected_map = HWC3(img_np[:, :, ::-1]) # BGR to RGB + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map diff --git a/custom_controlnet_aux/tile/guided_filter.py b/custom_controlnet_aux/tile/guided_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..88b953e504ccd5cf7fe00efa81207d41f5421091 --- /dev/null +++ b/custom_controlnet_aux/tile/guided_filter.py @@ -0,0 +1,281 @@ + +# -*- coding: utf-8 -*- +## @package guided_filter.core.filters +# +# Implementation of guided filter. +# * GuidedFilter: Original guided filter. +# * FastGuidedFilter: Fast version of the guided filter. +# @author tody +# @date 2015/08/26 + +import numpy as np +import cv2 + +## Convert image into float32 type. +def to32F(img): + if img.dtype == np.float32: + return img + return (1.0 / 255.0) * np.float32(img) + +## Convert image into uint8 type. +def to8U(img): + if img.dtype == np.uint8: + return img + return np.clip(np.uint8(255.0 * img), 0, 255) + +## Return if the input image is gray or not. +def _isGray(I): + return len(I.shape) == 2 + + +## Return down sampled image. +# @param scale (w/s, h/s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _downSample(I, scale=4, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_NEAREST) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w / scale), int(h / scale)), interpolation=cv2.INTER_NEAREST) + + +## Return up sampled image. +# @param scale (w*s, h*s) image will be created. +# @param shape I.shape[:2]=(h, w). numpy friendly size parameter. +def _upSample(I, scale=2, shape=None): + if shape is not None: + h, w = shape + return cv2.resize(I, (w, h), interpolation=cv2.INTER_LINEAR) + + h, w = I.shape[:2] + return cv2.resize(I, (int(w * scale), int(h * scale)), interpolation=cv2.INTER_LINEAR) + +## Fast guide filter. +class FastGuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + # @param scale Down sampled scale. + def __init__(self, I, radius=5, epsilon=0.4, scale=4): + I_32F = to32F(I) + self._I = I_32F + h, w = I.shape[:2] + + I_sub = _downSample(I_32F, scale) + + self._I_sub = I_sub + radius = int(radius / scale) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_sub, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_sub, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + shape_original = p.shape[:2] + + p_sub = _downSample(p_32F, shape=self._I_sub.shape[:2]) + + if _isGray(p_sub): + return self._filterGray(p_sub, shape_original) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_sub[:, :, ci], shape_original) + return to8U(q) + + def _filterGray(self, p_sub, shape_original): + ab_sub = self._guided_filter._computeCoefficients(p_sub) + ab = [_upSample(abi, shape=shape_original) for abi in ab_sub] + return self._guided_filter._computeOutput(ab, self._I) + + +## Guide filter. +class GuidedFilter: + ## Constructor. + # @param I Input guidance image. Color or gray. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + I_32F = to32F(I) + + if _isGray(I): + self._guided_filter = GuidedFilterGray(I_32F, radius, epsilon) + else: + self._guided_filter = GuidedFilterColor(I_32F, radius, epsilon) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return to8U(self._guided_filter.filter(p)) + + +## Common parts of guided filter. +# +# This class is used by guided_filter class. GuidedFilterGray and GuidedFilterColor. +# Based on guided_filter._computeCoefficients, guided_filter._computeOutput, +# GuidedFilterCommon.filter computes filtered image for color and gray. +class GuidedFilterCommon: + def __init__(self, guided_filter): + self._guided_filter = guided_filter + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + p_32F = to32F(p) + if _isGray(p_32F): + return self._filterGray(p_32F) + + cs = p.shape[2] + q = np.array(p_32F) + + for ci in range(cs): + q[:, :, ci] = self._filterGray(p_32F[:, :, ci]) + return q + + def _filterGray(self, p): + ab = self._guided_filter._computeCoefficients(p) + return self._guided_filter._computeOutput(ab, self._guided_filter._I) + + +## Guided filter for gray guidance image. +class GuidedFilterGray: + # @param I Input gray guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.4): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + self._I_mean = cv2.blur(I, (r, r)) + I_mean_sq = cv2.blur(I ** 2, (r, r)) + self._I_var = I_mean_sq - self._I_mean ** 2 + + def _computeCoefficients(self, p): + r = self._radius + p_mean = cv2.blur(p, (r, r)) + p_cov = p_mean - self._I_mean * p_mean + a = p_cov / (self._I_var + self._epsilon) + b = p_mean - a * self._I_mean + a_mean = cv2.blur(a, (r, r)) + b_mean = cv2.blur(b, (r, r)) + return a_mean, b_mean + + def _computeOutput(self, ab, I): + a_mean, b_mean = ab + return a_mean * I + b_mean + + +## Guided filter for color guidance image. +class GuidedFilterColor: + # @param I Input color guidance image. + # @param radius Radius of Guided Filter. + # @param epsilon Regularization term of Guided Filter. + def __init__(self, I, radius=5, epsilon=0.2): + self._radius = 2 * radius + 1 + self._epsilon = epsilon + self._I = to32F(I) + self._initFilter() + self._filter_common = GuidedFilterCommon(self) + + ## Apply filter for the input image. + # @param p Input image for the filtering. + def filter(self, p): + return self._filter_common.filter(p) + + def _initFilter(self): + I = self._I + r = self._radius + eps = self._epsilon + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + self._Ir_mean = cv2.blur(Ir, (r, r)) + self._Ig_mean = cv2.blur(Ig, (r, r)) + self._Ib_mean = cv2.blur(Ib, (r, r)) + + Irr_var = cv2.blur(Ir ** 2, (r, r)) - self._Ir_mean ** 2 + eps + Irg_var = cv2.blur(Ir * Ig, (r, r)) - self._Ir_mean * self._Ig_mean + Irb_var = cv2.blur(Ir * Ib, (r, r)) - self._Ir_mean * self._Ib_mean + Igg_var = cv2.blur(Ig * Ig, (r, r)) - self._Ig_mean * self._Ig_mean + eps + Igb_var = cv2.blur(Ig * Ib, (r, r)) - self._Ig_mean * self._Ib_mean + Ibb_var = cv2.blur(Ib * Ib, (r, r)) - self._Ib_mean * self._Ib_mean + eps + + Irr_inv = Igg_var * Ibb_var - Igb_var * Igb_var + Irg_inv = Igb_var * Irb_var - Irg_var * Ibb_var + Irb_inv = Irg_var * Igb_var - Igg_var * Irb_var + Igg_inv = Irr_var * Ibb_var - Irb_var * Irb_var + Igb_inv = Irb_var * Irg_var - Irr_var * Igb_var + Ibb_inv = Irr_var * Igg_var - Irg_var * Irg_var + + I_cov = Irr_inv * Irr_var + Irg_inv * Irg_var + Irb_inv * Irb_var + Irr_inv /= I_cov + Irg_inv /= I_cov + Irb_inv /= I_cov + Igg_inv /= I_cov + Igb_inv /= I_cov + Ibb_inv /= I_cov + + self._Irr_inv = Irr_inv + self._Irg_inv = Irg_inv + self._Irb_inv = Irb_inv + self._Igg_inv = Igg_inv + self._Igb_inv = Igb_inv + self._Ibb_inv = Ibb_inv + + def _computeCoefficients(self, p): + r = self._radius + I = self._I + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + p_mean = cv2.blur(p, (r, r)) + + Ipr_mean = cv2.blur(Ir * p, (r, r)) + Ipg_mean = cv2.blur(Ig * p, (r, r)) + Ipb_mean = cv2.blur(Ib * p, (r, r)) + + Ipr_cov = Ipr_mean - self._Ir_mean * p_mean + Ipg_cov = Ipg_mean - self._Ig_mean * p_mean + Ipb_cov = Ipb_mean - self._Ib_mean * p_mean + + ar = self._Irr_inv * Ipr_cov + self._Irg_inv * Ipg_cov + self._Irb_inv * Ipb_cov + ag = self._Irg_inv * Ipr_cov + self._Igg_inv * Ipg_cov + self._Igb_inv * Ipb_cov + ab = self._Irb_inv * Ipr_cov + self._Igb_inv * Ipg_cov + self._Ibb_inv * Ipb_cov + b = p_mean - ar * self._Ir_mean - ag * self._Ig_mean - ab * self._Ib_mean + + ar_mean = cv2.blur(ar, (r, r)) + ag_mean = cv2.blur(ag, (r, r)) + ab_mean = cv2.blur(ab, (r, r)) + b_mean = cv2.blur(b, (r, r)) + + return ar_mean, ag_mean, ab_mean, b_mean + + def _computeOutput(self, ab, I): + ar_mean, ag_mean, ab_mean, b_mean = ab + + Ir, Ig, Ib = I[:, :, 0], I[:, :, 1], I[:, :, 2] + + q = (ar_mean * Ir + + ag_mean * Ig + + ab_mean * Ib + + b_mean) + + return q \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/__init__.py b/custom_controlnet_aux/uniformer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ccb75770f8a34809975f28c10116b7ccb5a112b1 --- /dev/null +++ b/custom_controlnet_aux/uniformer/__init__.py @@ -0,0 +1,68 @@ +import os +from .inference import init_segmentor, inference_segmentor, show_result_pyplot +import warnings +import cv2 +import numpy as np +from PIL import Image +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME +import torch + +from custom_mmpkg.custom_mmseg.core.evaluation import get_palette + +config_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "upernet_global_small.py") + + + +class UniformerSegmentor: + def __init__(self, netNetwork): + self.model = netNetwork + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="upernet_global_small.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + netNetwork = init_segmentor(config_file, model_path, device="cpu") + netNetwork.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_path)['state_dict'].items()}) + netNetwork.eval() + + return cls(netNetwork) + + def to(self, device): + self.model.to(device) + return self + + def _inference(self, img): + if next(self.model.parameters()).device.type == 'mps': + # adaptive_avg_pool2d can fail on MPS, workaround with CPU + import torch.nn.functional + + orig_adaptive_avg_pool2d = torch.nn.functional.adaptive_avg_pool2d + def cpu_if_exception(input, *args, **kwargs): + try: + return orig_adaptive_avg_pool2d(input, *args, **kwargs) + except: + return orig_adaptive_avg_pool2d(input.cpu(), *args, **kwargs).to(input.device) + + try: + torch.nn.functional.adaptive_avg_pool2d = cpu_if_exception + result = inference_segmentor(self.model, img) + finally: + torch.nn.functional.adaptive_avg_pool2d = orig_adaptive_avg_pool2d + else: + result = inference_segmentor(self.model, img) + + res_img = show_result_pyplot(self.model, img, result, get_palette('ade'), opacity=1) + return res_img + + def __call__(self, input_image=None, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + detected_map = self._inference(input_image) + detected_map = remove_pad(HWC3(detected_map)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py new file mode 100644 index 0000000000000000000000000000000000000000..efc8b4bb20c981f3db6df7eb52b3dc0744c94cc0 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/ade20k.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'ADE20KDataset' +data_root = 'data/ade/ADEChallengeData2016' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py new file mode 100644 index 0000000000000000000000000000000000000000..298594ea925f87f22b37094a2ec50e370aec96a0 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/chase_db1.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'ChaseDB1Dataset' +data_root = 'data/CHASE_DB1' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (960, 999) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..f21867c63e1835f6fceb61f066e802fd8fd2a735 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes.py @@ -0,0 +1,54 @@ +# dataset settings +dataset_type = 'CityscapesDataset' +data_root = 'data/cityscapes/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 1024) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 1024), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/train', + ann_dir='gtFine/train', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='leftImg8bit/val', + ann_dir='gtFine/val', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py new file mode 100644 index 0000000000000000000000000000000000000000..336c7b254fe392b4703039fec86a83acdbd2e1a5 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/cityscapes_769x769.py @@ -0,0 +1,35 @@ +_base_ = './cityscapes.py' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (769, 769) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2049, 1025), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2049, 1025), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + train=dict(pipeline=train_pipeline), + val=dict(pipeline=test_pipeline), + test=dict(pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py new file mode 100644 index 0000000000000000000000000000000000000000..06e8ff606e0d2a4514ec8b7d2c6c436a32efcbf4 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/drive.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'DRIVEDataset' +data_root = 'data/DRIVE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (584, 565) +crop_size = (64, 64) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py new file mode 100644 index 0000000000000000000000000000000000000000..242d790eb1b83e75cf6b7eaa7a35c674099311ad --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/hrf.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'HRFDataset' +data_root = 'data/HRF' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (2336, 3504) +crop_size = (256, 256) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py new file mode 100644 index 0000000000000000000000000000000000000000..ff65bad1b86d7e3a5980bb5b9fc55798dc8df5f4 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py new file mode 100644 index 0000000000000000000000000000000000000000..37585abab89834b95cd5bdd993b994fca1db65f6 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_context_59.py @@ -0,0 +1,60 @@ +# dataset settings +dataset_type = 'PascalContextDataset59' +data_root = 'data/VOCdevkit/VOC2010/' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) + +img_scale = (520, 520) +crop_size = (480, 480) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', reduce_zero_label=True), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClassContext', + split='ImageSets/SegmentationContext/val.txt', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1d42d0c5781f56dc177d860d856bb34adce555 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12.py @@ -0,0 +1,57 @@ +# dataset settings +dataset_type = 'PascalVOCDataset' +data_root = 'data/VOCdevkit/VOC2012' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +crop_size = (512, 512) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=(2048, 512), ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']), +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=(2048, 512), + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']), + ]) +] +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/train.txt', + pipeline=train_pipeline), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='JPEGImages', + ann_dir='SegmentationClass', + split='ImageSets/Segmentation/val.txt', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py new file mode 100644 index 0000000000000000000000000000000000000000..3f23b6717d53ad29f02dd15046802a2631a5076b --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/pascal_voc12_aug.py @@ -0,0 +1,9 @@ +_base_ = './pascal_voc12.py' +# dataset settings +data = dict( + train=dict( + ann_dir=['SegmentationClass', 'SegmentationClassAug'], + split=[ + 'ImageSets/Segmentation/train.txt', + 'ImageSets/Segmentation/aug.txt' + ])) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py b/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py new file mode 100644 index 0000000000000000000000000000000000000000..3f71b25488cc11a6b4d582ac52b5a24e1ad1cf8e --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/datasets/stare.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'STAREDataset' +data_root = 'data/STARE' +img_norm_cfg = dict( + mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True) +img_scale = (605, 700) +crop_size = (128, 128) +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations'), + dict(type='Resize', img_scale=img_scale, ratio_range=(0.5, 2.0)), + dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75), + dict(type='RandomFlip', prob=0.5), + dict(type='PhotoMetricDistortion'), + dict(type='Normalize', **img_norm_cfg), + dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255), + dict(type='DefaultFormatBundle'), + dict(type='Collect', keys=['img', 'gt_semantic_seg']) +] +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict( + type='MultiScaleFlipAug', + img_scale=img_scale, + # img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0], + flip=False, + transforms=[ + dict(type='Resize', keep_ratio=True), + dict(type='RandomFlip'), + dict(type='Normalize', **img_norm_cfg), + dict(type='ImageToTensor', keys=['img']), + dict(type='Collect', keys=['img']) + ]) +] + +data = dict( + samples_per_gpu=4, + workers_per_gpu=4, + train=dict( + type='RepeatDataset', + times=40000, + dataset=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/training', + ann_dir='annotations/training', + pipeline=train_pipeline)), + val=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline), + test=dict( + type=dataset_type, + data_root=data_root, + img_dir='images/validation', + ann_dir='annotations/validation', + pipeline=test_pipeline)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py b/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py new file mode 100644 index 0000000000000000000000000000000000000000..b564cc4e7e7d9a67dacaaddecb100e4d8f5c005b --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/default_runtime.py @@ -0,0 +1,14 @@ +# yapf:disable +log_config = dict( + interval=50, + hooks=[ + dict(type='TextLoggerHook', by_epoch=False), + # dict(type='TensorboardLoggerHook') + ]) +# yapf:enable +dist_params = dict(backend='nccl') +log_level = 'INFO' +load_from = None +resume_from = None +workflow = [('train', 1)] +cudnn_benchmark = True diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..a2cb653827e44e6015b3b83bc578003e614a6aa1 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/ann_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ANNHead', + in_channels=[1024, 2048], + in_index=[2, 3], + channels=512, + project_channels=256, + query_scales=(1, ), + key_pool_scales=(1, 3, 6, 8), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..c8f5316cbcf3896ba9de7ca2c801eba512f01d5e --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/apcnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='APCHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..794148f576b9e215c3c6963e73dffe98204b7717 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/ccnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='CCHead', + in_channels=2048, + in_index=3, + channels=512, + recurrence=2, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py b/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..eff8d9458c877c5db894957e0b1b4597e40da6ab --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/cgnet.py @@ -0,0 +1,35 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=1e-03, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='CGNet', + norm_cfg=norm_cfg, + in_channels=3, + num_channels=(32, 64, 128), + num_blocks=(3, 21), + dilations=(2, 4), + reductions=(8, 16)), + decode_head=dict( + type='FCNHead', + in_channels=256, + in_index=2, + channels=256, + num_convs=0, + concat_input=False, + dropout_ratio=0, + num_classes=19, + norm_cfg=norm_cfg, + loss_decode=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0, + class_weight=[ + 2.5959933, 6.7415504, 3.5354059, 9.8663225, 9.690899, 9.369352, + 10.289121, 9.953208, 4.3097677, 9.490387, 7.674431, 9.396905, + 10.347791, 6.3927646, 10.226669, 10.241062, 10.280587, + 10.396974, 10.055647 + ])), + # model training and testing settings + train_cfg=dict(sampler=None), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..2c934939fac48525f22ad86f489a041dd7db7d09 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/danet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DAHead', + in_channels=2048, + in_index=3, + channels=512, + pam_channels=64, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..d7a43bee01422ad4795dd27874e0cd4bb6cbfecf --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='ASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd262999d8b2cb8e14a5c32190ae73f479d8e81 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='ASPPHead', + in_channels=64, + in_index=4, + channels=16, + dilations=(1, 12, 24, 36), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..050e39e091d816df9028d23aa3ecf9db74e441e1 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/deeplabv3plus_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DepthwiseSeparableASPPHead', + in_channels=2048, + in_index=3, + channels=512, + dilations=(1, 12, 24, 36), + c1_in_channels=256, + c1_channels=48, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..d22ba52640bebd805b3b8d07025e276dfb023759 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/dmnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DMHead', + in_channels=2048, + in_index=3, + channels=512, + filter_sizes=(1, 3, 5, 7), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=dict(type='SyncBN', requires_grad=True), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..edb4c174c51e34c103737ba39bfc48bf831e561d --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/dnl_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='DNLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..26adcd430926de0862204a71d345f2543167f27b --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/emanet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EMAHead', + in_channels=2048, + in_index=3, + channels=256, + ema_channels=512, + num_bases=64, + num_stages=3, + momentum=0.1, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..be777123a886503172a95fe0719e956a147bbd68 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/encnet_r50-d8.py @@ -0,0 +1,48 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='EncHead', + in_channels=[512, 1024, 2048], + in_index=(1, 2, 3), + channels=512, + num_codes=32, + use_se_loss=True, + add_lateral=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0), + loss_se_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.2)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py new file mode 100644 index 0000000000000000000000000000000000000000..32fdeb659355a5ce5ef2cc7c2f30742703811cdf --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fast_scnn.py @@ -0,0 +1,57 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True, momentum=0.01) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='FastSCNN', + downsample_dw_channels=(32, 48), + global_in_channels=64, + global_block_channels=(64, 96, 128), + global_block_strides=(2, 2, 1), + global_out_channels=128, + higher_in_channels=64, + lower_in_channels=128, + fusion_out_channels=128, + out_indices=(0, 1, 2), + norm_cfg=norm_cfg, + align_corners=False), + decode_head=dict( + type='DepthwiseSeparableFCNHead', + in_channels=128, + channels=128, + concat_input=False, + num_classes=19, + in_index=-1, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + auxiliary_head=[ + dict( + type='FCNHead', + in_channels=128, + channels=32, + num_convs=1, + num_classes=19, + in_index=-2, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + dict( + type='FCNHead', + in_channels=64, + channels=32, + num_convs=1, + num_classes=19, + in_index=-3, + norm_cfg=norm_cfg, + concat_input=False, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=0.4)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py new file mode 100644 index 0000000000000000000000000000000000000000..c3e299bc89ada56ca14bbffcbdb08a586b8ed9e9 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_hr18.py @@ -0,0 +1,52 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + channels=sum([18, 36, 72, 144]), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..5e98f6cc918b6146fc6d613c6918e825ef1355c3 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_r50-d8.py @@ -0,0 +1,45 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='FCNHead', + in_channels=2048, + in_index=3, + channels=512, + num_convs=2, + concat_input=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..a33e7972877f902d0e7d18401ca675e3e4e60a18 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fcn_unet_s5-d16.py @@ -0,0 +1,51 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='FCNHead', + in_channels=64, + in_index=4, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..86ab327db92e44c14822d65f1c9277cb007f17c1 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_r50.py @@ -0,0 +1,36 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py b/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..8aae98c5991055bfcc08e82ccdc09f8b1d9f8a8d --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/fpn_uniformer.py @@ -0,0 +1,35 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1), + neck=dict( + type='FPN', + in_channels=[64, 128, 320, 512], + out_channels=256, + num_outs=4), + decode_head=dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=0.1, + num_classes=150, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole') +) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..3d2ad69f5c22adfe79d5fdabf920217628987166 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/gcnet_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='GCHead', + in_channels=2048, + in_index=3, + channels=512, + ratio=1 / 4., + pooling_type='att', + fusion_types=('channel_add', ), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..93258242a90695cc94a7c6bd41562d6a75988771 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/lraspp_m-v3-d8.py @@ -0,0 +1,25 @@ +# model settings +norm_cfg = dict(type='SyncBN', eps=0.001, requires_grad=True) +model = dict( + type='EncoderDecoder', + backbone=dict( + type='MobileNetV3', + arch='large', + out_indices=(1, 3, 16), + norm_cfg=norm_cfg), + decode_head=dict( + type='LRASPPHead', + in_channels=(16, 24, 960), + in_index=(0, 1, 2), + channels=128, + input_transform='multiple_select', + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..5674a39854cafd1f2e363bac99c58ccae62f24da --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/nonlocal_r50-d8.py @@ -0,0 +1,46 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='NLHead', + in_channels=2048, + in_index=3, + channels=512, + dropout_ratio=0.1, + reduction=2, + use_scale=True, + mode='embedded_gaussian', + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py b/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py new file mode 100644 index 0000000000000000000000000000000000000000..c60f62a7cdf3f5c5096a7a7e725e8268fddcb057 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_hr18.py @@ -0,0 +1,68 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://msra/hrnetv2_w18', + backbone=dict( + type='HRNet', + norm_cfg=norm_cfg, + norm_eval=False, + extra=dict( + stage1=dict( + num_modules=1, + num_branches=1, + block='BOTTLENECK', + num_blocks=(4, ), + num_channels=(64, )), + stage2=dict( + num_modules=1, + num_branches=2, + block='BASIC', + num_blocks=(4, 4), + num_channels=(18, 36)), + stage3=dict( + num_modules=4, + num_branches=3, + block='BASIC', + num_blocks=(4, 4, 4), + num_channels=(18, 36, 72)), + stage4=dict( + num_modules=3, + num_branches=4, + block='BASIC', + num_blocks=(4, 4, 4, 4), + num_channels=(18, 36, 72, 144)))), + decode_head=[ + dict( + type='FCNHead', + in_channels=[18, 36, 72, 144], + channels=sum([18, 36, 72, 144]), + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + kernel_size=1, + num_convs=1, + concat_input=False, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=[18, 36, 72, 144], + in_index=(0, 1, 2, 3), + input_transform='resize_concat', + channels=512, + ocr_channels=256, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..615aa3ff703942b6c22b2d6e9642504dd3e41ebd --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/ocrnet_r50-d8.py @@ -0,0 +1,47 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=[ + dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + dict( + type='OCRHead', + in_channels=2048, + in_index=3, + channels=512, + ocr_channels=256, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py b/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..9d323dbf9466d41e0800aa57ef84045f3d874bdf --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/pointrend_r50.py @@ -0,0 +1,56 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='CascadeEncoderDecoder', + num_stages=2, + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + neck=dict( + type='FPN', + in_channels=[256, 512, 1024, 2048], + out_channels=256, + num_outs=4), + decode_head=[ + dict( + type='FPNHead', + in_channels=[256, 256, 256, 256], + in_index=[0, 1, 2, 3], + feature_strides=[4, 8, 16, 32], + channels=128, + dropout_ratio=-1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + dict( + type='PointHead', + in_channels=[256], + in_index=[0], + channels=256, + num_fcs=3, + coarse_pred_each_layer=True, + dropout_ratio=-1, + num_classes=19, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)) + ], + # model training and testing settings + train_cfg=dict( + num_points=2048, oversample_ratio=3, importance_sample_ratio=0.75), + test_cfg=dict( + mode='whole', + subdivision_steps=2, + subdivision_num_points=8196, + scale_factor=2)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..689513fa9d2a40f14bf0ae4ae61f38f0dcc1b3da --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/psanet_r50-d8.py @@ -0,0 +1,49 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSAHead', + in_channels=2048, + in_index=3, + channels=512, + mask_size=(97, 97), + psa_type='bi-direction', + compact=False, + shrink_factor=2, + normalization_factor=1.0, + psa_softmax=True, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py b/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py new file mode 100644 index 0000000000000000000000000000000000000000..f451e08ad2eb0732dcb806b1851eb978d4acf136 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_r50-d8.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 2, 4), + strides=(1, 2, 1, 1), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='PSPHead', + in_channels=2048, + in_index=3, + channels=512, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py b/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py new file mode 100644 index 0000000000000000000000000000000000000000..fcff9ec4f41fad158344ecd77313dc14564f3682 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/pspnet_unet_s5-d16.py @@ -0,0 +1,50 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UNet', + in_channels=3, + base_channels=64, + num_stages=5, + strides=(1, 1, 1, 1, 1), + enc_num_convs=(2, 2, 2, 2, 2), + dec_num_convs=(2, 2, 2, 2), + downsamples=(True, True, True, True), + enc_dilations=(1, 1, 1, 1, 1), + dec_dilations=(1, 1, 1, 1), + with_cp=False, + conv_cfg=None, + norm_cfg=norm_cfg, + act_cfg=dict(type='ReLU'), + upsample_cfg=dict(type='InterpConv'), + norm_eval=False), + decode_head=dict( + type='PSPHead', + in_channels=64, + in_index=4, + channels=16, + pool_scales=(1, 2, 3, 6), + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=128, + in_index=3, + channels=64, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=2, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='slide', crop_size=256, stride=170)) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py b/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py new file mode 100644 index 0000000000000000000000000000000000000000..10974962fdd7136031fd06de1700f497d355ceaa --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_r50.py @@ -0,0 +1,44 @@ +# model settings +norm_cfg = dict(type='SyncBN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained='open-mmlab://resnet50_v1c', + backbone=dict( + type='ResNetV1c', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + dilations=(1, 1, 1, 1), + strides=(1, 2, 2, 2), + norm_cfg=norm_cfg, + norm_eval=False, + style='pytorch', + contract_dilation=True), + decode_head=dict( + type='UPerHead', + in_channels=[256, 512, 1024, 2048], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=1024, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) diff --git a/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py b/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..41aa4db809dc6e2c508e98051f61807d07477903 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/models/upernet_uniformer.py @@ -0,0 +1,43 @@ +# model settings +norm_cfg = dict(type='BN', requires_grad=True) +model = dict( + type='EncoderDecoder', + pretrained=None, + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + mlp_ratio=4., + qkv_bias=True, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1), + decode_head=dict( + type='UPerHead', + in_channels=[64, 128, 320, 512], + in_index=[0, 1, 2, 3], + pool_scales=(1, 2, 3, 6), + channels=512, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)), + auxiliary_head=dict( + type='FCNHead', + in_channels=320, + in_index=2, + channels=256, + num_convs=1, + concat_input=False, + dropout_ratio=0.1, + num_classes=19, + norm_cfg=norm_cfg, + align_corners=False, + loss_decode=dict( + type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)), + # model training and testing settings + train_cfg=dict(), + test_cfg=dict(mode='whole')) \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py new file mode 100644 index 0000000000000000000000000000000000000000..52603890b10f25faf8eec9f9e5a4468fae09b811 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_160k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=160000) +checkpoint_config = dict(by_epoch=False, interval=16000) +evaluation = dict(interval=16000, metric='mIoU') diff --git a/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py new file mode 100644 index 0000000000000000000000000000000000000000..bf780a1b6f6521833c6a5859675147824efa599d --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_20k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=20000) +checkpoint_config = dict(by_epoch=False, interval=2000) +evaluation = dict(interval=2000, metric='mIoU') diff --git a/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py new file mode 100644 index 0000000000000000000000000000000000000000..cdbf841abcb26eed87bf76ab816aff4bae0630ee --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_40k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=40000) +checkpoint_config = dict(by_epoch=False, interval=4000) +evaluation = dict(interval=4000, metric='mIoU') diff --git a/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py new file mode 100644 index 0000000000000000000000000000000000000000..c190cee6bdc7922b688ea75dc8f152fa15c24617 --- /dev/null +++ b/custom_controlnet_aux/uniformer/configs/_base_/schedules/schedule_80k.py @@ -0,0 +1,9 @@ +# optimizer +optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005) +optimizer_config = dict() +# learning policy +lr_config = dict(policy='poly', power=0.9, min_lr=1e-4, by_epoch=False) +# runtime settings +runner = dict(type='IterBasedRunner', max_iters=80000) +checkpoint_config = dict(by_epoch=False, interval=8000) +evaluation = dict(interval=8000, metric='mIoU') diff --git a/custom_controlnet_aux/uniformer/inference.py b/custom_controlnet_aux/uniformer/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7efc93e16f51e70d80340f76d74c6f3db5a26443 --- /dev/null +++ b/custom_controlnet_aux/uniformer/inference.py @@ -0,0 +1,137 @@ + +import torch + +import custom_mmpkg.custom_mmcv as mmcv +from custom_mmpkg.custom_mmcv.parallel import collate, scatter +from custom_mmpkg.custom_mmcv.runner import load_checkpoint +from custom_mmpkg.custom_mmseg.datasets.pipelines import Compose +from custom_mmpkg.custom_mmseg.models import build_segmentor + +def init_segmentor(config, checkpoint=None, device='cuda:0'): + """Initialize a segmentor from config file. + + Args: + config (str or :obj:`mmcv.Config`): Config file path or the config + object. + checkpoint (str, optional): Checkpoint path. If left as None, the model + will not load any weights. + device (str, optional) CPU/CUDA device option. Default 'cuda:0'. + Use 'cpu' for loading model on CPU. + Returns: + nn.Module: The constructed segmentor. + """ + if isinstance(config, str): + config = mmcv.Config.fromfile(config) + elif not isinstance(config, mmcv.Config): + raise TypeError('config must be a filename or Config object, ' + 'but got {}'.format(type(config))) + config.model.pretrained = None + config.model.train_cfg = None + model = build_segmentor(config.model, test_cfg=config.get('test_cfg')) + if checkpoint is not None: + checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') + model.CLASSES = checkpoint['meta']['CLASSES'] + model.PALETTE = checkpoint['meta']['PALETTE'] + model.cfg = config # save the config in the model for convenience + model.to(device) + model.eval() + return model + + +class LoadImage: + """A simple pipeline to load image.""" + + def __call__(self, results): + """Call function to load images into results. + + Args: + results (dict): A result dict contains the file name + of the image to be read. + + Returns: + dict: ``results`` will be returned containing loaded image. + """ + + if isinstance(results['img'], str): + results['filename'] = results['img'] + results['ori_filename'] = results['img'] + else: + results['filename'] = None + results['ori_filename'] = None + img = mmcv.imread(results['img']) + results['img'] = img + results['img_shape'] = img.shape + results['ori_shape'] = img.shape + return results + + +def inference_segmentor(model, img): + """Inference image(s) with the segmentor. + + Args: + model (nn.Module): The loaded segmentor. + imgs (str/ndarray or list[str/ndarray]): Either image files or loaded + images. + + Returns: + (list[Tensor]): The segmentation result. + """ + cfg = model.cfg + device = next(model.parameters()).device # model device + # build the data pipeline + test_pipeline = [LoadImage()] + cfg.data.test.pipeline[1:] + test_pipeline = Compose(test_pipeline) + # prepare data + data = dict(img=img) + data = test_pipeline(data) + data = collate([data], samples_per_gpu=1) + if next(model.parameters()).is_cuda: + # scatter to specified GPU + data = scatter(data, [device])[0] + else: + data['img_metas'] = [i.data[0] for i in data['img_metas']] + + data['img'] = [x.to(device) for x in data['img']] + + # forward the model + with torch.no_grad(): + result = model(return_loss=False, rescale=True, **data) + return result + + +def show_result_pyplot(model, + img, + result, + palette=None, + fig_size=(15, 10), + opacity=0.5, + title='', + block=True): + """Visualize the segmentation results on the image. + + Args: + model (nn.Module): The loaded segmentor. + img (str or np.ndarray): Image filename or loaded image. + result (list): The segmentation result. + palette (list[list[int]]] | None): The palette of segmentation + map. If None is given, random palette will be generated. + Default: None + fig_size (tuple): Figure size of the pyplot figure. + opacity(float): Opacity of painted segmentation map. + Default 0.5. + Must be in (0, 1] range. + title (str): The title of pyplot figure. + Default is ''. + block (bool): Whether to block the pyplot figure. + Default is True. + """ + if hasattr(model, 'module'): + model = model.module + img = model.show_result( + img, result, palette=palette, show=False, opacity=opacity) + # plt.figure(figsize=fig_size) + # plt.imshow(mmcv.bgr2rgb(img)) + # plt.title(title) + # plt.tight_layout() + # plt.show(block=block) + return mmcv.bgr2rgb(img) \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py b/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b958738b9fd93bfcec239c550df1d9a44b8c536 --- /dev/null +++ b/custom_controlnet_aux/uniformer/mmcv_custom/__init__.py @@ -0,0 +1,5 @@ +# -*- coding: utf-8 -*- + +from .checkpoint import load_checkpoint + +__all__ = ['load_checkpoint'] \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py b/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8453fedcd47fafbedd40ca7ed485dce2e23434e0 --- /dev/null +++ b/custom_controlnet_aux/uniformer/mmcv_custom/checkpoint.py @@ -0,0 +1,500 @@ +# Copyright (c) Open-MMLab. All rights reserved. +import io +import os +import os.path as osp +import pkgutil +import time +import warnings +from collections import OrderedDict +from importlib import import_module +from tempfile import TemporaryDirectory + +import torch +import torchvision +from torch.optim import Optimizer +from torch.utils import model_zoo +from torch.nn import functional as F + +import custom_mmpkg.custom_mmcv as mmcv +from custom_mmpkg.custom_mmcv.fileio import FileClient +from custom_mmpkg.custom_mmcv.fileio import load as load_file +from custom_mmpkg.custom_mmcv.parallel import is_module_wrapper +from custom_mmpkg.custom_mmcv.utils import mkdir_or_exist +from custom_mmpkg.custom_mmcv.runner import get_dist_info + +ENV_MMCV_HOME = 'MMCV_HOME' +ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME' +DEFAULT_CACHE_DIR = '~/.cache' + + +def _get_mmcv_home(): + mmcv_home = os.path.expanduser( + os.getenv( + ENV_MMCV_HOME, + os.path.join( + os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv'))) + + mkdir_or_exist(mmcv_home) + return mmcv_home + + +def load_state_dict(module, state_dict, strict=False, logger=None): + """Load state_dict to a module. + + This method is modified from :meth:`torch.nn.Module.load_state_dict`. + Default value for ``strict`` is set to ``False`` and the message for + param mismatch will be shown even if strict is False. + + Args: + module (Module): Module that receives the state_dict. + state_dict (OrderedDict): Weights. + strict (bool): whether to strictly enforce that the keys + in :attr:`state_dict` match the keys returned by this module's + :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. + logger (:obj:`logging.Logger`, optional): Logger to log the error + message. If not specified, print function will be used. + """ + unexpected_keys = [] + all_missing_keys = [] + err_msg = [] + + metadata = getattr(state_dict, '_metadata', None) + state_dict = state_dict.copy() + if metadata is not None: + state_dict._metadata = metadata + + # use _load_from_state_dict to enable checkpoint version control + def load(module, prefix=''): + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + local_metadata = {} if metadata is None else metadata.get( + prefix[:-1], {}) + module._load_from_state_dict(state_dict, prefix, local_metadata, True, + all_missing_keys, unexpected_keys, + err_msg) + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + '.') + + load(module) + load = None # break load->load reference cycle + + # ignore "num_batches_tracked" of BN layers + missing_keys = [ + key for key in all_missing_keys if 'num_batches_tracked' not in key + ] + + if unexpected_keys: + err_msg.append('unexpected key in source ' + f'state_dict: {", ".join(unexpected_keys)}\n') + if missing_keys: + err_msg.append( + f'missing keys in source state_dict: {", ".join(missing_keys)}\n') + + rank, _ = get_dist_info() + if len(err_msg) > 0 and rank == 0: + err_msg.insert( + 0, 'The model and loaded state dict do not match exactly\n') + err_msg = '\n'.join(err_msg) + if strict: + raise RuntimeError(err_msg) + elif logger is not None: + logger.warning(err_msg) + else: + print(err_msg) + + +def load_url_dist(url, model_dir=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + checkpoint = model_zoo.load_url(url, model_dir=model_dir) + return checkpoint + + +def load_pavimodel_dist(model_path, map_location=None): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + try: + from pavi import modelcloud + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + if rank == 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load(downloaded_file, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + model = modelcloud.get(model_path) + with TemporaryDirectory() as tmp_dir: + downloaded_file = osp.join(tmp_dir, model.name) + model.download(downloaded_file) + checkpoint = torch.load( + downloaded_file, map_location=map_location) + return checkpoint + + +def load_fileclient_dist(filename, backend, map_location): + """In distributed setting, this function only download checkpoint at local + rank 0.""" + rank, world_size = get_dist_info() + rank = int(os.environ.get('LOCAL_RANK', rank)) + allowed_backends = ['ceph'] + if backend not in allowed_backends: + raise ValueError(f'Load from Backend {backend} is not supported.') + if rank == 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + if world_size > 1: + torch.distributed.barrier() + if rank > 0: + fileclient = FileClient(backend=backend) + buffer = io.BytesIO(fileclient.get(filename)) + checkpoint = torch.load(buffer, map_location=map_location) + return checkpoint + + +def get_torchvision_models(): + model_urls = dict() + for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__): + if ispkg: + continue + _zoo = import_module(f'torchvision.models.{name}') + if hasattr(_zoo, 'model_urls'): + _urls = getattr(_zoo, 'model_urls') + model_urls.update(_urls) + return model_urls + + +def get_external_models(): + mmcv_home = _get_mmcv_home() + default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json') + default_urls = load_file(default_json_path) + assert isinstance(default_urls, dict) + external_json_path = osp.join(mmcv_home, 'open_mmlab.json') + if osp.exists(external_json_path): + external_urls = load_file(external_json_path) + assert isinstance(external_urls, dict) + default_urls.update(external_urls) + + return default_urls + + +def get_mmcls_models(): + mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json') + mmcls_urls = load_file(mmcls_json_path) + + return mmcls_urls + + +def get_deprecated_model_names(): + deprecate_json_path = osp.join(mmcv.__path__[0], + 'model_zoo/deprecated.json') + deprecate_urls = load_file(deprecate_json_path) + assert isinstance(deprecate_urls, dict) + + return deprecate_urls + + +def _process_mmcls_checkpoint(checkpoint): + state_dict = checkpoint['state_dict'] + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + if k.startswith('backbone.'): + new_state_dict[k[9:]] = v + new_checkpoint = dict(state_dict=new_state_dict) + + return new_checkpoint + + +def _load_checkpoint(filename, map_location=None): + """Load checkpoint from somewhere (modelzoo, file, url). + + Args: + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str | None): Same as :func:`torch.load`. Default: None. + + Returns: + dict | OrderedDict: The loaded checkpoint. It can be either an + OrderedDict storing model weights or a dict containing other + information, which depends on the checkpoint. + """ + if filename.startswith('modelzoo://'): + warnings.warn('The URL scheme of "modelzoo://" is deprecated, please ' + 'use "torchvision://" instead') + model_urls = get_torchvision_models() + model_name = filename[11:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('torchvision://'): + model_urls = get_torchvision_models() + model_name = filename[14:] + checkpoint = load_url_dist(model_urls[model_name]) + elif filename.startswith('open-mmlab://'): + model_urls = get_external_models() + model_name = filename[13:] + deprecated_urls = get_deprecated_model_names() + if model_name in deprecated_urls: + warnings.warn(f'open-mmlab://{model_name} is deprecated in favor ' + f'of open-mmlab://{deprecated_urls[model_name]}') + model_name = deprecated_urls[model_name] + model_url = model_urls[model_name] + # check if is url + if model_url.startswith(('http://', 'https://')): + checkpoint = load_url_dist(model_url) + else: + filename = osp.join(_get_mmcv_home(), model_url) + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + elif filename.startswith('mmcls://'): + model_urls = get_mmcls_models() + model_name = filename[8:] + checkpoint = load_url_dist(model_urls[model_name]) + checkpoint = _process_mmcls_checkpoint(checkpoint) + elif filename.startswith(('http://', 'https://')): + checkpoint = load_url_dist(filename) + elif filename.startswith('pavi://'): + model_path = filename[7:] + checkpoint = load_pavimodel_dist(model_path, map_location=map_location) + elif filename.startswith('s3://'): + checkpoint = load_fileclient_dist( + filename, backend='ceph', map_location=map_location) + else: + if not osp.isfile(filename): + raise IOError(f'{filename} is not a checkpoint file') + checkpoint = torch.load(filename, map_location=map_location) + return checkpoint + + +def load_checkpoint(model, + filename, + map_location='cpu', + strict=False, + logger=None): + """Load checkpoint from a file or URI. + + Args: + model (Module): Module to load checkpoint. + filename (str): Accept local filepath, URL, ``torchvision://xxx``, + ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for + details. + map_location (str): Same as :func:`torch.load`. + strict (bool): Whether to allow different params for the model and + checkpoint. + logger (:mod:`logging.Logger` or None): The logger for error message. + + Returns: + dict or OrderedDict: The loaded checkpoint. + """ + checkpoint = _load_checkpoint(filename, map_location) + # OrderedDict is a subclass of dict + if not isinstance(checkpoint, dict): + raise RuntimeError( + f'No state_dict found in checkpoint file {filename}') + # get state_dict from checkpoint + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + elif 'model' in checkpoint: + state_dict = checkpoint['model'] + else: + state_dict = checkpoint + # strip prefix of state_dict + if list(state_dict.keys())[0].startswith('module.'): + state_dict = {k[7:]: v for k, v in state_dict.items()} + + # for MoBY, load model of online branch + if sorted(list(state_dict.keys()))[0].startswith('encoder'): + state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')} + + # reshape absolute position embedding + if state_dict.get('absolute_pos_embed') is not None: + absolute_pos_embed = state_dict['absolute_pos_embed'] + N1, L, C1 = absolute_pos_embed.size() + N2, C2, H, W = model.absolute_pos_embed.size() + if N1 != N2 or C1 != C2 or L != H*W: + logger.warning("Error in loading absolute_pos_embed, pass") + else: + state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2) + + # interpolate position bias table if needed + relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k] + for table_key in relative_position_bias_table_keys: + table_pretrained = state_dict[table_key] + table_current = model.state_dict()[table_key] + L1, nH1 = table_pretrained.size() + L2, nH2 = table_current.size() + if nH1 != nH2: + logger.warning(f"Error in loading {table_key}, pass") + else: + if L1 != L2: + S1 = int(L1 ** 0.5) + S2 = int(L2 ** 0.5) + table_pretrained_resized = F.interpolate( + table_pretrained.permute(1, 0).view(1, nH1, S1, S1), + size=(S2, S2), mode='bicubic') + state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0) + + # load state_dict + load_state_dict(model, state_dict, strict, logger) + return checkpoint + + +def weights_to_cpu(state_dict): + """Copy a model state_dict to cpu. + + Args: + state_dict (OrderedDict): Model weights on GPU. + + Returns: + OrderedDict: Model weights on GPU. + """ + state_dict_cpu = OrderedDict() + for key, val in state_dict.items(): + state_dict_cpu[key] = val.cpu() + return state_dict_cpu + + +def _save_to_state_dict(module, destination, prefix, keep_vars): + """Saves module state to `destination` dictionary. + + This method is modified from :meth:`torch.nn.Module._save_to_state_dict`. + + Args: + module (nn.Module): The module to generate state_dict. + destination (dict): A dict where state will be stored. + prefix (str): The prefix for parameters and buffers used in this + module. + """ + for name, param in module._parameters.items(): + if param is not None: + destination[prefix + name] = param if keep_vars else param.detach() + for name, buf in module._buffers.items(): + # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d + if buf is not None: + destination[prefix + name] = buf if keep_vars else buf.detach() + + +def get_state_dict(module, destination=None, prefix='', keep_vars=False): + """Returns a dictionary containing a whole state of the module. + + Both parameters and persistent buffers (e.g. running averages) are + included. Keys are corresponding parameter and buffer names. + + This method is modified from :meth:`torch.nn.Module.state_dict` to + recursively check parallel module in case that the model has a complicated + structure, e.g., nn.Module(nn.Module(DDP)). + + Args: + module (nn.Module): The module to generate state_dict. + destination (OrderedDict): Returned dict for the state of the + module. + prefix (str): Prefix of the key. + keep_vars (bool): Whether to keep the variable property of the + parameters. Default: False. + + Returns: + dict: A dictionary containing a whole state of the module. + """ + # recursively check parallel module in case that the model has a + # complicated structure, e.g., nn.Module(nn.Module(DDP)) + if is_module_wrapper(module): + module = module.module + + # below is the same as torch.nn.Module.state_dict() + if destination is None: + destination = OrderedDict() + destination._metadata = OrderedDict() + destination._metadata[prefix[:-1]] = local_metadata = dict( + version=module._version) + _save_to_state_dict(module, destination, prefix, keep_vars) + for name, child in module._modules.items(): + if child is not None: + get_state_dict( + child, destination, prefix + name + '.', keep_vars=keep_vars) + for hook in module._state_dict_hooks.values(): + hook_result = hook(module, destination, prefix, local_metadata) + if hook_result is not None: + destination = hook_result + return destination + + +def save_checkpoint(model, filename, optimizer=None, meta=None): + """Save checkpoint to file. + + The checkpoint will have 3 fields: ``meta``, ``state_dict`` and + ``optimizer``. By default ``meta`` will contain version and time info. + + Args: + model (Module): Module whose params are to be saved. + filename (str): Checkpoint filename. + optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. + meta (dict, optional): Metadata to be saved in checkpoint. + """ + if meta is None: + meta = {} + elif not isinstance(meta, dict): + raise TypeError(f'meta must be a dict or None, but got {type(meta)}') + meta.update(mmcv_version=mmcv.__version__, time=time.asctime()) + + if is_module_wrapper(model): + model = model.module + + if hasattr(model, 'CLASSES') and model.CLASSES is not None: + # save class name to the meta + meta.update(CLASSES=model.CLASSES) + + checkpoint = { + 'meta': meta, + 'state_dict': weights_to_cpu(get_state_dict(model)) + } + # save optimizer state dict in the checkpoint + if isinstance(optimizer, Optimizer): + checkpoint['optimizer'] = optimizer.state_dict() + elif isinstance(optimizer, dict): + checkpoint['optimizer'] = {} + for name, optim in optimizer.items(): + checkpoint['optimizer'][name] = optim.state_dict() + + if filename.startswith('pavi://'): + try: + from pavi import modelcloud + from pavi.exception import NodeNotFoundError + except ImportError: + raise ImportError( + 'Please install pavi to load checkpoint from modelcloud.') + model_path = filename[7:] + root = modelcloud.Folder() + model_dir, model_name = osp.split(model_path) + try: + model = modelcloud.get(model_dir) + except NodeNotFoundError: + model = root.create_training_model(model_dir) + with TemporaryDirectory() as tmp_dir: + checkpoint_file = osp.join(tmp_dir, model_name) + with open(checkpoint_file, 'wb') as f: + torch.save(checkpoint, f) + f.flush() + model.create_file(checkpoint_file, name=model_name) + else: + mmcv.mkdir_or_exist(osp.dirname(filename)) + # immediately flush buffer + with open(filename, 'wb') as f: + torch.save(checkpoint, f) + f.flush() \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/uniformer.py b/custom_controlnet_aux/uniformer/uniformer.py new file mode 100644 index 0000000000000000000000000000000000000000..30039e2eb45b01c336493ad7a86a5a4e33aa9c1b --- /dev/null +++ b/custom_controlnet_aux/uniformer/uniformer.py @@ -0,0 +1,421 @@ +# -------------------------------------------------------- +# UniFormer +# Copyright (c) 2022 SenseTime X-Lab +# Licensed under The MIT License [see LICENSE for details] +# Written by Kunchang Li +# -------------------------------------------------------- + + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from functools import partial +from collections import OrderedDict +from custom_timm.models.layers import DropPath, to_2tuple, trunc_normal_ +from custom_mmpkg.custom_mmseg.utils import get_root_logger +from custom_mmpkg.custom_mmseg.models.builder import BACKBONES + +from .mmcv_custom import load_checkpoint + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CMlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class CBlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = nn.BatchNorm2d(dim) + self.conv1 = nn.Conv2d(dim, dim, 1) + self.conv2 = nn.Conv2d(dim, dim, 1) + self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.BatchNorm2d(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = CMlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + x = x + self.drop_path(self.conv2(self.attn(self.conv1(self.norm1(x))))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class Attention(nn.Module): + def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights + self.scale = qk_scale or head_dim ** -0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SABlock(nn.Module): + def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + B, N, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = x + self.drop_path(self.attn(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.transpose(1, 2).reshape(B, N, H, W) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class SABlock_Windows(nn.Module): + def __init__(self, dim, num_heads, window_size=14, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., + drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.window_size=window_size + self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim) + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, + attn_drop=attn_drop, proj_drop=drop) + # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + def forward(self, x): + x = x + self.pos_embed(x) + x = x.permute(0, 2, 3, 1) + B, H, W, C = x.shape + shortcut = x + x = self.norm1(x) + + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + x_windows = window_partition(x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + x = x.permute(0, 3, 1, 2).reshape(B, C, H, W) + return x + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + """ + def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.norm = nn.LayerNorm(embed_dim) + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + B, _, H, W = x.shape + x = self.proj(x) + B, _, H, W = x.shape + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + return x + + +@BACKBONES.register_module() +class UniFormer(nn.Module): + """ Vision Transformer + A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - + https://arxiv.org/abs/2010.11929 + """ + def __init__(self, layers=[3, 4, 8, 3], img_size=224, in_chans=3, num_classes=80, embed_dim=[64, 128, 320, 512], + head_dim=64, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, + drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6), + pretrained_path=None, use_checkpoint=False, checkpoint_num=[0, 0, 0, 0], + windows=False, hybrid=False, window_size=14): + """ + Args: + layer (list): number of block in each layer + img_size (int, tuple): input image size + in_chans (int): number of input channels + num_classes (int): number of classes for classification head + embed_dim (int): embedding dimension + head_dim (int): dimension of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + qk_scale (float): override default qk scale of head_dim ** -0.5 if set + representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set + drop_rate (float): dropout rate + attn_drop_rate (float): attention dropout rate + drop_path_rate (float): stochastic depth rate + norm_layer (nn.Module): normalization layer + pretrained_path (str): path of pretrained model + use_checkpoint (bool): whether use checkpoint + checkpoint_num (list): index for using checkpoint in every stage + windows (bool): whether use window MHRA + hybrid (bool): whether use hybrid MHRA + window_size (int): size of window (>14) + """ + super().__init__() + self.num_classes = num_classes + self.use_checkpoint = use_checkpoint + self.checkpoint_num = checkpoint_num + self.windows = windows + print(f'Use Checkpoint: {self.use_checkpoint}') + print(f'Checkpoint Number: {self.checkpoint_num}') + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) + + self.patch_embed1 = PatchEmbed( + img_size=img_size, patch_size=4, in_chans=in_chans, embed_dim=embed_dim[0]) + self.patch_embed2 = PatchEmbed( + img_size=img_size // 4, patch_size=2, in_chans=embed_dim[0], embed_dim=embed_dim[1]) + self.patch_embed3 = PatchEmbed( + img_size=img_size // 8, patch_size=2, in_chans=embed_dim[1], embed_dim=embed_dim[2]) + self.patch_embed4 = PatchEmbed( + img_size=img_size // 16, patch_size=2, in_chans=embed_dim[2], embed_dim=embed_dim[3]) + + self.pos_drop = nn.Dropout(p=drop_rate) + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(layers))] # stochastic depth decay rule + num_heads = [dim // head_dim for dim in embed_dim] + self.blocks1 = nn.ModuleList([ + CBlock( + dim=embed_dim[0], num_heads=num_heads[0], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) + for i in range(layers[0])]) + self.norm1=norm_layer(embed_dim[0]) + self.blocks2 = nn.ModuleList([ + CBlock( + dim=embed_dim[1], num_heads=num_heads[1], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]], norm_layer=norm_layer) + for i in range(layers[1])]) + self.norm2 = norm_layer(embed_dim[1]) + if self.windows: + print('Use local window for all blocks in stage3') + self.blocks3 = nn.ModuleList([ + SABlock_Windows( + dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer) + for i in range(layers[2])]) + elif hybrid: + print('Use hybrid window for blocks in stage3') + block3 = [] + for i in range(layers[2]): + if (i + 1) % 4 == 0: + block3.append(SABlock( + dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)) + else: + block3.append(SABlock_Windows( + dim=embed_dim[2], num_heads=num_heads[2], window_size=window_size, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer)) + self.blocks3 = nn.ModuleList(block3) + else: + print('Use global window for all blocks in stage3') + self.blocks3 = nn.ModuleList([ + SABlock( + dim=embed_dim[2], num_heads=num_heads[2], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]], norm_layer=norm_layer) + for i in range(layers[2])]) + self.norm3 = norm_layer(embed_dim[2]) + self.blocks4 = nn.ModuleList([ + SABlock( + dim=embed_dim[3], num_heads=num_heads[3], mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, + drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i+layers[0]+layers[1]+layers[2]], norm_layer=norm_layer) + for i in range(layers[3])]) + self.norm4 = norm_layer(embed_dim[3]) + + # Representation layer + if representation_size: + self.num_features = representation_size + self.pre_logits = nn.Sequential(OrderedDict([ + ('fc', nn.Linear(embed_dim, representation_size)), + ('act', nn.Tanh()) + ])) + else: + self.pre_logits = nn.Identity() + + self.apply(self._init_weights) + self.init_weights(pretrained=pretrained_path) + + def init_weights(self, pretrained): + if isinstance(pretrained, str): + logger = get_root_logger() + load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger) + print(f'Load pretrained model from {pretrained}') + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'pos_embed', 'cls_token'} + + def get_classifier(self): + return self.head + + def reset_classifier(self, num_classes, global_pool=''): + self.num_classes = num_classes + self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() + + def forward_features(self, x): + out = [] + x = self.patch_embed1(x) + x = self.pos_drop(x) + for i, blk in enumerate(self.blocks1): + if self.use_checkpoint and i < self.checkpoint_num[0]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm1(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed2(x) + for i, blk in enumerate(self.blocks2): + if self.use_checkpoint and i < self.checkpoint_num[1]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm2(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed3(x) + for i, blk in enumerate(self.blocks3): + if self.use_checkpoint and i < self.checkpoint_num[2]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm3(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + x = self.patch_embed4(x) + for i, blk in enumerate(self.blocks4): + if self.use_checkpoint and i < self.checkpoint_num[3]: + x = checkpoint.checkpoint(blk, x) + else: + x = blk(x) + x_out = self.norm4(x.permute(0, 2, 3, 1)) + out.append(x_out.permute(0, 3, 1, 2).contiguous()) + return tuple(out) + + def forward(self, x): + x = self.forward_features(x) + return x \ No newline at end of file diff --git a/custom_controlnet_aux/uniformer/upernet_global_small.py b/custom_controlnet_aux/uniformer/upernet_global_small.py new file mode 100644 index 0000000000000000000000000000000000000000..b1084a1148461ad65e6f206463cf59148f822f0b --- /dev/null +++ b/custom_controlnet_aux/uniformer/upernet_global_small.py @@ -0,0 +1,44 @@ +_base_ = [ + 'configs/_base_/models/upernet_uniformer.py', + 'configs/_base_/datasets/ade20k.py', + 'configs/_base_/default_runtime.py', + 'configs/_base_/schedules/schedule_160k.py' +] + +custom_imports = dict( + imports=['custom_controlnet_aux.uniformer.uniformer'], + allow_failed_imports=False +) + +model = dict( + backbone=dict( + type='UniFormer', + embed_dim=[64, 128, 320, 512], + layers=[3, 4, 8, 3], + head_dim=64, + drop_path_rate=0.25, + windows=False, + hybrid=False + ), + decode_head=dict( + in_channels=[64, 128, 320, 512], + num_classes=150 + ), + auxiliary_head=dict( + in_channels=320, + num_classes=150 + )) + +# AdamW optimizer, no weight decay for position embedding & layer norm in backbone +optimizer = dict(_delete_=True, type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01, + paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.), + 'relative_position_bias_table': dict(decay_mult=0.), + 'norm': dict(decay_mult=0.)})) + +lr_config = dict(_delete_=True, policy='poly', + warmup='linear', + warmup_iters=1500, + warmup_ratio=1e-6, + power=1.0, min_lr=0.0, by_epoch=False) + +data=dict(samples_per_gpu=2) \ No newline at end of file diff --git a/custom_controlnet_aux/unimatch/__init__.py b/custom_controlnet_aux/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..add304dc30aaaa85ca722190b27171c4191f5d72 --- /dev/null +++ b/custom_controlnet_aux/unimatch/__init__.py @@ -0,0 +1,195 @@ +import os +import warnings + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import resize_image_with_pad,common_input_validate, custom_hf_download, UNIMATCH_MODEL_NAME +from .utils.flow_viz import save_vis_flow_tofile, flow_to_image +from .unimatch.unimatch import UniMatch +import torch.nn.functional as F +from argparse import Namespace + +def inference_flow(model, + image1, #np array of HWC + image2, + padding_factor=8, + inference_size=None, + attn_type='swin', + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + num_reg_refine=1, + pred_bidir_flow=False, + pred_bwd_flow=False, + fwd_bwd_consistency_check=False, + device="cpu", + **kwargs + ): + fixed_inference_size = inference_size + transpose_img = False + image1 = torch.from_numpy(image1).permute(2, 0, 1).float().unsqueeze(0).to(device) + image2 = torch.from_numpy(image2).permute(2, 0, 1).float().unsqueeze(0).to(device) + + # the model is trained with size: width > height + if image1.size(-2) > image1.size(-1): + image1 = torch.transpose(image1, -2, -1) + image2 = torch.transpose(image2, -2, -1) + transpose_img = True + + nearest_size = [int(np.ceil(image1.size(-2) / padding_factor)) * padding_factor, + int(np.ceil(image1.size(-1) / padding_factor)) * padding_factor] + # resize to nearest size or specified size + inference_size = nearest_size if fixed_inference_size is None else fixed_inference_size + assert isinstance(inference_size, list) or isinstance(inference_size, tuple) + ori_size = image1.shape[-2:] + + # resize before inference + if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]: + image1 = F.interpolate(image1, size=inference_size, mode='bilinear', + align_corners=True) + image2 = F.interpolate(image2, size=inference_size, mode='bilinear', + align_corners=True) + if pred_bwd_flow: + image1, image2 = image2, image1 + + results_dict = model(image1, image2, + attn_type=attn_type, + attn_splits_list=attn_splits_list, + corr_radius_list=corr_radius_list, + prop_radius_list=prop_radius_list, + num_reg_refine=num_reg_refine, + task='flow', + pred_bidir_flow=pred_bidir_flow, + ) + flow_pr = results_dict['flow_preds'][-1] # [B, 2, H, W] + + # resize back + if inference_size[0] != ori_size[0] or inference_size[1] != ori_size[1]: + flow_pr = F.interpolate(flow_pr, size=ori_size, mode='bilinear', + align_corners=True) + flow_pr[:, 0] = flow_pr[:, 0] * ori_size[-1] / inference_size[-1] + flow_pr[:, 1] = flow_pr[:, 1] * ori_size[-2] / inference_size[-2] + + if transpose_img: + flow_pr = torch.transpose(flow_pr, -2, -1) + + flow = flow_pr[0].permute(1, 2, 0).cpu().numpy() # [H, W, 2] + + vis_image = flow_to_image(flow) + + # also predict backward flow + if pred_bidir_flow: + assert flow_pr.size(0) == 2 # [2, H, W, 2] + flow_bwd = flow_pr[1].permute(1, 2, 0).cpu().numpy() # [H, W, 2] + vis_image = flow_to_image(flow_bwd) + flow = flow_bwd + return flow, vis_image + +MODEL_CONFIGS = { + "gmflow-scale1": Namespace( + num_scales=1, + upsample_factor=8, + + attn_type="swin", + feature_channels=128, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + + attn_splits_list=[2], + corr_radius_list=[-1], + prop_radius_list=[-1], + + reg_refine=False, + num_reg_refine=1 + ), + "gmflow-scale2": Namespace( + num_scales=2, + upsample_factor=4, + padding_factor=32, + + attn_type="swin", + feature_channels=128, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + + reg_refine=False, + num_reg_refine=1 + ), + "gmflow-scale2-regrefine6": Namespace( + num_scales=2, + upsample_factor=4, + padding_factor=32, + + attn_type="swin", + feature_channels=128, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + + attn_splits_list=[2, 8], + corr_radius_list=[-1, 4], + prop_radius_list=[-1, 1], + + reg_refine=True, + num_reg_refine=6 + ) +} + +class UnimatchDetector: + def __init__(self, unimatch, config_args): + self.unimatch = unimatch + self.config_args = config_args + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=UNIMATCH_MODEL_NAME, filename="gmflow-scale2-regrefine6-mixdata.pth"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + config_args = None + for key in list(MODEL_CONFIGS.keys())[::-1]: + if key in filename: + config_args = MODEL_CONFIGS[key] + break + assert config_args, f"Couldn't find hardcoded Unimatch config for {filename}" + + model = UniMatch(feature_channels=config_args.feature_channels, + num_scales=config_args.num_scales, + upsample_factor=config_args.upsample_factor, + num_head=config_args.num_head, + ffn_dim_expansion=config_args.ffn_dim_expansion, + num_transformer_layers=config_args.num_transformer_layers, + reg_refine=config_args.reg_refine, + task='flow') + + sd = torch.load(model_path, map_location="cpu") + model.load_state_dict(sd['model']) + return cls(model, config_args) + + def to(self, device): + self.unimatch.to(device) + self.device = device + return self + + def __call__(self, image1, image2, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", pred_bwd_flow=False, pred_bidir_flow=False, **kwargs): + assert image1.shape == image2.shape, f"[Unimatch] image1 and image2 must have the same size, got {image1.shape} and {image2.shape}" + + image1, output_type = common_input_validate(image1, output_type, **kwargs) + #image1, remove_pad = resize_image_with_pad(image1, detect_resolution, upscale_method) + image2, output_type = common_input_validate(image2, output_type, **kwargs) + #image2, remove_pad = resize_image_with_pad(image2, detect_resolution, upscale_method) + with torch.no_grad(): + flow, vis_image = inference_flow(self.unimatch, image1, image2, device=self.device, pred_bwd_flow=pred_bwd_flow, pred_bidir_flow=pred_bidir_flow, **vars(self.config_args)) + + if output_type == "pil": + vis_image = Image.fromarray(vis_image) + + return flow, vis_image diff --git a/custom_controlnet_aux/unimatch/unimatch/__init__.py b/custom_controlnet_aux/unimatch/unimatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/unimatch/unimatch/attention.py b/custom_controlnet_aux/unimatch/unimatch/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..92a3c878afe541753022ba85c43b5b2e86e4d254 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/attention.py @@ -0,0 +1,253 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import split_feature, merge_splits, split_feature_1d, merge_splits_1d + + +def single_head_full_attention(q, k, v): + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + scores = torch.matmul(q, k.permute(0, 2, 1)) / (q.size(2) ** .5) # [B, L, L] + attn = torch.softmax(scores, dim=2) # [B, L, L] + out = torch.matmul(attn, v) # [B, L, C] + + return out + + +def single_head_full_attention_1d(q, k, v, + h=None, + w=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + scores = torch.matmul(q, k.permute(0, 1, 3, 2)) / scale_factor # [B, H, W, W] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v).view(b, -1, c) # [B, H*W, C] + + return out + + +def single_head_split_window_attention(q, k, v, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # q, k, v: [B, L, C] + assert q.dim() == k.dim() == v.dim() == 3 + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * num_splits + + window_size_h = h // num_splits + window_size_w = w // num_splits + + q = q.view(b, h, w, c) # [B, H, W, C] + k = k.view(b, h, w, c) + v = v.view(b, h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_h = window_size_h // 2 + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + k = torch.roll(k, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + v = torch.roll(v, shifts=(-shift_size_h, -shift_size_w), dims=(1, 2)) + + q = split_feature(q, num_splits=num_splits, channel_last=True) # [B*K*K, H/K, W/K, C] + k = split_feature(k, num_splits=num_splits, channel_last=True) + v = split_feature(v, num_splits=num_splits, channel_last=True) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*K*K, H/K*W/K, H/K*W/K] + + if with_shift: + scores += attn_mask.repeat(b, 1, 1) + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*K*K, H/K*W/K, C] + + out = merge_splits(out.view(b_new, h // num_splits, w // num_splits, c), + num_splits=num_splits, channel_last=True) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=(shift_size_h, shift_size_w), dims=(1, 2)) + + out = out.view(b, -1, c) + + return out + + +def single_head_split_window_attention_1d(q, k, v, + relative_position_bias=None, + num_splits=1, + with_shift=False, + h=None, + w=None, + attn_mask=None, + ): + # q, k, v: [B, L, C] + + assert h is not None and w is not None + assert q.size(1) == h * w + + b, _, c = q.size() + + b_new = b * num_splits * h + + window_size_w = w // num_splits + + q = q.view(b * h, w, c) # [B*H, W, C] + k = k.view(b * h, w, c) + v = v.view(b * h, w, c) + + scale_factor = c ** 0.5 + + if with_shift: + assert attn_mask is not None # compute once + shift_size_w = window_size_w // 2 + + q = torch.roll(q, shifts=-shift_size_w, dims=1) + k = torch.roll(k, shifts=-shift_size_w, dims=1) + v = torch.roll(v, shifts=-shift_size_w, dims=1) + + q = split_feature_1d(q, num_splits=num_splits) # [B*H*K, W/K, C] + k = split_feature_1d(k, num_splits=num_splits) + v = split_feature_1d(v, num_splits=num_splits) + + scores = torch.matmul(q.view(b_new, -1, c), k.view(b_new, -1, c).permute(0, 2, 1) + ) / scale_factor # [B*H*K, W/K, W/K] + + if with_shift: + # attn_mask: [K, W/K, W/K] + scores += attn_mask.repeat(b * h, 1, 1) # [B*H*K, W/K, W/K] + + attn = torch.softmax(scores, dim=-1) + + out = torch.matmul(attn, v.view(b_new, -1, c)) # [B*H*K, W/K, C] + + out = merge_splits_1d(out, h, num_splits=num_splits) # [B, H, W, C] + + # shift back + if with_shift: + out = torch.roll(out, shifts=shift_size_w, dims=2) + + out = out.view(b, -1, c) + + return out + + +class SelfAttnPropagation(nn.Module): + """ + flow propagation with self-attention on feature + query: feature0, key: feature0, value: flow + """ + + def __init__(self, in_channels, + **kwargs, + ): + super(SelfAttnPropagation, self).__init__() + + self.q_proj = nn.Linear(in_channels, in_channels) + self.k_proj = nn.Linear(in_channels, in_channels) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, flow, + local_window_attn=False, + local_window_radius=1, + **kwargs, + ): + # q, k: feature [B, C, H, W], v: flow [B, 2, H, W] + if local_window_attn: + return self.forward_local_window_attn(feature0, flow, + local_window_radius=local_window_radius) + + b, c, h, w = feature0.size() + + query = feature0.view(b, c, h * w).permute(0, 2, 1) # [B, H*W, C] + + # a note: the ``correct'' implementation should be: + # ``query = self.q_proj(query), key = self.k_proj(query)'' + # this problem is observed while cleaning up the code + # however, this doesn't affect the performance since the projection is a linear operation, + # thus the two projection matrices for key can be merged + # so I just leave it as is in order to not re-train all models :) + query = self.q_proj(query) # [B, H*W, C] + key = self.k_proj(query) # [B, H*W, C] + + value = flow.view(b, flow.size(1), h * w).permute(0, 2, 1) # [B, H*W, 2] + + scores = torch.matmul(query, key.permute(0, 2, 1)) / (c ** 0.5) # [B, H*W, H*W] + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, value) # [B, H*W, 2] + out = out.view(b, h, w, value.size(-1)).permute(0, 3, 1, 2) # [B, 2, H, W] + + return out + + def forward_local_window_attn(self, feature0, flow, + local_window_radius=1, + ): + assert flow.size(1) == 2 or flow.size(1) == 1 # flow or disparity or depth + assert local_window_radius > 0 + + b, c, h, w = feature0.size() + + value_channel = flow.size(1) + + feature0_reshape = self.q_proj(feature0.view(b, c, -1).permute(0, 2, 1) + ).reshape(b * h * w, 1, c) # [B*H*W, 1, C] + + kernel_size = 2 * local_window_radius + 1 + + feature0_proj = self.k_proj(feature0.view(b, c, -1).permute(0, 2, 1)).permute(0, 2, 1).reshape(b, c, h, w) + + feature0_window = F.unfold(feature0_proj, kernel_size=kernel_size, + padding=local_window_radius) # [B, C*(2R+1)^2), H*W] + + feature0_window = feature0_window.view(b, c, kernel_size ** 2, h, w).permute( + 0, 3, 4, 1, 2).reshape(b * h * w, c, kernel_size ** 2) # [B*H*W, C, (2R+1)^2] + + flow_window = F.unfold(flow, kernel_size=kernel_size, + padding=local_window_radius) # [B, 2*(2R+1)^2), H*W] + + flow_window = flow_window.view(b, value_channel, kernel_size ** 2, h, w).permute( + 0, 3, 4, 2, 1).reshape(b * h * w, kernel_size ** 2, value_channel) # [B*H*W, (2R+1)^2, 2] + + scores = torch.matmul(feature0_reshape, feature0_window) / (c ** 0.5) # [B*H*W, 1, (2R+1)^2] + + prob = torch.softmax(scores, dim=-1) + + out = torch.matmul(prob, flow_window).view(b, h, w, value_channel + ).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + return out diff --git a/custom_controlnet_aux/unimatch/unimatch/backbone.py b/custom_controlnet_aux/unimatch/unimatch/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..a30942eca9cad56e75252c3026dca95bf1021df7 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/backbone.py @@ -0,0 +1,117 @@ +import torch.nn as nn + +from .trident_conv import MultiScaleTridentConv + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_layer=nn.InstanceNorm2d, stride=1, dilation=1, + ): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, stride=stride, bias=False) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, + dilation=dilation, padding=dilation, bias=False) + self.relu = nn.ReLU(inplace=True) + + self.norm1 = norm_layer(planes) + self.norm2 = norm_layer(planes) + if not stride == 1 or in_planes != planes: + self.norm3 = norm_layer(planes) + + if stride == 1 and in_planes == planes: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class CNNEncoder(nn.Module): + def __init__(self, output_dim=128, + norm_layer=nn.InstanceNorm2d, + num_output_scales=1, + **kwargs, + ): + super(CNNEncoder, self).__init__() + self.num_branch = num_output_scales + + feature_dims = [64, 96, 128] + + self.conv1 = nn.Conv2d(3, feature_dims[0], kernel_size=7, stride=2, padding=3, bias=False) # 1/2 + self.norm1 = norm_layer(feature_dims[0]) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = feature_dims[0] + self.layer1 = self._make_layer(feature_dims[0], stride=1, norm_layer=norm_layer) # 1/2 + self.layer2 = self._make_layer(feature_dims[1], stride=2, norm_layer=norm_layer) # 1/4 + + # highest resolution 1/4 or 1/8 + stride = 2 if num_output_scales == 1 else 1 + self.layer3 = self._make_layer(feature_dims[2], stride=stride, + norm_layer=norm_layer, + ) # 1/4 or 1/8 + + self.conv2 = nn.Conv2d(feature_dims[2], output_dim, 1, 1, 0) + + if self.num_branch > 1: + if self.num_branch == 4: + strides = (1, 2, 4, 8) + elif self.num_branch == 3: + strides = (1, 2, 4) + elif self.num_branch == 2: + strides = (1, 2) + else: + raise ValueError + + self.trident_conv = MultiScaleTridentConv(output_dim, output_dim, + kernel_size=3, + strides=strides, + paddings=1, + num_branch=self.num_branch, + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1, dilation=1, norm_layer=nn.InstanceNorm2d): + layer1 = ResidualBlock(self.in_planes, dim, norm_layer=norm_layer, stride=stride, dilation=dilation) + layer2 = ResidualBlock(dim, dim, norm_layer=norm_layer, stride=1, dilation=dilation) + + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) # 1/2 + x = self.layer2(x) # 1/4 + x = self.layer3(x) # 1/8 or 1/4 + + x = self.conv2(x) + + if self.num_branch > 1: + out = self.trident_conv([x] * self.num_branch) # high to low res + else: + out = [x] + + return out diff --git a/custom_controlnet_aux/unimatch/unimatch/geometry.py b/custom_controlnet_aux/unimatch/unimatch/geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..775a95783aeee66a44e6290525de94909af648df --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/geometry.py @@ -0,0 +1,195 @@ +import torch +import torch.nn.functional as F + + +def coords_grid(b, h, w, homogeneous=False, device=None): + y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W] + + stacks = [x, y] + + if homogeneous: + ones = torch.ones_like(x) # [H, W] + stacks.append(ones) + + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + + if device is not None: + grid = grid.to(device) + + return grid + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def bilinear_sample(img, sample_coords, mode='bilinear', padding_mode='zeros', return_mask=False): + # img: [B, C, H, W] + # sample_coords: [B, 2, H, W] in image scale + if sample_coords.size(1) != 2: # [B, H, W, 2] + sample_coords = sample_coords.permute(0, 3, 1, 2) + + b, _, h, w = sample_coords.shape + + # Normalize to [-1, 1] + x_grid = 2 * sample_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * sample_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, H, W, 2] + + img = F.grid_sample(img, grid, mode=mode, padding_mode=padding_mode, align_corners=True) + + if return_mask: + mask = (x_grid >= -1) & (y_grid >= -1) & (x_grid <= 1) & (y_grid <= 1) # [B, H, W] + + return img, mask + + return img + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sample(feature, grid, padding_mode=padding_mode, + return_mask=mask) + + +def forward_backward_consistency_check(fwd_flow, bwd_flow, + alpha=0.01, + beta=0.5 + ): + # fwd_flow, bwd_flow: [B, 2, H, W] + # alpha and beta values are following UnFlow (https://arxiv.org/abs/1711.07837) + assert fwd_flow.dim() == 4 and bwd_flow.dim() == 4 + assert fwd_flow.size(1) == 2 and bwd_flow.size(1) == 2 + flow_mag = torch.norm(fwd_flow, dim=1) + torch.norm(bwd_flow, dim=1) # [B, H, W] + + warped_bwd_flow = flow_warp(bwd_flow, fwd_flow) # [B, 2, H, W] + warped_fwd_flow = flow_warp(fwd_flow, bwd_flow) # [B, 2, H, W] + + diff_fwd = torch.norm(fwd_flow + warped_bwd_flow, dim=1) # [B, H, W] + diff_bwd = torch.norm(bwd_flow + warped_fwd_flow, dim=1) + + threshold = alpha * flow_mag + beta + + fwd_occ = (diff_fwd > threshold).float() # [B, H, W] + bwd_occ = (diff_bwd > threshold).float() + + return fwd_occ, bwd_occ + + +def back_project(depth, intrinsics): + # Back project 2D pixel coords to 3D points + # depth: [B, H, W] + # intrinsics: [B, 3, 3] + b, h, w = depth.shape + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + + intrinsics_inv = torch.inverse(intrinsics) # [B, 3, 3] + + points = intrinsics_inv.bmm(grid.view(b, 3, -1)).view(b, 3, h, w) * depth.unsqueeze(1) # [B, 3, H, W] + + return points + + +def camera_transform(points_ref, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None): + # Transform 3D points from reference camera to target camera + # points_ref: [B, 3, H, W] + # extrinsics_ref: [B, 4, 4] + # extrinsics_tgt: [B, 4, 4] + # extrinsics_rel: [B, 4, 4], relative pose transform + b, _, h, w = points_ref.shape + + if extrinsics_rel is None: + extrinsics_rel = torch.bmm(extrinsics_tgt, torch.inverse(extrinsics_ref)) # [B, 4, 4] + + points_tgt = torch.bmm(extrinsics_rel[:, :3, :3], + points_ref.view(b, 3, -1)) + extrinsics_rel[:, :3, -1:] # [B, 3, H*W] + + points_tgt = points_tgt.view(b, 3, h, w) # [B, 3, H, W] + + return points_tgt + + +def reproject(points_tgt, intrinsics, return_mask=False): + # reproject to target view + # points_tgt: [B, 3, H, W] + # intrinsics: [B, 3, 3] + + b, _, h, w = points_tgt.shape + + proj_points = torch.bmm(intrinsics, points_tgt.view(b, 3, -1)).view(b, 3, h, w) # [B, 3, H, W] + + X = proj_points[:, 0] + Y = proj_points[:, 1] + Z = proj_points[:, 2].clamp(min=1e-3) + + pixel_coords = torch.stack([X / Z, Y / Z], dim=1).view(b, 2, h, w) # [B, 2, H, W] in image scale + + if return_mask: + # valid mask in pixel space + mask = (pixel_coords[:, 0] >= 0) & (pixel_coords[:, 0] <= (w - 1)) & ( + pixel_coords[:, 1] >= 0) & (pixel_coords[:, 1] <= (h - 1)) # [B, H, W] + + return pixel_coords, mask + + return pixel_coords + + +def reproject_coords(depth_ref, intrinsics, extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + # Compute reprojection sample coords + points_ref = back_project(depth_ref, intrinsics) # [B, 3, H, W] + points_tgt = camera_transform(points_ref, extrinsics_ref, extrinsics_tgt, extrinsics_rel=extrinsics_rel) + + if return_mask: + reproj_coords, mask = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords, mask + + reproj_coords = reproject(points_tgt, intrinsics, + return_mask=return_mask) # [B, 2, H, W] in image scale + + return reproj_coords + + +def compute_flow_with_depth_pose(depth_ref, intrinsics, + extrinsics_ref=None, extrinsics_tgt=None, extrinsics_rel=None, + return_mask=False): + b, h, w = depth_ref.shape + coords_init = coords_grid(b, h, w, device=depth_ref.device) # [B, 2, H, W] + + if return_mask: + reproj_coords, mask = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + rigid_flow = reproj_coords - coords_init + + return rigid_flow, mask + + reproj_coords = reproject_coords(depth_ref, intrinsics, extrinsics_ref, extrinsics_tgt, + extrinsics_rel=extrinsics_rel, + return_mask=return_mask) # [B, 2, H, W] + + rigid_flow = reproj_coords - coords_init + + return rigid_flow diff --git a/custom_controlnet_aux/unimatch/unimatch/matching.py b/custom_controlnet_aux/unimatch/unimatch/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..595437f2307202ab36d7c2ee3dfa0ab44e4dc830 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/matching.py @@ -0,0 +1,279 @@ +import torch +import torch.nn.functional as F + +from .geometry import coords_grid, generate_window_grid, normalize_coords + + +def global_correlation_softmax(feature0, feature1, + pred_bidir_flow=False, + ): + # global correlation + b, c, h, w = feature0.shape + feature0 = feature0.view(b, c, -1).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.view(b, c, -1) # [B, C, H*W] + + correlation = torch.matmul(feature0, feature1).view(b, h, w, h, w) / (c ** 0.5) # [B, H, W, H, W] + + # flow from softmax + init_grid = coords_grid(b, h, w).to(correlation.device) # [B, 2, H, W] + grid = init_grid.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + correlation = correlation.view(b, h * w, h * w) # [B, H*W, H*W] + + if pred_bidir_flow: + correlation = torch.cat((correlation, correlation.permute(0, 2, 1)), dim=0) # [2*B, H*W, H*W] + init_grid = init_grid.repeat(2, 1, 1, 1) # [2*B, 2, H, W] + grid = grid.repeat(2, 1, 1) # [2*B, H*W, 2] + b = b * 2 + + prob = F.softmax(correlation, dim=-1) # [B, H*W, H*W] + + correspondence = torch.matmul(prob, grid).view(b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + # when predicting bidirectional flow, flow is the concatenation of forward flow and backward flow + flow = correspondence - init_grid + + return flow, prob + + +def local_correlation_softmax(feature0, feature1, local_radius, + padding_mode='zeros', + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1)^2, 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)^2] + + correspondence = torch.matmul(prob.unsqueeze(-2), sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2) # [B, 2, H, W] + + flow = correspondence - coords_init + match_prob = prob + + return flow, match_prob + + +def local_correlation_with_flow(feature0, feature1, + flow, + local_radius, + padding_mode='zeros', + dilation=1, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1) # [B, H*W, 2] + + local_h = 2 * local_radius + 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(-local_radius, local_radius, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [2R+1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1)^2, 2] + sample_coords = coords.unsqueeze(-2) + window_grid * dilation # [B, H*W, (2R+1)^2, 2] + + # flow can be zero when using features after transformer + if not isinstance(flow, float): + sample_coords = sample_coords + flow.view( + b, 2, -1).permute(0, 2, 1).unsqueeze(-2) # [B, H*W, (2R+1)^2, 2] + else: + assert flow == 0. + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode=padding_mode, align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)^2] + feature0_view = feature0.permute(0, 2, 3, 1).view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)^2] + + corr = corr.view(b, h, w, -1).permute(0, 3, 1, 2).contiguous() # [B, (2R+1)^2, H, W] + + return corr + + +def global_correlation_softmax_stereo(feature0, feature1, + ): + # global correlation on horizontal direction + b, c, h, w = feature0.shape + + x_grid = torch.linspace(0, w - 1, w, device=feature0.device) # [W] + + feature0 = feature0.permute(0, 2, 3, 1) # [B, H, W, C] + feature1 = feature1.permute(0, 2, 1, 3) # [B, H, C, W] + + correlation = torch.matmul(feature0, feature1) / (c ** 0.5) # [B, H, W, W] + + # mask subsequent positions to make disparity positive + mask = torch.triu(torch.ones((w, w)), diagonal=1).type_as(feature0) # [W, W] + valid_mask = (mask == 0).unsqueeze(0).unsqueeze(0).repeat(b, h, 1, 1) # [B, H, W, W] + + correlation[~valid_mask] = -1e9 + + prob = F.softmax(correlation, dim=-1) # [B, H, W, W] + + correspondence = (x_grid.view(1, 1, 1, w) * prob).sum(-1) # [B, H, W] + + # NOTE: unlike flow, disparity is typically positive + disparity = x_grid.view(1, 1, w).repeat(b, h, 1) - correspondence # [B, H, W] + + return disparity.unsqueeze(1), prob # feature resolution + + +def local_correlation_softmax_stereo(feature0, feature1, local_radius, + ): + b, c, h, w = feature0.size() + coords_init = coords_grid(b, h, w).to(feature0.device) # [B, 2, H, W] + coords = coords_init.view(b, 2, -1).permute(0, 2, 1).contiguous() # [B, H*W, 2] + + local_h = 1 + local_w = 2 * local_radius + 1 + + window_grid = generate_window_grid(0, 0, + -local_radius, local_radius, + local_h, local_w, device=feature0.device) # [1, 2R+1, 2] + window_grid = window_grid.reshape(-1, 2).repeat(b, 1, 1, 1) # [B, 1, (2R+1), 2] + sample_coords = coords.unsqueeze(-2) + window_grid # [B, H*W, (2R+1), 2] + + sample_coords_softmax = sample_coords + + # exclude coords that are out of image space + valid_x = (sample_coords[:, :, :, 0] >= 0) & (sample_coords[:, :, :, 0] < w) # [B, H*W, (2R+1)^2] + valid_y = (sample_coords[:, :, :, 1] >= 0) & (sample_coords[:, :, :, 1] < h) # [B, H*W, (2R+1)^2] + + valid = valid_x & valid_y # [B, H*W, (2R+1)^2], used to mask out invalid values when softmax + + # normalize coordinates to [-1, 1] + sample_coords_norm = normalize_coords(sample_coords, h, w) # [-1, 1] + window_feature = F.grid_sample(feature1, sample_coords_norm, + padding_mode='zeros', align_corners=True + ).permute(0, 2, 1, 3) # [B, H*W, C, (2R+1)] + feature0_view = feature0.permute(0, 2, 3, 1).contiguous().view(b, h * w, 1, c) # [B, H*W, 1, C] + + corr = torch.matmul(feature0_view, window_feature).view(b, h * w, -1) / (c ** 0.5) # [B, H*W, (2R+1)] + + # mask invalid locations + corr[~valid] = -1e9 + + prob = F.softmax(corr, -1) # [B, H*W, (2R+1)] + + correspondence = torch.matmul(prob.unsqueeze(-2), + sample_coords_softmax).squeeze(-2).view( + b, h, w, 2).permute(0, 3, 1, 2).contiguous() # [B, 2, H, W] + + flow = correspondence - coords_init # flow at feature resolution + match_prob = prob + + flow_x = -flow[:, :1] # [B, 1, H, W] + + return flow_x, match_prob + + +def correlation_softmax_depth(feature0, feature1, + intrinsics, + pose, + depth_candidates, + depth_from_argmax=False, + pred_bidir_depth=False, + ): + b, c, h, w = feature0.size() + assert depth_candidates.dim() == 4 # [B, D, H, W] + scale_factor = c ** 0.5 + + if pred_bidir_depth: + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + intrinsics = intrinsics.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + depth_candidates = depth_candidates.repeat(2, 1, 1, 1) + + # depth candidates are actually inverse depth + warped_feature1 = warp_with_pose_depth_candidates(feature1, intrinsics, pose, + 1. / depth_candidates, + ) # [B, C, D, H, W] + + correlation = (feature0.unsqueeze(2) * warped_feature1).sum(1) / scale_factor # [B, D, H, W] + + match_prob = F.softmax(correlation, dim=1) # [B, D, H, W] + + # for cross-task transfer (flow -> depth), extract depth with argmax at test time + if depth_from_argmax: + index = torch.argmax(match_prob, dim=1, keepdim=True) + depth = torch.gather(depth_candidates, dim=1, index=index) + else: + depth = (match_prob * depth_candidates).sum(dim=1, keepdim=True) # [B, 1, H, W] + + return depth, match_prob + + +def warp_with_pose_depth_candidates(feature1, intrinsics, pose, depth, + clamp_min_depth=1e-3, + ): + """ + feature1: [B, C, H, W] + intrinsics: [B, 3, 3] + pose: [B, 4, 4] + depth: [B, D, H, W] + """ + + assert intrinsics.size(1) == intrinsics.size(2) == 3 + assert pose.size(1) == pose.size(2) == 4 + assert depth.dim() == 4 + + b, d, h, w = depth.size() + c = feature1.size(1) + + with torch.no_grad(): + # pixel coordinates + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + # back project to 3D and transform viewpoint + points = torch.inverse(intrinsics).bmm(grid.view(b, 3, -1)) # [B, 3, H*W] + points = torch.bmm(pose[:, :3, :3], points).unsqueeze(2).repeat( + 1, 1, d, 1) * depth.view(b, 1, d, h * w) # [B, 3, D, H*W] + points = points + pose[:, :3, -1:].unsqueeze(-1) # [B, 3, D, H*W] + # reproject to 2D image plane + points = torch.bmm(intrinsics, points.view(b, 3, -1)).view(b, 3, d, h * w) # [B, 3, D, H*W] + pixel_coords = points[:, :2] / points[:, -1:].clamp(min=clamp_min_depth) # [B, 2, D, H*W] + + # normalize to [-1, 1] + x_grid = 2 * pixel_coords[:, 0] / (w - 1) - 1 + y_grid = 2 * pixel_coords[:, 1] / (h - 1) - 1 + + grid = torch.stack([x_grid, y_grid], dim=-1) # [B, D, H*W, 2] + + # sample features + warped_feature = F.grid_sample(feature1, grid.view(b, d * h, w, 2), mode='bilinear', + padding_mode='zeros', + align_corners=True).view(b, c, d, h, w) # [B, C, D, H, W] + + return warped_feature diff --git a/custom_controlnet_aux/unimatch/unimatch/position.py b/custom_controlnet_aux/unimatch/unimatch/position.py new file mode 100644 index 0000000000000000000000000000000000000000..14a6da436c818b7c2784e92dba66f7947d34b7ce --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/position.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +# https://github.com/facebookresearch/detr/blob/main/models/position_encoding.py + +import torch +import torch.nn as nn +import math + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=True, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x): + # x = tensor_list.tensors # [B, C, H, W] + # mask = tensor_list.mask # [B, H, W], input with padding, valid as 0 + b, c, h, w = x.size() + mask = torch.ones((b, h, w), device=x.device) # [B, H, W] + y_embed = mask.cumsum(1, dtype=torch.float32) + x_embed = mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos diff --git a/custom_controlnet_aux/unimatch/unimatch/reg_refine.py b/custom_controlnet_aux/unimatch/unimatch/reg_refine.py new file mode 100644 index 0000000000000000000000000000000000000000..47f83da1c5dcd476069e841d045db04998be3604 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/reg_refine.py @@ -0,0 +1,119 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256, + out_dim=2, + ): + super(FlowHead, self).__init__() + + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, out_dim, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv2(self.relu(self.conv1(x))) + + return out + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192 + 128, + kernel_size=5, + ): + padding = (kernel_size - 1) // 2 + + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convr1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + self.convq1 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (1, kernel_size), padding=(0, padding)) + + self.convz2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convr2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + self.convq2 = nn.Conv2d(hidden_dim + input_dim, hidden_dim, (kernel_size, 1), padding=(padding, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, corr_channels=324, + flow_channels=2, + ): + super(BasicMotionEncoder, self).__init__() + + self.convc1 = nn.Conv2d(corr_channels, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(flow_channels, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - flow_channels, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, corr_channels=324, + hidden_dim=128, + context_dim=128, + downsample_factor=8, + flow_dim=2, + bilinear_up=False, + ): + super(BasicUpdateBlock, self).__init__() + + self.encoder = BasicMotionEncoder(corr_channels=corr_channels, + flow_channels=flow_dim, + ) + + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=context_dim + hidden_dim) + + self.flow_head = FlowHead(hidden_dim, hidden_dim=256, + out_dim=flow_dim, + ) + + if bilinear_up: + self.mask = None + else: + self.mask = nn.Sequential( + nn.Conv2d(hidden_dim, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, downsample_factor ** 2 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + if self.mask is not None: + mask = self.mask(net) + else: + mask = None + + return net, mask, delta_flow diff --git a/custom_controlnet_aux/unimatch/unimatch/transformer.py b/custom_controlnet_aux/unimatch/unimatch/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..4878e23a64f6609b1bf10740b0a794d8da836c31 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/transformer.py @@ -0,0 +1,294 @@ +import torch +import torch.nn as nn + +from .attention import (single_head_full_attention, single_head_split_window_attention, + single_head_full_attention_1d, single_head_split_window_attention_1d) +from .utils import generate_shift_window_attn_mask, generate_shift_window_attn_mask_1d + + +class TransformerLayer(nn.Module): + def __init__(self, + d_model=128, + nhead=1, + no_ffn=False, + ffn_dim_expansion=4, + ): + super(TransformerLayer, self).__init__() + + self.dim = d_model + self.nhead = nhead + self.no_ffn = no_ffn + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + + self.merge = nn.Linear(d_model, d_model, bias=False) + + self.norm1 = nn.LayerNorm(d_model) + + # no ffn after self-attn, with ffn after cross-attn + if not self.no_ffn: + in_channels = d_model * 2 + self.mlp = nn.Sequential( + nn.Linear(in_channels, in_channels * ffn_dim_expansion, bias=False), + nn.GELU(), + nn.Linear(in_channels * ffn_dim_expansion, d_model, bias=False), + ) + + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + query, key, value = source, target, target + + # for stereo: 2d attn in self-attn, 1d attn in cross-attn + is_self_attn = (query - key).abs().max() < 1e-6 + + # single-head attention + query = self.q_proj(query) # [B, L, C] + key = self.k_proj(key) # [B, L, C] + value = self.v_proj(value) # [B, L, C] + + if attn_type == 'swin' and attn_num_splits > 1: # self, cross-attn: both swin 2d + if self.nhead > 1: + # we observe that multihead attention slows down the speed and increases the memory consumption + # without bringing obvious performance gains and thus the implementation is removed + raise NotImplementedError + else: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + + elif attn_type == 'self_swin2d_cross_1d': # self-attn: swin 2d, cross-attn: full 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + + else: + # cross attn 1d + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + elif attn_type == 'self_swin2d_cross_swin1d': # self-attn: swin 2d, cross-attn: swin 1d + if self.nhead > 1: + raise NotImplementedError + else: + if is_self_attn: + if attn_num_splits > 1: + # self attn shift window + message = single_head_split_window_attention(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask, + ) + else: + # full 2d attn + message = single_head_full_attention(query, key, value) # [N, L, C] + else: + if attn_num_splits > 1: + assert shifted_window_attn_mask_1d is not None + # cross attn 1d shift + message = single_head_split_window_attention_1d(query, key, value, + num_splits=attn_num_splits, + with_shift=with_shift, + h=height, + w=width, + attn_mask=shifted_window_attn_mask_1d, + ) + else: + message = single_head_full_attention_1d(query, key, value, + h=height, + w=width, + ) + + else: + message = single_head_full_attention(query, key, value) # [B, L, C] + + message = self.merge(message) # [B, L, C] + message = self.norm1(message) + + if not self.no_ffn: + message = self.mlp(torch.cat([source, message], dim=-1)) + message = self.norm2(message) + + return source + message + + +class TransformerBlock(nn.Module): + """self attention + cross attention + FFN""" + + def __init__(self, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(TransformerBlock, self).__init__() + + self.self_attn = TransformerLayer(d_model=d_model, + nhead=nhead, + no_ffn=True, + ffn_dim_expansion=ffn_dim_expansion, + ) + + self.cross_attn_ffn = TransformerLayer(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + + def forward(self, source, target, + height=None, + width=None, + shifted_window_attn_mask=None, + shifted_window_attn_mask_1d=None, + attn_type='swin', + with_shift=False, + attn_num_splits=None, + ): + # source, target: [B, L, C] + + # self attention + source = self.self_attn(source, source, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + # cross attention and ffn + source = self.cross_attn_ffn(source, target, + height=height, + width=width, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + attn_type=attn_type, + with_shift=with_shift, + attn_num_splits=attn_num_splits, + ) + + return source + + +class FeatureTransformer(nn.Module): + def __init__(self, + num_layers=6, + d_model=128, + nhead=1, + ffn_dim_expansion=4, + ): + super(FeatureTransformer, self).__init__() + + self.d_model = d_model + self.nhead = nhead + + self.layers = nn.ModuleList([ + TransformerBlock(d_model=d_model, + nhead=nhead, + ffn_dim_expansion=ffn_dim_expansion, + ) + for i in range(num_layers)]) + + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feature0, feature1, + attn_type='swin', + attn_num_splits=None, + **kwargs, + ): + + b, c, h, w = feature0.shape + assert self.d_model == c + + feature0 = feature0.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + feature1 = feature1.flatten(-2).permute(0, 2, 1) # [B, H*W, C] + + # 2d attention + if 'swin' in attn_type and attn_num_splits > 1: + # global and refine use different number of splits + window_size_h = h // attn_num_splits + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask = generate_shift_window_attn_mask( + input_resolution=(h, w), + window_size_h=window_size_h, + window_size_w=window_size_w, + shift_size_h=window_size_h // 2, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K*K, H/K*W/K, H/K*W/K] + else: + shifted_window_attn_mask = None + + # 1d attention + if 'swin1d' in attn_type and attn_num_splits > 1: + window_size_w = w // attn_num_splits + + # compute attn mask once + shifted_window_attn_mask_1d = generate_shift_window_attn_mask_1d( + input_w=w, + window_size_w=window_size_w, + shift_size_w=window_size_w // 2, + device=feature0.device, + ) # [K, W/K, W/K] + else: + shifted_window_attn_mask_1d = None + + # concat feature0 and feature1 in batch dimension to compute in parallel + concat0 = torch.cat((feature0, feature1), dim=0) # [2B, H*W, C] + concat1 = torch.cat((feature1, feature0), dim=0) # [2B, H*W, C] + + for i, layer in enumerate(self.layers): + concat0 = layer(concat0, concat1, + height=h, + width=w, + attn_type=attn_type, + with_shift='swin' in attn_type and attn_num_splits > 1 and i % 2 == 1, + attn_num_splits=attn_num_splits, + shifted_window_attn_mask=shifted_window_attn_mask, + shifted_window_attn_mask_1d=shifted_window_attn_mask_1d, + ) + + # update feature1 + concat1 = torch.cat(concat0.chunk(chunks=2, dim=0)[::-1], dim=0) + + feature0, feature1 = concat0.chunk(chunks=2, dim=0) # [B, H*W, C] + + # reshape back + feature0 = feature0.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + feature1 = feature1.view(b, h, w, c).permute(0, 3, 1, 2).contiguous() # [B, C, H, W] + + return feature0, feature1 diff --git a/custom_controlnet_aux/unimatch/unimatch/trident_conv.py b/custom_controlnet_aux/unimatch/unimatch/trident_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..29a2a73e964a88b68bc095772d9c3cc443e3e0fe --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/trident_conv.py @@ -0,0 +1,90 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# https://github.com/facebookresearch/detectron2/blob/main/projects/TridentNet/tridentnet/trident_conv.py + +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.modules.utils import _pair + + +class MultiScaleTridentConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + strides=1, + paddings=0, + dilations=1, + dilation=1, + groups=1, + num_branch=1, + test_branch_idx=-1, + bias=False, + norm=None, + activation=None, + ): + super(MultiScaleTridentConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.num_branch = num_branch + self.stride = _pair(stride) + self.groups = groups + self.with_bias = bias + self.dilation = dilation + if isinstance(paddings, int): + paddings = [paddings] * self.num_branch + if isinstance(dilations, int): + dilations = [dilations] * self.num_branch + if isinstance(strides, int): + strides = [strides] * self.num_branch + self.paddings = [_pair(padding) for padding in paddings] + self.dilations = [_pair(dilation) for dilation in dilations] + self.strides = [_pair(stride) for stride in strides] + self.test_branch_idx = test_branch_idx + self.norm = norm + self.activation = activation + + assert len({self.num_branch, len(self.paddings), len(self.strides)}) == 1 + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, inputs): + num_branch = self.num_branch if self.training or self.test_branch_idx == -1 else 1 + assert len(inputs) == num_branch + + if self.training or self.test_branch_idx == -1: + outputs = [ + F.conv2d(input, self.weight, self.bias, stride, padding, self.dilation, self.groups) + for input, stride, padding in zip(inputs, self.strides, self.paddings) + ] + else: + outputs = [ + F.conv2d( + inputs[0], + self.weight, + self.bias, + self.strides[self.test_branch_idx] if self.test_branch_idx == -1 else self.strides[-1], + self.paddings[self.test_branch_idx] if self.test_branch_idx == -1 else self.paddings[-1], + self.dilation, + self.groups, + ) + ] + + if self.norm is not None: + outputs = [self.norm(x) for x in outputs] + if self.activation is not None: + outputs = [self.activation(x) for x in outputs] + return outputs diff --git a/custom_controlnet_aux/unimatch/unimatch/unimatch.py b/custom_controlnet_aux/unimatch/unimatch/unimatch.py new file mode 100644 index 0000000000000000000000000000000000000000..c442c67433ac2c42ebe6b4dca6ab7ff4d765e8cb --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/unimatch.py @@ -0,0 +1,367 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .backbone import CNNEncoder +from .transformer import FeatureTransformer +from .matching import (global_correlation_softmax, local_correlation_softmax, local_correlation_with_flow, + global_correlation_softmax_stereo, local_correlation_softmax_stereo, + correlation_softmax_depth) +from .attention import SelfAttnPropagation +from .geometry import flow_warp, compute_flow_with_depth_pose +from .reg_refine import BasicUpdateBlock +from .utils import normalize_img, feature_add_position, upsample_flow_with_mask + + +class UniMatch(nn.Module): + def __init__(self, + num_scales=1, + feature_channels=128, + upsample_factor=8, + num_head=1, + ffn_dim_expansion=4, + num_transformer_layers=6, + reg_refine=False, # optional local regression refinement + task='flow', + ): + super(UniMatch, self).__init__() + + self.feature_channels = feature_channels + self.num_scales = num_scales + self.upsample_factor = upsample_factor + self.reg_refine = reg_refine + + # CNN + self.backbone = CNNEncoder(output_dim=feature_channels, num_output_scales=num_scales) + + # Transformer + self.transformer = FeatureTransformer(num_layers=num_transformer_layers, + d_model=feature_channels, + nhead=num_head, + ffn_dim_expansion=ffn_dim_expansion, + ) + + # propagation with self-attn + self.feature_flow_attn = SelfAttnPropagation(in_channels=feature_channels) + + if not self.reg_refine or task == 'depth': + # convex upsampling simiar to RAFT + # concat feature0 and low res flow as input + self.upsampler = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), + nn.ReLU(inplace=True), + nn.Conv2d(256, upsample_factor ** 2 * 9, 1, 1, 0)) + # thus far, all the learnable parameters are task-agnostic + + if reg_refine: + # optional task-specific local regression refinement + self.refine_proj = nn.Conv2d(128, 256, 1) + self.refine = BasicUpdateBlock(corr_channels=(2 * 4 + 1) ** 2, + downsample_factor=upsample_factor, + flow_dim=2 if task == 'flow' else 1, + bilinear_up=task == 'depth', + ) + + def extract_feature(self, img0, img1): + concat = torch.cat((img0, img1), dim=0) # [2B, C, H, W] + features = self.backbone(concat) # list of [2B, C, H, W], resolution from high to low + + # reverse: resolution from low to high + features = features[::-1] + + feature0, feature1 = [], [] + + for i in range(len(features)): + feature = features[i] + chunks = torch.chunk(feature, 2, 0) # tuple + feature0.append(chunks[0]) + feature1.append(chunks[1]) + + return feature0, feature1 + + def upsample_flow(self, flow, feature, bilinear=False, upsample_factor=8, + is_depth=False): + if bilinear: + multiplier = 1 if is_depth else upsample_factor + up_flow = F.interpolate(flow, scale_factor=upsample_factor, + mode='bilinear', align_corners=True) * multiplier + else: + concat = torch.cat((flow, feature), dim=1) + mask = self.upsampler(concat) + up_flow = upsample_flow_with_mask(flow, mask, upsample_factor=self.upsample_factor, + is_depth=is_depth) + + return up_flow + + def forward(self, img0, img1, + attn_type=None, + attn_splits_list=None, + corr_radius_list=None, + prop_radius_list=None, + num_reg_refine=1, + pred_bidir_flow=False, + task='flow', + intrinsics=None, + pose=None, # relative pose transform + min_depth=1. / 0.5, # inverse depth range + max_depth=1. / 10, + num_depth_candidates=64, + depth_from_argmax=False, + pred_bidir_depth=False, + **kwargs, + ): + + if pred_bidir_flow: + assert task == 'flow' + + if task == 'depth': + assert self.num_scales == 1 # multi-scale depth model is not supported yet + + results_dict = {} + flow_preds = [] + + if task == 'flow': + # stereo and depth tasks have normalized img in dataloader + img0, img1 = normalize_img(img0, img1) # [B, 3, H, W] + + # list of features, resolution low to high + feature0_list, feature1_list = self.extract_feature(img0, img1) # list of features + + flow = None + + if task != 'depth': + assert len(attn_splits_list) == len(corr_radius_list) == len(prop_radius_list) == self.num_scales + else: + assert len(attn_splits_list) == len(prop_radius_list) == self.num_scales == 1 + + for scale_idx in range(self.num_scales): + feature0, feature1 = feature0_list[scale_idx], feature1_list[scale_idx] + + if pred_bidir_flow and scale_idx > 0: + # predicting bidirectional flow with refinement + feature0, feature1 = torch.cat((feature0, feature1), dim=0), torch.cat((feature1, feature0), dim=0) + + feature0_ori, feature1_ori = feature0, feature1 + + upsample_factor = self.upsample_factor * (2 ** (self.num_scales - 1 - scale_idx)) + + if task == 'depth': + # scale intrinsics + intrinsics_curr = intrinsics.clone() + intrinsics_curr[:, :2] = intrinsics_curr[:, :2] / upsample_factor + + if scale_idx > 0: + assert task != 'depth' # not supported for multi-scale depth model + flow = F.interpolate(flow, scale_factor=2, mode='bilinear', align_corners=True) * 2 + + if flow is not None: + assert task != 'depth' + flow = flow.detach() + + if task == 'stereo': + # construct flow vector for disparity + # flow here is actually disparity + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + feature1 = flow_warp(feature1, displace) # [B, C, H, W] + elif task == 'flow': + feature1 = flow_warp(feature1, flow) # [B, C, H, W] + else: + raise NotImplementedError + + attn_splits = attn_splits_list[scale_idx] + if task != 'depth': + corr_radius = corr_radius_list[scale_idx] + prop_radius = prop_radius_list[scale_idx] + + # add position to features + feature0, feature1 = feature_add_position(feature0, feature1, attn_splits, self.feature_channels) + + # Transformer + feature0, feature1 = self.transformer(feature0, feature1, + attn_type=attn_type, + attn_num_splits=attn_splits, + ) + + # correlation and softmax + if task == 'depth': + # first generate depth candidates + b, _, h, w = feature0.size() + depth_candidates = torch.linspace(min_depth, max_depth, num_depth_candidates).type_as(feature0) + depth_candidates = depth_candidates.view(1, num_depth_candidates, 1, 1).repeat(b, 1, h, + w) # [B, D, H, W] + + flow_pred = correlation_softmax_depth(feature0, feature1, + intrinsics_curr, + pose, + depth_candidates=depth_candidates, + depth_from_argmax=depth_from_argmax, + pred_bidir_depth=pred_bidir_depth, + )[0] + + else: + if corr_radius == -1: # global matching + if task == 'flow': + flow_pred = global_correlation_softmax(feature0, feature1, pred_bidir_flow)[0] + elif task == 'stereo': + flow_pred = global_correlation_softmax_stereo(feature0, feature1)[0] + else: + raise NotImplementedError + else: # local matching + if task == 'flow': + flow_pred = local_correlation_softmax(feature0, feature1, corr_radius)[0] + elif task == 'stereo': + flow_pred = local_correlation_softmax_stereo(feature0, feature1, corr_radius)[0] + else: + raise NotImplementedError + + # flow or residual flow + flow = flow + flow_pred if flow is not None else flow_pred + + if task == 'stereo': + flow = flow.clamp(min=0) # positive disparity + + # upsample to the original resolution for supervison at training time only + if self.training: + flow_bilinear = self.upsample_flow(flow, None, bilinear=True, upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_bilinear) + + # flow propagation with self-attn + if (pred_bidir_flow or pred_bidir_depth) and scale_idx == 0: + feature0 = torch.cat((feature0, feature1), dim=0) # [2*B, C, H, W] for propagation + + flow = self.feature_flow_attn(feature0, flow.detach(), + local_window_attn=prop_radius > 0, + local_window_radius=prop_radius, + ) + + # bilinear exclude the last one + if self.training and scale_idx < self.num_scales - 1: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + if scale_idx == self.num_scales - 1: + if not self.reg_refine: + # upsample to the original image resolution + + if task == 'stereo': + flow_pad = torch.cat((-flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + flow_up_pad = self.upsample_flow(flow_pad, feature0) + flow_up = -flow_up_pad[:, :1] # [B, 1, H, W] + elif task == 'depth': + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + else: + flow_up = self.upsample_flow(flow, feature0) + + flow_preds.append(flow_up) + else: + # task-specific local regression refinement + # supervise current flow + if self.training: + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=task == 'depth') + flow_preds.append(flow_up) + + assert num_reg_refine > 0 + for refine_iter_idx in range(num_reg_refine): + flow = flow.detach() + + if task == 'stereo': + zeros = torch.zeros_like(flow) # [B, 1, H, W] + # NOTE: reverse disp, disparity is positive + displace = torch.cat((-flow, zeros), dim=1) # [B, 2, H, W] + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=displace, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + elif task == 'depth': + if pred_bidir_depth and refine_iter_idx == 0: + intrinsics_curr = intrinsics_curr.repeat(2, 1, 1) + pose = torch.cat((pose, torch.inverse(pose)), dim=0) + + feature0_ori, feature1_ori = torch.cat((feature0_ori, feature1_ori), + dim=0), torch.cat((feature1_ori, + feature0_ori), dim=0) + + flow_from_depth = compute_flow_with_depth_pose(1. / flow.squeeze(1), + intrinsics_curr, + extrinsics_rel=pose, + ) + + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow_from_depth, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + else: + correlation = local_correlation_with_flow( + feature0_ori, + feature1_ori, + flow=flow, + local_radius=4, + ) # [B, (2R+1)^2, H, W] + + proj = self.refine_proj(feature0) + + net, inp = torch.chunk(proj, chunks=2, dim=1) + + net = torch.tanh(net) + inp = torch.relu(inp) + + net, up_mask, residual_flow = self.refine(net, inp, correlation, flow.clone(), + ) + + if task == 'depth': + flow = (flow - residual_flow).clamp(min=min_depth, max=max_depth) + else: + flow = flow + residual_flow + + if task == 'stereo': + flow = flow.clamp(min=0) # positive + + if self.training or refine_iter_idx == num_reg_refine - 1: + if task == 'depth': + if refine_iter_idx < num_reg_refine - 1: + # bilinear upsampling + flow_up = self.upsample_flow(flow, feature0, bilinear=True, + upsample_factor=upsample_factor, + is_depth=True) + else: + # last one convex upsampling + # NOTE: clamp depth due to the zero padding in the unfold in the convex upsampling + # pad depth to 2 channels as flow + depth_pad = torch.cat((flow, torch.zeros_like(flow)), dim=1) # [B, 2, H, W] + depth_up_pad = self.upsample_flow(depth_pad, feature0, + is_depth=True).clamp(min=min_depth, + max=max_depth) + flow_up = depth_up_pad[:, :1] # [B, 1, H, W] + + else: + flow_up = upsample_flow_with_mask(flow, up_mask, upsample_factor=self.upsample_factor, + is_depth=task == 'depth') + + flow_preds.append(flow_up) + + if task == 'stereo': + for i in range(len(flow_preds)): + flow_preds[i] = flow_preds[i].squeeze(1) # [B, H, W] + + # convert inverse depth to depth + if task == 'depth': + for i in range(len(flow_preds)): + flow_preds[i] = 1. / flow_preds[i].squeeze(1) # [B, H, W] + + results_dict.update({'flow_preds': flow_preds}) + + return results_dict diff --git a/custom_controlnet_aux/unimatch/unimatch/utils.py b/custom_controlnet_aux/unimatch/unimatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a51d3882d082820cf4749ef6e4771b30ca3763b0 --- /dev/null +++ b/custom_controlnet_aux/unimatch/unimatch/utils.py @@ -0,0 +1,216 @@ +import torch +import torch.nn.functional as F +from .position import PositionEmbeddingSine + + +def generate_window_grid(h_min, h_max, w_min, w_max, len_h, len_w, device=None): + assert device is not None + + x, y = torch.meshgrid([torch.linspace(w_min, w_max, len_w, device=device), + torch.linspace(h_min, h_max, len_h, device=device)], + ) + grid = torch.stack((x, y), -1).transpose(0, 1).float() # [H, W, 2] + + return grid + + +def normalize_coords(coords, h, w): + # coords: [B, H, W, 2] + c = torch.Tensor([(w - 1) / 2., (h - 1) / 2.]).float().to(coords.device) + return (coords - c) / c # [-1, 1] + + +def normalize_img(img0, img1): + # loaded images are in [0, 255] + # normalize by ImageNet mean and std + mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device) + std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device) + img0 = (img0 / 255. - mean) / std + img1 = (img1 / 255. - mean) / std + + return img0, img1 + + +def split_feature(feature, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B, H, W, C] + b, h, w, c = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c + ).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C] + else: # [B, C, H, W] + b, c, h, w = feature.size() + assert h % num_splits == 0 and w % num_splits == 0 + + b_new = b * num_splits * num_splits + h_new = h // num_splits + w_new = w // num_splits + + feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits + ).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K] + + return feature + + +def merge_splits(splits, + num_splits=2, + channel_last=False, + ): + if channel_last: # [B*K*K, H/K, W/K, C] + b, h, w, c = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, h, w, c) + merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view( + new_b, num_splits * h, num_splits * w, c) # [B, H, W, C] + else: # [B*K*K, C, H/K, W/K] + b, c, h, w = splits.size() + new_b = b // num_splits // num_splits + + splits = splits.view(new_b, num_splits, num_splits, c, h, w) + merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view( + new_b, c, num_splits * h, num_splits * w) # [B, C, H, W] + + return merge + + +def generate_shift_window_attn_mask(input_resolution, window_size_h, window_size_w, + shift_size_h, shift_size_w, device=torch.device('cuda')): + # ref: https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py + # calculate attention mask for SW-MSA + h, w = input_resolution + img_mask = torch.zeros((1, h, w, 1)).to(device) # 1 H W 1 + h_slices = (slice(0, -window_size_h), + slice(-window_size_h, -shift_size_h), + slice(-shift_size_h, None)) + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = split_feature(img_mask, num_splits=input_resolution[-1] // window_size_w, channel_last=True) + + mask_windows = mask_windows.view(-1, window_size_h * window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask + + +def feature_add_position(feature0, feature1, attn_splits, feature_channels): + pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2) + + if attn_splits > 1: # add position in splited window + feature0_splits = split_feature(feature0, num_splits=attn_splits) + feature1_splits = split_feature(feature1, num_splits=attn_splits) + + position = pos_enc(feature0_splits) + + feature0_splits = feature0_splits + position + feature1_splits = feature1_splits + position + + feature0 = merge_splits(feature0_splits, num_splits=attn_splits) + feature1 = merge_splits(feature1_splits, num_splits=attn_splits) + else: + position = pos_enc(feature0) + + feature0 = feature0 + position + feature1 = feature1 + position + + return feature0, feature1 + + +def upsample_flow_with_mask(flow, up_mask, upsample_factor, + is_depth=False): + # convex upsampling following raft + + mask = up_mask + b, flow_channel, h, w = flow.shape + mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] + mask = torch.softmax(mask, dim=2) + + multiplier = 1 if is_depth else upsample_factor + up_flow = F.unfold(multiplier * flow, [3, 3], padding=1) + up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] + + up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] + up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, + upsample_factor * w) # [B, 2, K*H, K*W] + + return up_flow + + +def split_feature_1d(feature, + num_splits=2, + ): + # feature: [B, W, C] + b, w, c = feature.size() + assert w % num_splits == 0 + + b_new = b * num_splits + w_new = w // num_splits + + feature = feature.view(b, num_splits, w // num_splits, c + ).view(b_new, w_new, c) # [B*K, W/K, C] + + return feature + + +def merge_splits_1d(splits, + h, + num_splits=2, + ): + b, w, c = splits.size() + new_b = b // num_splits // h + + splits = splits.view(new_b, h, num_splits, w, c) + merge = splits.view( + new_b, h, num_splits * w, c) # [B, H, W, C] + + return merge + + +def window_partition_1d(x, window_size_w): + """ + Args: + x: (B, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, C) + """ + B, W, C = x.shape + x = x.view(B, W // window_size_w, window_size_w, C).view(-1, window_size_w, C) + return x + + +def generate_shift_window_attn_mask_1d(input_w, window_size_w, + shift_size_w, device=torch.device('cuda')): + # calculate attention mask for SW-MSA + img_mask = torch.zeros((1, input_w, 1)).to(device) # 1 W 1 + w_slices = (slice(0, -window_size_w), + slice(-window_size_w, -shift_size_w), + slice(-shift_size_w, None)) + cnt = 0 + for w in w_slices: + img_mask[:, w, :] = cnt + cnt += 1 + + mask_windows = window_partition_1d(img_mask, window_size_w) # nW, window_size, 1 + mask_windows = mask_windows.view(-1, window_size_w) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) # nW, window_size, window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + return attn_mask diff --git a/custom_controlnet_aux/unimatch/utils/dist_utils.py b/custom_controlnet_aux/unimatch/utils/dist_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9958a48fe7c0bb5b33457f132b49d513d8420419 --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/dist_utils.py @@ -0,0 +1,105 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# https://github.com/open-mmlab/mmcv/blob/7540cf73ac7e5d1e14d0ffbd9b6759e83929ecfc/mmcv/runner/dist_utils.py + +import os +import subprocess + +import torch +import torch.multiprocessing as mp +from torch import distributed as dist + + +def init_dist(launcher, backend='nccl', **kwargs): + if mp.get_start_method(allow_none=True) is None: + mp.set_start_method('spawn') + if launcher == 'pytorch': + _init_dist_pytorch(backend, **kwargs) + elif launcher == 'mpi': + _init_dist_mpi(backend, **kwargs) + elif launcher == 'slurm': + _init_dist_slurm(backend, **kwargs) + else: + raise ValueError(f'Invalid launcher type: {launcher}') + + +def _init_dist_pytorch(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_mpi(backend, **kwargs): + # TODO: use local_rank instead of rank % num_gpus + rank = int(os.environ['OMPI_COMM_WORLD_RANK']) + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(rank % num_gpus) + dist.init_process_group(backend=backend, **kwargs) + + +def _init_dist_slurm(backend, port=None): + """Initialize slurm distributed training environment. + If argument ``port`` is not specified, then the master port will be system + environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system + environment variable, then a default port ``29500`` will be used. + Args: + backend (str): Backend of torch.distributed. + port (int, optional): Master port. Defaults to None. + """ + proc_id = int(os.environ['SLURM_PROCID']) + ntasks = int(os.environ['SLURM_NTASKS']) + node_list = os.environ['SLURM_NODELIST'] + num_gpus = torch.cuda.device_count() + torch.cuda.set_device(proc_id % num_gpus) + addr = subprocess.getoutput( + f'scontrol show hostname {node_list} | head -n1') + # specify master port + if port is not None: + os.environ['MASTER_PORT'] = str(port) + elif 'MASTER_PORT' in os.environ: + pass # use MASTER_PORT in the environment variable + else: + # 29500 is torch.distributed default port + os.environ['MASTER_PORT'] = '29500' + # use MASTER_ADDR in the environment variable if it already exists + if 'MASTER_ADDR' not in os.environ: + os.environ['MASTER_ADDR'] = addr + os.environ['WORLD_SIZE'] = str(ntasks) + os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) + os.environ['RANK'] = str(proc_id) + dist.init_process_group(backend=backend) + + +def get_dist_info(): + # if (TORCH_VERSION != 'parrots' + # and digit_version(TORCH_VERSION) < digit_version('1.0')): + # initialized = dist._initialized + # else: + if dist.is_available(): + initialized = dist.is_initialized() + else: + initialized = False + if initialized: + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +# from DETR repo +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print diff --git a/custom_controlnet_aux/unimatch/utils/file_io.py b/custom_controlnet_aux/unimatch/utils/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..4f0a99098a2d82dc73987c269795718053880434 --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/file_io.py @@ -0,0 +1,224 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import re +from PIL import Image +import sys +import cv2 +import json +import os + + +def read_img(filename): + # convert to RGB for scene flow finalpass data + img = np.array(Image.open(filename).convert('RGB')).astype(np.float32) + return img + + +def read_disp(filename, subset=False, vkitti2=False, sintel=False, + tartanair=False, instereo2k=False, crestereo=False, + fallingthings=False, + argoverse=False, + raw_disp_png=False, + ): + # Scene Flow dataset + if filename.endswith('pfm'): + # For finalpass and cleanpass, gt disparity is positive, subset is negative + disp = np.ascontiguousarray(_read_pfm(filename)[0]) + if subset: + disp = -disp + # VKITTI2 dataset + elif vkitti2: + disp = _read_vkitti2_disp(filename) + # Sintel + elif sintel: + disp = _read_sintel_disparity(filename) + elif tartanair: + disp = _read_tartanair_disp(filename) + elif instereo2k: + disp = _read_instereo2k_disp(filename) + elif crestereo: + disp = _read_crestereo_disp(filename) + elif fallingthings: + disp = _read_fallingthings_disp(filename) + elif argoverse: + disp = _read_argoverse_disp(filename) + elif raw_disp_png: + disp = np.array(Image.open(filename)).astype(np.float32) + # KITTI + elif filename.endswith('png'): + disp = _read_kitti_disp(filename) + elif filename.endswith('npy'): + disp = np.load(filename) + else: + raise Exception('Invalid disparity file format!') + return disp # [H, W] + + +def _read_pfm(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header.decode("ascii") == 'PF': + color = True + elif header.decode("ascii") == 'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) + if dim_match: + width, height = list(map(int, dim_match.groups())) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().decode("ascii").rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data, scale + + +def write_pfm(file, image, scale=1): + file = open(file, 'wb') + + color = None + + if image.dtype.name != 'float32': + raise Exception('Image dtype must be float32.') + + image = np.flipud(image) + + if len(image.shape) == 3 and image.shape[2] == 3: # color image + color = True + elif len(image.shape) == 2 or len( + image.shape) == 3 and image.shape[2] == 1: # greyscale + color = False + else: + raise Exception( + 'Image must have H x W x 3, H x W x 1 or H x W dimensions.') + + file.write(b'PF\n' if color else b'Pf\n') + file.write(b'%d %d\n' % (image.shape[1], image.shape[0])) + + endian = image.dtype.byteorder + + if endian == '<' or endian == '=' and sys.byteorder == 'little': + scale = -scale + + file.write(b'%f\n' % scale) + + image.tofile(file) + + +def _read_kitti_disp(filename): + depth = np.array(Image.open(filename)) + depth = depth.astype(np.float32) / 256. + return depth + + +def _read_vkitti2_disp(filename): + # read depth + depth = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) # in cm + depth = (depth / 100).astype(np.float32) # depth clipped to 655.35m for sky + + valid = (depth > 0) & (depth < 655) # depth clipped to 655.35m for sky + + # convert to disparity + focal_length = 725.0087 # in pixels + baseline = 0.532725 # meter + + disp = baseline * focal_length / depth + + disp[~valid] = 0.000001 # invalid as very small value + + return disp + + +def _read_sintel_disparity(filename): + """ Return disparity read from filename. """ + f_in = np.array(Image.open(filename)) + + d_r = f_in[:, :, 0].astype('float32') + d_g = f_in[:, :, 1].astype('float32') + d_b = f_in[:, :, 2].astype('float32') + + depth = d_r * 4 + d_g / (2 ** 6) + d_b / (2 ** 14) + return depth + + +def _read_tartanair_disp(filename): + # the infinite distant object such as the sky has a large depth value (e.g. 10000) + depth = np.load(filename) + + # change to disparity image + disparity = 80.0 / depth + + return disparity + + +def _read_instereo2k_disp(filename): + disp = np.array(Image.open(filename)) + disp = disp.astype(np.float32) / 100. + return disp + + +def _read_crestereo_disp(filename): + disp = np.array(Image.open(filename)) + return disp.astype(np.float32) / 32. + + +def _read_fallingthings_disp(filename): + depth = np.array(Image.open(filename)) + camera_file = os.path.join(os.path.dirname(filename), '_camera_settings.json') + with open(camera_file, 'r') as f: + intrinsics = json.load(f) + fx = intrinsics['camera_settings'][0]['intrinsic_settings']['fx'] + disp = (fx * 6.0 * 100) / depth.astype(np.float32) + + return disp + + +def _read_argoverse_disp(filename): + disparity_map = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + return np.float32(disparity_map) / 256. + + +def extract_video(video_name): + cap = cv2.VideoCapture(video_name) + assert cap.isOpened(), f'Failed to load video file {video_name}' + # get video info + size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)), + int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))) + fps = cap.get(cv2.CAP_PROP_FPS) + + print('video size (hxw): %dx%d' % (size[1], size[0])) + print('fps: %d' % fps) + + imgs = [] + while cap.isOpened(): + # get frames + flag, img = cap.read() + if not flag: + break + # to rgb format + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + imgs.append(img) + + return imgs, fps diff --git a/custom_controlnet_aux/unimatch/utils/flow_viz.py b/custom_controlnet_aux/unimatch/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe3f139d8fc54478fc1880eb6aa5a286660540a --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/flow_viz.py @@ -0,0 +1,290 @@ +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from PIL import Image + + +def make_colorwheel(): + ''' + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + ''' + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col:col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col:col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col:col + MR, 0] = 255 + return colorwheel + + +def flow_compute_color(u, v, convert_to_bgr=False): + ''' + Applies the flow color wheel to (possibly clipped) flow components u and v. + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param u: np.ndarray, input horizontal flow + :param v: np.ndarray, input vertical flow + :param convert_to_bgr: bool, whether to change ordering and output BGR instead of RGB + :return: + ''' + + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 1 + f = fk - k0 + + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range? + + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + + return flow_image + + +def flow_to_color(flow_uv, clip_flow=None, convert_to_bgr=False): + ''' + Expects a two dimensional flow image of shape [H,W,2] + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + :param flow_uv: np.ndarray of shape [H,W,2] + :param clip_flow: float, maximum clipping value for flow + :return: + ''' + + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + + return flow_compute_color(u, v, convert_to_bgr) + + +UNKNOWN_FLOW_THRESH = 1e7 +SMALLFLOW = 0.0 +LARGEFLOW = 1e8 + + +def make_color_wheel(): + """ + Generate color wheel according Middlebury color code + :return: Color wheel + """ + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + + colorwheel = np.zeros([ncols, 3]) + + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.transpose(np.floor(255 * np.arange(0, RY) / RY)) + col += RY + + # YG + colorwheel[col:col + YG, 0] = 255 - np.transpose(np.floor(255 * np.arange(0, YG) / YG)) + colorwheel[col:col + YG, 1] = 255 + col += YG + + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.transpose(np.floor(255 * np.arange(0, GC) / GC)) + col += GC + + # CB + colorwheel[col:col + CB, 1] = 255 - np.transpose(np.floor(255 * np.arange(0, CB) / CB)) + colorwheel[col:col + CB, 2] = 255 + col += CB + + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.transpose(np.floor(255 * np.arange(0, BM) / BM)) + col += + BM + + # MR + colorwheel[col:col + MR, 2] = 255 - np.transpose(np.floor(255 * np.arange(0, MR) / MR)) + colorwheel[col:col + MR, 0] = 255 + + return colorwheel + + +def compute_color(u, v): + """ + compute optical flow color map + :param u: optical flow horizontal map + :param v: optical flow vertical map + :return: optical flow in color code + """ + [h, w] = u.shape + img = np.zeros([h, w, 3]) + nanIdx = np.isnan(u) | np.isnan(v) + u[nanIdx] = 0 + v[nanIdx] = 0 + + colorwheel = make_color_wheel() + ncols = np.size(colorwheel, 0) + + rad = np.sqrt(u ** 2 + v ** 2) + + a = np.arctan2(-v, -u) / np.pi + + fk = (a + 1) / 2 * (ncols - 1) + 1 + + k0 = np.floor(fk).astype(int) + + k1 = k0 + 1 + k1[k1 == ncols + 1] = 1 + f = fk - k0 + + for i in range(0, np.size(colorwheel, 1)): + tmp = colorwheel[:, i] + col0 = tmp[k0 - 1] / 255 + col1 = tmp[k1 - 1] / 255 + col = (1 - f) * col0 + f * col1 + + idx = rad <= 1 + col[idx] = 1 - rad[idx] * (1 - col[idx]) + notidx = np.logical_not(idx) + + col[notidx] *= 0.75 + img[:, :, i] = np.uint8(np.floor(255 * col * (1 - nanIdx))) + + return img + + +# from https://github.com/gengshan-y/VCN +def flow_to_image(flow): + """ + Convert flow into middlebury color code image + :param flow: optical flow map + :return: optical flow image in middlebury color + """ + u = flow[:, :, 0] + v = flow[:, :, 1] + + maxu = -999. + maxv = -999. + minu = 999. + minv = 999. + + idxUnknow = (abs(u) > UNKNOWN_FLOW_THRESH) | (abs(v) > UNKNOWN_FLOW_THRESH) + u[idxUnknow] = 0 + v[idxUnknow] = 0 + + maxu = max(maxu, np.max(u)) + minu = min(minu, np.min(u)) + + maxv = max(maxv, np.max(v)) + minv = min(minv, np.min(v)) + + rad = np.sqrt(u ** 2 + v ** 2) + maxrad = max(-1, np.max(rad)) + + u = u / (maxrad + np.finfo(float).eps) + v = v / (maxrad + np.finfo(float).eps) + + img = compute_color(u, v) + + idx = np.repeat(idxUnknow[:, :, np.newaxis], 3, axis=2) + img[idx] = 0 + + return np.uint8(img) + + +def save_vis_flow_tofile(flow, output_path): + vis_flow = flow_to_image(flow) + Image.fromarray(vis_flow).save(output_path) + + +def flow_tensor_to_image(flow): + """Used for tensorboard visualization""" + flow = flow.permute(1, 2, 0) # [H, W, 2] + flow = flow.detach().cpu().numpy() + flow = flow_to_image(flow) # [H, W, 3] + flow = np.transpose(flow, (2, 0, 1)) # [3, H, W] + + return flow diff --git a/custom_controlnet_aux/unimatch/utils/frame_utils.py b/custom_controlnet_aux/unimatch/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d10d2ee9b3b2832617ddf66225528af59e995c7d --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/frame_utils.py @@ -0,0 +1,158 @@ +import numpy as np +from PIL import Image +from os.path import * +import re +import cv2 + +TAG_CHAR = np.array([202021.25], np.float32) + + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) + # Reshape testdata into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + + +def writeFlow(filename, uv, v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert (uv.ndim == 3) + assert (uv.shape[2] == 2) + u = uv[:, :, 0] + v = uv[:, :, 1] + else: + u = uv + + assert (u.shape == v.shape) + height, width = u.shape + f = open(filename, 'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width * nBands)) + tmp[:, np.arange(width) * 2] = u + tmp[:, np.arange(width) * 2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH | cv2.IMREAD_COLOR) + flow = flow[:, :, ::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2 ** 15) / 64.0 + return flow, valid + + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2 ** 15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] + + +def read_vkitti2_flow(filename): + # In R, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] + # In G, flow along x-axis normalized by image width and quantized to [0;2^16 – 1] + # B = 0 for invalid flow (e.g., sky pixels) + bgr = cv2.imread(filename, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + h, w, _c = bgr.shape + assert bgr.dtype == np.uint16 and _c == 3 + # b == invalid flow flag == 0 for sky or other invalid flow + invalid = bgr[:, :, 0] == 0 + # g,r == flow_y,x normalized by height,width and scaled to [0;2**16 – 1] + out_flow = 2.0 / (2 ** 16 - 1.0) * bgr[:, :, 2:0:-1].astype('f4') - 1 # [H, W, 2] + out_flow[..., 0] *= (w - 1) + out_flow[..., 1] *= (h - 1) + + out_flow[invalid] = 0.000001 # invalid as very small value to add supervison on the sky + valid = (np.logical_or(invalid, ~invalid)).astype(np.float32) + + return out_flow, valid diff --git a/custom_controlnet_aux/unimatch/utils/logger.py b/custom_controlnet_aux/unimatch/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..742b23c515932dea67a576cd9169845291f9091b --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/logger.py @@ -0,0 +1,104 @@ +import torch + +from utils.flow_viz import flow_tensor_to_image +from .visualization import viz_depth_tensor + + +class Logger: + def __init__(self, lr_scheduler, + summary_writer, + summary_freq=100, + start_step=0, + img_mean=None, + img_std=None, + ): + self.lr_scheduler = lr_scheduler + self.total_steps = start_step + self.running_loss = {} + self.summary_writer = summary_writer + self.summary_freq = summary_freq + + self.img_mean = img_mean + self.img_std = img_std + + def print_training_status(self, mode='train', is_depth=False): + if is_depth: + print('step: %06d \t loss: %.3f' % (self.total_steps, self.running_loss['total_loss'] / self.summary_freq)) + else: + print('step: %06d \t epe: %.3f' % (self.total_steps, self.running_loss['epe'] / self.summary_freq)) + + for k in self.running_loss: + self.summary_writer.add_scalar(mode + '/' + k, + self.running_loss[k] / self.summary_freq, self.total_steps) + self.running_loss[k] = 0.0 + + def lr_summary(self): + lr = self.lr_scheduler.get_last_lr()[0] + self.summary_writer.add_scalar('lr', lr, self.total_steps) + + def add_image_summary(self, img1, img2, flow_preds=None, flow_gt=None, mode='train', + is_depth=False, + ): + if self.total_steps % self.summary_freq == 0: + if is_depth: + img1 = self.unnormalize_image(img1.detach().cpu()) # [3, H, W], range [0, 1] + img2 = self.unnormalize_image(img2.detach().cpu()) + + concat = torch.cat((img1, img2), dim=-1) # [3, H, W*2] + + self.summary_writer.add_image(mode + '/img', concat, self.total_steps) + else: + img_concat = torch.cat((img1[0].detach().cpu(), img2[0].detach().cpu()), dim=-1) + img_concat = img_concat.type(torch.uint8) # convert to uint8 to visualize in tensorboard + + flow_pred = flow_tensor_to_image(flow_preds[-1][0]) + forward_flow_gt = flow_tensor_to_image(flow_gt[0]) + flow_concat = torch.cat((torch.from_numpy(flow_pred), + torch.from_numpy(forward_flow_gt)), dim=-1) + + concat = torch.cat((img_concat, flow_concat), dim=-2) + + self.summary_writer.add_image(mode + '/img_pred_gt', concat, self.total_steps) + + def add_depth_summary(self, depth_pred, depth_gt, mode='train'): + # assert depth_pred.dim() == 2 # [H, W] + if self.total_steps % self.summary_freq == 0 or 'val' in mode: + pred_viz = viz_depth_tensor(depth_pred.detach().cpu()) # [3, H, W] + gt_viz = viz_depth_tensor(depth_gt.detach().cpu()) + + concat = torch.cat((pred_viz, gt_viz), dim=-1) # [3, H, W*2] + + self.summary_writer.add_image(mode + '/depth_pred_gt', concat, self.total_steps) + + def unnormalize_image(self, img): + # img: [3, H, W], used for visualizing image + mean = torch.tensor(self.img_mean).view(3, 1, 1).type_as(img) + std = torch.tensor(self.img_std).view(3, 1, 1).type_as(img) + + out = img * std + mean + + return out + + def push(self, metrics, mode='train', is_depth=False, ): + self.total_steps += 1 + + self.lr_summary() + + for key in metrics: + if key not in self.running_loss: + self.running_loss[key] = 0.0 + + self.running_loss[key] += metrics[key] + + if self.total_steps % self.summary_freq == 0: + self.print_training_status(mode, is_depth=is_depth) + self.running_loss = {} + + def write_dict(self, results): + for key in results: + tag = key.split('_')[0] + tag = tag + '/' + key + self.summary_writer.add_scalar(tag, results[key], self.total_steps) + + def close(self): + self.summary_writer.close() diff --git a/custom_controlnet_aux/unimatch/utils/misc.py b/custom_controlnet_aux/unimatch/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3f84496c996287f9f46bdefa98d3697d136e6d --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/misc.py @@ -0,0 +1,36 @@ +import os +import sys +import json + + +def read_text_lines(filepath): + with open(filepath, 'r') as f: + lines = f.readlines() + lines = [l.rstrip() for l in lines] + return lines + + +def check_path(path): + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) # explicitly set exist_ok when multi-processing + + +def save_command(save_path, filename='command_train.txt'): + check_path(save_path) + command = sys.argv + save_file = os.path.join(save_path, filename) + # Save all training commands when resuming training + with open(save_file, 'a') as f: + f.write(' '.join(command)) + f.write('\n\n') + + +def save_args(args, filename='args.json'): + args_dict = vars(args) + check_path(args.checkpoint_dir) + save_path = os.path.join(args.checkpoint_dir, filename) + + # save all training args when resuming training + with open(save_path, 'a') as f: + json.dump(args_dict, f, indent=4, sort_keys=False) + f.write('\n\n') diff --git a/custom_controlnet_aux/unimatch/utils/utils.py b/custom_controlnet_aux/unimatch/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..73d780fa4e16fc1daf4f940791e1f708adf18232 --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/utils.py @@ -0,0 +1,157 @@ +import torch +import torch.nn.functional as F +import numpy as np + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims, mode='sintel', padding_factor=8): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // padding_factor) + 1) * padding_factor - self.ht) % padding_factor + pad_wd = (((self.wd // padding_factor) + 1) * padding_factor - self.wd) % padding_factor + if mode == 'sintel': + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False, padding_mode='zeros'): + """ Wrapper for grid_sample, uses pixel coordinates """ + if coords.size(-1) != 2: # [B, 2, H, W] -> [B, H, W, 2] + coords = coords.permute(0, 2, 3, 1) + + H, W = img.shape[-2:] + # H = height if height is not None else img.shape[-2] + # W = width if width is not None else img.shape[-1] + + xgrid, ygrid = coords.split([1, 1], dim=-1) + + # To handle H or W equals to 1 by explicitly defining height and width + if H == 1: + assert ygrid.abs().max() < 1e-8 + H = 10 + if W == 1: + assert xgrid.abs().max() < 1e-8 + W = 10 + + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, mode=mode, + padding_mode=padding_mode, + align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.squeeze(-1).float() + + return img + + +def coords_grid(batch, ht, wd, normalize=False): + if normalize: # [-1, 1] + coords = torch.meshgrid(2 * torch.arange(ht) / (ht - 1) - 1, + 2 * torch.arange(wd) / (wd - 1) - 1) + else: + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) # [B, 2, H, W] + + +def coords_grid_np(h, w): # used for accumulating high speed sintel flow testdata + coords = np.meshgrid(np.arange(h, dtype=np.float32), + np.arange(w, dtype=np.float32), indexing='ij') + coords = np.stack(coords[::-1], axis=-1) # [H, W, 2] + + return coords + + +def compute_out_of_boundary_mask(flow, downsample_factor=None): + # flow: [B, 2, H, W] + assert flow.dim() == 4 and flow.size(1) == 2 + b, _, h, w = flow.shape + init_coords = coords_grid(b, h, w).to(flow.device) + corres = init_coords + flow # [B, 2, H, W] + + if downsample_factor is not None: + assert w % downsample_factor == 0 and h % downsample_factor == 0 + # the actual max disp can predict is in the downsampled feature resolution, then upsample + max_w = (w // downsample_factor - 1) * downsample_factor + max_h = (h // downsample_factor - 1) * downsample_factor + # print('max_w: %d, max_h: %d' % (max_w, max_h)) + else: + max_w = w - 1 + max_h = h - 1 + + valid_mask = (corres[:, 0] >= 0) & (corres[:, 0] <= max_w) & (corres[:, 1] >= 0) & (corres[:, 1] <= max_h) + + # in case very large flow + flow_mask = (flow[:, 0].abs() <= max_w) & (flow[:, 1].abs() <= max_h) + + valid_mask = valid_mask & flow_mask + + return valid_mask # [B, H, W] + + +def normalize_coords(grid): + """Normalize coordinates of image scale to [-1, 1] + Args: + grid: [B, 2, H, W] + """ + assert grid.size(1) == 2 + h, w = grid.size()[2:] + grid[:, 0, :, :] = 2 * (grid[:, 0, :, :].clone() / (w - 1)) - 1 # x: [-1, 1] + grid[:, 1, :, :] = 2 * (grid[:, 1, :, :].clone() / (h - 1)) - 1 # y: [-1, 1] + # grid = grid.permute((0, 2, 3, 1)) # [B, H, W, 2] + return grid + + +def flow_warp(feature, flow, mask=False, padding_mode='zeros'): + b, c, h, w = feature.size() + assert flow.size(1) == 2 + + grid = coords_grid(b, h, w).to(flow.device) + flow # [B, 2, H, W] + + return bilinear_sampler(feature, grid, mask=mask, padding_mode=padding_mode) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +def bilinear_upflow(flow, scale_factor=8): + assert flow.size(1) == 2 + flow = F.interpolate(flow, scale_factor=scale_factor, + mode='bilinear', align_corners=True) * scale_factor + + return flow + + +def upsample_flow(flow, img): + if flow.size(-1) != img.size(-1): + scale_factor = img.size(-1) / flow.size(-1) + flow = F.interpolate(flow, size=img.size()[-2:], + mode='bilinear', align_corners=True) * scale_factor + return flow + + +def count_parameters(model): + num = sum(p.numel() for p in model.parameters() if p.requires_grad) + return num + + +def set_bn_eval(m): + classname = m.__class__.__name__ + if classname.find('BatchNorm') != -1: + m.eval() diff --git a/custom_controlnet_aux/unimatch/utils/visualization.py b/custom_controlnet_aux/unimatch/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..30fe8ae6ee884d1e59518f17a7a80019f7433468 --- /dev/null +++ b/custom_controlnet_aux/unimatch/utils/visualization.py @@ -0,0 +1,107 @@ +import torch +import torch.utils.data +import numpy as np +import torchvision.utils as vutils +import cv2 +from matplotlib.cm import get_cmap +import matplotlib as mpl +import matplotlib.cm as cm + + +def vis_disparity(disp): + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + + return disp_vis + + +def gen_error_colormap(): + cols = np.array( + [[0 / 3.0, 0.1875 / 3.0, 49, 54, 149], + [0.1875 / 3.0, 0.375 / 3.0, 69, 117, 180], + [0.375 / 3.0, 0.75 / 3.0, 116, 173, 209], + [0.75 / 3.0, 1.5 / 3.0, 171, 217, 233], + [1.5 / 3.0, 3 / 3.0, 224, 243, 248], + [3 / 3.0, 6 / 3.0, 254, 224, 144], + [6 / 3.0, 12 / 3.0, 253, 174, 97], + [12 / 3.0, 24 / 3.0, 244, 109, 67], + [24 / 3.0, 48 / 3.0, 215, 48, 39], + [48 / 3.0, np.inf, 165, 0, 38]], dtype=np.float32) + cols[:, 2: 5] /= 255. + return cols + + +def disp_error_img(D_est_tensor, D_gt_tensor, abs_thres=3., rel_thres=0.05, dilate_radius=1): + D_gt_np = D_gt_tensor.detach().cpu().numpy() + D_est_np = D_est_tensor.detach().cpu().numpy() + B, H, W = D_gt_np.shape + # valid mask + mask = D_gt_np > 0 + # error in percentage. When error <= 1, the pixel is valid since <= 3px & 5% + error = np.abs(D_gt_np - D_est_np) + error[np.logical_not(mask)] = 0 + error[mask] = np.minimum(error[mask] / abs_thres, (error[mask] / D_gt_np[mask]) / rel_thres) + # get colormap + cols = gen_error_colormap() + # create error image + error_image = np.zeros([B, H, W, 3], dtype=np.float32) + for i in range(cols.shape[0]): + error_image[np.logical_and(error >= cols[i][0], error < cols[i][1])] = cols[i, 2:] + # TODO: imdilate + # error_image = cv2.imdilate(D_err, strel('disk', dilate_radius)); + error_image[np.logical_not(mask)] = 0. + # show color tag in the top-left cornor of the image + for i in range(cols.shape[0]): + distance = 20 + error_image[:, :10, i * distance:(i + 1) * distance, :] = cols[i, 2:] + + return torch.from_numpy(np.ascontiguousarray(error_image.transpose([0, 3, 1, 2]))) + + +def save_images(logger, mode_tag, images_dict, global_step): + images_dict = tensor2numpy(images_dict) + for tag, values in images_dict.items(): + if not isinstance(values, list) and not isinstance(values, tuple): + values = [values] + for idx, value in enumerate(values): + if len(value.shape) == 3: + value = value[:, np.newaxis, :, :] + value = value[:1] + value = torch.from_numpy(value) + + image_name = '{}/{}'.format(mode_tag, tag) + if len(values) > 1: + image_name = image_name + "_" + str(idx) + logger.add_image(image_name, vutils.make_grid(value, padding=0, nrow=1, normalize=True, scale_each=True), + global_step) + + +def tensor2numpy(var_dict): + for key, vars in var_dict.items(): + if isinstance(vars, np.ndarray): + var_dict[key] = vars + elif isinstance(vars, torch.Tensor): + var_dict[key] = vars.data.cpu().numpy() + else: + raise NotImplementedError("invalid input type for tensor2numpy") + + return var_dict + + +def viz_depth_tensor(disp, return_numpy=False, colormap='plasma'): + # visualize inverse depth + assert isinstance(disp, torch.Tensor) + + disp = disp.numpy() + vmax = np.percentile(disp, 95) + normalizer = mpl.colors.Normalize(vmin=disp.min(), vmax=vmax) + mapper = cm.ScalarMappable(norm=normalizer, cmap=colormap) + colormapped_im = (mapper.to_rgba(disp)[:, :, :3] * 255).astype(np.uint8) # [H, W, 3] + + if return_numpy: + return colormapped_im + + viz = torch.from_numpy(colormapped_im).permute(2, 0, 1) # [3, H, W] + + return viz diff --git a/custom_controlnet_aux/util.py b/custom_controlnet_aux/util.py new file mode 100644 index 0000000000000000000000000000000000000000..9eb744ceea2dedee737abe27ddd82ccc89119502 --- /dev/null +++ b/custom_controlnet_aux/util.py @@ -0,0 +1,352 @@ +import os +import random +import tempfile +import warnings +from contextlib import suppress +from pathlib import Path + +import cv2 +import numpy as np +import torch +from huggingface_hub import constants, hf_hub_download +from torch.hub import get_dir, download_url_to_file +from ast import literal_eval + + +TORCHHUB_PATH = Path(__file__).parent / 'depth_anything' / 'torchhub' +HF_MODEL_NAME = "lllyasviel/Annotators" +DWPOSE_MODEL_NAME = "yzd-v/DWPose" +BDS_MODEL_NAME = "bdsqlsz/qinglong_controlnet-lllite" +DENSEPOSE_MODEL_NAME = "LayerNorm/DensePose-TorchScript-with-hint-image" +MESH_GRAPHORMER_MODEL_NAME = "hr16/ControlNet-HandRefiner-pruned" +SAM_MODEL_NAME = "dhkim2810/MobileSAM" +UNIMATCH_MODEL_NAME = "hr16/Unimatch" +DEPTH_ANYTHING_MODEL_NAME = "LiheYoung/Depth-Anything" #HF Space +DIFFUSION_EDGE_MODEL_NAME = "hr16/Diffusion-Edge" +METRIC3D_MODEL_NAME = "JUGGHM/Metric3D" + +DEPTH_ANYTHING_V2_MODEL_NAME_DICT = { + "depth_anything_v2_vits.pth": "depth-anything/Depth-Anything-V2-Small", + "depth_anything_v2_vitb.pth": "depth-anything/Depth-Anything-V2-Base", + "depth_anything_v2_vitl.pth": "depth-anything/Depth-Anything-V2-Large", + "depth_anything_v2_vitg.pth": "depth-anything/Depth-Anything-V2-Giant", + "depth_anything_v2_metric_vkitti_vitl.pth": "depth-anything/Depth-Anything-V2-Metric-VKITTI-Large", + "depth_anything_v2_metric_hypersim_vitl.pth": "depth-anything/Depth-Anything-V2-Metric-Hypersim-Large" +} + +temp_dir = tempfile.gettempdir() +annotator_ckpts_path = os.path.join(Path(__file__).parents[2], 'ckpts') +USE_SYMLINKS = False + +try: + annotator_ckpts_path = os.environ['AUX_ANNOTATOR_CKPTS_PATH'] +except: + warnings.warn("Custom pressesor model path not set successfully.") + pass + +try: + USE_SYMLINKS = literal_eval(os.environ['AUX_USE_SYMLINKS']) +except: + warnings.warn("USE_SYMLINKS not set successfully. Using default value: False to download models.") + pass + +try: + temp_dir = os.environ['AUX_TEMP_DIR'] + if len(temp_dir) >= 60: + warnings.warn(f"custom temp dir is too long. Using default") + temp_dir = tempfile.gettempdir() +except: + warnings.warn(f"custom temp dir not set successfully") + pass + +here = Path(__file__).parent.resolve() + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def make_noise_disk(H, W, C, F, rng=None): + if rng: + noise = rng.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + else: + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) + +def safer_memory(x): + # Fix many MAC/AMD problems + return np.ascontiguousarray(x.copy()).copy() + +UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"] +def get_upscale_method(method_str): + assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}" + return getattr(cv2, method_str) + +def pad64(x): + return int(np.ceil(float(x) / 64.0) * 64 - x) + +#https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17 +#Added upscale_method, mode params +def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'): + if skip_hwc3: + img = input_image + else: + img = HWC3(input_image) + H_raw, W_raw, _ = img.shape + if resolution == 0: + return img, lambda x: x + k = float(resolution) / float(min(H_raw, W_raw)) + H_target = int(np.round(float(H_raw) * k)) + W_target = int(np.round(float(W_raw) * k)) + img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA) + H_pad, W_pad = pad64(H_target), pad64(W_target) + img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode) + + def remove_pad(x): + return safer_memory(x[:H_target, :W_target, ...]) + + return safer_memory(img_padded), remove_pad + +def common_input_validate(input_image, output_type, **kwargs): + if "img" in kwargs: + warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning) + input_image = kwargs.pop("img") + + if "return_pil" in kwargs: + warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning) + output_type = "pil" if kwargs["return_pil"] else "np" + + if type(output_type) is bool: + warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions") + if output_type: + output_type = "pil" + + if input_image is None: + raise ValueError("input_image must be defined.") + + if not isinstance(input_image, np.ndarray): + input_image = np.array(input_image, dtype=np.uint8) + output_type = output_type or "pil" + else: + output_type = output_type or "np" + + return (input_image, output_type) + +def torch_gc(): + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + +def ade_palette(): + """ADE20K palette that maps each class to RGB values.""" + return [[120, 120, 120], [180, 120, 120], [6, 230, 230], [80, 50, 50], + [4, 200, 3], [120, 120, 80], [140, 140, 140], [204, 5, 255], + [230, 230, 230], [4, 250, 7], [224, 5, 255], [235, 255, 7], + [150, 5, 61], [120, 120, 70], [8, 255, 51], [255, 6, 82], + [143, 255, 140], [204, 255, 4], [255, 51, 7], [204, 70, 3], + [0, 102, 200], [61, 230, 250], [255, 6, 51], [11, 102, 255], + [255, 7, 71], [255, 9, 224], [9, 7, 230], [220, 220, 220], + [255, 9, 92], [112, 9, 255], [8, 255, 214], [7, 255, 224], + [255, 184, 6], [10, 255, 71], [255, 41, 10], [7, 255, 255], + [224, 255, 8], [102, 8, 255], [255, 61, 6], [255, 194, 7], + [255, 122, 8], [0, 255, 20], [255, 8, 41], [255, 5, 153], + [6, 51, 255], [235, 12, 255], [160, 150, 20], [0, 163, 255], + [140, 140, 140], [250, 10, 15], [20, 255, 0], [31, 255, 0], + [255, 31, 0], [255, 224, 0], [153, 255, 0], [0, 0, 255], + [255, 71, 0], [0, 235, 255], [0, 173, 255], [31, 0, 255], + [11, 200, 200], [255, 82, 0], [0, 255, 245], [0, 61, 255], + [0, 255, 112], [0, 255, 133], [255, 0, 0], [255, 163, 0], + [255, 102, 0], [194, 255, 0], [0, 143, 255], [51, 255, 0], + [0, 82, 255], [0, 255, 41], [0, 255, 173], [10, 0, 255], + [173, 255, 0], [0, 255, 153], [255, 92, 0], [255, 0, 255], + [255, 0, 245], [255, 0, 102], [255, 173, 0], [255, 0, 20], + [255, 184, 184], [0, 31, 255], [0, 255, 61], [0, 71, 255], + [255, 0, 204], [0, 255, 194], [0, 255, 82], [0, 10, 255], + [0, 112, 255], [51, 0, 255], [0, 194, 255], [0, 122, 255], + [0, 255, 163], [255, 153, 0], [0, 255, 10], [255, 112, 0], + [143, 255, 0], [82, 0, 255], [163, 255, 0], [255, 235, 0], + [8, 184, 170], [133, 0, 255], [0, 255, 92], [184, 0, 255], + [255, 0, 31], [0, 184, 255], [0, 214, 255], [255, 0, 112], + [92, 255, 0], [0, 224, 255], [112, 224, 255], [70, 184, 160], + [163, 0, 255], [153, 0, 255], [71, 255, 0], [255, 0, 163], + [255, 204, 0], [255, 0, 143], [0, 255, 235], [133, 255, 0], + [255, 0, 235], [245, 0, 255], [255, 0, 122], [255, 245, 0], + [10, 190, 212], [214, 255, 0], [0, 204, 255], [20, 0, 255], + [255, 255, 0], [0, 153, 255], [0, 41, 255], [0, 255, 204], + [41, 0, 255], [41, 255, 0], [173, 0, 255], [0, 245, 255], + [71, 0, 255], [122, 0, 255], [0, 255, 184], [0, 92, 255], + [184, 255, 0], [0, 133, 255], [255, 214, 0], [25, 194, 194], + [102, 255, 0], [92, 0, 255]] + +#https://stackoverflow.com/a/44873382 +#Assume that the minimum version of Python ppl use is 3.9 +def sha256sum(file_path): + import hashlib + h = hashlib.sha256() + b = bytearray(128*1024) + mv = memoryview(b) + with open(file_path, 'rb', buffering=0) as f: + while n := f.readinto(mv): + h.update(mv[:n]) + return h.hexdigest() + +def check_hash_from_torch_hub(file_path, filename): + basename, _ = filename.split('.') + _, ref_hash = basename.split('-') + curr_hash = sha256sum(file_path) + return curr_hash[:len(ref_hash)] == ref_hash + +def custom_torch_download(filename, ckpts_dir=annotator_ckpts_path): + local_dir = os.path.join(get_dir(), 'checkpoints') + model_path = os.path.join(local_dir, filename) + + if not os.path.exists(model_path): + print(f"Failed to find {model_path}.\n Downloading from pytorch.org") + local_dir = os.path.join(ckpts_dir, "torch") + if not os.path.exists(local_dir): + os.mkdir(local_dir) + + model_path = os.path.join(local_dir, filename) + + if not os.path.exists(model_path): + model_url = "https://download.pytorch.org/models/"+filename + try: + download_url_to_file(url = model_url, dst = model_path) + except: + warnings.warn(f"SSL verify failed, try use HTTP instead. {filename}'s hash will be checked") + download_url_to_file(url = model_url, dst = model_path) + assert check_hash_from_torch_hub(model_path, filename), f"Hash check failed as file {filename} is corrupted" + print("Hash check passed") + + print(f"model_path is {model_path}") + return model_path + +def custom_hf_download(pretrained_model_or_path, filename, cache_dir=temp_dir, ckpts_dir=annotator_ckpts_path, subfolder='', use_symlinks=USE_SYMLINKS, repo_type="model"): + + local_dir = os.path.join(ckpts_dir, pretrained_model_or_path) + model_path = os.path.join(local_dir, *subfolder.split('/'), filename) + + if len(str(model_path)) >= 255: + warnings.warn(f"Path {model_path} is too long, \n please change annotator_ckpts_path in config.yaml") + + if not os.path.exists(model_path): + print(f"Failed to find {model_path}.\n Downloading from huggingface.co") + print(f"cacher folder is {cache_dir}, you can change it by custom_tmp_path in config.yaml") + if use_symlinks: + cache_dir_d = constants.HF_HUB_CACHE # use huggingface newer env variables `HF_HUB_CACHE` + if cache_dir_d is None: + import platform + if platform.system() == "Windows": + cache_dir_d = os.path.join(os.getenv("USERPROFILE"), ".cache", "huggingface", "hub") + else: + cache_dir_d = os.path.join(os.getenv("HOME"), ".cache", "huggingface", "hub") + try: + # test_link + Path(cache_dir_d).mkdir(parents=True, exist_ok=True) + Path(ckpts_dir).mkdir(parents=True, exist_ok=True) + (Path(cache_dir_d) / f"linktest_{filename}.txt").touch() + # symlink instead of link avoid `invalid cross-device link` error. + os.symlink(os.path.join(cache_dir_d, f"linktest_{filename}.txt"), os.path.join(ckpts_dir, f"linktest_{filename}.txt")) + print("Using symlinks to download models. \n",\ + "Make sure you have enough space on your cache folder. \n",\ + "And do not purge the cache folder after downloading.\n",\ + "Otherwise, you will have to re-download the models every time you run the script.\n",\ + "You can use USE_SYMLINKS: False in config.yaml to avoid this behavior.") + except: + print("Maybe not able to create symlink. Disable using symlinks.") + use_symlinks = False + cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path) + finally: # always remove test link files + with suppress(FileNotFoundError): + os.remove(os.path.join(ckpts_dir, f"linktest_{filename}.txt")) + os.remove(os.path.join(cache_dir_d, f"linktest_{filename}.txt")) + else: + cache_dir_d = os.path.join(cache_dir, "ckpts", pretrained_model_or_path) + + model_path = hf_hub_download(repo_id=pretrained_model_or_path, + cache_dir=cache_dir_d, + local_dir=local_dir, + subfolder=subfolder, + filename=filename, + local_dir_use_symlinks=use_symlinks, + resume_download=True, + etag_timeout=100, + repo_type=repo_type + ) + if not use_symlinks: + try: + import shutil + shutil.rmtree(os.path.join(cache_dir, "ckpts")) + except Exception as e : + print(e) + + print(f"model_path is {model_path}") + + return model_path diff --git a/custom_controlnet_aux/zoe/LICENSE b/custom_controlnet_aux/zoe/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a1e90d007836c327846ce8e5151013b115042ab --- /dev/null +++ b/custom_controlnet_aux/zoe/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2022 Intelligent Systems Lab Org + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/__init__.py b/custom_controlnet_aux/zoe/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..27208cba0b3a834f23fa966f6af0176faefe012c --- /dev/null +++ b/custom_controlnet_aux/zoe/__init__.py @@ -0,0 +1,111 @@ +import os + +import cv2 +import numpy as np +import torch +from einops import rearrange +from PIL import Image + +from custom_controlnet_aux.util import HWC3, common_input_validate, resize_image_with_pad, custom_hf_download, HF_MODEL_NAME, DEPTH_ANYTHING_MODEL_NAME +from .zoedepth.models.zoedepth.zoedepth_v1 import ZoeDepth +from .zoedepth.models.zoedepth_anything.zoedepth_v1 import ZoeDepth as ZoeDepthAnything +from .zoedepth.utils.config import get_config + + +class ZoeDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=HF_MODEL_NAME, filename="ZoeD_M12_N.pt"): + model_path = custom_hf_download(pretrained_model_or_path, filename) + + conf = get_config("zoedepth", "infer") + model = ZoeDepth.build_from_config(conf) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().to(self.device) + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = remove_pad(HWC3(depth_image)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map + +class ZoeDepthAnythingDetector: + def __init__(self, model): + self.model = model + self.device = "cpu" + + @classmethod + def from_pretrained(cls, pretrained_model_or_path=DEPTH_ANYTHING_MODEL_NAME, filename="depth_anything_metric_depth_indoor.pt"): + model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="checkpoints_metric_depth", repo_type="space") + + conf = get_config("zoedepth", "infer") + model = ZoeDepthAnything.build_from_config(conf) + model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))['model']) + model.eval() + + return cls(model) + + def to(self, device): + self.model.to(device) + self.device = device + return self + + def __call__(self, input_image, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs): + input_image, output_type = common_input_validate(input_image, output_type, **kwargs) + input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method) + + image_depth = input_image + with torch.no_grad(): + image_depth = torch.from_numpy(image_depth).float().to(self.device) + image_depth = image_depth / 255.0 + image_depth = rearrange(image_depth, 'h w c -> 1 c h w') + depth = self.model.infer(image_depth) + + depth = depth[0, 0].cpu().numpy() + + vmin = np.percentile(depth, 2) + vmax = np.percentile(depth, 85) + + depth -= vmin + depth /= vmax - vmin + depth = 1.0 - depth + depth_image = (depth * 255.0).clip(0, 255).astype(np.uint8) + + detected_map = remove_pad(HWC3(depth_image)) + + if output_type == "pil": + detected_map = Image.fromarray(detected_map) + + return detected_map \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/__init__.py b/custom_controlnet_aux/zoe/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_controlnet_aux/zoe/zoedepth/models/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/custom_controlnet_aux/zoe/zoedepth/models/base_models/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/base_models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/base_models/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/custom_controlnet_aux/zoe/zoedepth/models/base_models/depth_anything.py b/custom_controlnet_aux/zoe/zoedepth/models/base_models/depth_anything.py new file mode 100644 index 0000000000000000000000000000000000000000..9149d93d9a028ef09374316c5ee5a31ba39d5b05 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/base_models/depth_anything.py @@ -0,0 +1,377 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize +from .dpt_dinov2.dpt import DPT_DINOv2 +from custom_controlnet_aux.util import custom_hf_download, DEPTH_ANYTHING_MODEL_NAME + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + print("Params passed to Resize transform:") + print("\twidth: ", width) + print("\theight: ", height) + print("\tresize_target: ", resize_target) + print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + print("\tensure_multiple_of: ", ensure_multiple_of) + print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (int(height), int(width)), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + # self.normalization = Normalize( + # mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.normalization = Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=14, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class DepthAnythingCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + # print('input to midas:', x.shape) + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + + with torch.set_grad_enabled(self.trainable): + + rel_depth = self.core(x) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "pos_embed" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "pos_embed" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.depth_head.scratch.output_conv2.children())[ + 1].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.depth_head.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.depth_head.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.depth_head.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.depth_head.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.depth_head.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self): + self.output_channels = [256, 256, 256, 256, 256] + + @staticmethod + def build(midas_model_type="dinov2_large", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if "img_size" in kwargs: + kwargs = DepthAnythingCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + + depth_anything = DPT_DINOv2(out_channels=[256, 512, 1024, 1024], use_clstoken=False) + depth_anything_path = custom_hf_download(DEPTH_ANYTHING_MODEL_NAME, "depth_anything_vitl14.pth", subfolder="checkpoints", repo_type="space") + state_dict = torch.load(depth_anything_path, map_location='cpu') + depth_anything.load_state_dict(state_dict) + + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + + depth_anything_core = DepthAnythingCore(depth_anything, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + + depth_anything_core.set_output_channels() + return depth_anything_core + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/blocks.py b/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc585ea8ca43a3ecca8e82200b0156599694c3b --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/blocks.py @@ -0,0 +1,153 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape*2 + out_shape3 = out_shape*4 + if len(in_shape) >= 4: + out_shape4 = out_shape*8 + + scratch.layer1_rn = nn.Conv2d( + in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer2_rn = nn.Conv2d( + in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + scratch.layer3_rn = nn.Conv2d( + in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d( + in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups + ) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + self.conv2 = nn.Conv2d( + features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups + ) + + if self.bn==True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn==True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn==True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand==True: + out_features = features//2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate( + output, **modifier, mode="bilinear", align_corners=self.align_corners + ) + + output = self.out_conv(output) + + return output \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/dpt.py b/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..2569ff4bb31ee8f4bc78d8e9dbbc6dd51e98bc1f --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/base_models/dpt_dinov2/dpt.py @@ -0,0 +1,158 @@ +import torch +import torch.nn as nn + +from .blocks import FeatureFusionBlock, _make_scratch +import torch.nn.functional as F +from custom_controlnet_aux.util import TORCHHUB_PATH + + +def _make_fusion_block(features, use_bn, size = None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class DPTHead(nn.Module): + def __init__(self, in_channels, features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + # out_channels = [in_channels // 8, in_channels // 4, in_channels // 2, in_channels] + # out_channels = [in_channels // 4, in_channels // 2, in_channels, in_channels] + # out_channels = [in_channels, in_channels, in_channels, in_channels] + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + + x = self.projects[i](x) + x = self.resize_layers[i](x) + + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + out = self.scratch.output_conv1(path_1) + out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + out = self.scratch.output_conv2(out) + + return out + + +class DPT_DINOv2(nn.Module): + def __init__(self, encoder='vitl', features=256, use_bn=False, out_channels=[256, 512, 1024, 1024], use_clstoken=False): + + super(DPT_DINOv2, self).__init__() + + torch.manual_seed(1) + + self.pretrained = torch.hub.load(TORCHHUB_PATH / 'facebookresearch_dinov2_main', 'dinov2_{:}14'.format(encoder), source='local', pretrained=False) + + dim = self.pretrained.blocks[0].attn.qkv.in_features + + self.depth_head = DPTHead(dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + h, w = x.shape[-2:] + + features = self.pretrained.get_intermediate_layers(x, 4, return_class_token=True) + + patch_h, patch_w = h // 14, w // 14 + + depth = self.depth_head(features, patch_h, patch_w) + depth = F.interpolate(depth, size=(h, w), mode="bilinear", align_corners=True) + depth = F.relu(depth) + + return depth.squeeze(1) \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/base_models/midas.py b/custom_controlnet_aux/zoe/zoedepth/models/base_models/midas.py new file mode 100644 index 0000000000000000000000000000000000000000..e04f81f42e2179303be0e3ad3c354e70a993c349 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/base_models/midas.py @@ -0,0 +1,383 @@ +# MIT License +import os + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn +import numpy as np +from torchvision.transforms import Normalize +import inspect +from pathlib import Path + + +def denormalize(x): + """Reverses the imagenet normalization applied to the input. + + Args: + x (torch.Tensor - shape(N,3,H,W)): input tensor + + Returns: + torch.Tensor - shape(N,3,H,W): Denormalized input + """ + mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(x.device) + std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(x.device) + return x * std + mean + +def get_activation(name, bank): + def hook(model, input, output): + bank[name] = output + return hook + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + ): + """Init. + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + # print("Params passed to Resize transform:") + # print("\twidth: ", width) + # print("\theight: ", height) + # print("\tresize_target: ", resize_target) + # print("\tkeep_aspect_ratio: ", keep_aspect_ratio) + # print("\tensure_multiple_of: ", ensure_multiple_of) + # print("\tresize_method: ", resize_method) + + self.__width = width + self.__height = height + + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) + * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented" + ) + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, min_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, min_val=self.__width + ) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of( + scale_height * height, max_val=self.__height + ) + new_width = self.constrain_to_multiple_of( + scale_width * width, max_val=self.__width + ) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError( + f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, x): + width, height = self.get_size(*x.shape[-2:][::-1]) + return nn.functional.interpolate(x, (int(height), int(width)), mode='bilinear', align_corners=True) + +class PrepForMidas(object): + def __init__(self, resize_mode="minimal", keep_aspect_ratio=True, img_size=384, do_resize=True): + if isinstance(img_size, int): + img_size = (img_size, img_size) + net_h, net_w = img_size + self.normalization = Normalize( + mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + self.resizer = Resize(net_w, net_h, keep_aspect_ratio=keep_aspect_ratio, ensure_multiple_of=32, resize_method=resize_mode) \ + if do_resize else nn.Identity() + + def __call__(self, x): + return self.normalization(self.resizer(x)) + + +class MidasCore(nn.Module): + def __init__(self, midas, trainable=False, fetch_features=True, layer_names=('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'), freeze_bn=False, keep_aspect_ratio=True, + img_size=384, **kwargs): + """Midas Base model used for multi-scale feature extraction. + + Args: + midas (torch.nn.Module): Midas model. + trainable (bool, optional): Train midas model. Defaults to False. + fetch_features (bool, optional): Extract multi-scale features. Defaults to True. + layer_names (tuple, optional): Layers used for feature extraction. Order = (head output features, last layer features, ...decoder features). Defaults to ('out_conv', 'l4_rn', 'r4', 'r3', 'r2', 'r1'). + freeze_bn (bool, optional): Freeze BatchNorm. Generally results in better finetuning performance. Defaults to False. + keep_aspect_ratio (bool, optional): Keep the aspect ratio of input images while resizing. Defaults to True. + img_size (int, tuple, optional): Input resolution. Defaults to 384. + """ + super().__init__() + self.core = midas + self.output_channels = None + self.core_out = {} + self.trainable = trainable + self.fetch_features = fetch_features + # midas.scratch.output_conv = nn.Identity() + self.handles = [] + # self.layer_names = ['out_conv','l4_rn', 'r4', 'r3', 'r2', 'r1'] + self.layer_names = layer_names + + self.set_trainable(trainable) + self.set_fetch_features(fetch_features) + + self.prep = PrepForMidas(keep_aspect_ratio=keep_aspect_ratio, + img_size=img_size, do_resize=kwargs.get('do_resize', True)) + + if freeze_bn: + self.freeze_bn() + + def set_trainable(self, trainable): + self.trainable = trainable + if trainable: + self.unfreeze() + else: + self.freeze() + return self + + def set_fetch_features(self, fetch_features): + self.fetch_features = fetch_features + if fetch_features: + if len(self.handles) == 0: + self.attach_hooks(self.core) + else: + self.remove_hooks() + return self + + def freeze(self): + for p in self.parameters(): + p.requires_grad = False + self.trainable = False + return self + + def unfreeze(self): + for p in self.parameters(): + p.requires_grad = True + self.trainable = True + return self + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + return self + + def forward(self, x, denorm=False, return_rel_depth=False): + with torch.no_grad(): + if denorm: + x = denormalize(x) + x = self.prep(x) + # print("Shape after prep: ", x.shape) + + with torch.set_grad_enabled(self.trainable): + + # print("Input size to Midascore", x.shape) + rel_depth = self.core(x) + # print("Output from custom_midas_repo.midas shape", rel_depth.shape) + if not self.fetch_features: + return rel_depth + out = [self.core_out[k] for k in self.layer_names] + + if return_rel_depth: + return rel_depth, out + return out + + def get_rel_pos_params(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(self): + for name, p in self.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + def freeze_encoder(self, freeze_rel_pos=False): + if freeze_rel_pos: + for p in self.core.pretrained.parameters(): + p.requires_grad = False + else: + for p in self.get_enc_params_except_rel_pos(): + p.requires_grad = False + return self + + def attach_hooks(self, midas): + if len(self.handles) > 0: + self.remove_hooks() + if "out_conv" in self.layer_names: + self.handles.append(list(midas.scratch.output_conv.children())[ + 3].register_forward_hook(get_activation("out_conv", self.core_out))) + if "r4" in self.layer_names: + self.handles.append(midas.scratch.refinenet4.register_forward_hook( + get_activation("r4", self.core_out))) + if "r3" in self.layer_names: + self.handles.append(midas.scratch.refinenet3.register_forward_hook( + get_activation("r3", self.core_out))) + if "r2" in self.layer_names: + self.handles.append(midas.scratch.refinenet2.register_forward_hook( + get_activation("r2", self.core_out))) + if "r1" in self.layer_names: + self.handles.append(midas.scratch.refinenet1.register_forward_hook( + get_activation("r1", self.core_out))) + if "l4_rn" in self.layer_names: + self.handles.append(midas.scratch.layer4_rn.register_forward_hook( + get_activation("l4_rn", self.core_out))) + + return self + + def remove_hooks(self): + for h in self.handles: + h.remove() + return self + + def __del__(self): + self.remove_hooks() + + def set_output_channels(self, model_type): + self.output_channels = MIDAS_SETTINGS[model_type] + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", train_midas=False, use_pretrained_midas=True, fetch_features=False, freeze_bn=True, force_keep_ar=False, force_reload=False, **kwargs): + if midas_model_type not in MIDAS_SETTINGS: + raise ValueError( + f"Invalid model type: {midas_model_type}. Must be one of {list(MIDAS_SETTINGS.keys())}") + if "img_size" in kwargs: + kwargs = MidasCore.parse_img_size(kwargs) + img_size = kwargs.pop("img_size", [384, 384]) + # print("img_size", img_size) + import custom_midas_repo + midas_path = Path(inspect.getfile(custom_midas_repo)).parent.resolve() + del custom_midas_repo + midas = torch.hub.load(midas_path, midas_model_type, + pretrained=use_pretrained_midas, force_reload=force_reload, source='local') + kwargs.update({'keep_aspect_ratio': force_keep_ar}) + midas_core = MidasCore(midas, trainable=train_midas, fetch_features=fetch_features, + freeze_bn=freeze_bn, img_size=img_size, **kwargs) + midas_core.set_output_channels(midas_model_type) + return midas_core + + @staticmethod + def build_from_config(config): + return MidasCore.build(**config) + + @staticmethod + def parse_img_size(config): + assert 'img_size' in config + if isinstance(config['img_size'], str): + assert "," in config['img_size'], "img_size should be a string with comma separated img_size=H,W" + config['img_size'] = list(map(int, config['img_size'].split(","))) + assert len( + config['img_size']) == 2, "img_size should be a string with comma separated img_size=H,W" + elif isinstance(config['img_size'], int): + config['img_size'] = [config['img_size'], config['img_size']] + else: + assert isinstance(config['img_size'], list) and len( + config['img_size']) == 2, "img_size should be a list of H,W" + return config + + +nchannels2models = { + tuple([256]*5): ["DPT_BEiT_L_384", "DPT_BEiT_L_512", "DPT_BEiT_B_384", "DPT_SwinV2_L_384", "DPT_SwinV2_B_384", "DPT_SwinV2_T_256", "DPT_Large", "DPT_Hybrid"], + (512, 256, 128, 64, 64): ["MiDaS_small"] +} + +# Model name to number of output channels +MIDAS_SETTINGS = {m: k for k, v in nchannels2models.items() + for m in v + } diff --git a/custom_controlnet_aux/zoe/zoedepth/models/builder.py b/custom_controlnet_aux/zoe/zoedepth/models/builder.py new file mode 100644 index 0000000000000000000000000000000000000000..0818311b642561712a03a66655c638ce09a04cca --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/builder.py @@ -0,0 +1,51 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from importlib import import_module +from .depth_model import DepthModel + +def build_model(config) -> DepthModel: + """Builds a model from a config. The model is specified by the model name and version in the config. The model is then constructed using the build_from_config function of the model interface. + This function should be used to construct models for training and evaluation. + + Args: + config (dict): Config dict. Config is constructed in utils/config.py. Each model has its own config file(s) saved in its root model folder. + + Returns: + torch.nn.Module: Model corresponding to name and version as specified in config + """ + module_name = f"zoedepth.models.{config.model}" + try: + module = import_module(module_name) + except ModuleNotFoundError as e: + # print the original error message + print(e) + raise ValueError( + f"Model {config.model} not found. Refer above error for details.") from e + try: + get_version = getattr(module, "get_version") + except AttributeError as e: + raise ValueError( + f"Model {config.model} has no get_version function.") from e + return get_version(config.version_name).build_from_config(config) diff --git a/custom_controlnet_aux/zoe/zoedepth/models/depth_model.py b/custom_controlnet_aux/zoe/zoedepth/models/depth_model.py new file mode 100644 index 0000000000000000000000000000000000000000..fc421c108ea3928c9add62b4c190500d9bd4eda1 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/depth_model.py @@ -0,0 +1,152 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision import transforms +import PIL.Image +from PIL import Image +from typing import Union + + +class DepthModel(nn.Module): + def __init__(self): + super().__init__() + self.device = 'cpu' + + def to(self, device) -> nn.Module: + self.device = device + return super().to(device) + + def forward(self, x, *args, **kwargs): + raise NotImplementedError + + def _infer(self, x: torch.Tensor): + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + return self(x)['metric_depth'] + + def _infer_with_pad_aug(self, x: torch.Tensor, pad_input: bool=True, fh: float=3, fw: float=3, upsampling_mode: str='bicubic', padding_mode="reflect", **kwargs) -> torch.Tensor: + """ + Inference interface for the model with padding augmentation + Padding augmentation fixes the boundary artifacts in the output depth map. + Boundary artifacts are sometimes caused by the fact that the model is trained on NYU raw dataset which has a black or white border around the image. + This augmentation pads the input image and crops the prediction back to the original size / view. + + Note: This augmentation is not required for the models trained with 'avoid_boundary'=True. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to pad the input or not. Defaults to True. + fh (float, optional): height padding factor. The padding is calculated as sqrt(h/2) * fh. Defaults to 3. + fw (float, optional): width padding factor. The padding is calculated as sqrt(w/2) * fw. Defaults to 3. + upsampling_mode (str, optional): upsampling mode. Defaults to 'bicubic'. + padding_mode (str, optional): padding mode. Defaults to "reflect". + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # assert x is nchw and c = 3 + assert x.dim() == 4, "x must be 4 dimensional, got {}".format(x.dim()) + assert x.shape[1] == 3, "x must have 3 channels, got {}".format(x.shape[1]) + + if pad_input: + assert fh > 0 or fw > 0, "atlease one of fh and fw must be greater than 0" + pad_h = int(np.sqrt(x.shape[2]/2) * fh) + pad_w = int(np.sqrt(x.shape[3]/2) * fw) + padding = [pad_w, pad_w] + if pad_h > 0: + padding += [pad_h, pad_h] + + x = F.pad(x, padding, mode=padding_mode, **kwargs) + out = self._infer(x) + if out.shape[-2:] != x.shape[-2:]: + out = F.interpolate(out, size=(x.shape[2], x.shape[3]), mode=upsampling_mode, align_corners=False) + if pad_input: + # crop to the original size, handling the case where pad_h and pad_w is 0 + if pad_h > 0: + out = out[:, :, pad_h:-pad_h,:] + if pad_w > 0: + out = out[:, :, :, pad_w:-pad_w] + return out + + def infer_with_flip_aug(self, x, pad_input: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model with horizontal flip augmentation + Horizontal flip augmentation improves the accuracy of the model by averaging the output of the model with and without horizontal flip. + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + # infer with horizontal flip and average + out = self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + out_flip = self._infer_with_pad_aug(torch.flip(x, dims=[3]), pad_input=pad_input, **kwargs) + out = (out + torch.flip(out_flip, dims=[3])) / 2 + return out + + def infer(self, x, pad_input: bool=True, with_flip_aug: bool=True, **kwargs) -> torch.Tensor: + """ + Inference interface for the model + Args: + x (torch.Tensor): input tensor of shape (b, c, h, w) + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + Returns: + torch.Tensor: output tensor of shape (b, 1, h, w) + """ + if with_flip_aug: + return self.infer_with_flip_aug(x, pad_input=pad_input, **kwargs) + else: + return self._infer_with_pad_aug(x, pad_input=pad_input, **kwargs) + + @torch.no_grad() + def infer_pil(self, pil_img, pad_input: bool=True, with_flip_aug: bool=True, output_type: str="numpy", **kwargs) -> Union[np.ndarray, PIL.Image.Image, torch.Tensor]: + """ + Inference interface for the model for PIL image + Args: + pil_img (PIL.Image.Image): input PIL image + pad_input (bool, optional): whether to use padding augmentation. Defaults to True. + with_flip_aug (bool, optional): whether to use horizontal flip augmentation. Defaults to True. + output_type (str, optional): output type. Supported values are 'numpy', 'pil' and 'tensor'. Defaults to "numpy". + """ + x = transforms.ToTensor()(pil_img).unsqueeze(0).to(self.device) + out_tensor = self.infer(x, pad_input=pad_input, with_flip_aug=with_flip_aug, **kwargs) + if output_type == "numpy": + return out_tensor.squeeze().cpu().numpy() + elif output_type == "pil": + # uint16 is required for depth pil image + out_16bit_numpy = (out_tensor.squeeze().cpu().numpy()*256).astype(np.uint16) + return Image.fromarray(out_16bit_numpy) + elif output_type == "tensor": + return out_tensor.squeeze().cpu() + else: + raise ValueError(f"output_type {output_type} not supported. Supported values are 'numpy', 'pil' and 'tensor'") + \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/layers/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c344f725c8a10dcaf29d4c308eb49d86ac51ff88 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/layers/__init__.py @@ -0,0 +1,23 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat diff --git a/custom_controlnet_aux/zoe/zoedepth/models/layers/attractor.py b/custom_controlnet_aux/zoe/zoedepth/models/layers/attractor.py new file mode 100644 index 0000000000000000000000000000000000000000..2a8efe645adea1d88a12e2ac5cc6bb2a251eef9d --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/layers/attractor.py @@ -0,0 +1,208 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +@torch.jit.script +def exp_attractor(dx, alpha: float = 300, gamma: int = 2): + """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) + + +@torch.jit.script +def inv_attractor(dx, alpha: float = 300, gamma: int = 2): + """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center + This is the default one according to the accompanying paper. + + Args: + dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. + alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. + gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. + + Returns: + torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc + """ + return dx.div(1+alpha*dx.pow(gamma)) + + +class AttractorLayer(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), # x2 for linear norm + nn.ReLU(inplace=True) + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + eps = 1e-3 + A = A + eps + n, c, h, w = A.shape + A = A.view(n, self.n_attractors, 2, h, w) + A_normed = A / A.sum(dim=2, keepdim=True) # n, a, 2, h, w + A_normed = A[:, :, 0, ...] # n, na, h, w + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func(dist(A_normed.unsqueeze( + 2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + # .shape N, nbins, h, w + delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = (self.max_depth - self.min_depth) * \ + b_new_centers + self.min_depth + B_centers, _ = torch.sort(B_centers, dim=1) + B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) + return b_new_centers, B_centers + + +class AttractorLayerUnnormed(nn.Module): + def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, + alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): + """ + Attractor layer for bin centers. Bin centers are unbounded + """ + super().__init__() + + self.n_attractors = n_attractors + self.n_bins = n_bins + self.min_depth = min_depth + self.max_depth = max_depth + self.alpha = alpha + self.gamma = gamma + self.kind = kind + self.attractor_type = attractor_type + self.memory_efficient = memory_efficient + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + Args: + x (torch.Tensor) : feature block; shape - n, c, h, w + b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w + + Returns: + tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate( + prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + + A = self._net(x) + n, c, h, w = A.shape + + b_prev = nn.functional.interpolate( + b_prev, (h, w), mode='bilinear', align_corners=True) + b_centers = b_prev + + if self.attractor_type == 'exp': + dist = exp_attractor + else: + dist = inv_attractor + + if not self.memory_efficient: + func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] + # .shape N, nbins, h, w + delta_c = func( + dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) + else: + delta_c = torch.zeros_like(b_centers, device=b_centers.device) + for i in range(self.n_attractors): + delta_c += dist(A[:, i, ...].unsqueeze(1) - + b_centers) # .shape N, nbins, h, w + + if self.kind == 'mean': + delta_c = delta_c / self.n_attractors + + b_new_centers = b_centers + delta_c + B_centers = b_new_centers + + return b_new_centers, B_centers diff --git a/custom_controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py b/custom_controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..3208405dfb78fdfc28d5765e5a6d5dbe31967a23 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/layers/dist_layers.py @@ -0,0 +1,121 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +def log_binom(n, k, eps=1e-7): + """ log(nCk) using stirling approximation """ + n = n + eps + k = k + eps + return n * torch.log(n) - k * torch.log(k) - (n-k) * torch.log(n-k+eps) + + +class LogBinomial(nn.Module): + def __init__(self, n_classes=256, act=torch.softmax): + """Compute log binomial distribution for n_classes + + Args: + n_classes (int, optional): number of output classes. Defaults to 256. + """ + super().__init__() + self.K = n_classes + self.act = act + self.register_buffer('k_idx', torch.arange( + 0, n_classes).view(1, -1, 1, 1)) + self.register_buffer('K_minus_1', torch.Tensor( + [self.K-1]).view(1, -1, 1, 1)) + + def forward(self, x, t=1., eps=1e-4): + """Compute log binomial distribution for x + + Args: + x (torch.Tensor - NCHW): probabilities + t (float, torch.Tensor - NCHW, optional): Temperature of distribution. Defaults to 1.. + eps (float, optional): Small number for numerical stability. Defaults to 1e-4. + + Returns: + torch.Tensor -NCHW: log binomial distribution logbinomial(p;t) + """ + if x.ndim == 3: + x = x.unsqueeze(1) # make it nchw + + one_minus_x = torch.clamp(1 - x, eps, 1) + x = torch.clamp(x, eps, 1) + y = log_binom(self.K_minus_1, self.k_idx) + self.k_idx * \ + torch.log(x) + (self.K - 1 - self.k_idx) * torch.log(one_minus_x) + return self.act(y/t, dim=1) + + +class ConditionalLogBinomial(nn.Module): + def __init__(self, in_features, condition_dim, n_classes=256, bottleneck_factor=2, p_eps=1e-4, max_temp=50, min_temp=1e-7, act=torch.softmax): + """Conditional Log Binomial distribution + + Args: + in_features (int): number of input channels in main feature + condition_dim (int): number of input channels in condition feature + n_classes (int, optional): Number of classes. Defaults to 256. + bottleneck_factor (int, optional): Hidden dim factor. Defaults to 2. + p_eps (float, optional): small eps value. Defaults to 1e-4. + max_temp (float, optional): Maximum temperature of output distribution. Defaults to 50. + min_temp (float, optional): Minimum temperature of output distribution. Defaults to 1e-7. + """ + super().__init__() + self.p_eps = p_eps + self.max_temp = max_temp + self.min_temp = min_temp + self.log_binomial_transform = LogBinomial(n_classes, act=act) + bottleneck = (in_features + condition_dim) // bottleneck_factor + self.mlp = nn.Sequential( + nn.Conv2d(in_features + condition_dim, bottleneck, + kernel_size=1, stride=1, padding=0), + nn.GELU(), + # 2 for p linear norm, 2 for t linear norm + nn.Conv2d(bottleneck, 2+2, kernel_size=1, stride=1, padding=0), + nn.Softplus() + ) + + def forward(self, x, cond): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Main feature + cond (torch.Tensor - NCHW): condition feature + + Returns: + torch.Tensor: Output log binomial distribution + """ + pt = self.mlp(torch.concat((x, cond), dim=1)) + p, t = pt[:, :2, ...], pt[:, 2:, ...] + + p = p + self.p_eps + p = p[:, 0, ...] / (p[:, 0, ...] + p[:, 1, ...]) + + t = t + self.p_eps + t = t[:, 0, ...] / (t[:, 0, ...] + t[:, 1, ...]) + t = t.unsqueeze(1) + t = (self.max_temp - self.min_temp) * t + self.min_temp + + return self.log_binomial_transform(p, t) diff --git a/custom_controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py b/custom_controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..f94481605c3e6958ce50e73b2eb31d9f0c07dc67 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/layers/localbins_layers.py @@ -0,0 +1,169 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class SeedBinRegressor(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are bounded on (min_depth, max_depth) interval. + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Min depth value. Defaults to 1e-3. + max_depth (float, optional): Max depth value. Defaults to 10. + """ + super().__init__() + self.version = "1_1" + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B = self._net(x) + eps = 1e-3 + B = B + eps + B_widths_normed = B / B.sum(dim=1, keepdim=True) + B_widths = (self.max_depth - self.min_depth) * \ + B_widths_normed # .shape NCHW + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad( + B_widths, (0, 0, 0, 0, 1, 0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:, 1:, ...]) + return B_widths_normed, B_centers + + +class SeedBinRegressorUnnormed(nn.Module): + def __init__(self, in_features, n_bins=16, mlp_dim=256, min_depth=1e-3, max_depth=10): + """Bin center regressor network. Bin centers are unbounded + + Args: + in_features (int): input channels + n_bins (int, optional): Number of bin centers. Defaults to 16. + mlp_dim (int, optional): Hidden dimension. Defaults to 256. + min_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + max_depth (float, optional): Not used. (for compatibility with SeedBinRegressor) + """ + super().__init__() + self.version = "1_1" + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, n_bins, 1, 1, 0), + nn.Softplus() + ) + + def forward(self, x): + """ + Returns tensor of bin_width vectors (centers). One vector b for every pixel + """ + B_centers = self._net(x) + return B_centers, B_centers + + +class Projector(nn.Module): + def __init__(self, in_features, out_features, mlp_dim=128): + """Projector MLP + + Args: + in_features (int): input channels + out_features (int): output channels + mlp_dim (int, optional): hidden dimension. Defaults to 128. + """ + super().__init__() + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.ReLU(inplace=True), + nn.Conv2d(mlp_dim, out_features, 1, 1, 0), + ) + + def forward(self, x): + return self._net(x) + + + +class LinearSplitter(nn.Module): + def __init__(self, in_features, prev_nbins, split_factor=2, mlp_dim=128, min_depth=1e-3, max_depth=10): + super().__init__() + + self.prev_nbins = prev_nbins + self.split_factor = split_factor + self.min_depth = min_depth + self.max_depth = max_depth + + self._net = nn.Sequential( + nn.Conv2d(in_features, mlp_dim, 1, 1, 0), + nn.GELU(), + nn.Conv2d(mlp_dim, prev_nbins * split_factor, 1, 1, 0), + nn.ReLU() + ) + + def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): + """ + x : feature block; shape - n, c, h, w + b_prev : previous bin widths normed; shape - n, prev_nbins, h, w + """ + if prev_b_embedding is not None: + if interpolate: + prev_b_embedding = nn.functional.interpolate(prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) + x = x + prev_b_embedding + S = self._net(x) + eps = 1e-3 + S = S + eps + n, c, h, w = S.shape + S = S.view(n, self.prev_nbins, self.split_factor, h, w) + S_normed = S / S.sum(dim=2, keepdim=True) # fractional splits + + b_prev = nn.functional.interpolate(b_prev, (h,w), mode='bilinear', align_corners=True) + + + b_prev = b_prev / b_prev.sum(dim=1, keepdim=True) # renormalize for gurantees + # print(b_prev.shape, S_normed.shape) + # if is_for_query:(1).expand(-1, b_prev.size(0)//n, -1, -1, -1, -1).flatten(0,1) # TODO ? can replace all this with a single torch.repeat? + b = b_prev.unsqueeze(2) * S_normed + b = b.flatten(1,2) # .shape n, prev_nbins * split_factor, h, w + + # calculate bin centers for loss calculation + B_widths = (self.max_depth - self.min_depth) * b # .shape N, nprev * splitfactor, H, W + # pad has the form (left, right, top, bottom, front, back) + B_widths = nn.functional.pad(B_widths, (0,0,0,0,1,0), mode='constant', value=self.min_depth) + B_edges = torch.cumsum(B_widths, dim=1) # .shape NCHW + + B_centers = 0.5 * (B_edges[:, :-1, ...] + B_edges[:,1:,...]) + return b, B_centers \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py b/custom_controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..99d9e51a06b981bae45ce7dd64eaef19a4121991 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/layers/patch_transformer.py @@ -0,0 +1,91 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch +import torch.nn as nn + + +class PatchTransformerEncoder(nn.Module): + def __init__(self, in_channels, patch_size=10, embedding_dim=128, num_heads=4, use_class_token=False): + """ViT-like transformer block + + Args: + in_channels (int): Input channels + patch_size (int, optional): patch size. Defaults to 10. + embedding_dim (int, optional): Embedding dimension in transformer model. Defaults to 128. + num_heads (int, optional): number of attention heads. Defaults to 4. + use_class_token (bool, optional): Whether to use extra token at the start for global accumulation (called as "class token"). Defaults to False. + """ + super(PatchTransformerEncoder, self).__init__() + self.use_class_token = use_class_token + encoder_layers = nn.TransformerEncoderLayer( + embedding_dim, num_heads, dim_feedforward=1024) + self.transformer_encoder = nn.TransformerEncoder( + encoder_layers, num_layers=4) # takes shape S,N,E + + self.embedding_convPxP = nn.Conv2d(in_channels, embedding_dim, + kernel_size=patch_size, stride=patch_size, padding=0) + + def positional_encoding_1d(self, sequence_length, batch_size, embedding_dim, device='cpu'): + """Generate positional encodings + + Args: + sequence_length (int): Sequence length + embedding_dim (int): Embedding dimension + + Returns: + torch.Tensor SBE: Positional encodings + """ + position = torch.arange( + 0, sequence_length, dtype=torch.float32, device=device).unsqueeze(1) + index = torch.arange( + 0, embedding_dim, 2, dtype=torch.float32, device=device).unsqueeze(0) + div_term = torch.exp(index * (-torch.log(torch.tensor(10000.0, device=device)) / embedding_dim)) + pos_encoding = position * div_term + pos_encoding = torch.cat([torch.sin(pos_encoding), torch.cos(pos_encoding)], dim=1) + pos_encoding = pos_encoding.unsqueeze(1).repeat(1, batch_size, 1) + return pos_encoding + + + def forward(self, x): + """Forward pass + + Args: + x (torch.Tensor - NCHW): Input feature tensor + + Returns: + torch.Tensor - SNE: Transformer output embeddings. S - sequence length (=HW/patch_size^2), N - batch size, E - embedding dim + """ + embeddings = self.embedding_convPxP(x).flatten( + 2) # .shape = n,c,s = n, embedding_dim, s + if self.use_class_token: + # extra special token at start ? + embeddings = nn.functional.pad(embeddings, (1, 0)) + + # change to S,N,E format required by transformer + embeddings = embeddings.permute(2, 0, 1) + S, N, E = embeddings.shape + embeddings = embeddings + self.positional_encoding_1d(S, N, E, device=embeddings.device) + x = self.transformer_encoder(embeddings) # .shape = S, N, E + return x diff --git a/custom_controlnet_aux/zoe/zoedepth/models/model_io.py b/custom_controlnet_aux/zoe/zoedepth/models/model_io.py new file mode 100644 index 0000000000000000000000000000000000000000..78b6579631dd847ac76651238cb5a948b5a66286 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/model_io.py @@ -0,0 +1,92 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import torch + +def load_state_dict(model, state_dict): + """Load state_dict into model, handling DataParallel and DistributedDataParallel. Also checks for "model" key in state_dict. + + DataParallel prefixes state_dict keys with 'module.' when saving. + If the model is not a DataParallel model but the state_dict is, then prefixes are removed. + If the model is a DataParallel model but the state_dict is not, then prefixes are added. + """ + state_dict = state_dict.get('model', state_dict) + # if model is a DataParallel model, then state_dict keys are prefixed with 'module.' + + do_prefix = isinstance( + model, (torch.nn.DataParallel, torch.nn.parallel.DistributedDataParallel)) + state = {} + for k, v in state_dict.items(): + if k.startswith('module.') and not do_prefix: + k = k[7:] + + if not k.startswith('module.') and do_prefix: + k = 'module.' + k + + state[k] = v + + model.load_state_dict(state) + print("Loaded successfully") + return model + + +def load_wts(model, checkpoint_path): + ckpt = torch.load(checkpoint_path, map_location='cpu') + return load_state_dict(model, ckpt) + + +def load_state_dict_from_url(model, url, **kwargs): + state_dict = torch.hub.load_state_dict_from_url(url, map_location='cpu', **kwargs) + return load_state_dict(model, state_dict) + + +def load_state_from_resource(model, resource: str): + """Loads weights to the model from a given resource. A resource can be of following types: + 1. URL. Prefixed with "url::" + e.g. url::http(s)://url.resource.com/ckpt.pt + + 2. Local path. Prefixed with "local::" + e.g. local::/path/to/ckpt.pt + + + Args: + model (torch.nn.Module): Model + resource (str): resource string + + Returns: + torch.nn.Module: Model with loaded weights + """ + print(f"Using pretrained resource {resource}") + + if resource.startswith('url::'): + url = resource.split('url::')[1] + return load_state_dict_from_url(model, url, progress=True) + + elif resource.startswith('local::'): + path = resource.split('local::')[1] + return load_wts(model, path) + + else: + raise ValueError("Invalid resource type, only url:: and local:: are supported") + \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json new file mode 100644 index 0000000000000000000000000000000000000000..3112ed78c89f00e1d13f5d6e5be87cd3216b6dc7 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth.json @@ -0,0 +1,58 @@ +{ + "model": { + "name": "ZoeDepth", + "version_name": "v1", + "n_bins": 64, + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "midas_model_type" : "DPT_BEiT_L_384", + "min_temp": 0.0212, + "max_temp": 50.0, + "output_distribution": "logbinomial", + "memory_efficient": true, + "inverse_midas": false, + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.000161, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 0.2, + "w_reg": 0, + "w_grad": 0, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "midas_lr_factor": 1, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10, + "freeze_midas_bn": true + + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null, + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : null + } +} \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json new file mode 100644 index 0000000000000000000000000000000000000000..b51802aa44b91c39e15aacaac4b5ab6bec884414 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/config_zoedepth_kitti.json @@ -0,0 +1,22 @@ +{ + "model": { + "bin_centers_type": "normed", + "img_size": [384, 768] + }, + + "train": { + }, + + "infer":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt", + "force_keep_ar": true + }, + + "eval":{ + "train_midas": false, + "use_pretrained_midas": false, + "pretrained_resource" : "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt" + } +} \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..bc931b059d6165c84e8ff4f09d5c62d19930cee9 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth/zoedepth_v1.py @@ -0,0 +1,250 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + b, c, h, w = x.shape + # print("input shape ", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc33f737d238766559f0e3a8def3c0b568f23b7f --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_v1 import ZoeDepth + +all_versions = { + "v1": ZoeDepth, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/zoedepth_v1.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/zoedepth_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..dd45f67a5f30cfe88c2fe4f03b4cb97526b00ddd --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_anything/zoedepth_v1.py @@ -0,0 +1,264 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn +from ..depth_model import DepthModel +from ..base_models.midas import MidasCore +from ..base_models.depth_anything import DepthAnythingCore +from ..layers.attractor import AttractorLayer, AttractorLayerUnnormed +from ..layers.dist_layers import ConditionalLogBinomial +from ..layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from ..model_io import load_state_from_resource + + +class ZoeDepth(DepthModel): + def __init__(self, core, n_bins=64, bin_centers_type="softplus", bin_embedding_dim=128, min_depth=1e-3, max_depth=10, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', min_temp=5, max_temp=50, train_midas=True, + midas_lr_factor=10, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepth model. This is the version of ZoeDepth that has a single metric head + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + n_bins (int, optional): Number of bin centers. Defaults to 64. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "softplus". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + min_depth (float, optional): Lower bound for normed bin centers. Defaults to 1e-3. + max_depth (float, optional): Upper bound for normed bin centers. Defaults to 10. + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + """ + super().__init__() + + self.core = core + self.max_depth = max_depth + self.min_depth = min_depth + self.min_temp = min_temp + self.bin_centers_type = bin_centers_type + + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.train_midas = train_midas + self.inverse_midas = inverse_midas + + if self.encoder_lr_factor <= 0: + self.core.freeze_encoder( + freeze_rel_pos=self.pos_enc_lr_factor <= 0) + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + + # print('core output channels:', self.core.output_channels) + + self.conv2 = nn.Conv2d(btlnck_features, btlnck_features, + kernel_size=1, stride=1, padding=0) # btlnck conv + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + + self.seed_bin_regressor = SeedBinRegressorLayer( + btlnck_features, n_bins=n_bins, min_depth=min_depth, max_depth=max_depth) + self.seed_projector = Projector(btlnck_features, bin_embedding_dim) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim) + for num_out in num_out_features + ]) + self.attractors = nn.ModuleList([ + Attractor(bin_embedding_dim, n_bins, n_attractors=n_attractors[i], min_depth=min_depth, max_depth=max_depth, + alpha=attractor_alpha, gamma=attractor_gamma, kind=attractor_kind, attractor_type=attractor_type) + for i in range(len(num_out_features)) + ]) + + last_in = N_MIDAS_OUT + 1 # +1 for relative depth + + # use log binomial instead of softmax + self.conditional_log_binomial = ConditionalLogBinomial( + last_in, bin_embedding_dim, n_classes=n_bins, min_temp=min_temp, max_temp=max_temp) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W) + return_final_centers (bool, optional): Whether to return the final bin centers. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. This reverses ImageNet normalization as midas normalization is different. Defaults to False. + return_probs (bool, optional): Whether to return the output probability distribution. Defaults to False. + + Returns: + dict: Dictionary containing the following keys: + - rel_depth (torch.Tensor): Relative depth map of shape (B, H, W) + - metric_depth (torch.Tensor): Metric depth map of shape (B, 1, H, W) + - bin_centers (torch.Tensor): Bin centers of shape (B, n_bins). Present only if return_final_centers is True + - probs (torch.Tensor): Output probability distribution of shape (B, n_bins, H, W). Present only if return_probs is True + + """ + # print('input shape', x.shape) + + b, c, h, w = x.shape + # print("input shape:", x.shape) + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + # print("output shapes", rel_depth.shape, out.shape) + # print('rel_depth shape:', rel_depth.shape) + # print('out type:', type(out)) + # for k in range(len(out)): + # print(k, out[k].shape) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + _, seed_b_centers = self.seed_bin_regressor(x) + + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - self.min_depth) / \ + (self.max_depth - self.min_depth) + else: + b_prev = seed_b_centers + + prev_b_embedding = self.seed_projector(x) + + # unroll this loop for better performance + for projector, attractor, x in zip(self.projectors, self.attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b.clone() + prev_b_embedding = b_embedding.clone() + + last = outconv_activation + + if self.inverse_midas: + # invert depth followed by normalization + rel_depth = 1.0 / (rel_depth + 1e-6) + rel_depth = (rel_depth - rel_depth.min()) / \ + (rel_depth.max() - rel_depth.min()) + # concat rel depth with last. First interpolate rel depth to last size + rel_cond = rel_depth.unsqueeze(1) + rel_cond = nn.functional.interpolate( + rel_cond, size=last.shape[2:], mode='bilinear', align_corners=True) + last = torch.cat([last, rel_cond], dim=1) + + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + x = self.conditional_log_binomial(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + b_centers = nn.functional.interpolate( + b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + # Structure output dict + output = dict(metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + if self.encoder_lr_factor > 0: + param_conf.append({'params': self.core.get_enc_params_except_rel_pos( + ), 'lr': lr / self.encoder_lr_factor}) + + if self.pos_enc_lr_factor > 0: + param_conf.append( + {'params': self.core.get_rel_pos_params(), 'lr': lr / self.pos_enc_lr_factor}) + + # midas_params = self.core.core.scratch.parameters() + midas_params = self.core.core.depth_head.parameters() + midas_lr_factor = self.midas_lr_factor + param_conf.append( + {'params': midas_params, 'lr': lr / midas_lr_factor}) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + + param_conf.append({'params': remaining_params, 'lr': lr}) + + return param_conf + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + # core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + # train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + + core = DepthAnythingCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + + model = ZoeDepth(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepth.build(**config) \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..513a278b939c10c010e3c0250ec73544d5663886 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/__init__.py @@ -0,0 +1,31 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +from .zoedepth_nk_v1 import ZoeDepthNK + +all_versions = { + "v1": ZoeDepthNK, +} + +get_version = lambda v : all_versions[v] \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json new file mode 100644 index 0000000000000000000000000000000000000000..42bab2a3ad159a09599a5aba270c491021a3cf1a --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/config_zoedepth_nk.json @@ -0,0 +1,67 @@ +{ + "model": { + "name": "ZoeDepthNK", + "version_name": "v1", + "bin_conf" : [ + { + "name": "nyu", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 10.0 + }, + { + "name": "kitti", + "n_bins": 64, + "min_depth": 1e-3, + "max_depth": 80.0 + } + ], + "bin_embedding_dim": 128, + "bin_centers_type": "softplus", + "n_attractors":[16, 8, 4, 1], + "attractor_alpha": 1000, + "attractor_gamma": 2, + "attractor_kind" : "mean", + "attractor_type" : "inv", + "min_temp": 0.0212, + "max_temp": 50.0, + "memory_efficient": true, + "midas_model_type" : "DPT_BEiT_L_384", + "img_size": [384, 512] + }, + + "train": { + "train_midas": true, + "use_pretrained_midas": true, + "trainer": "zoedepth_nk", + "epochs": 5, + "bs": 16, + "optim_kwargs": {"lr": 0.0002512, "wd": 0.01}, + "sched_kwargs": {"div_factor": 1, "final_div_factor": 10000, "pct_start": 0.7, "three_phase":false, "cycle_momentum": true}, + "same_lr": false, + "w_si": 1, + "w_domain": 100, + "avoid_boundary": false, + "random_crop": false, + "input_width": 640, + "input_height": 480, + "w_grad": 0, + "w_reg": 0, + "midas_lr_factor": 10, + "encoder_lr_factor":10, + "pos_enc_lr_factor":10 + }, + + "infer": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false, + "force_keep_ar": true + }, + + "eval": { + "train_midas": false, + "pretrained_resource": "url::https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt", + "use_pretrained_midas": false + } +} \ No newline at end of file diff --git a/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py new file mode 100644 index 0000000000000000000000000000000000000000..7368ae8031188a9f946d9d3f29633c96e791e68e --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/models/zoedepth_nk/zoedepth_nk_v1.py @@ -0,0 +1,333 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import itertools + +import torch +import torch.nn as nn + +from zoedepth.models.depth_model import DepthModel +from zoedepth.models.base_models.midas import MidasCore +from zoedepth.models.layers.attractor import AttractorLayer, AttractorLayerUnnormed +from zoedepth.models.layers.dist_layers import ConditionalLogBinomial +from zoedepth.models.layers.localbins_layers import (Projector, SeedBinRegressor, + SeedBinRegressorUnnormed) +from zoedepth.models.layers.patch_transformer import PatchTransformerEncoder +from zoedepth.models.model_io import load_state_from_resource + + +class ZoeDepthNK(DepthModel): + def __init__(self, core, bin_conf, bin_centers_type="softplus", bin_embedding_dim=128, + n_attractors=[16, 8, 4, 1], attractor_alpha=300, attractor_gamma=2, attractor_kind='sum', attractor_type='exp', + min_temp=5, max_temp=50, + memory_efficient=False, train_midas=True, + is_midas_pretrained=True, midas_lr_factor=1, encoder_lr_factor=10, pos_enc_lr_factor=10, inverse_midas=False, **kwargs): + """ZoeDepthNK model. This is the version of ZoeDepth that has two metric heads and uses a learned router to route to experts. + + Args: + core (models.base_models.midas.MidasCore): The base midas model that is used for extraction of "relative" features + + bin_conf (List[dict]): A list of dictionaries that contain the bin configuration for each metric head. Each dictionary should contain the following keys: + "name" (str, typically same as the dataset name), "n_bins" (int), "min_depth" (float), "max_depth" (float) + + The length of this list determines the number of metric heads. + bin_centers_type (str, optional): "normed" or "softplus". Activation type used for bin centers. For "normed" bin centers, linear normalization trick is applied. This results in bounded bin centers. + For "softplus", softplus activation is used and thus are unbounded. Defaults to "normed". + bin_embedding_dim (int, optional): bin embedding dimension. Defaults to 128. + + n_attractors (List[int], optional): Number of bin attractors at decoder layers. Defaults to [16, 8, 4, 1]. + attractor_alpha (int, optional): Proportional attractor strength. Refer to models.layers.attractor for more details. Defaults to 300. + attractor_gamma (int, optional): Exponential attractor strength. Refer to models.layers.attractor for more details. Defaults to 2. + attractor_kind (str, optional): Attraction aggregation "sum" or "mean". Defaults to 'sum'. + attractor_type (str, optional): Type of attractor to use; "inv" (Inverse attractor) or "exp" (Exponential attractor). Defaults to 'exp'. + + min_temp (int, optional): Lower bound for temperature of output probability distribution. Defaults to 5. + max_temp (int, optional): Upper bound for temperature of output probability distribution. Defaults to 50. + + memory_efficient (bool, optional): Whether to use memory efficient version of attractor layers. Memory efficient version is slower but is recommended incase of multiple metric heads in order save GPU memory. Defaults to False. + + train_midas (bool, optional): Whether to train "core", the base midas model. Defaults to True. + is_midas_pretrained (bool, optional): Is "core" pretrained? Defaults to True. + midas_lr_factor (int, optional): Learning rate reduction factor for base midas model except its encoder and positional encodings. Defaults to 10. + encoder_lr_factor (int, optional): Learning rate reduction factor for the encoder in midas model. Defaults to 10. + pos_enc_lr_factor (int, optional): Learning rate reduction factor for positional encodings in the base midas model. Defaults to 10. + + """ + + super().__init__() + + self.core = core + self.bin_conf = bin_conf + self.min_temp = min_temp + self.max_temp = max_temp + self.memory_efficient = memory_efficient + self.train_midas = train_midas + self.is_midas_pretrained = is_midas_pretrained + self.midas_lr_factor = midas_lr_factor + self.encoder_lr_factor = encoder_lr_factor + self.pos_enc_lr_factor = pos_enc_lr_factor + self.inverse_midas = inverse_midas + + N_MIDAS_OUT = 32 + btlnck_features = self.core.output_channels[0] + num_out_features = self.core.output_channels[1:] + # self.scales = [16, 8, 4, 2] # spatial scale factors + + self.conv2 = nn.Conv2d( + btlnck_features, btlnck_features, kernel_size=1, stride=1, padding=0) + + # Transformer classifier on the bottleneck + self.patch_transformer = PatchTransformerEncoder( + btlnck_features, 1, 128, use_class_token=True) + self.mlp_classifier = nn.Sequential( + nn.Linear(128, 128), + nn.ReLU(), + nn.Linear(128, 2) + ) + + if bin_centers_type == "normed": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayer + elif bin_centers_type == "softplus": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid1": + SeedBinRegressorLayer = SeedBinRegressor + Attractor = AttractorLayerUnnormed + elif bin_centers_type == "hybrid2": + SeedBinRegressorLayer = SeedBinRegressorUnnormed + Attractor = AttractorLayer + else: + raise ValueError( + "bin_centers_type should be one of 'normed', 'softplus', 'hybrid1', 'hybrid2'") + self.bin_centers_type = bin_centers_type + # We have bins for each bin conf. + # Create a map (ModuleDict) of 'name' -> seed_bin_regressor + self.seed_bin_regressors = nn.ModuleDict( + {conf['name']: SeedBinRegressorLayer(btlnck_features, conf["n_bins"], mlp_dim=bin_embedding_dim//2, min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for conf in bin_conf} + ) + + self.seed_projector = Projector( + btlnck_features, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + self.projectors = nn.ModuleList([ + Projector(num_out, bin_embedding_dim, mlp_dim=bin_embedding_dim//2) + for num_out in num_out_features + ]) + + # Create a map (ModuleDict) of 'name' -> attractors (ModuleList) + self.attractors = nn.ModuleDict( + {conf['name']: nn.ModuleList([ + Attractor(bin_embedding_dim, n_attractors[i], + mlp_dim=bin_embedding_dim, alpha=attractor_alpha, + gamma=attractor_gamma, kind=attractor_kind, + attractor_type=attractor_type, memory_efficient=memory_efficient, + min_depth=conf["min_depth"], max_depth=conf["max_depth"]) + for i in range(len(n_attractors)) + ]) + for conf in bin_conf} + ) + + last_in = N_MIDAS_OUT + # conditional log binomial for each bin conf + self.conditional_log_binomial = nn.ModuleDict( + {conf['name']: ConditionalLogBinomial(last_in, bin_embedding_dim, conf['n_bins'], bottleneck_factor=4, min_temp=self.min_temp, max_temp=self.max_temp) + for conf in bin_conf} + ) + + def forward(self, x, return_final_centers=False, denorm=False, return_probs=False, **kwargs): + """ + Args: + x (torch.Tensor): Input image tensor of shape (B, C, H, W). Assumes all images are from the same domain. + return_final_centers (bool, optional): Whether to return the final centers of the attractors. Defaults to False. + denorm (bool, optional): Whether to denormalize the input image. Defaults to False. + return_probs (bool, optional): Whether to return the probabilities of the bins. Defaults to False. + + Returns: + dict: Dictionary of outputs with keys: + - "rel_depth": Relative depth map of shape (B, 1, H, W) + - "metric_depth": Metric depth map of shape (B, 1, H, W) + - "domain_logits": Domain logits of shape (B, 2) + - "bin_centers": Bin centers of shape (B, N, H, W). Present only if return_final_centers is True + - "probs": Bin probabilities of shape (B, N, H, W). Present only if return_probs is True + """ + b, c, h, w = x.shape + self.orig_input_width = w + self.orig_input_height = h + rel_depth, out = self.core(x, denorm=denorm, return_rel_depth=True) + + outconv_activation = out[0] + btlnck = out[1] + x_blocks = out[2:] + + x_d0 = self.conv2(btlnck) + x = x_d0 + + # Predict which path to take + embedding = self.patch_transformer(x)[0] # N, E + domain_logits = self.mlp_classifier(embedding) # N, 2 + domain_vote = torch.softmax(domain_logits.sum( + dim=0, keepdim=True), dim=-1) # 1, 2 + + # Get the path + bin_conf_name = ["nyu", "kitti"][torch.argmax( + domain_vote, dim=-1).squeeze().item()] + + try: + conf = [c for c in self.bin_conf if c.name == bin_conf_name][0] + except IndexError: + raise ValueError( + f"bin_conf_name {bin_conf_name} not found in bin_confs") + + min_depth = conf['min_depth'] + max_depth = conf['max_depth'] + + seed_bin_regressor = self.seed_bin_regressors[bin_conf_name] + _, seed_b_centers = seed_bin_regressor(x) + if self.bin_centers_type == 'normed' or self.bin_centers_type == 'hybrid2': + b_prev = (seed_b_centers - min_depth)/(max_depth - min_depth) + else: + b_prev = seed_b_centers + prev_b_embedding = self.seed_projector(x) + + attractors = self.attractors[bin_conf_name] + for projector, attractor, x in zip(self.projectors, attractors, x_blocks): + b_embedding = projector(x) + b, b_centers = attractor( + b_embedding, b_prev, prev_b_embedding, interpolate=True) + b_prev = b + prev_b_embedding = b_embedding + + last = outconv_activation + + b_centers = nn.functional.interpolate( + b_centers, last.shape[-2:], mode='bilinear', align_corners=True) + b_embedding = nn.functional.interpolate( + b_embedding, last.shape[-2:], mode='bilinear', align_corners=True) + + clb = self.conditional_log_binomial[bin_conf_name] + x = clb(last, b_embedding) + + # Now depth value is Sum px * cx , where cx are bin_centers from the last bin tensor + # print(x.shape, b_centers.shape) + # b_centers = nn.functional.interpolate(b_centers, x.shape[-2:], mode='bilinear', align_corners=True) + out = torch.sum(x * b_centers, dim=1, keepdim=True) + + output = dict(domain_logits=domain_logits, metric_depth=out) + if return_final_centers or return_probs: + output['bin_centers'] = b_centers + + if return_probs: + output['probs'] = x + return output + + def get_lr_params(self, lr): + """ + Learning rate configuration for different layers of the model + + Args: + lr (float) : Base learning rate + Returns: + list : list of parameters to optimize and their learning rates, in the format required by torch optimizers. + """ + param_conf = [] + if self.train_midas: + def get_rel_pos_params(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" in name: + yield p + + def get_enc_params_except_rel_pos(): + for name, p in self.core.core.pretrained.named_parameters(): + if "relative_position" not in name: + yield p + + encoder_params = get_enc_params_except_rel_pos() + rel_pos_params = get_rel_pos_params() + midas_params = self.core.core.scratch.parameters() + midas_lr_factor = self.midas_lr_factor if self.is_midas_pretrained else 1.0 + param_conf.extend([ + {'params': encoder_params, 'lr': lr / self.encoder_lr_factor}, + {'params': rel_pos_params, 'lr': lr / self.pos_enc_lr_factor}, + {'params': midas_params, 'lr': lr / midas_lr_factor} + ]) + + remaining_modules = [] + for name, child in self.named_children(): + if name != 'core': + remaining_modules.append(child) + remaining_params = itertools.chain( + *[child.parameters() for child in remaining_modules]) + param_conf.append({'params': remaining_params, 'lr': lr}) + return param_conf + + def get_conf_parameters(self, conf_name): + """ + Returns parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + params = [] + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + if bin_conf_name == conf_name: + params += list(module.parameters()) + return params + + def freeze_conf(self, conf_name): + """ + Freezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = False + + def unfreeze_conf(self, conf_name): + """ + Unfreezes all the parameters of all the ModuleDicts children that are exclusively used for the given bin configuration + """ + for p in self.get_conf_parameters(conf_name): + p.requires_grad = True + + def freeze_all_confs(self): + """ + Freezes all the parameters of all the ModuleDicts children + """ + for name, child in self.named_children(): + if isinstance(child, nn.ModuleDict): + for bin_conf_name, module in child.items(): + for p in module.parameters(): + p.requires_grad = False + + @staticmethod + def build(midas_model_type="DPT_BEiT_L_384", pretrained_resource=None, use_pretrained_midas=False, train_midas=False, freeze_midas_bn=True, **kwargs): + core = MidasCore.build(midas_model_type=midas_model_type, use_pretrained_midas=use_pretrained_midas, + train_midas=train_midas, fetch_features=True, freeze_bn=freeze_midas_bn, **kwargs) + model = ZoeDepthNK(core, **kwargs) + if pretrained_resource: + assert isinstance(pretrained_resource, str), "pretrained_resource must be a string" + model = load_state_from_resource(model, pretrained_resource) + return model + + @staticmethod + def build_from_config(config): + return ZoeDepthNK.build(**config) diff --git a/custom_controlnet_aux/zoe/zoedepth/utils/__init__.py b/custom_controlnet_aux/zoe/zoedepth/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f2668792389157609abb2a0846fb620e7d67eb9 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/utils/__init__.py @@ -0,0 +1,24 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + diff --git a/custom_controlnet_aux/zoe/zoedepth/utils/arg_utils.py b/custom_controlnet_aux/zoe/zoedepth/utils/arg_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a3004ec3679c0a40fd8961253733fb4343ad545 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/utils/arg_utils.py @@ -0,0 +1,33 @@ + + +def infer_type(x): # hacky way to infer type from string args + if not isinstance(x, str): + return x + + try: + x = int(x) + return x + except ValueError: + pass + + try: + x = float(x) + return x + except ValueError: + pass + + return x + + +def parse_unknown(unknown_args): + clean = [] + for a in unknown_args: + if "=" in a: + k, v = a.split("=") + clean.extend([k, v]) + else: + clean.append(a) + + keys = clean[::2] + values = clean[1::2] + return {k.replace("--", ""): infer_type(v) for k, v in zip(keys, values)} diff --git a/custom_controlnet_aux/zoe/zoedepth/utils/config.py b/custom_controlnet_aux/zoe/zoedepth/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..84996564663dadf0e720de2a68ef8c53106ed666 --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/utils/config.py @@ -0,0 +1,437 @@ +# MIT License + +# Copyright (c) 2022 Intelligent Systems Lab Org + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +# File author: Shariq Farooq Bhat + +import json +import os + +from .easydict import EasyDict as edict +from .arg_utils import infer_type + +import pathlib +import platform + +ROOT = pathlib.Path(__file__).parent.parent.resolve() + +HOME_DIR = os.path.expanduser("~") + +COMMON_CONFIG = { + "save_dir": os.path.expanduser("~/shortcuts/monodepth3_checkpoints"), + "project": "ZoeDepth", + "tags": '', + "notes": "", + "gpu": None, + "root": ".", + "uid": None, + "print_losses": False +} + +DATASETS_CONFIG = { + "kitti": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, # 704 + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "kitti_test": { + "dataset": "kitti", + "min_depth": 0.001, + "max_depth": 80, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file": "./train_test_inputs/kitti_eigen_train_files_with_gt.txt", + "input_height": 352, + "input_width": 1216, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/raw"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/kitti/gts"), + "filenames_file_eval": "./train_test_inputs/kitti_eigen_test_files_with_gt.txt", + + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + + "do_random_rotate": False, + "degree": 1.0, + "do_kb_crop": True, + "garg_crop": True, + "eigen_crop": False, + "use_right": False + }, + "nyu": { + "dataset": "nyu", + "avoid_boundary": False, + "min_depth": 1e-3, # originally 0.1 + "max_depth": 10, + "data_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "gt_path": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/sync/"), + "filenames_file": "./train_test_inputs/nyudepthv2_train_files_with_gt.txt", + "input_height": 480, + "input_width": 640, + "data_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "gt_path_eval": os.path.join(HOME_DIR, "shortcuts/datasets/nyu_depth_v2/official_splits/test/"), + "filenames_file_eval": "./train_test_inputs/nyudepthv2_test_files_with_gt.txt", + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth_diff": -10, + "max_depth_diff": 10, + + "do_random_rotate": True, + "degree": 1.0, + "do_kb_crop": False, + "garg_crop": False, + "eigen_crop": True + }, + "ibims": { + "dataset": "ibims", + "ibims_root": os.path.join(HOME_DIR, "shortcuts/datasets/ibims/ibims1_core_raw/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "sunrgbd": { + "dataset": "sunrgbd", + "sunrgbd_root": os.path.join(HOME_DIR, "shortcuts/datasets/SUNRGBD/test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 8, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_indoor": { + "dataset": "diml_indoor", + "diml_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_indoor_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 0, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diml_outdoor": { + "dataset": "diml_outdoor", + "diml_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diml_outdoor_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 2, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "diode_indoor": { + "dataset": "diode_indoor", + "diode_indoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_indoor/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 10, + "min_depth": 1e-3, + "max_depth": 10 + }, + "diode_outdoor": { + "dataset": "diode_outdoor", + "diode_outdoor_root": os.path.join(HOME_DIR, "shortcuts/datasets/diode_outdoor/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "hypersim_test": { + "dataset": "hypersim_test", + "hypersim_test_root": os.path.join(HOME_DIR, "shortcuts/datasets/hypersim_test/"), + "eigen_crop": True, + "garg_crop": False, + "do_kb_crop": False, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 10 + }, + "vkitti": { + "dataset": "vkitti", + "vkitti_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti_test/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80 + }, + "vkitti2": { + "dataset": "vkitti2", + "vkitti2_root": os.path.join(HOME_DIR, "shortcuts/datasets/vkitti2/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, + "ddad": { + "dataset": "ddad", + "ddad_root": os.path.join(HOME_DIR, "shortcuts/datasets/ddad/ddad_val/"), + "eigen_crop": False, + "garg_crop": True, + "do_kb_crop": True, + "min_depth_eval": 1e-3, + "max_depth_eval": 80, + "min_depth": 1e-3, + "max_depth": 80, + }, +} + +ALL_INDOOR = ["nyu", "ibims", "sunrgbd", "diode_indoor", "hypersim_test"] +ALL_OUTDOOR = ["kitti", "diml_outdoor", "diode_outdoor", "vkitti2", "ddad"] +ALL_EVAL_DATASETS = ALL_INDOOR + ALL_OUTDOOR + +COMMON_TRAINING_CONFIG = { + "dataset": "nyu", + "distributed": True, + "workers": 16, + "clip_grad": 0.1, + "use_shared_dict": False, + "shared_dict": None, + "use_amp": False, + + "aug": True, + "random_crop": False, + "random_translate": False, + "translate_prob": 0.2, + "max_translation": 100, + + "validate_every": 0.25, + "log_images_every": 0.1, + "prefetch": False, +} + + +def flatten(config, except_keys=('bin_conf')): + def recurse(inp): + if isinstance(inp, dict): + for key, value in inp.items(): + if key in except_keys: + yield (key, value) + if isinstance(value, dict): + yield from recurse(value) + else: + yield (key, value) + + return dict(list(recurse(config))) + + +def split_combined_args(kwargs): + """Splits the arguments that are combined with '__' into multiple arguments. + Combined arguments should have equal number of keys and values. + Keys are separated by '__' and Values are separated with ';'. + For example, '__n_bins__lr=256;0.001' + + Args: + kwargs (dict): key-value pairs of arguments where key-value is optionally combined according to the above format. + + Returns: + dict: Parsed dict with the combined arguments split into individual key-value pairs. + """ + new_kwargs = dict(kwargs) + for key, value in kwargs.items(): + if key.startswith("__"): + keys = key.split("__")[1:] + values = value.split(";") + assert len(keys) == len( + values), f"Combined arguments should have equal number of keys and values. Keys are separated by '__' and Values are separated with ';'. For example, '__n_bins__lr=256;0.001. Given (keys,values) is ({keys}, {values})" + for k, v in zip(keys, values): + new_kwargs[k] = v + return new_kwargs + + +def parse_list(config, key, dtype=int): + """Parse a list of values for the key if the value is a string. The values are separated by a comma. + Modifies the config in place. + """ + if key in config: + if isinstance(config[key], str): + config[key] = list(map(dtype, config[key].split(','))) + assert isinstance(config[key], list) and all([isinstance(e, dtype) for e in config[key]] + ), f"{key} should be a list of values dtype {dtype}. Given {config[key]} of type {type(config[key])} with values of type {[type(e) for e in config[key]]}." + + +def get_model_config(model_name, model_version=None): + """Find and parse the .json config file for the model. + + Args: + model_name (str): name of the model. The config file should be named config_{model_name}[_{model_version}].json under the models/{model_name} directory. + model_version (str, optional): Specific config version. If specified config_{model_name}_{model_version}.json is searched for and used. Otherwise config_{model_name}.json is used. Defaults to None. + + Returns: + easydict: the config dictionary for the model. + """ + config_fname = f"config_{model_name}_{model_version}.json" if model_version is not None else f"config_{model_name}.json" + config_file = os.path.join(ROOT, "models", model_name, config_fname) + if not os.path.exists(config_file): + return None + + with open(config_file, "r") as f: + config = edict(json.load(f)) + + # handle dictionary inheritance + # only training config is supported for inheritance + if "inherit" in config.train and config.train.inherit is not None: + inherit_config = get_model_config(config.train["inherit"]).train + for key, value in inherit_config.items(): + if key not in config.train: + config.train[key] = value + return edict(config) + + +def update_model_config(config, mode, model_name, model_version=None, strict=False): + model_config = get_model_config(model_name, model_version) + if model_config is not None: + config = {**config, ** + flatten({**model_config.model, **model_config[mode]})} + elif strict: + raise ValueError(f"Config file for model {model_name} not found.") + return config + + +def check_choices(name, value, choices): + # return # No checks in dev branch + if value not in choices: + raise ValueError(f"{name} {value} not in supported choices {choices}") + + +KEYS_TYPE_BOOL = ["use_amp", "distributed", "use_shared_dict", "same_lr", "aug", "three_phase", + "prefetch", "cycle_momentum"] # Casting is not necessary as their int casted values in config are 0 or 1 + + +def get_config(model_name, mode='train', dataset=None, **overwrite_kwargs): + """Main entry point to get the config for the model. + + Args: + model_name (str): name of the desired model. + mode (str, optional): "train" or "infer". Defaults to 'train'. + dataset (str, optional): If specified, the corresponding dataset configuration is loaded as well. Defaults to None. + + Keyword Args: key-value pairs of arguments to overwrite the default config. + + The order of precedence for overwriting the config is (Higher precedence first): + # 1. overwrite_kwargs + # 2. "config_version": Config file version if specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{config_version}.json + # 3. "version_name": Default Model version specific config specified in overwrite_kwargs. The corresponding config loaded is config_{model_name}_{version_name}.json + # 4. common_config: Default config for all models specified in COMMON_CONFIG + + Returns: + easydict: The config dictionary for the model. + """ + + + check_choices("Model", model_name, ["zoedepth", "zoedepth_nk"]) + check_choices("Mode", mode, ["train", "infer", "eval"]) + if mode == "train": + check_choices("Dataset", dataset, ["nyu", "kitti", "mix", None]) + + config = flatten({**COMMON_CONFIG, **COMMON_TRAINING_CONFIG}) + config = update_model_config(config, mode, model_name) + + # update with model version specific config + version_name = overwrite_kwargs.get("version_name", config["version_name"]) + config = update_model_config(config, mode, model_name, version_name) + + # update with config version if specified + config_version = overwrite_kwargs.get("config_version", None) + if config_version is not None: + print("Overwriting config with config_version", config_version) + config = update_model_config(config, mode, model_name, config_version) + + # update with overwrite_kwargs + # Combined args are useful for hyperparameter search + overwrite_kwargs = split_combined_args(overwrite_kwargs) + config = {**config, **overwrite_kwargs} + + # Casting to bool # TODO: Not necessary. Remove and test + for key in KEYS_TYPE_BOOL: + if key in config: + config[key] = bool(config[key]) + + # Model specific post processing of config + parse_list(config, "n_attractors") + + # adjust n_bins for each bin configuration if bin_conf is given and n_bins is passed in overwrite_kwargs + if 'bin_conf' in config and 'n_bins' in overwrite_kwargs: + bin_conf = config['bin_conf'] # list of dicts + n_bins = overwrite_kwargs['n_bins'] + new_bin_conf = [] + for conf in bin_conf: + conf['n_bins'] = n_bins + new_bin_conf.append(conf) + config['bin_conf'] = new_bin_conf + + if mode == "train": + orig_dataset = dataset + if dataset == "mix": + dataset = 'nyu' # Use nyu as default for mix. Dataset config is changed accordingly while loading the dataloader + if dataset is not None: + config['project'] = f"MonoDepth3-{orig_dataset}" # Set project for wandb + + if dataset is not None: + config['dataset'] = dataset + config = {**DATASETS_CONFIG[dataset], **config} + + + config['model'] = model_name + typed_config = {k: infer_type(v) for k, v in config.items()} + # add hostname to config + config['hostname'] = platform.node() + return edict(typed_config) + + +def change_dataset(config, new_dataset): + config.update(DATASETS_CONFIG[new_dataset]) + return config diff --git a/custom_controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py b/custom_controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..15928179b0182c6045d98bc0a7be1c6ca45f675e --- /dev/null +++ b/custom_controlnet_aux/zoe/zoedepth/utils/easydict/__init__.py @@ -0,0 +1,158 @@ +""" +EasyDict +Copy/pasted from https://github.com/makinacorpus/easydict +Original author: Mathieu Leplatre +""" + +class EasyDict(dict): + """ + Get attributes + + >>> d = EasyDict({'foo':3}) + >>> d['foo'] + 3 + >>> d.foo + 3 + >>> d.bar + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'bar' + + Works recursively + + >>> d = EasyDict({'foo':3, 'bar':{'x':1, 'y':2}}) + >>> isinstance(d.bar, dict) + True + >>> d.bar.x + 1 + + Bullet-proof + + >>> EasyDict({}) + {} + >>> EasyDict(d={}) + {} + >>> EasyDict(None) + {} + >>> d = {'a': 1} + >>> EasyDict(**d) + {'a': 1} + >>> EasyDict((('a', 1), ('b', 2))) + {'a': 1, 'b': 2} + + Set attributes + + >>> d = EasyDict() + >>> d.foo = 3 + >>> d.foo + 3 + >>> d.bar = {'prop': 'value'} + >>> d.bar.prop + 'value' + >>> d + {'foo': 3, 'bar': {'prop': 'value'}} + >>> d.bar.prop = 'newer' + >>> d.bar.prop + 'newer' + + + Values extraction + + >>> d = EasyDict({'foo':0, 'bar':[{'x':1, 'y':2}, {'x':3, 'y':4}]}) + >>> isinstance(d.bar, list) + True + >>> from operator import attrgetter + >>> list(map(attrgetter('x'), d.bar)) + [1, 3] + >>> list(map(attrgetter('y'), d.bar)) + [2, 4] + >>> d = EasyDict() + >>> list(d.keys()) + [] + >>> d = EasyDict(foo=3, bar=dict(x=1, y=2)) + >>> d.foo + 3 + >>> d.bar.x + 1 + + Still like a dict though + + >>> o = EasyDict({'clean':True}) + >>> list(o.items()) + [('clean', True)] + + And like a class + + >>> class Flower(EasyDict): + ... power = 1 + ... + >>> f = Flower() + >>> f.power + 1 + >>> f = Flower({'height': 12}) + >>> f.height + 12 + >>> f['power'] + 1 + >>> sorted(f.keys()) + ['height', 'power'] + + update and pop items + >>> d = EasyDict(a=1, b='2') + >>> e = EasyDict(c=3.0, a=9.0) + >>> d.update(e) + >>> d.c + 3.0 + >>> d['c'] + 3.0 + >>> d.get('c') + 3.0 + >>> d.update(a=4, b=4) + >>> d.b + 4 + >>> d.pop('a') + 4 + >>> d.a + Traceback (most recent call last): + ... + AttributeError: 'EasyDict' object has no attribute 'a' + """ + def __init__(self, d=None, **kwargs): + if d is None: + d = {} + else: + d = dict(d) + if kwargs: + d.update(**kwargs) + for k, v in d.items(): + setattr(self, k, v) + # Class attributes + for k in self.__class__.__dict__.keys(): + if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [self.__class__(x) + if isinstance(x, dict) else x for x in value] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(EasyDict, self).__setattr__(name, value) + super(EasyDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +if __name__ == "__main__": + import doctest + doctest.testmod() \ No newline at end of file diff --git a/custom_detectron2/__init__.py b/custom_detectron2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bdd994b49294485c27610772f97f177741f5518f --- /dev/null +++ b/custom_detectron2/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .utils.env import setup_environment + +setup_environment() + + +# This line will be programatically read/write by setup.py. +# Leave them at the bottom of this file and don't touch them. +__version__ = "0.6" diff --git a/custom_detectron2/checkpoint/__init__.py b/custom_detectron2/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..99da0469ae7e169d8970e4b642fed3f870076860 --- /dev/null +++ b/custom_detectron2/checkpoint/__init__.py @@ -0,0 +1,10 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +# File: + + +from . import catalog as _UNUSED # register the handler +from .detection_checkpoint import DetectionCheckpointer +from fvcore.common.checkpoint import Checkpointer, PeriodicCheckpointer + +__all__ = ["Checkpointer", "PeriodicCheckpointer", "DetectionCheckpointer"] diff --git a/custom_detectron2/checkpoint/c2_model_loading.py b/custom_detectron2/checkpoint/c2_model_loading.py new file mode 100644 index 0000000000000000000000000000000000000000..c6de2a3c830089aa7a0d27df96bb4a45fc5a7b0d --- /dev/null +++ b/custom_detectron2/checkpoint/c2_model_loading.py @@ -0,0 +1,412 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import re +from typing import Dict, List +import torch +from tabulate import tabulate + + +def convert_basic_c2_names(original_keys): + """ + Apply some basic name conversion to names in C2 weights. + It only deals with typical backbone models. + + Args: + original_keys (list[str]): + Returns: + list[str]: The same number of strings matching those in original_keys. + """ + layer_keys = copy.deepcopy(original_keys) + layer_keys = [ + {"pred_b": "linear_b", "pred_w": "linear_w"}.get(k, k) for k in layer_keys + ] # some hard-coded mappings + + layer_keys = [k.replace("_", ".") for k in layer_keys] + layer_keys = [re.sub("\\.b$", ".bias", k) for k in layer_keys] + layer_keys = [re.sub("\\.w$", ".weight", k) for k in layer_keys] + # Uniform both bn and gn names to "norm" + layer_keys = [re.sub("bn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.bias$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.rm", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.mean$", "norm.running_mean", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.riv$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.running.var$", "norm.running_var", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.gamma$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("bn\\.beta$", "norm.bias", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.s$", "norm.weight", k) for k in layer_keys] + layer_keys = [re.sub("gn\\.bias$", "norm.bias", k) for k in layer_keys] + + # stem + layer_keys = [re.sub("^res\\.conv1\\.norm\\.", "conv1.norm.", k) for k in layer_keys] + # to avoid mis-matching with "conv1" in other components (e.g. detection head) + layer_keys = [re.sub("^conv1\\.", "stem.conv1.", k) for k in layer_keys] + + # layer1-4 is used by torchvision, however we follow the C2 naming strategy (res2-5) + # layer_keys = [re.sub("^res2.", "layer1.", k) for k in layer_keys] + # layer_keys = [re.sub("^res3.", "layer2.", k) for k in layer_keys] + # layer_keys = [re.sub("^res4.", "layer3.", k) for k in layer_keys] + # layer_keys = [re.sub("^res5.", "layer4.", k) for k in layer_keys] + + # blocks + layer_keys = [k.replace(".branch1.", ".shortcut.") for k in layer_keys] + layer_keys = [k.replace(".branch2a.", ".conv1.") for k in layer_keys] + layer_keys = [k.replace(".branch2b.", ".conv2.") for k in layer_keys] + layer_keys = [k.replace(".branch2c.", ".conv3.") for k in layer_keys] + + # DensePose substitutions + layer_keys = [re.sub("^body.conv.fcn", "body_conv_fcn", k) for k in layer_keys] + layer_keys = [k.replace("AnnIndex.lowres", "ann_index_lowres") for k in layer_keys] + layer_keys = [k.replace("Index.UV.lowres", "index_uv_lowres") for k in layer_keys] + layer_keys = [k.replace("U.lowres", "u_lowres") for k in layer_keys] + layer_keys = [k.replace("V.lowres", "v_lowres") for k in layer_keys] + return layer_keys + + +def convert_c2_detectron_names(weights): + """ + Map Caffe2 Detectron weight names to Detectron2 names. + + Args: + weights (dict): name -> tensor + + Returns: + dict: detectron2 names -> tensor + dict: detectron2 names -> C2 names + """ + logger = logging.getLogger(__name__) + logger.info("Renaming Caffe2 weights ......") + original_keys = sorted(weights.keys()) + layer_keys = copy.deepcopy(original_keys) + + layer_keys = convert_basic_c2_names(layer_keys) + + # -------------------------------------------------------------------------- + # RPN hidden representation conv + # -------------------------------------------------------------------------- + # FPN case + # In the C2 model, the RPN hidden layer conv is defined for FPN level 2 and then + # shared for all other levels, hence the appearance of "fpn2" + layer_keys = [ + k.replace("conv.rpn.fpn2", "proposal_generator.rpn_head.conv") for k in layer_keys + ] + # Non-FPN case + layer_keys = [k.replace("conv.rpn", "proposal_generator.rpn_head.conv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # RPN box transformation conv + # -------------------------------------------------------------------------- + # FPN case (see note above about "fpn2") + layer_keys = [ + k.replace("rpn.bbox.pred.fpn2", "proposal_generator.rpn_head.anchor_deltas") + for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits.fpn2", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + # Non-FPN case + layer_keys = [ + k.replace("rpn.bbox.pred", "proposal_generator.rpn_head.anchor_deltas") for k in layer_keys + ] + layer_keys = [ + k.replace("rpn.cls.logits", "proposal_generator.rpn_head.objectness_logits") + for k in layer_keys + ] + + # -------------------------------------------------------------------------- + # Fast R-CNN box head + # -------------------------------------------------------------------------- + layer_keys = [re.sub("^bbox\\.pred", "bbox_pred", k) for k in layer_keys] + layer_keys = [re.sub("^cls\\.score", "cls_score", k) for k in layer_keys] + layer_keys = [re.sub("^fc6\\.", "box_head.fc1.", k) for k in layer_keys] + layer_keys = [re.sub("^fc7\\.", "box_head.fc2.", k) for k in layer_keys] + # 4conv1fc head tensor names: head_conv1_w, head_conv1_gn_s + layer_keys = [re.sub("^head\\.conv", "box_head.conv", k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # FPN lateral and output convolutions + # -------------------------------------------------------------------------- + def fpn_map(name): + """ + Look for keys with the following patterns: + 1) Starts with "fpn.inner." + Example: "fpn.inner.res2.2.sum.lateral.weight" + Meaning: These are lateral pathway convolutions + 2) Starts with "fpn.res" + Example: "fpn.res2.2.sum.weight" + Meaning: These are FPN output convolutions + """ + splits = name.split(".") + norm = ".norm" if "norm" in splits else "" + if name.startswith("fpn.inner."): + # splits example: ['fpn', 'inner', 'res2', '2', 'sum', 'lateral', 'weight'] + stage = int(splits[2][len("res") :]) + return "fpn_lateral{}{}.{}".format(stage, norm, splits[-1]) + elif name.startswith("fpn.res"): + # splits example: ['fpn', 'res2', '2', 'sum', 'weight'] + stage = int(splits[1][len("res") :]) + return "fpn_output{}{}.{}".format(stage, norm, splits[-1]) + return name + + layer_keys = [fpn_map(k) for k in layer_keys] + + # -------------------------------------------------------------------------- + # Mask R-CNN mask head + # -------------------------------------------------------------------------- + # roi_heads.StandardROIHeads case + layer_keys = [k.replace(".[mask].fcn", "mask_head.mask_fcn") for k in layer_keys] + layer_keys = [re.sub("^\\.mask\\.fcn", "mask_head.mask_fcn", k) for k in layer_keys] + layer_keys = [k.replace("mask.fcn.logits", "mask_head.predictor") for k in layer_keys] + # roi_heads.Res5ROIHeads case + layer_keys = [k.replace("conv5.mask", "mask_head.deconv") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Keypoint R-CNN head + # -------------------------------------------------------------------------- + # interestingly, the keypoint head convs have blob names that are simply "conv_fcnX" + layer_keys = [k.replace("conv.fcn", "roi_heads.keypoint_head.conv_fcn") for k in layer_keys] + layer_keys = [ + k.replace("kps.score.lowres", "roi_heads.keypoint_head.score_lowres") for k in layer_keys + ] + layer_keys = [k.replace("kps.score.", "roi_heads.keypoint_head.score.") for k in layer_keys] + + # -------------------------------------------------------------------------- + # Done with replacements + # -------------------------------------------------------------------------- + assert len(set(layer_keys)) == len(layer_keys) + assert len(original_keys) == len(layer_keys) + + new_weights = {} + new_keys_to_original_keys = {} + for orig, renamed in zip(original_keys, layer_keys): + new_keys_to_original_keys[renamed] = orig + if renamed.startswith("bbox_pred.") or renamed.startswith("mask_head.predictor."): + # remove the meaningless prediction weight for background class + new_start_idx = 4 if renamed.startswith("bbox_pred.") else 1 + new_weights[renamed] = weights[orig][new_start_idx:] + logger.info( + "Remove prediction weight for background class in {}. The shape changes from " + "{} to {}.".format( + renamed, tuple(weights[orig].shape), tuple(new_weights[renamed].shape) + ) + ) + elif renamed.startswith("cls_score."): + # move weights of bg class from original index 0 to last index + logger.info( + "Move classification weights for background class in {} from index 0 to " + "index {}.".format(renamed, weights[orig].shape[0] - 1) + ) + new_weights[renamed] = torch.cat([weights[orig][1:], weights[orig][:1]]) + else: + new_weights[renamed] = weights[orig] + + return new_weights, new_keys_to_original_keys + + +# Note the current matching is not symmetric. +# it assumes model_state_dict will have longer names. +def align_and_update_state_dicts(model_state_dict, ckpt_state_dict, c2_conversion=True): + """ + Match names between the two state-dict, and returns a new chkpt_state_dict with names + converted to match model_state_dict with heuristics. The returned dict can be later + loaded with fvcore checkpointer. + If `c2_conversion==True`, `ckpt_state_dict` is assumed to be a Caffe2 + model and will be renamed at first. + + Strategy: suppose that the models that we will create will have prefixes appended + to each of its keys, for example due to an extra level of nesting that the original + pre-trained weights from ImageNet won't contain. For example, model.state_dict() + might return backbone[0].body.res2.conv1.weight, while the pre-trained model contains + res2.conv1.weight. We thus want to match both parameters together. + For that, we look for each model weight, look among all loaded keys if there is one + that is a suffix of the current weight name, and use it if that's the case. + If multiple matches exist, take the one with longest size + of the corresponding name. For example, for the same model as before, the pretrained + weight file can contain both res2.conv1.weight, as well as conv1.weight. In this case, + we want to match backbone[0].body.conv1.weight to conv1.weight, and + backbone[0].body.res2.conv1.weight to res2.conv1.weight. + """ + model_keys = sorted(model_state_dict.keys()) + if c2_conversion: + ckpt_state_dict, original_keys = convert_c2_detectron_names(ckpt_state_dict) + # original_keys: the name in the original dict (before renaming) + else: + original_keys = {x: x for x in ckpt_state_dict.keys()} + ckpt_keys = sorted(ckpt_state_dict.keys()) + + def match(a, b): + # Matched ckpt_key should be a complete (starts with '.') suffix. + # For example, roi_heads.mesh_head.whatever_conv1 does not match conv1, + # but matches whatever_conv1 or mesh_head.whatever_conv1. + return a == b or a.endswith("." + b) + + # get a matrix of string matches, where each (i, j) entry correspond to the size of the + # ckpt_key string, if it matches + match_matrix = [len(j) if match(i, j) else 0 for i in model_keys for j in ckpt_keys] + match_matrix = torch.as_tensor(match_matrix).view(len(model_keys), len(ckpt_keys)) + # use the matched one with longest size in case of multiple matches + max_match_size, idxs = match_matrix.max(1) + # remove indices that correspond to no-match + idxs[max_match_size == 0] = -1 + + logger = logging.getLogger(__name__) + # matched_pairs (matched checkpoint key --> matched model key) + matched_keys = {} + result_state_dict = {} + for idx_model, idx_ckpt in enumerate(idxs.tolist()): + if idx_ckpt == -1: + continue + key_model = model_keys[idx_model] + key_ckpt = ckpt_keys[idx_ckpt] + value_ckpt = ckpt_state_dict[key_ckpt] + shape_in_model = model_state_dict[key_model].shape + + if shape_in_model != value_ckpt.shape: + logger.warning( + "Shape of {} in checkpoint is {}, while shape of {} in model is {}.".format( + key_ckpt, value_ckpt.shape, key_model, shape_in_model + ) + ) + logger.warning( + "{} will not be loaded. Please double check and see if this is desired.".format( + key_ckpt + ) + ) + continue + + assert key_model not in result_state_dict + result_state_dict[key_model] = value_ckpt + if key_ckpt in matched_keys: # already added to matched_keys + logger.error( + "Ambiguity found for {} in checkpoint!" + "It matches at least two keys in the model ({} and {}).".format( + key_ckpt, key_model, matched_keys[key_ckpt] + ) + ) + raise ValueError("Cannot match one checkpoint key to multiple keys in the model.") + + matched_keys[key_ckpt] = key_model + + # logging: + matched_model_keys = sorted(matched_keys.values()) + if len(matched_model_keys) == 0: + logger.warning("No weights in checkpoint matched with model.") + return ckpt_state_dict + common_prefix = _longest_common_prefix(matched_model_keys) + rev_matched_keys = {v: k for k, v in matched_keys.items()} + original_keys = {k: original_keys[rev_matched_keys[k]] for k in matched_model_keys} + + model_key_groups = _group_keys_by_module(matched_model_keys, original_keys) + table = [] + memo = set() + for key_model in matched_model_keys: + if key_model in memo: + continue + if key_model in model_key_groups: + group = model_key_groups[key_model] + memo |= set(group) + shapes = [tuple(model_state_dict[k].shape) for k in group] + table.append( + ( + _longest_common_prefix([k[len(common_prefix) :] for k in group]) + "*", + _group_str([original_keys[k] for k in group]), + " ".join([str(x).replace(" ", "") for x in shapes]), + ) + ) + else: + key_checkpoint = original_keys[key_model] + shape = str(tuple(model_state_dict[key_model].shape)) + table.append((key_model[len(common_prefix) :], key_checkpoint, shape)) + table_str = tabulate( + table, tablefmt="pipe", headers=["Names in Model", "Names in Checkpoint", "Shapes"] + ) + logger.info( + "Following weights matched with " + + (f"submodule {common_prefix[:-1]}" if common_prefix else "model") + + ":\n" + + table_str + ) + + unmatched_ckpt_keys = [k for k in ckpt_keys if k not in set(matched_keys.keys())] + for k in unmatched_ckpt_keys: + result_state_dict[k] = ckpt_state_dict[k] + return result_state_dict + + +def _group_keys_by_module(keys: List[str], original_names: Dict[str, str]): + """ + Params in the same submodule are grouped together. + + Args: + keys: names of all parameters + original_names: mapping from parameter name to their name in the checkpoint + + Returns: + dict[name -> all other names in the same group] + """ + + def _submodule_name(key): + pos = key.rfind(".") + if pos < 0: + return None + prefix = key[: pos + 1] + return prefix + + all_submodules = [_submodule_name(k) for k in keys] + all_submodules = [x for x in all_submodules if x] + all_submodules = sorted(all_submodules, key=len) + + ret = {} + for prefix in all_submodules: + group = [k for k in keys if k.startswith(prefix)] + if len(group) <= 1: + continue + original_name_lcp = _longest_common_prefix_str([original_names[k] for k in group]) + if len(original_name_lcp) == 0: + # don't group weights if original names don't share prefix + continue + + for k in group: + if k in ret: + continue + ret[k] = group + return ret + + +def _longest_common_prefix(names: List[str]) -> str: + """ + ["abc.zfg", "abc.zef"] -> "abc." + """ + names = [n.split(".") for n in names] + m1, m2 = min(names), max(names) + ret = [a for a, b in zip(m1, m2) if a == b] + ret = ".".join(ret) + "." if len(ret) else "" + return ret + + +def _longest_common_prefix_str(names: List[str]) -> str: + m1, m2 = min(names), max(names) + lcp = [] + for a, b in zip(m1, m2): + if a == b: + lcp.append(a) + else: + break + lcp = "".join(lcp) + return lcp + + +def _group_str(names: List[str]) -> str: + """ + Turn "common1", "common2", "common3" into "common{1,2,3}" + """ + lcp = _longest_common_prefix_str(names) + rest = [x[len(lcp) :] for x in names] + rest = "{" + ",".join(rest) + "}" + ret = lcp + rest + + # add some simplification for BN specifically + ret = ret.replace("bn_{beta,running_mean,running_var,gamma}", "bn_*") + ret = ret.replace("bn_beta,bn_running_mean,bn_running_var,bn_gamma", "bn_*") + return ret diff --git a/custom_detectron2/checkpoint/catalog.py b/custom_detectron2/checkpoint/catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..2287e9b7fbf4d3ed7ec9bdb26d6d1d4d8ed91196 --- /dev/null +++ b/custom_detectron2/checkpoint/catalog.py @@ -0,0 +1,114 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging + +from custom_detectron2.utils.file_io import PathHandler, PathManager + +class ModelCatalog(object): + """ + Store mappings from names to third-party models. + """ + + S3_C2_DETECTRON_PREFIX = "https://dl.fbaipublicfiles.com/detectron" + + # MSRA models have STRIDE_IN_1X1=True. False otherwise. + # NOTE: all BN models here have fused BN into an affine layer. + # As a result, you should only load them to a model with "FrozenBN". + # Loading them to a model with regular BN or SyncBN is wrong. + # Even when loaded to FrozenBN, it is still different from affine by an epsilon, + # which should be negligible for training. + # NOTE: all models here uses PIXEL_STD=[1,1,1] + # NOTE: Most of the BN models here are no longer used. We use the + # re-converted pre-trained models under detectron2 model zoo instead. + C2_IMAGENET_MODELS = { + "MSRA/R-50": "ImageNetPretrained/MSRA/R-50.pkl", + "MSRA/R-101": "ImageNetPretrained/MSRA/R-101.pkl", + "FAIR/R-50-GN": "ImageNetPretrained/47261647/R-50-GN.pkl", + "FAIR/R-101-GN": "ImageNetPretrained/47592356/R-101-GN.pkl", + "FAIR/X-101-32x8d": "ImageNetPretrained/20171220/X-101-32x8d.pkl", + "FAIR/X-101-64x4d": "ImageNetPretrained/FBResNeXt/X-101-64x4d.pkl", + "FAIR/X-152-32x8d-IN5k": "ImageNetPretrained/25093814/X-152-32x8d-IN5k.pkl", + } + + C2_DETECTRON_PATH_FORMAT = ( + "{prefix}/{url}/output/train/{dataset}/{type}/model_final.pkl" # noqa B950 + ) + + C2_DATASET_COCO = "coco_2014_train%3Acoco_2014_valminusminival" + C2_DATASET_COCO_KEYPOINTS = "keypoints_coco_2014_train%3Akeypoints_coco_2014_valminusminival" + + # format: {model_name} -> part of the url + C2_DETECTRON_MODELS = { + "35857197/e2e_faster_rcnn_R-50-C4_1x": "35857197/12_2017_baselines/e2e_faster_rcnn_R-50-C4_1x.yaml.01_33_49.iAX0mXvW", # noqa B950 + "35857345/e2e_faster_rcnn_R-50-FPN_1x": "35857345/12_2017_baselines/e2e_faster_rcnn_R-50-FPN_1x.yaml.01_36_30.cUF7QR7I", # noqa B950 + "35857890/e2e_faster_rcnn_R-101-FPN_1x": "35857890/12_2017_baselines/e2e_faster_rcnn_R-101-FPN_1x.yaml.01_38_50.sNxI7sX7", # noqa B950 + "36761737/e2e_faster_rcnn_X-101-32x8d-FPN_1x": "36761737/12_2017_baselines/e2e_faster_rcnn_X-101-32x8d-FPN_1x.yaml.06_31_39.5MIHi1fZ", # noqa B950 + "35858791/e2e_mask_rcnn_R-50-C4_1x": "35858791/12_2017_baselines/e2e_mask_rcnn_R-50-C4_1x.yaml.01_45_57.ZgkA7hPB", # noqa B950 + "35858933/e2e_mask_rcnn_R-50-FPN_1x": "35858933/12_2017_baselines/e2e_mask_rcnn_R-50-FPN_1x.yaml.01_48_14.DzEQe4wC", # noqa B950 + "35861795/e2e_mask_rcnn_R-101-FPN_1x": "35861795/12_2017_baselines/e2e_mask_rcnn_R-101-FPN_1x.yaml.02_31_37.KqyEK4tT", # noqa B950 + "36761843/e2e_mask_rcnn_X-101-32x8d-FPN_1x": "36761843/12_2017_baselines/e2e_mask_rcnn_X-101-32x8d-FPN_1x.yaml.06_35_59.RZotkLKI", # noqa B950 + "48616381/e2e_mask_rcnn_R-50-FPN_2x_gn": "GN/48616381/04_2018_gn_baselines/e2e_mask_rcnn_R-50-FPN_2x_gn_0416.13_23_38.bTlTI97Q", # noqa B950 + "37697547/e2e_keypoint_rcnn_R-50-FPN_1x": "37697547/12_2017_baselines/e2e_keypoint_rcnn_R-50-FPN_1x.yaml.08_42_54.kdzV35ao", # noqa B950 + "35998355/rpn_R-50-C4_1x": "35998355/12_2017_baselines/rpn_R-50-C4_1x.yaml.08_00_43.njH5oD9L", # noqa B950 + "35998814/rpn_R-50-FPN_1x": "35998814/12_2017_baselines/rpn_R-50-FPN_1x.yaml.08_06_03.Axg0r179", # noqa B950 + "36225147/fast_R-50-FPN_1x": "36225147/12_2017_baselines/fast_rcnn_R-50-FPN_1x.yaml.08_39_09.L3obSdQ2", # noqa B950 + } + + @staticmethod + def get(name): + if name.startswith("Caffe2Detectron/COCO"): + return ModelCatalog._get_c2_detectron_baseline(name) + if name.startswith("ImageNetPretrained/"): + return ModelCatalog._get_c2_imagenet_pretrained(name) + raise RuntimeError("model not present in the catalog: {}".format(name)) + + @staticmethod + def _get_c2_imagenet_pretrained(name): + prefix = ModelCatalog.S3_C2_DETECTRON_PREFIX + name = name[len("ImageNetPretrained/") :] + name = ModelCatalog.C2_IMAGENET_MODELS[name] + url = "/".join([prefix, name]) + return url + + @staticmethod + def _get_c2_detectron_baseline(name): + name = name[len("Caffe2Detectron/COCO/") :] + url = ModelCatalog.C2_DETECTRON_MODELS[name] + if "keypoint_rcnn" in name: + dataset = ModelCatalog.C2_DATASET_COCO_KEYPOINTS + else: + dataset = ModelCatalog.C2_DATASET_COCO + + if "35998355/rpn_R-50-C4_1x" in name: + # this one model is somehow different from others .. + type = "rpn" + else: + type = "generalized_rcnn" + + # Detectron C2 models are stored in the structure defined in `C2_DETECTRON_PATH_FORMAT`. + url = ModelCatalog.C2_DETECTRON_PATH_FORMAT.format( + prefix=ModelCatalog.S3_C2_DETECTRON_PREFIX, url=url, type=type, dataset=dataset + ) + return url + + +class ModelCatalogHandler(PathHandler): + """ + Resolve URL like catalog://. + """ + + PREFIX = "catalog://" + + def _get_supported_prefixes(self): + return [self.PREFIX] + + def _get_local_path(self, path, **kwargs): + logger = logging.getLogger(__name__) + catalog_path = ModelCatalog.get(path[len(self.PREFIX) :]) + logger.info("Catalog entry {} points to {}".format(path, catalog_path)) + return PathManager.get_local_path(catalog_path, **kwargs) + + def _open(self, path, mode="r", **kwargs): + return PathManager.open(self._get_local_path(path), mode, **kwargs) + + +PathManager.register_handler(ModelCatalogHandler()) diff --git a/custom_detectron2/checkpoint/detection_checkpoint.py b/custom_detectron2/checkpoint/detection_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2776783ac380d931ded4af69fc5e08bab5ab159a --- /dev/null +++ b/custom_detectron2/checkpoint/detection_checkpoint.py @@ -0,0 +1,145 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os +import pickle +from urllib.parse import parse_qs, urlparse +import torch +from fvcore.common.checkpoint import Checkpointer +from torch.nn.parallel import DistributedDataParallel + +import custom_detectron2.utils.comm as comm +from custom_detectron2.utils.file_io import PathManager + +from .c2_model_loading import align_and_update_state_dicts + + +class DetectionCheckpointer(Checkpointer): + """ + Same as :class:`Checkpointer`, but is able to: + 1. handle models in detectron & detectron2 model zoo, and apply conversions for legacy models. + 2. correctly load checkpoints that are only available on the master worker + """ + + def __init__(self, model, save_dir="", *, save_to_disk=None, **checkpointables): + is_main_process = comm.is_main_process() + super().__init__( + model, + save_dir, + save_to_disk=is_main_process if save_to_disk is None else save_to_disk, + **checkpointables, + ) + self.path_manager = PathManager + self._parsed_url_during_load = None + + def load(self, path, *args, **kwargs): + assert self._parsed_url_during_load is None + need_sync = False + logger = logging.getLogger(__name__) + logger.info("[DetectionCheckpointer] Loading from {} ...".format(path)) + + if path and isinstance(self.model, DistributedDataParallel): + path = self.path_manager.get_local_path(path) + has_file = os.path.isfile(path) + all_has_file = comm.all_gather(has_file) + if not all_has_file[0]: + raise OSError(f"File {path} not found on main worker.") + if not all(all_has_file): + logger.warning( + f"Not all workers can read checkpoint {path}. " + "Training may fail to fully resume." + ) + # TODO: broadcast the checkpoint file contents from main + # worker, and load from it instead. + need_sync = True + if not has_file: + path = None # don't load if not readable + + if path: + parsed_url = urlparse(path) + self._parsed_url_during_load = parsed_url + path = parsed_url._replace(query="").geturl() # remove query from filename + path = self.path_manager.get_local_path(path) + + self.logger.setLevel('CRITICAL') + ret = super().load(path, *args, **kwargs) + + if need_sync: + logger.info("Broadcasting model states from main worker ...") + self.model._sync_params_and_buffers() + self._parsed_url_during_load = None # reset to None + return ret + + def _load_file(self, filename): + if filename.endswith(".pkl"): + with PathManager.open(filename, "rb") as f: + data = pickle.load(f, encoding="latin1") + if "model" in data and "__author__" in data: + # file is in Detectron2 model zoo format + self.logger.info("Reading a file from '{}'".format(data["__author__"])) + return data + else: + # assume file is from Caffe2 / Detectron1 model zoo + if "blobs" in data: + # Detection models have "blobs", but ImageNet models don't + data = data["blobs"] + data = {k: v for k, v in data.items() if not k.endswith("_momentum")} + return {"model": data, "__author__": "Caffe2", "matching_heuristics": True} + elif filename.endswith(".pyth"): + # assume file is from pycls; no one else seems to use the ".pyth" extension + with PathManager.open(filename, "rb") as f: + data = torch.load(f) + assert ( + "model_state" in data + ), f"Cannot load .pyth file {filename}; pycls checkpoints must contain 'model_state'." + model_state = { + k: v + for k, v in data["model_state"].items() + if not k.endswith("num_batches_tracked") + } + return {"model": model_state, "__author__": "pycls", "matching_heuristics": True} + + loaded = self._torch_load(filename) + if "model" not in loaded: + loaded = {"model": loaded} + assert self._parsed_url_during_load is not None, "`_load_file` must be called inside `load`" + parsed_url = self._parsed_url_during_load + queries = parse_qs(parsed_url.query) + if queries.pop("matching_heuristics", "False") == ["True"]: + loaded["matching_heuristics"] = True + if len(queries) > 0: + raise ValueError( + f"Unsupported query remaining: f{queries}, orginal filename: {parsed_url.geturl()}" + ) + return loaded + + def _torch_load(self, f): + return super()._load_file(f) + + def _load_model(self, checkpoint): + if checkpoint.get("matching_heuristics", False): + self._convert_ndarray_to_tensor(checkpoint["model"]) + # convert weights by name-matching heuristics + checkpoint["model"] = align_and_update_state_dicts( + self.model.state_dict(), + checkpoint["model"], + c2_conversion=checkpoint.get("__author__", None) == "Caffe2", + ) + # for non-caffe2 models, use standard ways to load it + incompatible = super()._load_model(checkpoint) + + model_buffers = dict(self.model.named_buffers(recurse=False)) + for k in ["pixel_mean", "pixel_std"]: + # Ignore missing key message about pixel_mean/std. + # Though they may be missing in old checkpoints, they will be correctly + # initialized from config anyway. + if k in model_buffers: + try: + incompatible.missing_keys.remove(k) + except ValueError: + pass + for k in incompatible.unexpected_keys[:]: + # Ignore unexpected keys about cell anchors. They exist in old checkpoints + # but now they are non-persistent buffers and will not be in new checkpoints. + if "anchor_generator.cell_anchors" in k: + incompatible.unexpected_keys.remove(k) + return incompatible diff --git a/custom_detectron2/config/__init__.py b/custom_detectron2/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b98b0872b423a665525d25ab8c203c217543e149 --- /dev/null +++ b/custom_detectron2/config/__init__.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .compat import downgrade_config, upgrade_config +from .config import CfgNode, get_cfg, global_cfg, set_global_cfg, configurable +from .instantiate import instantiate +from .lazy import LazyCall, LazyConfig + +__all__ = [ + "CfgNode", + "get_cfg", + "global_cfg", + "set_global_cfg", + "downgrade_config", + "upgrade_config", + "configurable", + "instantiate", + "LazyCall", + "LazyConfig", +] + + +from custom_detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/custom_detectron2/config/compat.py b/custom_detectron2/config/compat.py new file mode 100644 index 0000000000000000000000000000000000000000..11a08c439bf14defd880e37a938fab8a08e68eeb --- /dev/null +++ b/custom_detectron2/config/compat.py @@ -0,0 +1,229 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Backward compatibility of configs. + +Instructions to bump version: ++ It's not needed to bump version if new keys are added. + It's only needed when backward-incompatible changes happen + (i.e., some existing keys disappear, or the meaning of a key changes) ++ To bump version, do the following: + 1. Increment _C.VERSION in defaults.py + 2. Add a converter in this file. + + Each ConverterVX has a function "upgrade" which in-place upgrades config from X-1 to X, + and a function "downgrade" which in-place downgrades config from X to X-1 + + In each function, VERSION is left unchanged. + + Each converter assumes that its input has the relevant keys + (i.e., the input is not a partial config). + 3. Run the tests (test_config.py) to make sure the upgrade & downgrade + functions are consistent. +""" + +import logging +from typing import List, Optional, Tuple + +from .config import CfgNode as CN +from .defaults import _C + +__all__ = ["upgrade_config", "downgrade_config"] + + +def upgrade_config(cfg: CN, to_version: Optional[int] = None) -> CN: + """ + Upgrade a config from its current version to a newer version. + + Args: + cfg (CfgNode): + to_version (int): defaults to the latest version. + """ + cfg = cfg.clone() + if to_version is None: + to_version = _C.VERSION + + assert cfg.VERSION <= to_version, "Cannot upgrade from v{} to v{}!".format( + cfg.VERSION, to_version + ) + for k in range(cfg.VERSION, to_version): + converter = globals()["ConverterV" + str(k + 1)] + converter.upgrade(cfg) + cfg.VERSION = k + 1 + return cfg + + +def downgrade_config(cfg: CN, to_version: int) -> CN: + """ + Downgrade a config from its current version to an older version. + + Args: + cfg (CfgNode): + to_version (int): + + Note: + A general downgrade of arbitrary configs is not always possible due to the + different functionalities in different versions. + The purpose of downgrade is only to recover the defaults in old versions, + allowing it to load an old partial yaml config. + Therefore, the implementation only needs to fill in the default values + in the old version when a general downgrade is not possible. + """ + cfg = cfg.clone() + assert cfg.VERSION >= to_version, "Cannot downgrade from v{} to v{}!".format( + cfg.VERSION, to_version + ) + for k in range(cfg.VERSION, to_version, -1): + converter = globals()["ConverterV" + str(k)] + converter.downgrade(cfg) + cfg.VERSION = k - 1 + return cfg + + +def guess_version(cfg: CN, filename: str) -> int: + """ + Guess the version of a partial config where the VERSION field is not specified. + Returns the version, or the latest if cannot make a guess. + + This makes it easier for users to migrate. + """ + logger = logging.getLogger(__name__) + + def _has(name: str) -> bool: + cur = cfg + for n in name.split("."): + if n not in cur: + return False + cur = cur[n] + return True + + # Most users' partial configs have "MODEL.WEIGHT", so guess on it + ret = None + if _has("MODEL.WEIGHT") or _has("TEST.AUG_ON"): + ret = 1 + + if ret is not None: + logger.warning("Config '{}' has no VERSION. Assuming it to be v{}.".format(filename, ret)) + else: + ret = _C.VERSION + logger.warning( + "Config '{}' has no VERSION. Assuming it to be compatible with latest v{}.".format( + filename, ret + ) + ) + return ret + + +def _rename(cfg: CN, old: str, new: str) -> None: + old_keys = old.split(".") + new_keys = new.split(".") + + def _set(key_seq: List[str], val: str) -> None: + cur = cfg + for k in key_seq[:-1]: + if k not in cur: + cur[k] = CN() + cur = cur[k] + cur[key_seq[-1]] = val + + def _get(key_seq: List[str]) -> CN: + cur = cfg + for k in key_seq: + cur = cur[k] + return cur + + def _del(key_seq: List[str]) -> None: + cur = cfg + for k in key_seq[:-1]: + cur = cur[k] + del cur[key_seq[-1]] + if len(cur) == 0 and len(key_seq) > 1: + _del(key_seq[:-1]) + + _set(new_keys, _get(old_keys)) + _del(old_keys) + + +class _RenameConverter: + """ + A converter that handles simple rename. + """ + + RENAME: List[Tuple[str, str]] = [] # list of tuples of (old name, new name) + + @classmethod + def upgrade(cls, cfg: CN) -> None: + for old, new in cls.RENAME: + _rename(cfg, old, new) + + @classmethod + def downgrade(cls, cfg: CN) -> None: + for old, new in cls.RENAME[::-1]: + _rename(cfg, new, old) + + +class ConverterV1(_RenameConverter): + RENAME = [("MODEL.RPN_HEAD.NAME", "MODEL.RPN.HEAD_NAME")] + + +class ConverterV2(_RenameConverter): + """ + A large bulk of rename, before public release. + """ + + RENAME = [ + ("MODEL.WEIGHT", "MODEL.WEIGHTS"), + ("MODEL.PANOPTIC_FPN.SEMANTIC_LOSS_SCALE", "MODEL.SEM_SEG_HEAD.LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.RPN_LOSS_SCALE", "MODEL.RPN.LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.INSTANCE_LOSS_SCALE", "MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT"), + ("MODEL.PANOPTIC_FPN.COMBINE_ON", "MODEL.PANOPTIC_FPN.COMBINE.ENABLED"), + ( + "MODEL.PANOPTIC_FPN.COMBINE_OVERLAP_THRESHOLD", + "MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH", + ), + ( + "MODEL.PANOPTIC_FPN.COMBINE_STUFF_AREA_LIMIT", + "MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT", + ), + ( + "MODEL.PANOPTIC_FPN.COMBINE_INSTANCES_CONFIDENCE_THRESHOLD", + "MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH", + ), + ("MODEL.ROI_HEADS.SCORE_THRESH", "MODEL.ROI_HEADS.SCORE_THRESH_TEST"), + ("MODEL.ROI_HEADS.NMS", "MODEL.ROI_HEADS.NMS_THRESH_TEST"), + ("MODEL.RETINANET.INFERENCE_SCORE_THRESHOLD", "MODEL.RETINANET.SCORE_THRESH_TEST"), + ("MODEL.RETINANET.INFERENCE_TOPK_CANDIDATES", "MODEL.RETINANET.TOPK_CANDIDATES_TEST"), + ("MODEL.RETINANET.INFERENCE_NMS_THRESHOLD", "MODEL.RETINANET.NMS_THRESH_TEST"), + ("TEST.DETECTIONS_PER_IMG", "TEST.DETECTIONS_PER_IMAGE"), + ("TEST.AUG_ON", "TEST.AUG.ENABLED"), + ("TEST.AUG_MIN_SIZES", "TEST.AUG.MIN_SIZES"), + ("TEST.AUG_MAX_SIZE", "TEST.AUG.MAX_SIZE"), + ("TEST.AUG_FLIP", "TEST.AUG.FLIP"), + ] + + @classmethod + def upgrade(cls, cfg: CN) -> None: + super().upgrade(cfg) + + if cfg.MODEL.META_ARCHITECTURE == "RetinaNet": + _rename( + cfg, "MODEL.RETINANET.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS" + ) + _rename(cfg, "MODEL.RETINANET.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES") + del cfg["MODEL"]["RPN"]["ANCHOR_SIZES"] + del cfg["MODEL"]["RPN"]["ANCHOR_ASPECT_RATIOS"] + else: + _rename(cfg, "MODEL.RPN.ANCHOR_ASPECT_RATIOS", "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS") + _rename(cfg, "MODEL.RPN.ANCHOR_SIZES", "MODEL.ANCHOR_GENERATOR.SIZES") + del cfg["MODEL"]["RETINANET"]["ANCHOR_SIZES"] + del cfg["MODEL"]["RETINANET"]["ANCHOR_ASPECT_RATIOS"] + del cfg["MODEL"]["RETINANET"]["ANCHOR_STRIDES"] + + @classmethod + def downgrade(cls, cfg: CN) -> None: + super().downgrade(cfg) + + _rename(cfg, "MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS", "MODEL.RPN.ANCHOR_ASPECT_RATIOS") + _rename(cfg, "MODEL.ANCHOR_GENERATOR.SIZES", "MODEL.RPN.ANCHOR_SIZES") + cfg.MODEL.RETINANET.ANCHOR_ASPECT_RATIOS = cfg.MODEL.RPN.ANCHOR_ASPECT_RATIOS + cfg.MODEL.RETINANET.ANCHOR_SIZES = cfg.MODEL.RPN.ANCHOR_SIZES + cfg.MODEL.RETINANET.ANCHOR_STRIDES = [] # this is not used anywhere in any version diff --git a/custom_detectron2/config/config.py b/custom_detectron2/config/config.py new file mode 100644 index 0000000000000000000000000000000000000000..e9ba1a5e4b7006bece7388efb59008dc902db0ba --- /dev/null +++ b/custom_detectron2/config/config.py @@ -0,0 +1,265 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import inspect +import logging +from fvcore.common.config import CfgNode as _CfgNode + +from custom_detectron2.utils.file_io import PathManager + + +class CfgNode(_CfgNode): + """ + The same as `fvcore.common.config.CfgNode`, but different in: + + 1. Use unsafe yaml loading by default. + Note that this may lead to arbitrary code execution: you must not + load a config file from untrusted sources before manually inspecting + the content of the file. + 2. Support config versioning. + When attempting to merge an old config, it will convert the old config automatically. + + .. automethod:: clone + .. automethod:: freeze + .. automethod:: defrost + .. automethod:: is_frozen + .. automethod:: load_yaml_with_base + .. automethod:: merge_from_list + .. automethod:: merge_from_other_cfg + """ + + @classmethod + def _open_cfg(cls, filename): + return PathManager.open(filename, "r") + + # Note that the default value of allow_unsafe is changed to True + def merge_from_file(self, cfg_filename: str, allow_unsafe: bool = True) -> None: + """ + Load content from the given config file and merge it into self. + + Args: + cfg_filename: config filename + allow_unsafe: allow unsafe yaml syntax + """ + assert PathManager.isfile(cfg_filename), f"Config file '{cfg_filename}' does not exist!" + loaded_cfg = self.load_yaml_with_base(cfg_filename, allow_unsafe=allow_unsafe) + loaded_cfg = type(self)(loaded_cfg) + + # defaults.py needs to import CfgNode + from .defaults import _C + + latest_ver = _C.VERSION + assert ( + latest_ver == self.VERSION + ), "CfgNode.merge_from_file is only allowed on a config object of latest version!" + + logger = logging.getLogger(__name__) + + loaded_ver = loaded_cfg.get("VERSION", None) + if loaded_ver is None: + from .compat import guess_version + + loaded_ver = guess_version(loaded_cfg, cfg_filename) + assert loaded_ver <= self.VERSION, "Cannot merge a v{} config into a v{} config.".format( + loaded_ver, self.VERSION + ) + + if loaded_ver == self.VERSION: + self.merge_from_other_cfg(loaded_cfg) + else: + # compat.py needs to import CfgNode + from .compat import upgrade_config, downgrade_config + + logger.warning( + "Loading an old v{} config file '{}' by automatically upgrading to v{}. " + "See docs/CHANGELOG.md for instructions to update your files.".format( + loaded_ver, cfg_filename, self.VERSION + ) + ) + # To convert, first obtain a full config at an old version + old_self = downgrade_config(self, to_version=loaded_ver) + old_self.merge_from_other_cfg(loaded_cfg) + new_config = upgrade_config(old_self) + self.clear() + self.update(new_config) + + def dump(self, *args, **kwargs): + """ + Returns: + str: a yaml string representation of the config + """ + # to make it show up in docs + return super().dump(*args, **kwargs) + + +global_cfg = CfgNode() + + +def get_cfg() -> CfgNode: + """ + Get a copy of the default config. + + Returns: + a detectron2 CfgNode instance. + """ + from .defaults import _C + + return _C.clone() + + +def set_global_cfg(cfg: CfgNode) -> None: + """ + Let the global config point to the given cfg. + + Assume that the given "cfg" has the key "KEY", after calling + `set_global_cfg(cfg)`, the key can be accessed by: + :: + from custom_detectron2.config import global_cfg + print(global_cfg.KEY) + + By using a hacky global config, you can access these configs anywhere, + without having to pass the config object or the values deep into the code. + This is a hacky feature introduced for quick prototyping / research exploration. + """ + global global_cfg + global_cfg.clear() + global_cfg.update(cfg) + + +def configurable(init_func=None, *, from_config=None): + """ + Decorate a function or a class's __init__ method so that it can be called + with a :class:`CfgNode` object using a :func:`from_config` function that translates + :class:`CfgNode` to arguments. + + Examples: + :: + # Usage 1: Decorator on __init__: + class A: + @configurable + def __init__(self, a, b=2, c=3): + pass + + @classmethod + def from_config(cls, cfg): # 'cfg' must be the first argument + # Returns kwargs to be passed to __init__ + return {"a": cfg.A, "b": cfg.B} + + a1 = A(a=1, b=2) # regular construction + a2 = A(cfg) # construct with a cfg + a3 = A(cfg, b=3, c=4) # construct with extra overwrite + + # Usage 2: Decorator on any function. Needs an extra from_config argument: + @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) + def a_func(a, b=2, c=3): + pass + + a1 = a_func(a=1, b=2) # regular call + a2 = a_func(cfg) # call with a cfg + a3 = a_func(cfg, b=3, c=4) # call with extra overwrite + + Args: + init_func (callable): a class's ``__init__`` method in usage 1. The + class must have a ``from_config`` classmethod which takes `cfg` as + the first argument. + from_config (callable): the from_config function in usage 2. It must take `cfg` + as its first argument. + """ + + if init_func is not None: + assert ( + inspect.isfunction(init_func) + and from_config is None + and init_func.__name__ == "__init__" + ), "Incorrect use of @configurable. Check API documentation for examples." + + @functools.wraps(init_func) + def wrapped(self, *args, **kwargs): + try: + from_config_func = type(self).from_config + except AttributeError as e: + raise AttributeError( + "Class with @configurable must have a 'from_config' classmethod." + ) from e + if not inspect.ismethod(from_config_func): + raise TypeError("Class with @configurable must have a 'from_config' classmethod.") + + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config_func, *args, **kwargs) + init_func(self, **explicit_args) + else: + init_func(self, *args, **kwargs) + + return wrapped + + else: + if from_config is None: + return configurable # @configurable() is made equivalent to @configurable + assert inspect.isfunction( + from_config + ), "from_config argument of configurable must be a function!" + + def wrapper(orig_func): + @functools.wraps(orig_func) + def wrapped(*args, **kwargs): + if _called_with_cfg(*args, **kwargs): + explicit_args = _get_args_from_config(from_config, *args, **kwargs) + return orig_func(**explicit_args) + else: + return orig_func(*args, **kwargs) + + wrapped.from_config = from_config + return wrapped + + return wrapper + + +def _get_args_from_config(from_config_func, *args, **kwargs): + """ + Use `from_config` to obtain explicit arguments. + + Returns: + dict: arguments to be used for cls.__init__ + """ + signature = inspect.signature(from_config_func) + if list(signature.parameters.keys())[0] != "cfg": + if inspect.isfunction(from_config_func): + name = from_config_func.__name__ + else: + name = f"{from_config_func.__self__}.from_config" + raise TypeError(f"{name} must take 'cfg' as the first argument!") + support_var_arg = any( + param.kind in [param.VAR_POSITIONAL, param.VAR_KEYWORD] + for param in signature.parameters.values() + ) + if support_var_arg: # forward all arguments to from_config, if from_config accepts them + ret = from_config_func(*args, **kwargs) + else: + # forward supported arguments to from_config + supported_arg_names = set(signature.parameters.keys()) + extra_kwargs = {} + for name in list(kwargs.keys()): + if name not in supported_arg_names: + extra_kwargs[name] = kwargs.pop(name) + ret = from_config_func(*args, **kwargs) + # forward the other arguments to __init__ + ret.update(extra_kwargs) + return ret + + +def _called_with_cfg(*args, **kwargs): + """ + Returns: + bool: whether the arguments contain CfgNode and should be considered + forwarded to from_config. + """ + from omegaconf import DictConfig + + if len(args) and isinstance(args[0], (_CfgNode, DictConfig)): + return True + if isinstance(kwargs.pop("cfg", None), (_CfgNode, DictConfig)): + return True + # `from_config`'s first argument is forced to be "cfg". + # So the above check covers all cases. + return False diff --git a/custom_detectron2/config/defaults.py b/custom_detectron2/config/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..6ec2f00a69ac53d97dcf67623c7b262d911454b4 --- /dev/null +++ b/custom_detectron2/config/defaults.py @@ -0,0 +1,650 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .config import CfgNode as CN + +# NOTE: given the new config system +# (https://detectron2.readthedocs.io/en/latest/tutorials/lazyconfigs.html), +# we will stop adding new functionalities to default CfgNode. + +# ----------------------------------------------------------------------------- +# Convention about Training / Test specific parameters +# ----------------------------------------------------------------------------- +# Whenever an argument can be either used for training or for testing, the +# corresponding name will be post-fixed by a _TRAIN for a training parameter, +# or _TEST for a test-specific parameter. +# For example, the number of images during training will be +# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be +# IMAGES_PER_BATCH_TEST + +# ----------------------------------------------------------------------------- +# Config definition +# ----------------------------------------------------------------------------- + +_C = CN() + +# The version number, to upgrade from old configs to new ones if any +# changes happen. It's recommended to keep a VERSION in your config file. +_C.VERSION = 2 + +_C.MODEL = CN() +_C.MODEL.LOAD_PROPOSALS = False +_C.MODEL.MASK_ON = False +_C.MODEL.KEYPOINT_ON = False +_C.MODEL.DEVICE = "cuda" +_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN" + +# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file +# to be loaded to the model. You can find available models in the model zoo. +_C.MODEL.WEIGHTS = "" + +# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR). +# To train on images of different number of channels, just set different mean & std. +# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] +_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675] +# When using pre-trained models in Detectron1 or any MSRA models, +# std has been absorbed into its conv1 weights, so the std needs to be set 1. +# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) +_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0] + + +# ----------------------------------------------------------------------------- +# INPUT +# ----------------------------------------------------------------------------- +_C.INPUT = CN() +# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge. +# Please refer to ResizeShortestEdge for detailed definition. +# Size of the smallest side of the image during training +_C.INPUT.MIN_SIZE_TRAIN = (800,) +# Sample size of smallest side by choice or random selection from range give by +# INPUT.MIN_SIZE_TRAIN +_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice" +# Maximum size of the side of the image during training +_C.INPUT.MAX_SIZE_TRAIN = 1333 +# Size of the smallest side of the image during testing. Set to zero to disable resize in testing. +_C.INPUT.MIN_SIZE_TEST = 800 +# Maximum size of the side of the image during testing +_C.INPUT.MAX_SIZE_TEST = 1333 +# Mode for flipping images used in data augmentation during training +# choose one of ["horizontal, "vertical", "none"] +_C.INPUT.RANDOM_FLIP = "horizontal" + +# `True` if cropping is used for data augmentation during training +_C.INPUT.CROP = CN({"ENABLED": False}) +# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation. +_C.INPUT.CROP.TYPE = "relative_range" +# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of +# pixels if CROP.TYPE is "absolute" +_C.INPUT.CROP.SIZE = [0.9, 0.9] + + +# Whether the model needs RGB, YUV, HSV etc. +# Should be one of the modes defined here, as we use PIL to read the image: +# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes +# with BGR being the one exception. One can set image format to BGR, we will +# internally use RGB for conversion and flip the channels over +_C.INPUT.FORMAT = "BGR" +# The ground truth mask format that the model will use. +# Mask R-CNN supports either "polygon" or "bitmask" as ground truth. +_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask" + + +# ----------------------------------------------------------------------------- +# Dataset +# ----------------------------------------------------------------------------- +_C.DATASETS = CN() +# List of the dataset names for training. Must be registered in DatasetCatalog +# Samples from these datasets will be merged and used as one dataset. +_C.DATASETS.TRAIN = () +# List of the pre-computed proposal files for training, which must be consistent +# with datasets listed in DATASETS.TRAIN. +_C.DATASETS.PROPOSAL_FILES_TRAIN = () +# Number of top scoring precomputed proposals to keep for training +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000 +# List of the dataset names for testing. Must be registered in DatasetCatalog +_C.DATASETS.TEST = () +# List of the pre-computed proposal files for test, which must be consistent +# with datasets listed in DATASETS.TEST. +_C.DATASETS.PROPOSAL_FILES_TEST = () +# Number of top scoring precomputed proposals to keep for test +_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000 + +# ----------------------------------------------------------------------------- +# DataLoader +# ----------------------------------------------------------------------------- +_C.DATALOADER = CN() +# Number of data loading threads +_C.DATALOADER.NUM_WORKERS = 4 +# If True, each batch should contain only images for which the aspect ratio +# is compatible. This groups portrait images together, and landscape images +# are not batched with portrait images. +_C.DATALOADER.ASPECT_RATIO_GROUPING = True +# Options: TrainingSampler, RepeatFactorTrainingSampler +_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler" +# Repeat threshold for RepeatFactorTrainingSampler +_C.DATALOADER.REPEAT_THRESHOLD = 0.0 +# Tf True, when working on datasets that have instance annotations, the +# training dataloader will filter out images without associated annotations +_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True + +# ---------------------------------------------------------------------------- # +# Backbone options +# ---------------------------------------------------------------------------- # +_C.MODEL.BACKBONE = CN() + +_C.MODEL.BACKBONE.NAME = "build_resnet_backbone" +# Freeze the first several stages so they are not trained. +# There are 5 stages in ResNet. The first is a convolution, and the following +# stages are each group of residual blocks. +_C.MODEL.BACKBONE.FREEZE_AT = 2 + + +# ---------------------------------------------------------------------------- # +# FPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.FPN = CN() +# Names of the input feature maps to be used by FPN +# They must have contiguous power of 2 strides +# e.g., ["res2", "res3", "res4", "res5"] +_C.MODEL.FPN.IN_FEATURES = [] +_C.MODEL.FPN.OUT_CHANNELS = 256 + +# Options: "" (no norm), "GN" +_C.MODEL.FPN.NORM = "" + +# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg" +_C.MODEL.FPN.FUSE_TYPE = "sum" + + +# ---------------------------------------------------------------------------- # +# Proposal generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.PROPOSAL_GENERATOR = CN() +# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals" +_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN" +# Proposal height and width both need to be greater than MIN_SIZE +# (a the scale used during training or inference) +_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0 + + +# ---------------------------------------------------------------------------- # +# Anchor generator options +# ---------------------------------------------------------------------------- # +_C.MODEL.ANCHOR_GENERATOR = CN() +# The generator can be any name in the ANCHOR_GENERATOR registry +_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator" +# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input. +# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for +# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1. +# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]] +# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect +# ratios are generated by an anchor generator. +# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W) +# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true, +# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used +# for all IN_FEATURES. +_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]] +# Anchor angles. +# list[list[float]], the angle in degrees, for each input feature map. +# ANGLES[i] specifies the list of angles for IN_FEATURES[i]. +_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]] +# Relative offset between the center of the first anchor and the top-left corner of the image +# Value has to be in [0, 1). Recommend to use 0.5, which means half stride. +# The value is not expected to affect model accuracy. +_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0 + +# ---------------------------------------------------------------------------- # +# RPN options +# ---------------------------------------------------------------------------- # +_C.MODEL.RPN = CN() +_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY + +# Names of the input feature maps to be used by RPN +# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN +_C.MODEL.RPN.IN_FEATURES = ["res4"] +# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels +# Set to -1 or a large value, e.g. 100000, to disable pruning anchors +_C.MODEL.RPN.BOUNDARY_THRESH = -1 +# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD] +# Minimum overlap required between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD +# ==> positive RPN example: 1) +# Maximum overlap allowed between an anchor and ground-truth box for the +# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD +# ==> negative RPN example: 0) +# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD) +# are ignored (-1) +_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7] +_C.MODEL.RPN.IOU_LABELS = [0, -1, 1] +# Number of regions per image used to train RPN +_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256 +# Target fraction of foreground (positive) examples per RPN minibatch +_C.MODEL.RPN.POSITIVE_FRACTION = 0.5 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1" +_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0 +# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets +_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0 +_C.MODEL.RPN.LOSS_WEIGHT = 1.0 +# Number of top scoring RPN proposals to keep before applying NMS +# When FPN is used, this is *per FPN level* (not total) +_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000 +_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000 +# Number of top scoring RPN proposals to keep after applying NMS +# When FPN is used, this limit is applied per level and then again to the union +# of proposals from all levels +# NOTE: When FPN is used, the meaning of this config is different from Detectron1. +# It means per-batch topk in Detectron1, but per-image topk here. +# See the "find_top_rpn_proposals" function for details. +_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000 +_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000 +# NMS threshold used on RPN proposals +_C.MODEL.RPN.NMS_THRESH = 0.7 +# Set this to -1 to use the same number of output channels as input channels. +_C.MODEL.RPN.CONV_DIMS = [-1] + +# ---------------------------------------------------------------------------- # +# ROI HEADS options +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_HEADS = CN() +_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads" +# Number of foreground classes +_C.MODEL.ROI_HEADS.NUM_CLASSES = 80 +# Names of the input feature maps to be used by ROI heads +# Currently all heads (box, mask, ...) use the same input feature map list +# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN +_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"] +# IOU overlap ratios [IOU_THRESHOLD] +# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD) +# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD) +_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5] +_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1] +# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training +# Total number of RoIs per training minibatch = +# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH +# E.g., a common configuration is: 512 * 16 = 8192 +_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512 +# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0) +_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25 + +# Only used on test mode + +# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to +# balance obtaining high recall with not having too many low precision +# detections that will slow down inference post processing steps (like NMS) +# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down +# inference. +_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05 +# Overlap threshold used for non-maximum suppression (suppress boxes with +# IoU >= this threshold) +_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5 +# If True, augment proposals with ground-truth boxes before sampling proposals to +# train ROI heads. +_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True + +# ---------------------------------------------------------------------------- # +# Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_HEAD = CN() +# C4 don't use head name option +# Options for non-C4 models: FastRCNNConvFCHead, +_C.MODEL.ROI_BOX_HEAD.NAME = "" +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1" +# The final scaling coefficient on the box regression loss, used to balance the magnitude of its +# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`. +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0 +# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets +# These are empirically chosen to approximately lead to unit variance targets +_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0) +# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1. +_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0 +_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2" + +_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0 +# Hidden layer dimension for FC layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024 +_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0 +# Channel dimension for Conv layers in the RoI box head +_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_BOX_HEAD.NORM = "" +# Whether to use class agnostic for bbox regression +_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False +# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes. +_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False + +# Federated loss can be used to improve the training of LVIS +_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False +# Sigmoid cross entrophy is used with federated loss +_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False +# The power value applied to image_count when calcualting frequency weight +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER = 0.5 +# Number of classes to keep in total +_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES = 50 + +# ---------------------------------------------------------------------------- # +# Cascaded Box Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_BOX_CASCADE_HEAD = CN() +# The number of cascade stages is implicitly defined by the length of the following two configs. +_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = ( + (10.0, 10.0, 5.0, 5.0), + (20.0, 20.0, 10.0, 10.0), + (30.0, 30.0, 15.0, 15.0), +) +_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7) + + +# ---------------------------------------------------------------------------- # +# Mask Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_MASK_HEAD = CN() +_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead" +_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head +_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256 +# Normalization method for the convolution layers. +# Options: "" (no norm), "GN", "SyncBN". +_C.MODEL.ROI_MASK_HEAD.NORM = "" +# Whether to use class agnostic for mask prediction +_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2" + + +# ---------------------------------------------------------------------------- # +# Keypoint Head +# ---------------------------------------------------------------------------- # +_C.MODEL.ROI_KEYPOINT_HEAD = CN() +_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead" +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14 +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0 +_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8)) +_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO. + +# Images with too few (or no) keypoints are excluded from training. +_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1 +# Normalize by the total number of visible keypoints in the minibatch if True. +# Otherwise, normalize by the total number of keypoints that could ever exist +# in the minibatch. +# The keypoint softmax loss is only calculated on visible keypoints. +# Since the number of visible keypoints can vary significantly between +# minibatches, this has the effect of up-weighting the importance of +# minibatches with few visible keypoints. (Imagine the extreme case of +# only one visible keypoint versus N: in the case of N, each one +# contributes 1/N to the gradient compared to the single keypoint +# determining the gradient direction). Instead, we can normalize the +# loss by the total number of keypoints, if it were the case that all +# keypoints were visible in a full minibatch. (Returning to the example, +# this means that the one visible keypoint contributes as much as each +# of the N keypoints.) +_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True +# Multi-task loss weight to use for keypoints +# Recommended values: +# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True +# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False +_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0 +# Type of pooling operation applied to the incoming feature map for each RoI +_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2" + +# ---------------------------------------------------------------------------- # +# Semantic Segmentation Head +# ---------------------------------------------------------------------------- # +_C.MODEL.SEM_SEG_HEAD = CN() +_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead" +_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"] +# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for +# the correposnding pixel. +_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255 +# Number of classes in the semantic segmentation head +_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54 +# Number of channels in the 3x3 convs inside semantic-FPN heads. +_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128 +# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride. +_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4 +# Normalization method for the convolution layers. Options: "" (no norm), "GN". +_C.MODEL.SEM_SEG_HEAD.NORM = "GN" +_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0 + +_C.MODEL.PANOPTIC_FPN = CN() +# Scaling of all losses from instance detection / segmentation head. +_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0 + +# options when combining instance & semantic segmentation outputs +_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used +_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5 +_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096 +_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5 + + +# ---------------------------------------------------------------------------- # +# RetinaNet Head +# ---------------------------------------------------------------------------- # +_C.MODEL.RETINANET = CN() + +# This is the number of foreground classes. +_C.MODEL.RETINANET.NUM_CLASSES = 80 + +_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"] + +# Convolutions to use in the cls and bbox tower +# NOTE: this doesn't include the last conv for logits +_C.MODEL.RETINANET.NUM_CONVS = 4 + +# IoU overlap ratio [bg, fg] for labeling anchors. +# Anchors with < bg are labeled negative (0) +# Anchors with >= bg and < fg are ignored (-1) +# Anchors with >= fg are labeled positive (1) +_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5] +_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1] + +# Prior prob for rare case (i.e. foreground) at the beginning of training. +# This is used to set the bias for the logits layer of the classifier subnet. +# This improves training stability in the case of heavy class imbalance. +_C.MODEL.RETINANET.PRIOR_PROB = 0.01 + +# Inference cls score threshold, only anchors with score > INFERENCE_TH are +# considered for inference (to improve speed) +_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05 +# Select topk candidates before NMS +_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000 +_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5 + +# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets +_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0) + +# Loss parameters +_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0 +_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25 +_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1 +# Options are: "smooth_l1", "giou", "diou", "ciou" +_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1" + +# One of BN, SyncBN, FrozenBN, GN +# Only supports GN until unshared norm is implemented +_C.MODEL.RETINANET.NORM = "" + + +# ---------------------------------------------------------------------------- # +# ResNe[X]t options (ResNets = {ResNet, ResNeXt} +# Note that parts of a resnet may be used for both the backbone and the head +# These options apply to both +# ---------------------------------------------------------------------------- # +_C.MODEL.RESNETS = CN() + +_C.MODEL.RESNETS.DEPTH = 50 +_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone + +# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt +_C.MODEL.RESNETS.NUM_GROUPS = 1 + +# Options: FrozenBN, GN, "SyncBN", "BN" +_C.MODEL.RESNETS.NORM = "FrozenBN" + +# Baseline width of each group. +# Scaling this parameters will scale the width of all bottleneck layers. +_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64 + +# Place the stride 2 conv on the 1x1 filter +# Use True only for the original MSRA ResNet; use False for C2 and Torch models +_C.MODEL.RESNETS.STRIDE_IN_1X1 = True + +# Apply dilation in stage "res5" +_C.MODEL.RESNETS.RES5_DILATION = 1 + +# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet +# For R18 and R34, this needs to be set to 64 +_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256 +_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64 + +# Apply Deformable Convolution in stages +# Specify if apply deform_conv on Res2, Res3, Res4, Res5 +_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False] +# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168); +# Use False for DeformableV1. +_C.MODEL.RESNETS.DEFORM_MODULATED = False +# Number of groups in deformable conv. +_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1 + + +# ---------------------------------------------------------------------------- # +# Solver +# ---------------------------------------------------------------------------- # +_C.SOLVER = CN() + +# Options: WarmupMultiStepLR, WarmupCosineLR. +# See detectron2/solver/build.py for definition. +_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR" + +_C.SOLVER.MAX_ITER = 40000 + +_C.SOLVER.BASE_LR = 0.001 +# The end lr, only used by WarmupCosineLR +_C.SOLVER.BASE_LR_END = 0.0 + +_C.SOLVER.MOMENTUM = 0.9 + +_C.SOLVER.NESTEROV = False + +_C.SOLVER.WEIGHT_DECAY = 0.0001 +# The weight decay that's applied to parameters of normalization layers +# (typically the affine transformation) +_C.SOLVER.WEIGHT_DECAY_NORM = 0.0 + +_C.SOLVER.GAMMA = 0.1 +# The iteration number to decrease learning rate by GAMMA. +_C.SOLVER.STEPS = (30000,) +# Number of decays in WarmupStepWithFixedGammaLR schedule +_C.SOLVER.NUM_DECAYS = 3 + +_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000 +_C.SOLVER.WARMUP_ITERS = 1000 +_C.SOLVER.WARMUP_METHOD = "linear" +# Whether to rescale the interval for the learning schedule after warmup +_C.SOLVER.RESCALE_INTERVAL = False + +# Save a checkpoint after every this number of iterations +_C.SOLVER.CHECKPOINT_PERIOD = 5000 + +# Number of images per batch across all machines. This is also the number +# of training images per step (i.e. per iteration). If we use 16 GPUs +# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch. +# May be adjusted automatically if REFERENCE_WORLD_SIZE is set. +_C.SOLVER.IMS_PER_BATCH = 16 + +# The reference number of workers (GPUs) this config is meant to train with. +# It takes no effect when set to 0. +# With a non-zero value, it will be used by DefaultTrainer to compute a desired +# per-worker batch size, and then scale the other related configs (total batch size, +# learning rate, etc) to match the per-worker batch size. +# See documentation of `DefaultTrainer.auto_scale_workers` for details: +_C.SOLVER.REFERENCE_WORLD_SIZE = 0 + +# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for +# biases. This is not useful (at least for recent models). You should avoid +# changing these and they exist only to reproduce Detectron v1 training if +# desired. +_C.SOLVER.BIAS_LR_FACTOR = 1.0 +_C.SOLVER.WEIGHT_DECAY_BIAS = None # None means following WEIGHT_DECAY + +# Gradient clipping +_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False}) +# Type of gradient clipping, currently 2 values are supported: +# - "value": the absolute values of elements of each gradients are clipped +# - "norm": the norm of the gradient for each parameter is clipped thus +# affecting all elements in the parameter +_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value" +# Maximum absolute value used for clipping gradients +_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0 +# Floating point number p for L-p norm to be used with the "norm" +# gradient clipping type; for L-inf, please specify .inf +_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0 + +# Enable automatic mixed precision for training +# Note that this does not change model's inference behavior. +# To use AMP in inference, run inference under autocast() +_C.SOLVER.AMP = CN({"ENABLED": False}) + +# ---------------------------------------------------------------------------- # +# Specific test options +# ---------------------------------------------------------------------------- # +_C.TEST = CN() +# For end-to-end tests to verify the expected accuracy. +# Each item is [task, metric, value, tolerance] +# e.g.: [['bbox', 'AP', 38.5, 0.2]] +_C.TEST.EXPECTED_RESULTS = [] +# The period (in terms of steps) to evaluate the model during training. +# Set to 0 to disable. +_C.TEST.EVAL_PERIOD = 0 +# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval +# When empty, it will use the defaults in COCO. +# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. +_C.TEST.KEYPOINT_OKS_SIGMAS = [] +# Maximum number of detections to return per image during inference (100 is +# based on the limit established for the COCO dataset). +_C.TEST.DETECTIONS_PER_IMAGE = 100 + +_C.TEST.AUG = CN({"ENABLED": False}) +_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200) +_C.TEST.AUG.MAX_SIZE = 4000 +_C.TEST.AUG.FLIP = True + +_C.TEST.PRECISE_BN = CN({"ENABLED": False}) +_C.TEST.PRECISE_BN.NUM_ITER = 200 + +# ---------------------------------------------------------------------------- # +# Misc options +# ---------------------------------------------------------------------------- # +# Directory where output files are written +_C.OUTPUT_DIR = "./output" +# Set seed to negative to fully randomize everything. +# Set seed to positive to use a fixed seed. Note that a fixed seed increases +# reproducibility but does not guarantee fully deterministic behavior. +# Disabling all parallelism further increases reproducibility. +_C.SEED = -1 +# Benchmark different cudnn algorithms. +# If input images have very different sizes, this option will have large overhead +# for about 10k iterations. It usually hurts total time, but can benefit for certain models. +# If input images have the same or similar sizes, benchmark is often helpful. +_C.CUDNN_BENCHMARK = False +# The period (in terms of steps) for minibatch visualization at train time. +# Set to 0 to disable. +_C.VIS_PERIOD = 0 + +# global config is for quick hack purposes. +# You can set them in command line or config files, +# and access it with: +# +# from custom_detectron2.config import global_cfg +# print(global_cfg.HACK) +# +# Do not commit any configs into it. +_C.GLOBAL = CN() +_C.GLOBAL.HACK = 1.0 diff --git a/custom_detectron2/config/instantiate.py b/custom_detectron2/config/instantiate.py new file mode 100644 index 0000000000000000000000000000000000000000..24528d71bf09c7ae768aa80df3cf413c8d02598b --- /dev/null +++ b/custom_detectron2/config/instantiate.py @@ -0,0 +1,88 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import collections.abc as abc +import dataclasses +import logging +from typing import Any + +from custom_detectron2.utils.registry import _convert_target_to_string, locate + +__all__ = ["dump_dataclass", "instantiate"] + + +def dump_dataclass(obj: Any): + """ + Dump a dataclass recursively into a dict that can be later instantiated. + + Args: + obj: a dataclass object + + Returns: + dict + """ + assert dataclasses.is_dataclass(obj) and not isinstance( + obj, type + ), "dump_dataclass() requires an instance of a dataclass." + ret = {"_target_": _convert_target_to_string(type(obj))} + for f in dataclasses.fields(obj): + v = getattr(obj, f.name) + if dataclasses.is_dataclass(v): + v = dump_dataclass(v) + if isinstance(v, (list, tuple)): + v = [dump_dataclass(x) if dataclasses.is_dataclass(x) else x for x in v] + ret[f.name] = v + return ret + + +def instantiate(cfg): + """ + Recursively instantiate objects defined in dictionaries by + "_target_" and arguments. + + Args: + cfg: a dict-like object with "_target_" that defines the caller, and + other keys that define the arguments + + Returns: + object instantiated by cfg + """ + from omegaconf import ListConfig, DictConfig, OmegaConf + + if isinstance(cfg, ListConfig): + lst = [instantiate(x) for x in cfg] + return ListConfig(lst, flags={"allow_objects": True}) + if isinstance(cfg, list): + # Specialize for list, because many classes take + # list[objects] as arguments, such as ResNet, DatasetMapper + return [instantiate(x) for x in cfg] + + # If input is a DictConfig backed by dataclasses (i.e. omegaconf's structured config), + # instantiate it to the actual dataclass. + if isinstance(cfg, DictConfig) and dataclasses.is_dataclass(cfg._metadata.object_type): + return OmegaConf.to_object(cfg) + + if isinstance(cfg, abc.Mapping) and "_target_" in cfg: + # conceptually equivalent to hydra.utils.instantiate(cfg) with _convert_=all, + # but faster: https://github.com/facebookresearch/hydra/issues/1200 + cfg = {k: instantiate(v) for k, v in cfg.items()} + cls = cfg.pop("_target_") + cls = instantiate(cls) + + if isinstance(cls, str): + cls_name = cls + cls = locate(cls_name) + assert cls is not None, cls_name + else: + try: + cls_name = cls.__module__ + "." + cls.__qualname__ + except Exception: + # target could be anything, so the above could fail + cls_name = str(cls) + assert callable(cls), f"_target_ {cls} does not define a callable object" + try: + return cls(**cfg) + except TypeError: + logger = logging.getLogger(__name__) + logger.error(f"Error when instantiating {cls_name}!") + raise + return cfg # return as-is if don't know what to do diff --git a/custom_detectron2/config/lazy.py b/custom_detectron2/config/lazy.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8944f206ea5d6d1e0fa98153f8655819101efa --- /dev/null +++ b/custom_detectron2/config/lazy.py @@ -0,0 +1,435 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import ast +import builtins +import collections.abc as abc +import importlib +import inspect +import logging +import os +import uuid +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import is_dataclass +from typing import List, Tuple, Union +import yaml +from omegaconf import DictConfig, ListConfig, OmegaConf, SCMode + +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.registry import _convert_target_to_string + +__all__ = ["LazyCall", "LazyConfig"] + + +class LazyCall: + """ + Wrap a callable so that when it's called, the call will not be executed, + but returns a dict that describes the call. + + LazyCall object has to be called with only keyword arguments. Positional + arguments are not yet supported. + + Examples: + :: + from custom_detectron2.config import instantiate, LazyCall + + layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) + layer_cfg.out_channels = 64 # can edit it afterwards + layer = instantiate(layer_cfg) + """ + + def __init__(self, target): + if not (callable(target) or isinstance(target, (str, abc.Mapping))): + raise TypeError( + f"target of LazyCall must be a callable or defines a callable! Got {target}" + ) + self._target = target + + def __call__(self, **kwargs): + if is_dataclass(self._target): + # omegaconf object cannot hold dataclass type + # https://github.com/omry/omegaconf/issues/784 + target = _convert_target_to_string(self._target) + else: + target = self._target + kwargs["_target_"] = target + + return DictConfig(content=kwargs, flags={"allow_objects": True}) + + +def _visit_dict_config(cfg, func): + """ + Apply func recursively to all DictConfig in cfg. + """ + if isinstance(cfg, DictConfig): + func(cfg) + for v in cfg.values(): + _visit_dict_config(v, func) + elif isinstance(cfg, ListConfig): + for v in cfg: + _visit_dict_config(v, func) + + +def _validate_py_syntax(filename): + # see also https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/config.py + with PathManager.open(filename, "r") as f: + content = f.read() + try: + ast.parse(content) + except SyntaxError as e: + raise SyntaxError(f"Config file {filename} has syntax error!") from e + + +def _cast_to_config(obj): + # if given a dict, return DictConfig instead + if isinstance(obj, dict): + return DictConfig(obj, flags={"allow_objects": True}) + return obj + + +_CFG_PACKAGE_NAME = "detectron2._cfg_loader" +""" +A namespace to put all imported config into. +""" + + +def _random_package_name(filename): + # generate a random package name when loading config files + return _CFG_PACKAGE_NAME + str(uuid.uuid4())[:4] + "." + os.path.basename(filename) + + +@contextmanager +def _patch_import(): + """ + Enhance relative import statements in config files, so that they: + 1. locate files purely based on relative location, regardless of packages. + e.g. you can import file without having __init__ + 2. do not cache modules globally; modifications of module states has no side effect + 3. support other storage system through PathManager, so config files can be in the cloud + 4. imported dict are turned into omegaconf.DictConfig automatically + """ + old_import = builtins.__import__ + + def find_relative_file(original_file, relative_import_path, level): + # NOTE: "from . import x" is not handled. Because then it's unclear + # if such import should produce `x` as a python module or DictConfig. + # This can be discussed further if needed. + relative_import_err = """ +Relative import of directories is not allowed within config files. +Within a config file, relative import can only import other config files. +""".replace( + "\n", " " + ) + if not len(relative_import_path): + raise ImportError(relative_import_err) + + cur_file = os.path.dirname(original_file) + for _ in range(level - 1): + cur_file = os.path.dirname(cur_file) + cur_name = relative_import_path.lstrip(".") + for part in cur_name.split("."): + cur_file = os.path.join(cur_file, part) + if not cur_file.endswith(".py"): + cur_file += ".py" + if not PathManager.isfile(cur_file): + cur_file_no_suffix = cur_file[: -len(".py")] + if PathManager.isdir(cur_file_no_suffix): + raise ImportError(f"Cannot import from {cur_file_no_suffix}." + relative_import_err) + else: + raise ImportError( + f"Cannot import name {relative_import_path} from " + f"{original_file}: {cur_file} does not exist." + ) + return cur_file + + def new_import(name, globals=None, locals=None, fromlist=(), level=0): + if ( + # Only deal with relative imports inside config files + level != 0 + and globals is not None + and (globals.get("__package__", "") or "").startswith(_CFG_PACKAGE_NAME) + ): + cur_file = find_relative_file(globals["__file__"], name, level) + _validate_py_syntax(cur_file) + spec = importlib.machinery.ModuleSpec( + _random_package_name(cur_file), None, origin=cur_file + ) + module = importlib.util.module_from_spec(spec) + module.__file__ = cur_file + with PathManager.open(cur_file) as f: + content = f.read() + exec(compile(content, cur_file, "exec"), module.__dict__) + for name in fromlist: # turn imported dict into DictConfig automatically + val = _cast_to_config(module.__dict__[name]) + module.__dict__[name] = val + return module + return old_import(name, globals, locals, fromlist=fromlist, level=level) + + builtins.__import__ = new_import + yield new_import + builtins.__import__ = old_import + + +class LazyConfig: + """ + Provide methods to save, load, and overrides an omegaconf config object + which may contain definition of lazily-constructed objects. + """ + + @staticmethod + def load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Similar to :meth:`load()`, but load path relative to the caller's + source file. + + This has the same functionality as a relative import, except that this method + accepts filename as a string, so more characters are allowed in the filename. + """ + caller_frame = inspect.stack()[1] + caller_fname = caller_frame[0].f_code.co_filename + assert caller_fname != "", "load_rel Unable to find caller" + caller_dir = os.path.dirname(caller_fname) + filename = os.path.join(caller_dir, filename) + return LazyConfig.load(filename, keys) + + @staticmethod + def load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None): + """ + Load a config file. + + Args: + filename: absolute path or relative path w.r.t. the current working directory + keys: keys to load and return. If not given, return all keys + (whose values are config objects) in a dict. + """ + has_keys = keys is not None + filename = filename.replace("/./", "/") # redundant + if os.path.splitext(filename)[1] not in [".py", ".yaml", ".yml"]: + raise ValueError(f"Config file {filename} has to be a python or yaml file.") + if filename.endswith(".py"): + _validate_py_syntax(filename) + + with _patch_import(): + # Record the filename + module_namespace = { + "__file__": filename, + "__package__": _random_package_name(filename), + } + with PathManager.open(filename) as f: + content = f.read() + # Compile first with filename to: + # 1. make filename appears in stacktrace + # 2. make load_rel able to find its parent's (possibly remote) location + exec(compile(content, filename, "exec"), module_namespace) + + ret = module_namespace + else: + with PathManager.open(filename) as f: + obj = yaml.unsafe_load(f) + ret = OmegaConf.create(obj, flags={"allow_objects": True}) + + if has_keys: + if isinstance(keys, str): + return _cast_to_config(ret[keys]) + else: + return tuple(_cast_to_config(ret[a]) for a in keys) + else: + if filename.endswith(".py"): + # when not specified, only load those that are config objects + ret = DictConfig( + { + name: _cast_to_config(value) + for name, value in ret.items() + if isinstance(value, (DictConfig, ListConfig, dict)) + and not name.startswith("_") + }, + flags={"allow_objects": True}, + ) + return ret + + @staticmethod + def save(cfg, filename: str): + """ + Save a config object to a yaml file. + Note that when the config dictionary contains complex objects (e.g. lambda), + it can't be saved to yaml. In that case we will print an error and + attempt to save to a pkl file instead. + + Args: + cfg: an omegaconf config object + filename: yaml file name to save the config file + """ + logger = logging.getLogger(__name__) + try: + cfg = deepcopy(cfg) + except Exception: + pass + else: + # if it's deep-copyable, then... + def _replace_type_by_name(x): + if "_target_" in x and callable(x._target_): + try: + x._target_ = _convert_target_to_string(x._target_) + except AttributeError: + pass + + # not necessary, but makes yaml looks nicer + _visit_dict_config(cfg, _replace_type_by_name) + + save_pkl = False + try: + dict = OmegaConf.to_container( + cfg, + # Do not resolve interpolation when saving, i.e. do not turn ${a} into + # actual values when saving. + resolve=False, + # Save structures (dataclasses) in a format that can be instantiated later. + # Without this option, the type information of the dataclass will be erased. + structured_config_mode=SCMode.INSTANTIATE, + ) + dumped = yaml.dump(dict, default_flow_style=None, allow_unicode=True, width=9999) + with PathManager.open(filename, "w") as f: + f.write(dumped) + + try: + _ = yaml.unsafe_load(dumped) # test that it is loadable + except Exception: + logger.warning( + "The config contains objects that cannot serialize to a valid yaml. " + f"{filename} is human-readable but cannot be loaded." + ) + save_pkl = True + except Exception: + logger.exception("Unable to serialize the config to yaml. Error:") + save_pkl = True + + if save_pkl: + new_filename = filename + ".pkl" + # try: + # # retry by pickle + # with PathManager.open(new_filename, "wb") as f: + # cloudpickle.dump(cfg, f) + # logger.warning(f"Config is saved using cloudpickle at {new_filename}.") + # except Exception: + # pass + + @staticmethod + def apply_overrides(cfg, overrides: List[str]): + """ + In-place override contents of cfg. + + Args: + cfg: an omegaconf config object + overrides: list of strings in the format of "a=b" to override configs. + See https://hydra.cc/docs/next/advanced/override_grammar/basic/ + for syntax. + + Returns: + the cfg object + """ + + def safe_update(cfg, key, value): + parts = key.split(".") + for idx in range(1, len(parts)): + prefix = ".".join(parts[:idx]) + v = OmegaConf.select(cfg, prefix, default=None) + if v is None: + break + if not OmegaConf.is_config(v): + raise KeyError( + f"Trying to update key {key}, but {prefix} " + f"is not a config, but has type {type(v)}." + ) + OmegaConf.update(cfg, key, value, merge=True) + + try: + from hydra.core.override_parser.overrides_parser import OverridesParser + + has_hydra = True + except ImportError: + has_hydra = False + + if has_hydra: + parser = OverridesParser.create() + overrides = parser.parse_overrides(overrides) + for o in overrides: + key = o.key_or_group + value = o.value() + if o.is_delete(): + # TODO support this + raise NotImplementedError("deletion is not yet a supported override") + safe_update(cfg, key, value) + else: + # Fallback. Does not support all the features and error checking like hydra. + for o in overrides: + key, value = o.split("=") + try: + value = eval(value, {}) + except NameError: + pass + safe_update(cfg, key, value) + return cfg + + # @staticmethod + # def to_py(cfg, prefix: str = "cfg."): + # """ + # Try to convert a config object into Python-like psuedo code. + # + # Note that perfect conversion is not always possible. So the returned + # results are mainly meant to be human-readable, and not meant to be executed. + # + # Args: + # cfg: an omegaconf config object + # prefix: root name for the resulting code (default: "cfg.") + # + # + # Returns: + # str of formatted Python code + # """ + # import black + # + # cfg = OmegaConf.to_container(cfg, resolve=True) + # + # def _to_str(obj, prefix=None, inside_call=False): + # if prefix is None: + # prefix = [] + # if isinstance(obj, abc.Mapping) and "_target_" in obj: + # # Dict representing a function call + # target = _convert_target_to_string(obj.pop("_target_")) + # args = [] + # for k, v in sorted(obj.items()): + # args.append(f"{k}={_to_str(v, inside_call=True)}") + # args = ", ".join(args) + # call = f"{target}({args})" + # return "".join(prefix) + call + # elif isinstance(obj, abc.Mapping) and not inside_call: + # # Dict that is not inside a call is a list of top-level config objects that we + # # render as one object per line with dot separated prefixes + # key_list = [] + # for k, v in sorted(obj.items()): + # if isinstance(v, abc.Mapping) and "_target_" not in v: + # key_list.append(_to_str(v, prefix=prefix + [k + "."])) + # else: + # key = "".join(prefix) + k + # key_list.append(f"{key}={_to_str(v)}") + # return "\n".join(key_list) + # elif isinstance(obj, abc.Mapping): + # # Dict that is inside a call is rendered as a regular dict + # return ( + # "{" + # + ",".join( + # f"{repr(k)}: {_to_str(v, inside_call=inside_call)}" + # for k, v in sorted(obj.items()) + # ) + # + "}" + # ) + # elif isinstance(obj, list): + # return "[" + ",".join(_to_str(x, inside_call=inside_call) for x in obj) + "]" + # else: + # return repr(obj) + # + # py_str = _to_str(cfg, prefix=[prefix]) + # try: + # return black.format_str(py_str, mode=black.Mode()) + # except black.InvalidInput: + # return py_str diff --git a/custom_detectron2/data/__init__.py b/custom_detectron2/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..259f669b78bd05815cb8d3351fd6c5fc9a1b85a1 --- /dev/null +++ b/custom_detectron2/data/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from . import transforms # isort:skip + +from .build import ( + build_batch_data_loader, + build_detection_test_loader, + build_detection_train_loader, + get_detection_dataset_dicts, + load_proposals_into_dataset, + print_instances_class_histogram, +) +from .catalog import DatasetCatalog, MetadataCatalog, Metadata +from .common import DatasetFromList, MapDataset, ToIterableDataset +from .dataset_mapper import DatasetMapper + +# ensure the builtin datasets are registered +from . import datasets, samplers # isort:skip + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/data/benchmark.py b/custom_detectron2/data/benchmark.py new file mode 100644 index 0000000000000000000000000000000000000000..eda0bb5c751ee961a1cef8190d6ed966f06f190a --- /dev/null +++ b/custom_detectron2/data/benchmark.py @@ -0,0 +1,225 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from itertools import count +from typing import List, Tuple +import torch +import tqdm +from fvcore.common.timer import Timer + +from custom_detectron2.utils import comm + +from .build import build_batch_data_loader +from .common import DatasetFromList, MapDataset +from .samplers import TrainingSampler + +logger = logging.getLogger(__name__) + + +class _EmptyMapDataset(torch.utils.data.Dataset): + """ + Map anything to emptiness. + """ + + def __init__(self, dataset): + self.ds = dataset + + def __len__(self): + return len(self.ds) + + def __getitem__(self, idx): + _ = self.ds[idx] + return [0] + + +def iter_benchmark( + iterator, num_iter: int, warmup: int = 5, max_time_seconds: float = 60 +) -> Tuple[float, List[float]]: + """ + Benchmark an iterator/iterable for `num_iter` iterations with an extra + `warmup` iterations of warmup. + End early if `max_time_seconds` time is spent on iterations. + + Returns: + float: average time (seconds) per iteration + list[float]: time spent on each iteration. Sometimes useful for further analysis. + """ + num_iter, warmup = int(num_iter), int(warmup) + + iterator = iter(iterator) + for _ in range(warmup): + next(iterator) + timer = Timer() + all_times = [] + for curr_iter in tqdm.trange(num_iter): + start = timer.seconds() + if start > max_time_seconds: + num_iter = curr_iter + break + next(iterator) + all_times.append(timer.seconds() - start) + avg = timer.seconds() / num_iter + return avg, all_times + + +class DataLoaderBenchmark: + """ + Some common benchmarks that help understand perf bottleneck of a standard dataloader + made of dataset, mapper and sampler. + """ + + def __init__( + self, + dataset, + *, + mapper, + sampler=None, + total_batch_size, + num_workers=0, + max_time_seconds: int = 90, + ): + """ + Args: + max_time_seconds (int): maximum time to spent for each benchmark + other args: same as in `build.py:build_detection_train_loader` + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False, serialize=True) + if sampler is None: + sampler = TrainingSampler(len(dataset)) + + self.dataset = dataset + self.mapper = mapper + self.sampler = sampler + self.total_batch_size = total_batch_size + self.num_workers = num_workers + self.per_gpu_batch_size = self.total_batch_size // comm.get_world_size() + + self.max_time_seconds = max_time_seconds + + def _benchmark(self, iterator, num_iter, warmup, msg=None): + avg, all_times = iter_benchmark(iterator, num_iter, warmup, self.max_time_seconds) + if msg is not None: + self._log_time(msg, avg, all_times) + return avg, all_times + + def _log_time(self, msg, avg, all_times, distributed=False): + percentiles = [np.percentile(all_times, k, interpolation="nearest") for k in [1, 5, 95, 99]] + if not distributed: + logger.info( + f"{msg}: avg={1.0/avg:.1f} it/s, " + f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, " + f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s." + ) + return + avg_per_gpu = comm.all_gather(avg) + percentiles_per_gpu = comm.all_gather(percentiles) + if comm.get_rank() > 0: + return + for idx, avg, percentiles in zip(count(), avg_per_gpu, percentiles_per_gpu): + logger.info( + f"GPU{idx} {msg}: avg={1.0/avg:.1f} it/s, " + f"p1={percentiles[0]:.2g}s, p5={percentiles[1]:.2g}s, " + f"p95={percentiles[2]:.2g}s, p99={percentiles[3]:.2g}s." + ) + + def benchmark_dataset(self, num_iter, warmup=5): + """ + Benchmark the speed of taking raw samples from the dataset. + """ + + def loader(): + while True: + for k in self.sampler: + yield self.dataset[k] + + self._benchmark(loader(), num_iter, warmup, "Dataset Alone") + + def benchmark_mapper(self, num_iter, warmup=5): + """ + Benchmark the speed of taking raw samples from the dataset and map + them in a single process. + """ + + def loader(): + while True: + for k in self.sampler: + yield self.mapper(self.dataset[k]) + + self._benchmark(loader(), num_iter, warmup, "Single Process Mapper (sec/sample)") + + def benchmark_workers(self, num_iter, warmup=10): + """ + Benchmark the dataloader by tuning num_workers to [0, 1, self.num_workers]. + """ + candidates = [0, 1] + if self.num_workers not in candidates: + candidates.append(self.num_workers) + + dataset = MapDataset(self.dataset, self.mapper) + for n in candidates: + loader = build_batch_data_loader( + dataset, + self.sampler, + self.total_batch_size, + num_workers=n, + ) + self._benchmark( + iter(loader), + num_iter * max(n, 1), + warmup * max(n, 1), + f"DataLoader ({n} workers, bs={self.per_gpu_batch_size})", + ) + del loader + + def benchmark_IPC(self, num_iter, warmup=10): + """ + Benchmark the dataloader where each worker outputs nothing. This + eliminates the IPC overhead compared to the regular dataloader. + + PyTorch multiprocessing's IPC only optimizes for torch tensors. + Large numpy arrays or other data structure may incur large IPC overhead. + """ + n = self.num_workers + dataset = _EmptyMapDataset(MapDataset(self.dataset, self.mapper)) + loader = build_batch_data_loader( + dataset, self.sampler, self.total_batch_size, num_workers=n + ) + self._benchmark( + iter(loader), + num_iter * max(n, 1), + warmup * max(n, 1), + f"DataLoader ({n} workers, bs={self.per_gpu_batch_size}) w/o comm", + ) + + def benchmark_distributed(self, num_iter, warmup=10): + """ + Benchmark the dataloader in each distributed worker, and log results of + all workers. This helps understand the final performance as well as + the variances among workers. + + It also prints startup time (first iter) of the dataloader. + """ + gpu = comm.get_world_size() + dataset = MapDataset(self.dataset, self.mapper) + n = self.num_workers + loader = build_batch_data_loader( + dataset, self.sampler, self.total_batch_size, num_workers=n + ) + + timer = Timer() + loader = iter(loader) + next(loader) + startup_time = timer.seconds() + logger.info("Dataloader startup time: {:.2f} seconds".format(startup_time)) + + comm.synchronize() + + avg, all_times = self._benchmark(loader, num_iter * max(n, 1), warmup * max(n, 1)) + del loader + self._log_time( + f"DataLoader ({gpu} GPUs x {n} workers, total bs={self.total_batch_size})", + avg, + all_times, + True, + ) diff --git a/custom_detectron2/data/build.py b/custom_detectron2/data/build.py new file mode 100644 index 0000000000000000000000000000000000000000..aaa44834b76e9347f0ad5c452f0d8e2f532d1795 --- /dev/null +++ b/custom_detectron2/data/build.py @@ -0,0 +1,556 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import numpy as np +import operator +import pickle +from typing import Any, Callable, Dict, List, Optional, Union +import torch +import torch.utils.data as torchdata +from tabulate import tabulate +from termcolor import colored + +from custom_detectron2.config import configurable +from custom_detectron2.structures import BoxMode +from custom_detectron2.utils.comm import get_world_size +from custom_detectron2.utils.env import seed_all_rng +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.logger import _log_api_usage, log_first_n + +from .catalog import DatasetCatalog, MetadataCatalog +from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset, ToIterableDataset +from .dataset_mapper import DatasetMapper +from .detection_utils import check_metadata_consistency +from .samplers import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) + +""" +This file contains the default logic to build a dataloader for training or testing. +""" + +__all__ = [ + "build_batch_data_loader", + "build_detection_train_loader", + "build_detection_test_loader", + "get_detection_dataset_dicts", + "load_proposals_into_dataset", + "print_instances_class_histogram", +] + + +def filter_images_with_only_crowd_annotations(dataset_dicts): + """ + Filter out images with none annotations or only crowd annotations + (i.e., images without non-crowd annotations). + A common training-time preprocessing on COCO dataset. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format, but filtered. + """ + num_before = len(dataset_dicts) + + def valid(anns): + for ann in anns: + if ann.get("iscrowd", 0) == 0: + return True + return False + + dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with no usable annotations. {} images left.".format( + num_before - num_after, num_after + ) + ) + return dataset_dicts + + +def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): + """ + Filter out images with too few number of keypoints. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + + Returns: + list[dict]: the same format as dataset_dicts, but filtered. + """ + num_before = len(dataset_dicts) + + def visible_keypoints_in_image(dic): + # Each keypoints field has the format [x1, y1, v1, ...], where v is visibility + annotations = dic["annotations"] + return sum( + (np.array(ann["keypoints"][2::3]) > 0).sum() + for ann in annotations + if "keypoints" in ann + ) + + dataset_dicts = [ + x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image + ] + num_after = len(dataset_dicts) + logger = logging.getLogger(__name__) + logger.info( + "Removed {} images with fewer than {} keypoints.".format( + num_before - num_after, min_keypoints_per_image + ) + ) + return dataset_dicts + + +def load_proposals_into_dataset(dataset_dicts, proposal_file): + """ + Load precomputed object proposals into the dataset. + + The proposal file should be a pickled dict with the following keys: + + - "ids": list[int] or list[str], the image ids + - "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id + - "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores + corresponding to the boxes. + - "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. + proposal_file (str): file path of pre-computed proposals, in pkl format. + + Returns: + list[dict]: the same format as dataset_dicts, but added proposal field. + """ + logger = logging.getLogger(__name__) + logger.info("Loading proposals from: {}".format(proposal_file)) + + with PathManager.open(proposal_file, "rb") as f: + proposals = pickle.load(f, encoding="latin1") + + # Rename the key names in D1 proposal files + rename_keys = {"indexes": "ids", "scores": "objectness_logits"} + for key in rename_keys: + if key in proposals: + proposals[rename_keys[key]] = proposals.pop(key) + + # Fetch the indexes of all proposals that are in the dataset + # Convert image_id to str since they could be int. + img_ids = set({str(record["image_id"]) for record in dataset_dicts}) + id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} + + # Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' + bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS + + for record in dataset_dicts: + # Get the index of the proposal + i = id_to_index[str(record["image_id"])] + + boxes = proposals["boxes"][i] + objectness_logits = proposals["objectness_logits"][i] + # Sort the proposals in descending order of the scores + inds = objectness_logits.argsort()[::-1] + record["proposal_boxes"] = boxes[inds] + record["proposal_objectness_logits"] = objectness_logits[inds] + record["proposal_bbox_mode"] = bbox_mode + + return dataset_dicts + + +def print_instances_class_histogram(dataset_dicts, class_names): + """ + Args: + dataset_dicts (list[dict]): list of dataset dicts. + class_names (list[str]): list of class names (zero-indexed). + """ + num_classes = len(class_names) + hist_bins = np.arange(num_classes + 1) + histogram = np.zeros((num_classes,), dtype=np.int) + for entry in dataset_dicts: + annos = entry["annotations"] + classes = np.asarray( + [x["category_id"] for x in annos if not x.get("iscrowd", 0)], dtype=np.int + ) + if len(classes): + assert classes.min() >= 0, f"Got an invalid category_id={classes.min()}" + assert ( + classes.max() < num_classes + ), f"Got an invalid category_id={classes.max()} for a dataset of {num_classes} classes" + histogram += np.histogram(classes, bins=hist_bins)[0] + + N_COLS = min(6, len(class_names) * 2) + + def short_name(x): + # make long class names shorter. useful for lvis + if len(x) > 13: + return x[:11] + ".." + return x + + data = list( + itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) + ) + total_num_instances = sum(data[1::2]) + data.extend([None] * (N_COLS - (len(data) % N_COLS))) + if num_classes > 1: + data.extend(["total", total_num_instances]) + data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + data, + headers=["category", "#instances"] * (N_COLS // 2), + tablefmt="pipe", + numalign="left", + stralign="center", + ) + log_first_n( + logging.INFO, + "Distribution of instances among all {} categories:\n".format(num_classes) + + colored(table, "cyan"), + key="message", + ) + + +def get_detection_dataset_dicts( + names, + filter_empty=True, + min_keypoints=0, + proposal_files=None, + check_consistency=True, +): + """ + Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. + + Args: + names (str or list[str]): a dataset name or a list of dataset names + filter_empty (bool): whether to filter out images without instance annotations + min_keypoints (int): filter out images with fewer keypoints than + `min_keypoints`. Set to 0 to do nothing. + proposal_files (list[str]): if given, a list of object proposal files + that match each dataset in `names`. + check_consistency (bool): whether to check if datasets have consistent metadata. + + Returns: + list[dict]: a list of dicts following the standard dataset dict format. + """ + if isinstance(names, str): + names = [names] + assert len(names), names + dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in names] + + if isinstance(dataset_dicts[0], torchdata.Dataset): + if len(dataset_dicts) > 1: + # ConcatDataset does not work for iterable style dataset. + # We could support concat for iterable as well, but it's often + # not a good idea to concat iterables anyway. + return torchdata.ConcatDataset(dataset_dicts) + return dataset_dicts[0] + + for dataset_name, dicts in zip(names, dataset_dicts): + assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) + + if proposal_files is not None: + assert len(names) == len(proposal_files) + # load precomputed proposals from proposal files + dataset_dicts = [ + load_proposals_into_dataset(dataset_i_dicts, proposal_file) + for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) + ] + + dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) + + has_instances = "annotations" in dataset_dicts[0] + if filter_empty and has_instances: + dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) + if min_keypoints > 0 and has_instances: + dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) + + if check_consistency and has_instances: + try: + class_names = MetadataCatalog.get(names[0]).thing_classes + check_metadata_consistency("thing_classes", names) + print_instances_class_histogram(dataset_dicts, class_names) + except AttributeError: # class names are not available for this dataset + pass + + assert len(dataset_dicts), "No valid data found in {}.".format(",".join(names)) + return dataset_dicts + + +def build_batch_data_loader( + dataset, + sampler, + total_batch_size, + *, + aspect_ratio_grouping=False, + num_workers=0, + collate_fn=None, +): + """ + Build a batched dataloader. The main differences from `torch.utils.data.DataLoader` are: + 1. support aspect ratio grouping options + 2. use no "batch collation", because this is common for detection training + + Args: + dataset (torch.utils.data.Dataset): a pytorch map-style or iterable dataset. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces indices. + Must be provided iff. ``dataset`` is a map-style dataset. + total_batch_size, aspect_ratio_grouping, num_workers, collate_fn: see + :func:`build_detection_train_loader`. + + Returns: + iterable[list]. Length of each list is the batch size of the current + GPU. Each element in the list comes from the dataset. + """ + world_size = get_world_size() + assert ( + total_batch_size > 0 and total_batch_size % world_size == 0 + ), "Total batch size ({}) must be divisible by the number of gpus ({}).".format( + total_batch_size, world_size + ) + batch_size = total_batch_size // world_size + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + dataset = ToIterableDataset(dataset, sampler) + + if aspect_ratio_grouping: + data_loader = torchdata.DataLoader( + dataset, + num_workers=num_workers, + collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements + worker_init_fn=worker_init_reset_seed, + ) # yield individual mapped dict + data_loader = AspectRatioGroupedDataset(data_loader, batch_size) + if collate_fn is None: + return data_loader + return MapDataset(data_loader, collate_fn) + else: + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + drop_last=True, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + worker_init_fn=worker_init_reset_seed, + ) + + +def _train_loader_from_config(cfg, mapper=None, *, dataset=None, sampler=None): + if dataset is None: + dataset = get_detection_dataset_dicts( + cfg.DATASETS.TRAIN, + filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, + min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE + if cfg.MODEL.KEYPOINT_ON + else 0, + proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, + ) + _log_api_usage("dataset." + cfg.DATASETS.TRAIN[0]) + + if mapper is None: + mapper = DatasetMapper(cfg, True) + + if sampler is None: + sampler_name = cfg.DATALOADER.SAMPLER_TRAIN + logger = logging.getLogger(__name__) + if isinstance(dataset, torchdata.IterableDataset): + logger.info("Not using any sampler since the dataset is IterableDataset.") + sampler = None + else: + logger.info("Using training sampler {}".format(sampler_name)) + if sampler_name == "TrainingSampler": + sampler = TrainingSampler(len(dataset)) + elif sampler_name == "RepeatFactorTrainingSampler": + repeat_factors = RepeatFactorTrainingSampler.repeat_factors_from_category_frequency( + dataset, cfg.DATALOADER.REPEAT_THRESHOLD + ) + sampler = RepeatFactorTrainingSampler(repeat_factors) + elif sampler_name == "RandomSubsetTrainingSampler": + sampler = RandomSubsetTrainingSampler( + len(dataset), cfg.DATALOADER.RANDOM_SUBSET_RATIO + ) + else: + raise ValueError("Unknown training sampler: {}".format(sampler_name)) + + return { + "dataset": dataset, + "sampler": sampler, + "mapper": mapper, + "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, + "aspect_ratio_grouping": cfg.DATALOADER.ASPECT_RATIO_GROUPING, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + } + + +@configurable(from_config=_train_loader_from_config) +def build_detection_train_loader( + dataset, + *, + mapper, + sampler=None, + total_batch_size, + aspect_ratio_grouping=True, + num_workers=0, + collate_fn=None, +): + """ + Build a dataloader for object detection with some default features. + + Args: + dataset (list or torch.utils.data.Dataset): a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). It can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper (callable): a callable which takes a sample (dict) from dataset and + returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=True)``. + sampler (torch.utils.data.sampler.Sampler or None): a sampler that produces + indices to be applied on ``dataset``. + If ``dataset`` is map-style, the default sampler is a :class:`TrainingSampler`, + which coordinates an infinite random shuffle sequence across all workers. + Sampler must be None if ``dataset`` is iterable. + total_batch_size (int): total batch size across all workers. + aspect_ratio_grouping (bool): whether to group images with similar + aspect ratio for efficiency. When enabled, it requires each + element in dataset be a dict with keys "width" and "height". + num_workers (int): number of parallel data loading workers + collate_fn: a function that determines how to do batching, same as the argument of + `torch.utils.data.DataLoader`. Defaults to do no collation and return a list of + data. No collation is OK for small batch size and simple data structures. + If your batch size is large and each sample contains too many small tensors, + it's more efficient to collate them in data loader. + + Returns: + torch.utils.data.DataLoader: + a dataloader. Each output from it is a ``list[mapped_element]`` of length + ``total_batch_size / num_workers``, where ``mapped_element`` is produced + by the ``mapper``. + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = TrainingSampler(len(dataset)) + assert isinstance(sampler, torchdata.Sampler), f"Expect a Sampler but got {type(sampler)}" + return build_batch_data_loader( + dataset, + sampler, + total_batch_size, + aspect_ratio_grouping=aspect_ratio_grouping, + num_workers=num_workers, + collate_fn=collate_fn, + ) + + +def _test_loader_from_config(cfg, dataset_name, mapper=None): + """ + Uses the given `dataset_name` argument (instead of the names in cfg), because the + standard practice is to evaluate each test set individually (not combining them). + """ + if isinstance(dataset_name, str): + dataset_name = [dataset_name] + + dataset = get_detection_dataset_dicts( + dataset_name, + filter_empty=False, + proposal_files=[ + cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(x)] for x in dataset_name + ] + if cfg.MODEL.LOAD_PROPOSALS + else None, + ) + if mapper is None: + mapper = DatasetMapper(cfg, False) + return { + "dataset": dataset, + "mapper": mapper, + "num_workers": cfg.DATALOADER.NUM_WORKERS, + "sampler": InferenceSampler(len(dataset)) + if not isinstance(dataset, torchdata.IterableDataset) + else None, + } + + +@configurable(from_config=_test_loader_from_config) +def build_detection_test_loader( + dataset: Union[List[Any], torchdata.Dataset], + *, + mapper: Callable[[Dict[str, Any]], Any], + sampler: Optional[torchdata.Sampler] = None, + batch_size: int = 1, + num_workers: int = 0, + collate_fn: Optional[Callable[[List[Any]], Any]] = None, +) -> torchdata.DataLoader: + """ + Similar to `build_detection_train_loader`, with default batch size = 1, + and sampler = :class:`InferenceSampler`. This sampler coordinates all workers + to produce the exact set of all samples. + + Args: + dataset: a list of dataset dicts, + or a pytorch dataset (either map-style or iterable). They can be obtained + by using :func:`DatasetCatalog.get` or :func:`get_detection_dataset_dicts`. + mapper: a callable which takes a sample (dict) from dataset + and returns the format to be consumed by the model. + When using cfg, the default choice is ``DatasetMapper(cfg, is_train=False)``. + sampler: a sampler that produces + indices to be applied on ``dataset``. Default to :class:`InferenceSampler`, + which splits the dataset across all workers. Sampler must be None + if `dataset` is iterable. + batch_size: the batch size of the data loader to be created. + Default to 1 image per worker since this is the standard when reporting + inference time in papers. + num_workers: number of parallel data loading workers + collate_fn: same as the argument of `torch.utils.data.DataLoader`. + Defaults to do no collation and return a list of data. + + Returns: + DataLoader: a torch DataLoader, that loads the given detection + dataset, with test-time transformation and batching. + + Examples: + :: + data_loader = build_detection_test_loader( + DatasetRegistry.get("my_test"), + mapper=DatasetMapper(...)) + + # or, instantiate with a CfgNode: + data_loader = build_detection_test_loader(cfg, "my_test") + """ + if isinstance(dataset, list): + dataset = DatasetFromList(dataset, copy=False) + if mapper is not None: + dataset = MapDataset(dataset, mapper) + if isinstance(dataset, torchdata.IterableDataset): + assert sampler is None, "sampler must be None if dataset is IterableDataset" + else: + if sampler is None: + sampler = InferenceSampler(len(dataset)) + return torchdata.DataLoader( + dataset, + batch_size=batch_size, + sampler=sampler, + drop_last=False, + num_workers=num_workers, + collate_fn=trivial_batch_collator if collate_fn is None else collate_fn, + ) + + +def trivial_batch_collator(batch): + """ + A batch collator that does nothing. + """ + return batch + + +def worker_init_reset_seed(worker_id): + initial_seed = torch.initial_seed() % 2**31 + seed_all_rng(initial_seed + worker_id) diff --git a/custom_detectron2/data/catalog.py b/custom_detectron2/data/catalog.py new file mode 100644 index 0000000000000000000000000000000000000000..8a5507b8e3bd75eb2acba1b6218e7832090ca985 --- /dev/null +++ b/custom_detectron2/data/catalog.py @@ -0,0 +1,236 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import types +from collections import UserDict +from typing import List + +from custom_detectron2.utils.logger import log_first_n + +__all__ = ["DatasetCatalog", "MetadataCatalog", "Metadata"] + + +class _DatasetCatalog(UserDict): + """ + A global dictionary that stores information about the datasets and how to obtain them. + + It contains a mapping from strings + (which are names that identify a dataset, e.g. "coco_2014_train") + to a function which parses the dataset and returns the samples in the + format of `list[dict]`. + + The returned dicts should be in Detectron2 Dataset format (See DATASETS.md for details) + if used with the data loader functionalities in `data/build.py,data/detection_transform.py`. + + The purpose of having this catalog is to make it easy to choose + different datasets, by just using the strings in the config. + """ + + def register(self, name, func): + """ + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + func (callable): a callable which takes no arguments and returns a list of dicts. + It must return the same results if called multiple times. + """ + assert callable(func), "You must register a function with `DatasetCatalog.register`!" + assert name not in self, "Dataset '{}' is already registered!".format(name) + self[name] = func + + def get(self, name): + """ + Call the registered function and return its results. + + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + + Returns: + list[dict]: dataset annotations. + """ + try: + f = self[name] + except KeyError as e: + raise KeyError( + "Dataset '{}' is not registered! Available datasets are: {}".format( + name, ", ".join(list(self.keys())) + ) + ) from e + return f() + + def list(self) -> List[str]: + """ + List all registered datasets. + + Returns: + list[str] + """ + return list(self.keys()) + + def remove(self, name): + """ + Alias of ``pop``. + """ + self.pop(name) + + def __str__(self): + return "DatasetCatalog(registered datasets: {})".format(", ".join(self.keys())) + + __repr__ = __str__ + + +DatasetCatalog = _DatasetCatalog() +DatasetCatalog.__doc__ = ( + _DatasetCatalog.__doc__ + + """ + .. automethod:: detectron2.data.catalog.DatasetCatalog.register + .. automethod:: detectron2.data.catalog.DatasetCatalog.get +""" +) + + +class Metadata(types.SimpleNamespace): + """ + A class that supports simple attribute setter/getter. + It is intended for storing metadata of a dataset and make it accessible globally. + + Examples: + :: + # somewhere when you load the data: + MetadataCatalog.get("mydataset").thing_classes = ["person", "dog"] + + # somewhere when you print statistics or visualize: + classes = MetadataCatalog.get("mydataset").thing_classes + """ + + # the name of the dataset + # set default to N/A so that `self.name` in the errors will not trigger getattr again + name: str = "N/A" + + _RENAMED = { + "class_names": "thing_classes", + "dataset_id_to_contiguous_id": "thing_dataset_id_to_contiguous_id", + "stuff_class_names": "stuff_classes", + } + + def __getattr__(self, key): + if key in self._RENAMED: + log_first_n( + logging.WARNING, + "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]), + n=10, + ) + return getattr(self, self._RENAMED[key]) + + # "name" exists in every metadata + if len(self.__dict__) > 1: + raise AttributeError( + "Attribute '{}' does not exist in the metadata of dataset '{}'. Available " + "keys are {}.".format(key, self.name, str(self.__dict__.keys())) + ) + else: + raise AttributeError( + f"Attribute '{key}' does not exist in the metadata of dataset '{self.name}': " + "metadata is empty." + ) + + def __setattr__(self, key, val): + if key in self._RENAMED: + log_first_n( + logging.WARNING, + "Metadata '{}' was renamed to '{}'!".format(key, self._RENAMED[key]), + n=10, + ) + setattr(self, self._RENAMED[key], val) + + # Ensure that metadata of the same name stays consistent + try: + oldval = getattr(self, key) + assert oldval == val, ( + "Attribute '{}' in the metadata of '{}' cannot be set " + "to a different value!\n{} != {}".format(key, self.name, oldval, val) + ) + except AttributeError: + super().__setattr__(key, val) + + def as_dict(self): + """ + Returns all the metadata as a dict. + Note that modifications to the returned dict will not reflect on the Metadata object. + """ + return copy.copy(self.__dict__) + + def set(self, **kwargs): + """ + Set multiple metadata with kwargs. + """ + for k, v in kwargs.items(): + setattr(self, k, v) + return self + + def get(self, key, default=None): + """ + Access an attribute and return its value if exists. + Otherwise return default. + """ + try: + return getattr(self, key) + except AttributeError: + return default + + +class _MetadataCatalog(UserDict): + """ + MetadataCatalog is a global dictionary that provides access to + :class:`Metadata` of a given dataset. + + The metadata associated with a certain name is a singleton: once created, the + metadata will stay alive and will be returned by future calls to ``get(name)``. + + It's like global variables, so don't abuse it. + It's meant for storing knowledge that's constant and shared across the execution + of the program, e.g.: the class names in COCO. + """ + + def get(self, name): + """ + Args: + name (str): name of a dataset (e.g. coco_2014_train). + + Returns: + Metadata: The :class:`Metadata` instance associated with this name, + or create an empty one if none is available. + """ + assert len(name) + r = super().get(name, None) + if r is None: + r = self[name] = Metadata(name=name) + return r + + def list(self): + """ + List all registered metadata. + + Returns: + list[str]: keys (names of datasets) of all registered metadata + """ + return list(self.keys()) + + def remove(self, name): + """ + Alias of ``pop``. + """ + self.pop(name) + + def __str__(self): + return "MetadataCatalog(registered metadata: {})".format(", ".join(self.keys())) + + __repr__ = __str__ + + +MetadataCatalog = _MetadataCatalog() +MetadataCatalog.__doc__ = ( + _MetadataCatalog.__doc__ + + """ + .. automethod:: detectron2.data.catalog.MetadataCatalog.get +""" +) diff --git a/custom_detectron2/data/common.py b/custom_detectron2/data/common.py new file mode 100644 index 0000000000000000000000000000000000000000..42d15cfa7fa1bbf04afce3c616a590ac48554d2e --- /dev/null +++ b/custom_detectron2/data/common.py @@ -0,0 +1,301 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import itertools +import logging +import numpy as np +import pickle +import random +from typing import Callable, Union +import torch +import torch.utils.data as data +from torch.utils.data.sampler import Sampler + +from custom_detectron2.utils.serialize import PicklableWrapper + +__all__ = ["MapDataset", "DatasetFromList", "AspectRatioGroupedDataset", "ToIterableDataset"] + +logger = logging.getLogger(__name__) + + +def _shard_iterator_dataloader_worker(iterable): + # Shard the iterable if we're currently inside pytorch dataloader worker. + worker_info = data.get_worker_info() + if worker_info is None or worker_info.num_workers == 1: + # do nothing + yield from iterable + else: + yield from itertools.islice(iterable, worker_info.id, None, worker_info.num_workers) + + +class _MapIterableDataset(data.IterableDataset): + """ + Map a function over elements in an IterableDataset. + + Similar to pytorch's MapIterDataPipe, but support filtering when map_func + returns None. + + This class is not public-facing. Will be called by `MapDataset`. + """ + + def __init__(self, dataset, map_func): + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + def __len__(self): + return len(self._dataset) + + def __iter__(self): + for x in map(self._map_func, self._dataset): + if x is not None: + yield x + + +class MapDataset(data.Dataset): + """ + Map a function over the elements in a dataset. + """ + + def __init__(self, dataset, map_func): + """ + Args: + dataset: a dataset where map function is applied. Can be either + map-style or iterable dataset. When given an iterable dataset, + the returned object will also be an iterable dataset. + map_func: a callable which maps the element in dataset. map_func can + return None to skip the data (e.g. in case of errors). + How None is handled depends on the style of `dataset`. + If `dataset` is map-style, it randomly tries other elements. + If `dataset` is iterable, it skips the data and tries the next. + """ + self._dataset = dataset + self._map_func = PicklableWrapper(map_func) # wrap so that a lambda will work + + self._rng = random.Random(42) + self._fallback_candidates = set(range(len(dataset))) + + def __new__(cls, dataset, map_func): + is_iterable = isinstance(dataset, data.IterableDataset) + if is_iterable: + return _MapIterableDataset(dataset, map_func) + else: + return super().__new__(cls) + + def __getnewargs__(self): + return self._dataset, self._map_func + + def __len__(self): + return len(self._dataset) + + def __getitem__(self, idx): + retry_count = 0 + cur_idx = int(idx) + + while True: + data = self._map_func(self._dataset[cur_idx]) + if data is not None: + self._fallback_candidates.add(cur_idx) + return data + + # _map_func fails for this idx, use a random new index from the pool + retry_count += 1 + self._fallback_candidates.discard(cur_idx) + cur_idx = self._rng.sample(self._fallback_candidates, k=1)[0] + + if retry_count >= 3: + logger = logging.getLogger(__name__) + logger.warning( + "Failed to apply `_map_func` for idx: {}, retry count: {}".format( + idx, retry_count + ) + ) + + +class _TorchSerializedList(object): + """ + A list-like object whose items are serialized and stored in a torch tensor. When + launching a process that uses TorchSerializedList with "fork" start method, + the subprocess can read the same buffer without triggering copy-on-access. When + launching a process that uses TorchSerializedList with "spawn/forkserver" start + method, the list will be pickled by a special ForkingPickler registered by PyTorch + that moves data to shared memory. In both cases, this allows parent and child + processes to share RAM for the list data, hence avoids the issue in + https://github.com/pytorch/pytorch/issues/13246. + + See also https://ppwwyyxx.com/blog/2022/Demystify-RAM-Usage-in-Multiprocess-DataLoader/ + on how it works. + """ + + def __init__(self, lst: list): + self._lst = lst + + def _serialize(data): + buffer = pickle.dumps(data, protocol=-1) + return np.frombuffer(buffer, dtype=np.uint8) + + logger.info( + "Serializing {} elements to byte tensors and concatenating them all ...".format( + len(self._lst) + ) + ) + self._lst = [_serialize(x) for x in self._lst] + self._addr = np.asarray([len(x) for x in self._lst], dtype=np.int64) + self._addr = torch.from_numpy(np.cumsum(self._addr)) + self._lst = torch.from_numpy(np.concatenate(self._lst)) + logger.info("Serialized dataset takes {:.2f} MiB".format(len(self._lst) / 1024**2)) + + def __len__(self): + return len(self._addr) + + def __getitem__(self, idx): + start_addr = 0 if idx == 0 else self._addr[idx - 1].item() + end_addr = self._addr[idx].item() + bytes = memoryview(self._lst[start_addr:end_addr].numpy()) + + # @lint-ignore PYTHONPICKLEISBAD + return pickle.loads(bytes) + + +_DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = _TorchSerializedList + + +@contextlib.contextmanager +def set_default_dataset_from_list_serialize_method(new): + """ + Context manager for using custom serialize function when creating DatasetFromList + """ + + global _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + orig = _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = new + yield + _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD = orig + + +class DatasetFromList(data.Dataset): + """ + Wrap a list to a torch Dataset. It produces elements of the list as data. + """ + + def __init__( + self, + lst: list, + copy: bool = True, + serialize: Union[bool, Callable] = True, + ): + """ + Args: + lst (list): a list which contains elements to produce. + copy (bool): whether to deepcopy the element when producing it, + so that the result can be modified in place without affecting the + source in the list. + serialize (bool or callable): whether to serialize the stroage to other + backend. If `True`, the default serialize method will be used, if given + a callable, the callable will be used as serialize method. + """ + self._lst = lst + self._copy = copy + if not isinstance(serialize, (bool, Callable)): + raise TypeError(f"Unsupported type for argument `serailzie`: {serialize}") + self._serialize = serialize is not False + + if self._serialize: + serialize_method = ( + serialize + if isinstance(serialize, Callable) + else _DEFAULT_DATASET_FROM_LIST_SERIALIZE_METHOD + ) + logger.info(f"Serializing the dataset using: {serialize_method}") + self._lst = serialize_method(self._lst) + + def __len__(self): + return len(self._lst) + + def __getitem__(self, idx): + if self._copy and not self._serialize: + return copy.deepcopy(self._lst[idx]) + else: + return self._lst[idx] + + +class ToIterableDataset(data.IterableDataset): + """ + Convert an old indices-based (also called map-style) dataset + to an iterable-style dataset. + """ + + def __init__(self, dataset: data.Dataset, sampler: Sampler, shard_sampler: bool = True): + """ + Args: + dataset: an old-style dataset with ``__getitem__`` + sampler: a cheap iterable that produces indices to be applied on ``dataset``. + shard_sampler: whether to shard the sampler based on the current pytorch data loader + worker id. When an IterableDataset is forked by pytorch's DataLoader into multiple + workers, it is responsible for sharding its data based on worker id so that workers + don't produce identical data. + + Most samplers (like our TrainingSampler) do not shard based on dataloader worker id + and this argument should be set to True. But certain samplers may be already + sharded, in that case this argument should be set to False. + """ + assert not isinstance(dataset, data.IterableDataset), dataset + assert isinstance(sampler, Sampler), sampler + self.dataset = dataset + self.sampler = sampler + self.shard_sampler = shard_sampler + + def __iter__(self): + if not self.shard_sampler: + sampler = self.sampler + else: + # With map-style dataset, `DataLoader(dataset, sampler)` runs the + # sampler in main process only. But `DataLoader(ToIterableDataset(dataset, sampler))` + # will run sampler in every of the N worker. So we should only keep 1/N of the ids on + # each worker. The assumption is that sampler is cheap to iterate so it's fine to + # discard ids in workers. + sampler = _shard_iterator_dataloader_worker(self.sampler) + for idx in sampler: + yield self.dataset[idx] + + def __len__(self): + return len(self.sampler) + + +class AspectRatioGroupedDataset(data.IterableDataset): + """ + Batch data that have similar aspect ratio together. + In this implementation, images whose aspect ratio < (or >) 1 will + be batched together. + This improves training speed because the images then need less padding + to form a batch. + + It assumes the underlying dataset produces dicts with "width" and "height" keys. + It will then produce a list of original dicts with length = batch_size, + all with similar aspect ratios. + """ + + def __init__(self, dataset, batch_size): + """ + Args: + dataset: an iterable. Each element must be a dict with keys + "width" and "height", which will be used to batch data. + batch_size (int): + """ + self.dataset = dataset + self.batch_size = batch_size + self._buckets = [[] for _ in range(2)] + # Hard-coded two aspect ratio groups: w > h and w < h. + # Can add support for more aspect ratio groups, but doesn't seem useful + + def __iter__(self): + for d in self.dataset: + w, h = d["width"], d["height"] + bucket_id = 0 if w > h else 1 + bucket = self._buckets[bucket_id] + bucket.append(d) + if len(bucket) == self.batch_size: + data = bucket[:] + # Clear bucket first, because code after yield is not + # guaranteed to execute + del bucket[:] + yield data diff --git a/custom_detectron2/data/dataset_mapper.py b/custom_detectron2/data/dataset_mapper.py new file mode 100644 index 0000000000000000000000000000000000000000..884bbae6e2093974677c8ee8aebb2cae8f43d1b6 --- /dev/null +++ b/custom_detectron2/data/dataset_mapper.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import numpy as np +from typing import List, Optional, Union +import torch + +from custom_detectron2.config import configurable + +from . import detection_utils as utils +from . import transforms as T + +""" +This file contains the default mapping that's applied to "dataset dicts". +""" + +__all__ = ["DatasetMapper"] + + +class DatasetMapper: + """ + A callable which takes a dataset dict in Detectron2 Dataset format, + and map it into a format used by the model. + + This is the default callable to be used to map your dataset dict into training data. + You may need to follow it to implement your own one for customized logic, + such as a different way to read or transform images. + See :doc:`/tutorials/data_loading` for details. + + The callable currently does the following: + + 1. Read the image from "file_name" + 2. Applies cropping/geometric transforms to the image and annotations + 3. Prepare data and annotations to Tensor and :class:`Instances` + """ + + @configurable + def __init__( + self, + is_train: bool, + *, + augmentations: List[Union[T.Augmentation, T.Transform]], + image_format: str, + use_instance_mask: bool = False, + use_keypoint: bool = False, + instance_mask_format: str = "polygon", + keypoint_hflip_indices: Optional[np.ndarray] = None, + precomputed_proposal_topk: Optional[int] = None, + recompute_boxes: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + is_train: whether it's used in training or inference + augmentations: a list of augmentations or deterministic transforms to apply + image_format: an image format supported by :func:`detection_utils.read_image`. + use_instance_mask: whether to process instance segmentation annotations, if available + use_keypoint: whether to process keypoint annotations if available + instance_mask_format: one of "polygon" or "bitmask". Process instance segmentation + masks into this format. + keypoint_hflip_indices: see :func:`detection_utils.create_keypoint_hflip_indices` + precomputed_proposal_topk: if given, will load pre-computed + proposals from dataset_dict and keep the top k proposals for each image. + recompute_boxes: whether to overwrite bounding box annotations + by computing tight bounding boxes from instance mask annotations. + """ + if recompute_boxes: + assert use_instance_mask, "recompute_boxes requires instance masks" + # fmt: off + self.is_train = is_train + self.augmentations = T.AugmentationList(augmentations) + self.image_format = image_format + self.use_instance_mask = use_instance_mask + self.instance_mask_format = instance_mask_format + self.use_keypoint = use_keypoint + self.keypoint_hflip_indices = keypoint_hflip_indices + self.proposal_topk = precomputed_proposal_topk + self.recompute_boxes = recompute_boxes + # fmt: on + logger = logging.getLogger(__name__) + mode = "training" if is_train else "inference" + logger.info(f"[DatasetMapper] Augmentations used in {mode}: {augmentations}") + + @classmethod + def from_config(cls, cfg, is_train: bool = True): + augs = utils.build_augmentation(cfg, is_train) + if cfg.INPUT.CROP.ENABLED and is_train: + augs.insert(0, T.RandomCrop(cfg.INPUT.CROP.TYPE, cfg.INPUT.CROP.SIZE)) + recompute_boxes = cfg.MODEL.MASK_ON + else: + recompute_boxes = False + + ret = { + "is_train": is_train, + "augmentations": augs, + "image_format": cfg.INPUT.FORMAT, + "use_instance_mask": cfg.MODEL.MASK_ON, + "instance_mask_format": cfg.INPUT.MASK_FORMAT, + "use_keypoint": cfg.MODEL.KEYPOINT_ON, + "recompute_boxes": recompute_boxes, + } + + if cfg.MODEL.KEYPOINT_ON: + ret["keypoint_hflip_indices"] = utils.create_keypoint_hflip_indices(cfg.DATASETS.TRAIN) + + if cfg.MODEL.LOAD_PROPOSALS: + ret["precomputed_proposal_topk"] = ( + cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN + if is_train + else cfg.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST + ) + return ret + + def _transform_annotations(self, dataset_dict, transforms, image_shape): + # USER: Modify this if you want to keep them for some reason. + for anno in dataset_dict["annotations"]: + if not self.use_instance_mask: + anno.pop("segmentation", None) + if not self.use_keypoint: + anno.pop("keypoints", None) + + # USER: Implement additional transformations if you have other types of data + annos = [ + utils.transform_instance_annotations( + obj, transforms, image_shape, keypoint_hflip_indices=self.keypoint_hflip_indices + ) + for obj in dataset_dict.pop("annotations") + if obj.get("iscrowd", 0) == 0 + ] + instances = utils.annotations_to_instances( + annos, image_shape, mask_format=self.instance_mask_format + ) + + # After transforms such as cropping are applied, the bounding box may no longer + # tightly bound the object. As an example, imagine a triangle object + # [(0,0), (2,0), (0,2)] cropped by a box [(1,0),(2,2)] (XYXY format). The tight + # bounding box of the cropped triangle should be [(1,0),(2,1)], which is not equal to + # the intersection of original bounding box and the cropping box. + if self.recompute_boxes: + instances.gt_boxes = instances.gt_masks.get_bounding_boxes() + dataset_dict["instances"] = utils.filter_empty_instances(instances) + + def __call__(self, dataset_dict): + """ + Args: + dataset_dict (dict): Metadata of one image, in Detectron2 Dataset format. + + Returns: + dict: a format that builtin models in detectron2 accept + """ + dataset_dict = copy.deepcopy(dataset_dict) # it will be modified by code below + # USER: Write your own image loading if it's not from a file + image = utils.read_image(dataset_dict["file_name"], format=self.image_format) + utils.check_image_size(dataset_dict, image) + + # USER: Remove if you don't do semantic/panoptic segmentation. + if "sem_seg_file_name" in dataset_dict: + sem_seg_gt = utils.read_image(dataset_dict.pop("sem_seg_file_name"), "L").squeeze(2) + else: + sem_seg_gt = None + + aug_input = T.AugInput(image, sem_seg=sem_seg_gt) + transforms = self.augmentations(aug_input) + image, sem_seg_gt = aug_input.image, aug_input.sem_seg + + image_shape = image.shape[:2] # h, w + # Pytorch's dataloader is efficient on torch.Tensor due to shared-memory, + # but not efficient on large generic data structures due to the use of pickle & mp.Queue. + # Therefore it's important to use torch.Tensor. + dataset_dict["image"] = torch.as_tensor(np.ascontiguousarray(image.transpose(2, 0, 1))) + if sem_seg_gt is not None: + dataset_dict["sem_seg"] = torch.as_tensor(sem_seg_gt.astype("long")) + + # USER: Remove if you don't use pre-computed proposals. + # Most users would not need this feature. + if self.proposal_topk is not None: + utils.transform_proposals( + dataset_dict, image_shape, transforms, proposal_topk=self.proposal_topk + ) + + if not self.is_train: + # USER: Modify this if you want to keep them for some reason. + dataset_dict.pop("annotations", None) + dataset_dict.pop("sem_seg_file_name", None) + return dataset_dict + + if "annotations" in dataset_dict: + self._transform_annotations(dataset_dict, transforms, image_shape) + + return dataset_dict diff --git a/custom_detectron2/data/datasets/README.md b/custom_detectron2/data/datasets/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9fb3e4f7afec17137c95c78be6ef06d520ec8032 --- /dev/null +++ b/custom_detectron2/data/datasets/README.md @@ -0,0 +1,9 @@ + + +### Common Datasets + +The dataset implemented here do not need to load the data into the final format. +It should provide the minimal data structure needed to use the dataset, so it can be very efficient. + +For example, for an image dataset, just provide the file names and labels, but don't read the images. +Let the downstream decide how to read. diff --git a/custom_detectron2/data/datasets/__init__.py b/custom_detectron2/data/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a44bedc15e5f0e762fc4d77efd6f1b07c6ff77d0 --- /dev/null +++ b/custom_detectron2/data/datasets/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .coco import load_coco_json, load_sem_seg, register_coco_instances, convert_to_coco_json +from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated +from .lvis import load_lvis_json, register_lvis_instances, get_lvis_instances_meta +from .pascal_voc import load_voc_instances, register_pascal_voc +from . import builtin as _builtin # ensure the builtin datasets are registered + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/data/datasets/builtin.py b/custom_detectron2/data/datasets/builtin.py new file mode 100644 index 0000000000000000000000000000000000000000..1e7409b577b28b3f793143291ab2381bef6e0b20 --- /dev/null +++ b/custom_detectron2/data/datasets/builtin.py @@ -0,0 +1,259 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + + +""" +This file registers pre-defined datasets at hard-coded paths, and their metadata. + +We hard-code metadata for common datasets. This will enable: +1. Consistency check when loading the datasets +2. Use models on these standard datasets directly and run demos, + without having to download the dataset annotations + +We hard-code some paths to the dataset that's assumed to +exist in "./datasets/". + +Users SHOULD NOT use this file to create new dataset / metadata for new dataset. +To add new dataset, refer to the tutorial "docs/DATASETS.md". +""" + +import os + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog + +from .builtin_meta import ADE20K_SEM_SEG_CATEGORIES, _get_builtin_metadata +from .cityscapes import load_cityscapes_instances, load_cityscapes_semantic +from .cityscapes_panoptic import register_all_cityscapes_panoptic +from .coco import load_sem_seg, register_coco_instances +from .coco_panoptic import register_coco_panoptic, register_coco_panoptic_separated +from .lvis import get_lvis_instances_meta, register_lvis_instances +from .pascal_voc import register_pascal_voc + +# ==== Predefined datasets and splits for COCO ========== + +_PREDEFINED_SPLITS_COCO = {} +_PREDEFINED_SPLITS_COCO["coco"] = { + "coco_2014_train": ("coco/train2014", "coco/annotations/instances_train2014.json"), + "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"), + "coco_2014_minival": ("coco/val2014", "coco/annotations/instances_minival2014.json"), + "coco_2014_valminusminival": ( + "coco/val2014", + "coco/annotations/instances_valminusminival2014.json", + ), + "coco_2017_train": ("coco/train2017", "coco/annotations/instances_train2017.json"), + "coco_2017_val": ("coco/val2017", "coco/annotations/instances_val2017.json"), + "coco_2017_test": ("coco/test2017", "coco/annotations/image_info_test2017.json"), + "coco_2017_test-dev": ("coco/test2017", "coco/annotations/image_info_test-dev2017.json"), + "coco_2017_val_100": ("coco/val2017", "coco/annotations/instances_val2017_100.json"), +} + +_PREDEFINED_SPLITS_COCO["coco_person"] = { + "keypoints_coco_2014_train": ( + "coco/train2014", + "coco/annotations/person_keypoints_train2014.json", + ), + "keypoints_coco_2014_val": ("coco/val2014", "coco/annotations/person_keypoints_val2014.json"), + "keypoints_coco_2014_minival": ( + "coco/val2014", + "coco/annotations/person_keypoints_minival2014.json", + ), + "keypoints_coco_2014_valminusminival": ( + "coco/val2014", + "coco/annotations/person_keypoints_valminusminival2014.json", + ), + "keypoints_coco_2017_train": ( + "coco/train2017", + "coco/annotations/person_keypoints_train2017.json", + ), + "keypoints_coco_2017_val": ("coco/val2017", "coco/annotations/person_keypoints_val2017.json"), + "keypoints_coco_2017_val_100": ( + "coco/val2017", + "coco/annotations/person_keypoints_val2017_100.json", + ), +} + + +_PREDEFINED_SPLITS_COCO_PANOPTIC = { + "coco_2017_train_panoptic": ( + # This is the original panoptic annotation directory + "coco/panoptic_train2017", + "coco/annotations/panoptic_train2017.json", + # This directory contains semantic annotations that are + # converted from panoptic annotations. + # It is used by PanopticFPN. + # You can use the script at detectron2/datasets/prepare_panoptic_fpn.py + # to create these directories. + "coco/panoptic_stuff_train2017", + ), + "coco_2017_val_panoptic": ( + "coco/panoptic_val2017", + "coco/annotations/panoptic_val2017.json", + "coco/panoptic_stuff_val2017", + ), + "coco_2017_val_100_panoptic": ( + "coco/panoptic_val2017_100", + "coco/annotations/panoptic_val2017_100.json", + "coco/panoptic_stuff_val2017_100", + ), +} + + +def register_all_coco(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_COCO.items(): + for key, (image_root, json_file) in splits_per_dataset.items(): + # Assume pre-defined datasets live in `./datasets`. + register_coco_instances( + key, + _get_builtin_metadata(dataset_name), + os.path.join(root, json_file) if "://" not in json_file else json_file, + os.path.join(root, image_root), + ) + + for ( + prefix, + (panoptic_root, panoptic_json, semantic_root), + ) in _PREDEFINED_SPLITS_COCO_PANOPTIC.items(): + prefix_instances = prefix[: -len("_panoptic")] + instances_meta = MetadataCatalog.get(prefix_instances) + image_root, instances_json = instances_meta.image_root, instances_meta.json_file + # The "separated" version of COCO panoptic segmentation dataset, + # e.g. used by Panoptic FPN + register_coco_panoptic_separated( + prefix, + _get_builtin_metadata("coco_panoptic_separated"), + image_root, + os.path.join(root, panoptic_root), + os.path.join(root, panoptic_json), + os.path.join(root, semantic_root), + instances_json, + ) + # The "standard" version of COCO panoptic segmentation dataset, + # e.g. used by Panoptic-DeepLab + register_coco_panoptic( + prefix, + _get_builtin_metadata("coco_panoptic_standard"), + image_root, + os.path.join(root, panoptic_root), + os.path.join(root, panoptic_json), + instances_json, + ) + + +# ==== Predefined datasets and splits for LVIS ========== + + +_PREDEFINED_SPLITS_LVIS = { + "lvis_v1": { + "lvis_v1_train": ("coco/", "lvis/lvis_v1_train.json"), + "lvis_v1_val": ("coco/", "lvis/lvis_v1_val.json"), + "lvis_v1_test_dev": ("coco/", "lvis/lvis_v1_image_info_test_dev.json"), + "lvis_v1_test_challenge": ("coco/", "lvis/lvis_v1_image_info_test_challenge.json"), + }, + "lvis_v0.5": { + "lvis_v0.5_train": ("coco/", "lvis/lvis_v0.5_train.json"), + "lvis_v0.5_val": ("coco/", "lvis/lvis_v0.5_val.json"), + "lvis_v0.5_val_rand_100": ("coco/", "lvis/lvis_v0.5_val_rand_100.json"), + "lvis_v0.5_test": ("coco/", "lvis/lvis_v0.5_image_info_test.json"), + }, + "lvis_v0.5_cocofied": { + "lvis_v0.5_train_cocofied": ("coco/", "lvis/lvis_v0.5_train_cocofied.json"), + "lvis_v0.5_val_cocofied": ("coco/", "lvis/lvis_v0.5_val_cocofied.json"), + }, +} + + +def register_all_lvis(root): + for dataset_name, splits_per_dataset in _PREDEFINED_SPLITS_LVIS.items(): + for key, (image_root, json_file) in splits_per_dataset.items(): + register_lvis_instances( + key, + get_lvis_instances_meta(dataset_name), + os.path.join(root, json_file) if "://" not in json_file else json_file, + os.path.join(root, image_root), + ) + + +# ==== Predefined splits for raw cityscapes images =========== +_RAW_CITYSCAPES_SPLITS = { + "cityscapes_fine_{task}_train": ("cityscapes/leftImg8bit/train/", "cityscapes/gtFine/train/"), + "cityscapes_fine_{task}_val": ("cityscapes/leftImg8bit/val/", "cityscapes/gtFine/val/"), + "cityscapes_fine_{task}_test": ("cityscapes/leftImg8bit/test/", "cityscapes/gtFine/test/"), +} + + +def register_all_cityscapes(root): + for key, (image_dir, gt_dir) in _RAW_CITYSCAPES_SPLITS.items(): + meta = _get_builtin_metadata("cityscapes") + image_dir = os.path.join(root, image_dir) + gt_dir = os.path.join(root, gt_dir) + + inst_key = key.format(task="instance_seg") + DatasetCatalog.register( + inst_key, + lambda x=image_dir, y=gt_dir: load_cityscapes_instances( + x, y, from_json=True, to_polygons=True + ), + ) + MetadataCatalog.get(inst_key).set( + image_dir=image_dir, gt_dir=gt_dir, evaluator_type="cityscapes_instance", **meta + ) + + sem_key = key.format(task="sem_seg") + DatasetCatalog.register( + sem_key, lambda x=image_dir, y=gt_dir: load_cityscapes_semantic(x, y) + ) + MetadataCatalog.get(sem_key).set( + image_dir=image_dir, + gt_dir=gt_dir, + evaluator_type="cityscapes_sem_seg", + ignore_label=255, + **meta, + ) + + +# ==== Predefined splits for PASCAL VOC =========== +def register_all_pascal_voc(root): + SPLITS = [ + ("voc_2007_trainval", "VOC2007", "trainval"), + ("voc_2007_train", "VOC2007", "train"), + ("voc_2007_val", "VOC2007", "val"), + ("voc_2007_test", "VOC2007", "test"), + ("voc_2012_trainval", "VOC2012", "trainval"), + ("voc_2012_train", "VOC2012", "train"), + ("voc_2012_val", "VOC2012", "val"), + ] + for name, dirname, split in SPLITS: + year = 2007 if "2007" in name else 2012 + register_pascal_voc(name, os.path.join(root, dirname), split, year) + MetadataCatalog.get(name).evaluator_type = "pascal_voc" + + +def register_all_ade20k(root): + root = os.path.join(root, "ADEChallengeData2016") + for name, dirname in [("train", "training"), ("val", "validation")]: + image_dir = os.path.join(root, "images", dirname) + gt_dir = os.path.join(root, "annotations_detectron2", dirname) + name = f"ade20k_sem_seg_{name}" + DatasetCatalog.register( + name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg") + ) + MetadataCatalog.get(name).set( + stuff_classes=ADE20K_SEM_SEG_CATEGORIES[:], + image_root=image_dir, + sem_seg_root=gt_dir, + evaluator_type="sem_seg", + ignore_label=255, + ) + + +# True for open source; +# Internally at fb, we register them elsewhere +if __name__.endswith(".builtin"): + # Assume pre-defined datasets live in `./datasets`. + _root = os.path.expanduser(os.getenv("DETECTRON2_DATASETS", "datasets")) + register_all_coco(_root) + register_all_lvis(_root) + register_all_cityscapes(_root) + register_all_cityscapes_panoptic(_root) + register_all_pascal_voc(_root) + register_all_ade20k(_root) diff --git a/custom_detectron2/data/datasets/builtin_meta.py b/custom_detectron2/data/datasets/builtin_meta.py new file mode 100644 index 0000000000000000000000000000000000000000..63c7a1a31b31dd89b82011effee26471faccacf5 --- /dev/null +++ b/custom_detectron2/data/datasets/builtin_meta.py @@ -0,0 +1,350 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +Note: +For your custom dataset, there is no need to hard-code metadata anywhere in the code. +For example, for COCO-format dataset, metadata will be obtained automatically +when calling `load_coco_json`. For other dataset, metadata may also be obtained in other ways +during loading. + +However, we hard-coded metadata for a few common dataset here. +The only goal is to allow users who don't have these dataset to use pre-trained models. +Users don't have to download a COCO json (which contains metadata), in order to visualize a +COCO model (with correct class names and colors). +""" + + +# All coco categories, together with their nice-looking visualization colors +# It's from https://github.com/cocodataset/panopticapi/blob/master/panoptic_coco_categories.json +COCO_CATEGORIES = [ + {"color": [220, 20, 60], "isthing": 1, "id": 1, "name": "person"}, + {"color": [119, 11, 32], "isthing": 1, "id": 2, "name": "bicycle"}, + {"color": [0, 0, 142], "isthing": 1, "id": 3, "name": "car"}, + {"color": [0, 0, 230], "isthing": 1, "id": 4, "name": "motorcycle"}, + {"color": [106, 0, 228], "isthing": 1, "id": 5, "name": "airplane"}, + {"color": [0, 60, 100], "isthing": 1, "id": 6, "name": "bus"}, + {"color": [0, 80, 100], "isthing": 1, "id": 7, "name": "train"}, + {"color": [0, 0, 70], "isthing": 1, "id": 8, "name": "truck"}, + {"color": [0, 0, 192], "isthing": 1, "id": 9, "name": "boat"}, + {"color": [250, 170, 30], "isthing": 1, "id": 10, "name": "traffic light"}, + {"color": [100, 170, 30], "isthing": 1, "id": 11, "name": "fire hydrant"}, + {"color": [220, 220, 0], "isthing": 1, "id": 13, "name": "stop sign"}, + {"color": [175, 116, 175], "isthing": 1, "id": 14, "name": "parking meter"}, + {"color": [250, 0, 30], "isthing": 1, "id": 15, "name": "bench"}, + {"color": [165, 42, 42], "isthing": 1, "id": 16, "name": "bird"}, + {"color": [255, 77, 255], "isthing": 1, "id": 17, "name": "cat"}, + {"color": [0, 226, 252], "isthing": 1, "id": 18, "name": "dog"}, + {"color": [182, 182, 255], "isthing": 1, "id": 19, "name": "horse"}, + {"color": [0, 82, 0], "isthing": 1, "id": 20, "name": "sheep"}, + {"color": [120, 166, 157], "isthing": 1, "id": 21, "name": "cow"}, + {"color": [110, 76, 0], "isthing": 1, "id": 22, "name": "elephant"}, + {"color": [174, 57, 255], "isthing": 1, "id": 23, "name": "bear"}, + {"color": [199, 100, 0], "isthing": 1, "id": 24, "name": "zebra"}, + {"color": [72, 0, 118], "isthing": 1, "id": 25, "name": "giraffe"}, + {"color": [255, 179, 240], "isthing": 1, "id": 27, "name": "backpack"}, + {"color": [0, 125, 92], "isthing": 1, "id": 28, "name": "umbrella"}, + {"color": [209, 0, 151], "isthing": 1, "id": 31, "name": "handbag"}, + {"color": [188, 208, 182], "isthing": 1, "id": 32, "name": "tie"}, + {"color": [0, 220, 176], "isthing": 1, "id": 33, "name": "suitcase"}, + {"color": [255, 99, 164], "isthing": 1, "id": 34, "name": "frisbee"}, + {"color": [92, 0, 73], "isthing": 1, "id": 35, "name": "skis"}, + {"color": [133, 129, 255], "isthing": 1, "id": 36, "name": "snowboard"}, + {"color": [78, 180, 255], "isthing": 1, "id": 37, "name": "sports ball"}, + {"color": [0, 228, 0], "isthing": 1, "id": 38, "name": "kite"}, + {"color": [174, 255, 243], "isthing": 1, "id": 39, "name": "baseball bat"}, + {"color": [45, 89, 255], "isthing": 1, "id": 40, "name": "baseball glove"}, + {"color": [134, 134, 103], "isthing": 1, "id": 41, "name": "skateboard"}, + {"color": [145, 148, 174], "isthing": 1, "id": 42, "name": "surfboard"}, + {"color": [255, 208, 186], "isthing": 1, "id": 43, "name": "tennis racket"}, + {"color": [197, 226, 255], "isthing": 1, "id": 44, "name": "bottle"}, + {"color": [171, 134, 1], "isthing": 1, "id": 46, "name": "wine glass"}, + {"color": [109, 63, 54], "isthing": 1, "id": 47, "name": "cup"}, + {"color": [207, 138, 255], "isthing": 1, "id": 48, "name": "fork"}, + {"color": [151, 0, 95], "isthing": 1, "id": 49, "name": "knife"}, + {"color": [9, 80, 61], "isthing": 1, "id": 50, "name": "spoon"}, + {"color": [84, 105, 51], "isthing": 1, "id": 51, "name": "bowl"}, + {"color": [74, 65, 105], "isthing": 1, "id": 52, "name": "banana"}, + {"color": [166, 196, 102], "isthing": 1, "id": 53, "name": "apple"}, + {"color": [208, 195, 210], "isthing": 1, "id": 54, "name": "sandwich"}, + {"color": [255, 109, 65], "isthing": 1, "id": 55, "name": "orange"}, + {"color": [0, 143, 149], "isthing": 1, "id": 56, "name": "broccoli"}, + {"color": [179, 0, 194], "isthing": 1, "id": 57, "name": "carrot"}, + {"color": [209, 99, 106], "isthing": 1, "id": 58, "name": "hot dog"}, + {"color": [5, 121, 0], "isthing": 1, "id": 59, "name": "pizza"}, + {"color": [227, 255, 205], "isthing": 1, "id": 60, "name": "donut"}, + {"color": [147, 186, 208], "isthing": 1, "id": 61, "name": "cake"}, + {"color": [153, 69, 1], "isthing": 1, "id": 62, "name": "chair"}, + {"color": [3, 95, 161], "isthing": 1, "id": 63, "name": "couch"}, + {"color": [163, 255, 0], "isthing": 1, "id": 64, "name": "potted plant"}, + {"color": [119, 0, 170], "isthing": 1, "id": 65, "name": "bed"}, + {"color": [0, 182, 199], "isthing": 1, "id": 67, "name": "dining table"}, + {"color": [0, 165, 120], "isthing": 1, "id": 70, "name": "toilet"}, + {"color": [183, 130, 88], "isthing": 1, "id": 72, "name": "tv"}, + {"color": [95, 32, 0], "isthing": 1, "id": 73, "name": "laptop"}, + {"color": [130, 114, 135], "isthing": 1, "id": 74, "name": "mouse"}, + {"color": [110, 129, 133], "isthing": 1, "id": 75, "name": "remote"}, + {"color": [166, 74, 118], "isthing": 1, "id": 76, "name": "keyboard"}, + {"color": [219, 142, 185], "isthing": 1, "id": 77, "name": "cell phone"}, + {"color": [79, 210, 114], "isthing": 1, "id": 78, "name": "microwave"}, + {"color": [178, 90, 62], "isthing": 1, "id": 79, "name": "oven"}, + {"color": [65, 70, 15], "isthing": 1, "id": 80, "name": "toaster"}, + {"color": [127, 167, 115], "isthing": 1, "id": 81, "name": "sink"}, + {"color": [59, 105, 106], "isthing": 1, "id": 82, "name": "refrigerator"}, + {"color": [142, 108, 45], "isthing": 1, "id": 84, "name": "book"}, + {"color": [196, 172, 0], "isthing": 1, "id": 85, "name": "clock"}, + {"color": [95, 54, 80], "isthing": 1, "id": 86, "name": "vase"}, + {"color": [128, 76, 255], "isthing": 1, "id": 87, "name": "scissors"}, + {"color": [201, 57, 1], "isthing": 1, "id": 88, "name": "teddy bear"}, + {"color": [246, 0, 122], "isthing": 1, "id": 89, "name": "hair drier"}, + {"color": [191, 162, 208], "isthing": 1, "id": 90, "name": "toothbrush"}, + {"color": [255, 255, 128], "isthing": 0, "id": 92, "name": "banner"}, + {"color": [147, 211, 203], "isthing": 0, "id": 93, "name": "blanket"}, + {"color": [150, 100, 100], "isthing": 0, "id": 95, "name": "bridge"}, + {"color": [168, 171, 172], "isthing": 0, "id": 100, "name": "cardboard"}, + {"color": [146, 112, 198], "isthing": 0, "id": 107, "name": "counter"}, + {"color": [210, 170, 100], "isthing": 0, "id": 109, "name": "curtain"}, + {"color": [92, 136, 89], "isthing": 0, "id": 112, "name": "door-stuff"}, + {"color": [218, 88, 184], "isthing": 0, "id": 118, "name": "floor-wood"}, + {"color": [241, 129, 0], "isthing": 0, "id": 119, "name": "flower"}, + {"color": [217, 17, 255], "isthing": 0, "id": 122, "name": "fruit"}, + {"color": [124, 74, 181], "isthing": 0, "id": 125, "name": "gravel"}, + {"color": [70, 70, 70], "isthing": 0, "id": 128, "name": "house"}, + {"color": [255, 228, 255], "isthing": 0, "id": 130, "name": "light"}, + {"color": [154, 208, 0], "isthing": 0, "id": 133, "name": "mirror-stuff"}, + {"color": [193, 0, 92], "isthing": 0, "id": 138, "name": "net"}, + {"color": [76, 91, 113], "isthing": 0, "id": 141, "name": "pillow"}, + {"color": [255, 180, 195], "isthing": 0, "id": 144, "name": "platform"}, + {"color": [106, 154, 176], "isthing": 0, "id": 145, "name": "playingfield"}, + {"color": [230, 150, 140], "isthing": 0, "id": 147, "name": "railroad"}, + {"color": [60, 143, 255], "isthing": 0, "id": 148, "name": "river"}, + {"color": [128, 64, 128], "isthing": 0, "id": 149, "name": "road"}, + {"color": [92, 82, 55], "isthing": 0, "id": 151, "name": "roof"}, + {"color": [254, 212, 124], "isthing": 0, "id": 154, "name": "sand"}, + {"color": [73, 77, 174], "isthing": 0, "id": 155, "name": "sea"}, + {"color": [255, 160, 98], "isthing": 0, "id": 156, "name": "shelf"}, + {"color": [255, 255, 255], "isthing": 0, "id": 159, "name": "snow"}, + {"color": [104, 84, 109], "isthing": 0, "id": 161, "name": "stairs"}, + {"color": [169, 164, 131], "isthing": 0, "id": 166, "name": "tent"}, + {"color": [225, 199, 255], "isthing": 0, "id": 168, "name": "towel"}, + {"color": [137, 54, 74], "isthing": 0, "id": 171, "name": "wall-brick"}, + {"color": [135, 158, 223], "isthing": 0, "id": 175, "name": "wall-stone"}, + {"color": [7, 246, 231], "isthing": 0, "id": 176, "name": "wall-tile"}, + {"color": [107, 255, 200], "isthing": 0, "id": 177, "name": "wall-wood"}, + {"color": [58, 41, 149], "isthing": 0, "id": 178, "name": "water-other"}, + {"color": [183, 121, 142], "isthing": 0, "id": 180, "name": "window-blind"}, + {"color": [255, 73, 97], "isthing": 0, "id": 181, "name": "window-other"}, + {"color": [107, 142, 35], "isthing": 0, "id": 184, "name": "tree-merged"}, + {"color": [190, 153, 153], "isthing": 0, "id": 185, "name": "fence-merged"}, + {"color": [146, 139, 141], "isthing": 0, "id": 186, "name": "ceiling-merged"}, + {"color": [70, 130, 180], "isthing": 0, "id": 187, "name": "sky-other-merged"}, + {"color": [134, 199, 156], "isthing": 0, "id": 188, "name": "cabinet-merged"}, + {"color": [209, 226, 140], "isthing": 0, "id": 189, "name": "table-merged"}, + {"color": [96, 36, 108], "isthing": 0, "id": 190, "name": "floor-other-merged"}, + {"color": [96, 96, 96], "isthing": 0, "id": 191, "name": "pavement-merged"}, + {"color": [64, 170, 64], "isthing": 0, "id": 192, "name": "mountain-merged"}, + {"color": [152, 251, 152], "isthing": 0, "id": 193, "name": "grass-merged"}, + {"color": [208, 229, 228], "isthing": 0, "id": 194, "name": "dirt-merged"}, + {"color": [206, 186, 171], "isthing": 0, "id": 195, "name": "paper-merged"}, + {"color": [152, 161, 64], "isthing": 0, "id": 196, "name": "food-other-merged"}, + {"color": [116, 112, 0], "isthing": 0, "id": 197, "name": "building-other-merged"}, + {"color": [0, 114, 143], "isthing": 0, "id": 198, "name": "rock-merged"}, + {"color": [102, 102, 156], "isthing": 0, "id": 199, "name": "wall-other-merged"}, + {"color": [250, 141, 255], "isthing": 0, "id": 200, "name": "rug-merged"}, +] + +# fmt: off +COCO_PERSON_KEYPOINT_NAMES = ( + "nose", + "left_eye", "right_eye", + "left_ear", "right_ear", + "left_shoulder", "right_shoulder", + "left_elbow", "right_elbow", + "left_wrist", "right_wrist", + "left_hip", "right_hip", + "left_knee", "right_knee", + "left_ankle", "right_ankle", +) +# fmt: on + +# Pairs of keypoints that should be exchanged under horizontal flipping +COCO_PERSON_KEYPOINT_FLIP_MAP = ( + ("left_eye", "right_eye"), + ("left_ear", "right_ear"), + ("left_shoulder", "right_shoulder"), + ("left_elbow", "right_elbow"), + ("left_wrist", "right_wrist"), + ("left_hip", "right_hip"), + ("left_knee", "right_knee"), + ("left_ankle", "right_ankle"), +) + +# rules for pairs of keypoints to draw a line between, and the line color to use. +KEYPOINT_CONNECTION_RULES = [ + # face + ("left_ear", "left_eye", (102, 204, 255)), + ("right_ear", "right_eye", (51, 153, 255)), + ("left_eye", "nose", (102, 0, 204)), + ("nose", "right_eye", (51, 102, 255)), + # upper-body + ("left_shoulder", "right_shoulder", (255, 128, 0)), + ("left_shoulder", "left_elbow", (153, 255, 204)), + ("right_shoulder", "right_elbow", (128, 229, 255)), + ("left_elbow", "left_wrist", (153, 255, 153)), + ("right_elbow", "right_wrist", (102, 255, 224)), + # lower-body + ("left_hip", "right_hip", (255, 102, 0)), + ("left_hip", "left_knee", (255, 255, 77)), + ("right_hip", "right_knee", (153, 255, 204)), + ("left_knee", "left_ankle", (191, 255, 128)), + ("right_knee", "right_ankle", (255, 195, 77)), +] + +# All Cityscapes categories, together with their nice-looking visualization colors +# It's from https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/helpers/labels.py # noqa +CITYSCAPES_CATEGORIES = [ + {"color": (128, 64, 128), "isthing": 0, "id": 7, "trainId": 0, "name": "road"}, + {"color": (244, 35, 232), "isthing": 0, "id": 8, "trainId": 1, "name": "sidewalk"}, + {"color": (70, 70, 70), "isthing": 0, "id": 11, "trainId": 2, "name": "building"}, + {"color": (102, 102, 156), "isthing": 0, "id": 12, "trainId": 3, "name": "wall"}, + {"color": (190, 153, 153), "isthing": 0, "id": 13, "trainId": 4, "name": "fence"}, + {"color": (153, 153, 153), "isthing": 0, "id": 17, "trainId": 5, "name": "pole"}, + {"color": (250, 170, 30), "isthing": 0, "id": 19, "trainId": 6, "name": "traffic light"}, + {"color": (220, 220, 0), "isthing": 0, "id": 20, "trainId": 7, "name": "traffic sign"}, + {"color": (107, 142, 35), "isthing": 0, "id": 21, "trainId": 8, "name": "vegetation"}, + {"color": (152, 251, 152), "isthing": 0, "id": 22, "trainId": 9, "name": "terrain"}, + {"color": (70, 130, 180), "isthing": 0, "id": 23, "trainId": 10, "name": "sky"}, + {"color": (220, 20, 60), "isthing": 1, "id": 24, "trainId": 11, "name": "person"}, + {"color": (255, 0, 0), "isthing": 1, "id": 25, "trainId": 12, "name": "rider"}, + {"color": (0, 0, 142), "isthing": 1, "id": 26, "trainId": 13, "name": "car"}, + {"color": (0, 0, 70), "isthing": 1, "id": 27, "trainId": 14, "name": "truck"}, + {"color": (0, 60, 100), "isthing": 1, "id": 28, "trainId": 15, "name": "bus"}, + {"color": (0, 80, 100), "isthing": 1, "id": 31, "trainId": 16, "name": "train"}, + {"color": (0, 0, 230), "isthing": 1, "id": 32, "trainId": 17, "name": "motorcycle"}, + {"color": (119, 11, 32), "isthing": 1, "id": 33, "trainId": 18, "name": "bicycle"}, +] + +# fmt: off +ADE20K_SEM_SEG_CATEGORIES = [ + "wall", "building", "sky", "floor", "tree", "ceiling", "road, route", "bed", "window ", "grass", "cabinet", "sidewalk, pavement", "person", "earth, ground", "door", "table", "mountain, mount", "plant", "curtain", "chair", "car", "water", "painting, picture", "sofa", "shelf", "house", "sea", "mirror", "rug", "field", "armchair", "seat", "fence", "desk", "rock, stone", "wardrobe, closet, press", "lamp", "tub", "rail", "cushion", "base, pedestal, stand", "box", "column, pillar", "signboard, sign", "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink", "skyscraper", "fireplace", "refrigerator, icebox", "grandstand, covered stand", "path", "stairs", "runway", "case, display case, showcase, vitrine", "pool table, billiard table, snooker table", "pillow", "screen door, screen", "stairway, staircase", "river", "bridge, span", "bookcase", "blind, screen", "coffee table", "toilet, can, commode, crapper, pot, potty, stool, throne", "flower", "book", "hill", "bench", "countertop", "stove", "palm, palm tree", "kitchen island", "computer", "swivel chair", "boat", "bar", "arcade machine", "hovel, hut, hutch, shack, shanty", "bus", "towel", "light", "truck", "tower", "chandelier", "awning, sunshade, sunblind", "street lamp", "booth", "tv", "plane", "dirt track", "clothes", "pole", "land, ground, soil", "bannister, banister, balustrade, balusters, handrail", "escalator, moving staircase, moving stairway", "ottoman, pouf, pouffe, puff, hassock", "bottle", "buffet, counter, sideboard", "poster, posting, placard, notice, bill, card", "stage", "van", "ship", "fountain", "conveyer belt, conveyor belt, conveyer, conveyor, transporter", "canopy", "washer, automatic washer, washing machine", "plaything, toy", "pool", "stool", "barrel, cask", "basket, handbasket", "falls", "tent", "bag", "minibike, motorbike", "cradle", "oven", "ball", "food, solid food", "step, stair", "tank, storage tank", "trade name", "microwave", "pot", "animal", "bicycle", "lake", "dishwasher", "screen", "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase", "traffic light", "tray", "trash can", "fan", "pier", "crt screen", "plate", "monitor", "bulletin board", "shower", "radiator", "glass, drinking glass", "clock", "flag", # noqa +] +# After processed by `prepare_ade20k_sem_seg.py`, id 255 means ignore +# fmt: on + + +def _get_coco_instances_meta(): + thing_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 1] + thing_colors = [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 1] + assert len(thing_ids) == 80, len(thing_ids) + # Mapping from the incontiguous COCO category id to an id in [0, 79] + thing_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(thing_ids)} + thing_classes = [k["name"] for k in COCO_CATEGORIES if k["isthing"] == 1] + ret = { + "thing_dataset_id_to_contiguous_id": thing_dataset_id_to_contiguous_id, + "thing_classes": thing_classes, + "thing_colors": thing_colors, + } + return ret + + +def _get_coco_panoptic_separated_meta(): + """ + Returns metadata for "separated" version of the panoptic segmentation dataset. + """ + stuff_ids = [k["id"] for k in COCO_CATEGORIES if k["isthing"] == 0] + assert len(stuff_ids) == 53, len(stuff_ids) + + # For semantic segmentation, this mapping maps from contiguous stuff id + # (in [0, 53], used in models) to ids in the dataset (used for processing results) + # The id 0 is mapped to an extra category "thing". + stuff_dataset_id_to_contiguous_id = {k: i + 1 for i, k in enumerate(stuff_ids)} + # When converting COCO panoptic annotations to semantic annotations + # We label the "thing" category to 0 + stuff_dataset_id_to_contiguous_id[0] = 0 + + # 54 names for COCO stuff categories (including "things") + stuff_classes = ["things"] + [ + k["name"].replace("-other", "").replace("-merged", "") + for k in COCO_CATEGORIES + if k["isthing"] == 0 + ] + + # NOTE: I randomly picked a color for things + stuff_colors = [[82, 18, 128]] + [k["color"] for k in COCO_CATEGORIES if k["isthing"] == 0] + ret = { + "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id, + "stuff_classes": stuff_classes, + "stuff_colors": stuff_colors, + } + ret.update(_get_coco_instances_meta()) + return ret + + +def _get_builtin_metadata(dataset_name): + if dataset_name == "coco": + return _get_coco_instances_meta() + if dataset_name == "coco_panoptic_separated": + return _get_coco_panoptic_separated_meta() + elif dataset_name == "coco_panoptic_standard": + meta = {} + # The following metadata maps contiguous id from [0, #thing categories + + # #stuff categories) to their names and colors. We have to replica of the + # same name and color under "thing_*" and "stuff_*" because the current + # visualization function in D2 handles thing and class classes differently + # due to some heuristic used in Panoptic FPN. We keep the same naming to + # enable reusing existing visualization functions. + thing_classes = [k["name"] for k in COCO_CATEGORIES] + thing_colors = [k["color"] for k in COCO_CATEGORIES] + stuff_classes = [k["name"] for k in COCO_CATEGORIES] + stuff_colors = [k["color"] for k in COCO_CATEGORIES] + + meta["thing_classes"] = thing_classes + meta["thing_colors"] = thing_colors + meta["stuff_classes"] = stuff_classes + meta["stuff_colors"] = stuff_colors + + # Convert category id for training: + # category id: like semantic segmentation, it is the class id for each + # pixel. Since there are some classes not used in evaluation, the category + # id is not always contiguous and thus we have two set of category ids: + # - original category id: category id in the original dataset, mainly + # used for evaluation. + # - contiguous category id: [0, #classes), in order to train the linear + # softmax classifier. + thing_dataset_id_to_contiguous_id = {} + stuff_dataset_id_to_contiguous_id = {} + + for i, cat in enumerate(COCO_CATEGORIES): + if cat["isthing"]: + thing_dataset_id_to_contiguous_id[cat["id"]] = i + else: + stuff_dataset_id_to_contiguous_id[cat["id"]] = i + + meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id + meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id + + return meta + elif dataset_name == "coco_person": + return { + "thing_classes": ["person"], + "keypoint_names": COCO_PERSON_KEYPOINT_NAMES, + "keypoint_flip_map": COCO_PERSON_KEYPOINT_FLIP_MAP, + "keypoint_connection_rules": KEYPOINT_CONNECTION_RULES, + } + elif dataset_name == "cityscapes": + # fmt: off + CITYSCAPES_THING_CLASSES = [ + "person", "rider", "car", "truck", + "bus", "train", "motorcycle", "bicycle", + ] + CITYSCAPES_STUFF_CLASSES = [ + "road", "sidewalk", "building", "wall", "fence", "pole", "traffic light", + "traffic sign", "vegetation", "terrain", "sky", "person", "rider", "car", + "truck", "bus", "train", "motorcycle", "bicycle", + ] + # fmt: on + return { + "thing_classes": CITYSCAPES_THING_CLASSES, + "stuff_classes": CITYSCAPES_STUFF_CLASSES, + } + raise KeyError("No built-in metadata for dataset {}".format(dataset_name)) diff --git a/custom_detectron2/data/datasets/cityscapes.py b/custom_detectron2/data/datasets/cityscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..ddaa4e0e5057bc061edb5d965122a1b478c580ab --- /dev/null +++ b/custom_detectron2/data/datasets/cityscapes.py @@ -0,0 +1,329 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import functools +import json +import logging +import multiprocessing as mp +import numpy as np +import os +from itertools import chain +import custom_pycocotools.mask as mask_util +from PIL import Image + +from custom_detectron2.structures import BoxMode +from custom_detectron2.utils.comm import get_world_size +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.logger import setup_logger + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + + +logger = logging.getLogger(__name__) + + +def _get_cityscapes_files(image_dir, gt_dir): + files = [] + # scan through the directory + cities = PathManager.ls(image_dir) + logger.info(f"{len(cities)} cities found in '{image_dir}'.") + for city in cities: + city_img_dir = os.path.join(image_dir, city) + city_gt_dir = os.path.join(gt_dir, city) + for basename in PathManager.ls(city_img_dir): + image_file = os.path.join(city_img_dir, basename) + + suffix = "leftImg8bit.png" + assert basename.endswith(suffix), basename + basename = basename[: -len(suffix)] + + instance_file = os.path.join(city_gt_dir, basename + "gtFine_instanceIds.png") + label_file = os.path.join(city_gt_dir, basename + "gtFine_labelIds.png") + json_file = os.path.join(city_gt_dir, basename + "gtFine_polygons.json") + + files.append((image_file, instance_file, label_file, json_file)) + assert len(files), "No images found in {}".format(image_dir) + for f in files[0]: + assert PathManager.isfile(f), f + return files + + +def load_cityscapes_instances(image_dir, gt_dir, from_json=True, to_polygons=True): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train". + from_json (bool): whether to read annotations from the raw json file or the png files. + to_polygons (bool): whether to represent the segmentation as polygons + (COCO's format) instead of masks (cityscapes's format). + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + if from_json: + assert to_polygons, ( + "Cityscapes's json annotations are in polygon format. " + "Converting to mask format is not supported now." + ) + files = _get_cityscapes_files(image_dir, gt_dir) + + logger.info("Preprocessing cityscapes annotations ...") + # This is still not fast: all workers will execute duplicate works and will + # take up to 10m on a 8GPU server. + pool = mp.Pool(processes=max(mp.cpu_count() // get_world_size() // 2, 4)) + + ret = pool.map( + functools.partial(_cityscapes_files_to_dict, from_json=from_json, to_polygons=to_polygons), + files, + ) + logger.info("Loaded {} images from {}".format(len(ret), image_dir)) + + # Map cityscape ids to contiguous ids + from cityscapesscripts.helpers.labels import labels + + labels = [l for l in labels if l.hasInstances and not l.ignoreInEval] + dataset_id_to_contiguous_id = {l.id: idx for idx, l in enumerate(labels)} + for dict_per_image in ret: + for anno in dict_per_image["annotations"]: + anno["category_id"] = dataset_id_to_contiguous_id[anno["category_id"]] + return ret + + +def load_cityscapes_semantic(image_dir, gt_dir): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., "~/cityscapes/gtFine/train". + + Returns: + list[dict]: a list of dict, each has "file_name" and + "sem_seg_file_name". + """ + ret = [] + # gt_dir is small and contain many small files. make sense to fetch to local first + gt_dir = PathManager.get_local_path(gt_dir) + for image_file, _, label_file, json_file in _get_cityscapes_files(image_dir, gt_dir): + label_file = label_file.replace("labelIds", "labelTrainIds") + + with PathManager.open(json_file, "r") as f: + jsonobj = json.load(f) + ret.append( + { + "file_name": image_file, + "sem_seg_file_name": label_file, + "height": jsonobj["imgHeight"], + "width": jsonobj["imgWidth"], + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile( + ret[0]["sem_seg_file_name"] + ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa + return ret + + +def _cityscapes_files_to_dict(files, from_json, to_polygons): + """ + Parse cityscapes annotation files to a instance segmentation dataset dict. + + Args: + files (tuple): consists of (image_file, instance_id_file, label_id_file, json_file) + from_json (bool): whether to read annotations from the raw json file or the png files. + to_polygons (bool): whether to represent the segmentation as polygons + (COCO's format) instead of masks (cityscapes's format). + + Returns: + A dict in Detectron2 Dataset format. + """ + from cityscapesscripts.helpers.labels import id2label, name2label + + image_file, instance_id_file, _, json_file = files + + annos = [] + + if from_json: + from shapely.geometry import MultiPolygon, Polygon + + with PathManager.open(json_file, "r") as f: + jsonobj = json.load(f) + ret = { + "file_name": image_file, + "image_id": os.path.basename(image_file), + "height": jsonobj["imgHeight"], + "width": jsonobj["imgWidth"], + } + + # `polygons_union` contains the union of all valid polygons. + polygons_union = Polygon() + + # CityscapesScripts draw the polygons in sequential order + # and each polygon *overwrites* existing ones. See + # (https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/preparation/json2instanceImg.py) # noqa + # We use reverse order, and each polygon *avoids* early ones. + # This will resolve the ploygon overlaps in the same way as CityscapesScripts. + for obj in jsonobj["objects"][::-1]: + if "deleted" in obj: # cityscapes data format specific + continue + label_name = obj["label"] + + try: + label = name2label[label_name] + except KeyError: + if label_name.endswith("group"): # crowd area + label = name2label[label_name[: -len("group")]] + else: + raise + if label.id < 0: # cityscapes data format + continue + + # Cityscapes's raw annotations uses integer coordinates + # Therefore +0.5 here + poly_coord = np.asarray(obj["polygon"], dtype="f4") + 0.5 + # CityscapesScript uses PIL.ImageDraw.polygon to rasterize + # polygons for evaluation. This function operates in integer space + # and draws each pixel whose center falls into the polygon. + # Therefore it draws a polygon which is 0.5 "fatter" in expectation. + # We therefore dilate the input polygon by 0.5 as our input. + poly = Polygon(poly_coord).buffer(0.5, resolution=4) + + if not label.hasInstances or label.ignoreInEval: + # even if we won't store the polygon it still contributes to overlaps resolution + polygons_union = polygons_union.union(poly) + continue + + # Take non-overlapping part of the polygon + poly_wo_overlaps = poly.difference(polygons_union) + if poly_wo_overlaps.is_empty: + continue + polygons_union = polygons_union.union(poly) + + anno = {} + anno["iscrowd"] = label_name.endswith("group") + anno["category_id"] = label.id + + if isinstance(poly_wo_overlaps, Polygon): + poly_list = [poly_wo_overlaps] + elif isinstance(poly_wo_overlaps, MultiPolygon): + poly_list = poly_wo_overlaps.geoms + else: + raise NotImplementedError("Unknown geometric structure {}".format(poly_wo_overlaps)) + + poly_coord = [] + for poly_el in poly_list: + # COCO API can work only with exterior boundaries now, hence we store only them. + # TODO: store both exterior and interior boundaries once other parts of the + # codebase support holes in polygons. + poly_coord.append(list(chain(*poly_el.exterior.coords))) + anno["segmentation"] = poly_coord + (xmin, ymin, xmax, ymax) = poly_wo_overlaps.bounds + + anno["bbox"] = (xmin, ymin, xmax, ymax) + anno["bbox_mode"] = BoxMode.XYXY_ABS + + annos.append(anno) + else: + # See also the official annotation parsing scripts at + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/instances2dict.py # noqa + with PathManager.open(instance_id_file, "rb") as f: + inst_image = np.asarray(Image.open(f), order="F") + # ids < 24 are stuff labels (filtering them first is about 5% faster) + flattened_ids = np.unique(inst_image[inst_image >= 24]) + + ret = { + "file_name": image_file, + "image_id": os.path.basename(image_file), + "height": inst_image.shape[0], + "width": inst_image.shape[1], + } + + for instance_id in flattened_ids: + # For non-crowd annotations, instance_id // 1000 is the label_id + # Crowd annotations have <1000 instance ids + label_id = instance_id // 1000 if instance_id >= 1000 else instance_id + label = id2label[label_id] + if not label.hasInstances or label.ignoreInEval: + continue + + anno = {} + anno["iscrowd"] = instance_id < 1000 + anno["category_id"] = label.id + + mask = np.asarray(inst_image == instance_id, dtype=np.uint8, order="F") + + inds = np.nonzero(mask) + ymin, ymax = inds[0].min(), inds[0].max() + xmin, xmax = inds[1].min(), inds[1].max() + anno["bbox"] = (xmin, ymin, xmax, ymax) + if xmax <= xmin or ymax <= ymin: + continue + anno["bbox_mode"] = BoxMode.XYXY_ABS + if to_polygons: + # This conversion comes from D4809743 and D5171122, + # when Mask-RCNN was first developed. + contours = cv2.findContours(mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)[ + -2 + ] + polygons = [c.reshape(-1).tolist() for c in contours if len(c) >= 3] + # opencv's can produce invalid polygons + if len(polygons) == 0: + continue + anno["segmentation"] = polygons + else: + anno["segmentation"] = mask_util.encode(mask[:, :, None])[0] + annos.append(anno) + ret["annotations"] = annos + return ret + + +if __name__ == "__main__": + """ + Test the cityscapes dataset loader. + + Usage: + python -m detectron2.data.datasets.cityscapes \ + cityscapes/leftImg8bit/train cityscapes/gtFine/train + """ + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("image_dir") + parser.add_argument("gt_dir") + parser.add_argument("--type", choices=["instance", "semantic"], default="instance") + args = parser.parse_args() + from custom_detectron2.data.catalog import Metadata + from custom_detectron2.utils.visualizer import Visualizer + from cityscapesscripts.helpers.labels import labels + + logger = setup_logger(name=__name__) + + dirname = "cityscapes-data-vis" + os.makedirs(dirname, exist_ok=True) + + if args.type == "instance": + dicts = load_cityscapes_instances( + args.image_dir, args.gt_dir, from_json=True, to_polygons=True + ) + logger.info("Done loading {} samples.".format(len(dicts))) + + thing_classes = [k.name for k in labels if k.hasInstances and not k.ignoreInEval] + meta = Metadata().set(thing_classes=thing_classes) + + else: + dicts = load_cityscapes_semantic(args.image_dir, args.gt_dir) + logger.info("Done loading {} samples.".format(len(dicts))) + + stuff_classes = [k.name for k in labels if k.trainId != 255] + stuff_colors = [k.color for k in labels if k.trainId != 255] + meta = Metadata().set(stuff_classes=stuff_classes, stuff_colors=stuff_colors) + + for d in dicts: + img = np.array(Image.open(PathManager.open(d["file_name"], "rb"))) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + # cv2.imshow("a", vis.get_image()[:, :, ::-1]) + # cv2.waitKey() + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/custom_detectron2/data/datasets/cityscapes_panoptic.py b/custom_detectron2/data/datasets/cityscapes_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..efb8a30ce9c9cf62bc6e585dc029c9b003f6569c --- /dev/null +++ b/custom_detectron2/data/datasets/cityscapes_panoptic.py @@ -0,0 +1,187 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import json +import logging +import os + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog +from custom_detectron2.data.datasets.builtin_meta import CITYSCAPES_CATEGORIES +from custom_detectron2.utils.file_io import PathManager + +""" +This file contains functions to register the Cityscapes panoptic dataset to the DatasetCatalog. +""" + + +logger = logging.getLogger(__name__) + + +def get_cityscapes_panoptic_files(image_dir, gt_dir, json_info): + files = [] + # scan through the directory + cities = PathManager.ls(image_dir) + logger.info(f"{len(cities)} cities found in '{image_dir}'.") + image_dict = {} + for city in cities: + city_img_dir = os.path.join(image_dir, city) + for basename in PathManager.ls(city_img_dir): + image_file = os.path.join(city_img_dir, basename) + + suffix = "_leftImg8bit.png" + assert basename.endswith(suffix), basename + basename = os.path.basename(basename)[: -len(suffix)] + + image_dict[basename] = image_file + + for ann in json_info["annotations"]: + image_file = image_dict.get(ann["image_id"], None) + assert image_file is not None, "No image {} found for annotation {}".format( + ann["image_id"], ann["file_name"] + ) + label_file = os.path.join(gt_dir, ann["file_name"]) + segments_info = ann["segments_info"] + + files.append((image_file, label_file, segments_info)) + + assert len(files), "No images found in {}".format(image_dir) + assert PathManager.isfile(files[0][0]), files[0][0] + assert PathManager.isfile(files[0][1]), files[0][1] + return files + + +def load_cityscapes_panoptic(image_dir, gt_dir, gt_json, meta): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/cityscapes/leftImg8bit/train". + gt_dir (str): path to the raw annotations. e.g., + "~/cityscapes/gtFine/cityscapes_panoptic_train". + gt_json (str): path to the json file. e.g., + "~/cityscapes/gtFine/cityscapes_panoptic_train.json". + meta (dict): dictionary containing "thing_dataset_id_to_contiguous_id" + and "stuff_dataset_id_to_contiguous_id" to map category ids to + contiguous ids for training. + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + + def _convert_category_id(segment_info, meta): + if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: + segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + else: + segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + return segment_info + + assert os.path.exists( + gt_json + ), "Please run `python cityscapesscripts/preparation/createPanopticImgs.py` to generate label files." # noqa + with open(gt_json) as f: + json_info = json.load(f) + files = get_cityscapes_panoptic_files(image_dir, gt_dir, json_info) + ret = [] + for image_file, label_file, segments_info in files: + sem_label_file = ( + image_file.replace("leftImg8bit", "gtFine").split(".")[0] + "_labelTrainIds.png" + ) + segments_info = [_convert_category_id(x, meta) for x in segments_info] + ret.append( + { + "file_name": image_file, + "image_id": "_".join( + os.path.splitext(os.path.basename(image_file))[0].split("_")[:3] + ), + "sem_seg_file_name": sem_label_file, + "pan_seg_file_name": label_file, + "segments_info": segments_info, + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile( + ret[0]["sem_seg_file_name"] + ), "Please generate labelTrainIds.png with cityscapesscripts/preparation/createTrainIdLabelImgs.py" # noqa + assert PathManager.isfile( + ret[0]["pan_seg_file_name"] + ), "Please generate panoptic annotation with python cityscapesscripts/preparation/createPanopticImgs.py" # noqa + return ret + + +_RAW_CITYSCAPES_PANOPTIC_SPLITS = { + "cityscapes_fine_panoptic_train": ( + "cityscapes/leftImg8bit/train", + "cityscapes/gtFine/cityscapes_panoptic_train", + "cityscapes/gtFine/cityscapes_panoptic_train.json", + ), + "cityscapes_fine_panoptic_val": ( + "cityscapes/leftImg8bit/val", + "cityscapes/gtFine/cityscapes_panoptic_val", + "cityscapes/gtFine/cityscapes_panoptic_val.json", + ), + # "cityscapes_fine_panoptic_test": not supported yet +} + + +def register_all_cityscapes_panoptic(root): + meta = {} + # The following metadata maps contiguous id from [0, #thing categories + + # #stuff categories) to their names and colors. We have to replica of the + # same name and color under "thing_*" and "stuff_*" because the current + # visualization function in D2 handles thing and class classes differently + # due to some heuristic used in Panoptic FPN. We keep the same naming to + # enable reusing existing visualization functions. + thing_classes = [k["name"] for k in CITYSCAPES_CATEGORIES] + thing_colors = [k["color"] for k in CITYSCAPES_CATEGORIES] + stuff_classes = [k["name"] for k in CITYSCAPES_CATEGORIES] + stuff_colors = [k["color"] for k in CITYSCAPES_CATEGORIES] + + meta["thing_classes"] = thing_classes + meta["thing_colors"] = thing_colors + meta["stuff_classes"] = stuff_classes + meta["stuff_colors"] = stuff_colors + + # There are three types of ids in cityscapes panoptic segmentation: + # (1) category id: like semantic segmentation, it is the class id for each + # pixel. Since there are some classes not used in evaluation, the category + # id is not always contiguous and thus we have two set of category ids: + # - original category id: category id in the original dataset, mainly + # used for evaluation. + # - contiguous category id: [0, #classes), in order to train the classifier + # (2) instance id: this id is used to differentiate different instances from + # the same category. For "stuff" classes, the instance id is always 0; for + # "thing" classes, the instance id starts from 1 and 0 is reserved for + # ignored instances (e.g. crowd annotation). + # (3) panoptic id: this is the compact id that encode both category and + # instance id by: category_id * 1000 + instance_id. + thing_dataset_id_to_contiguous_id = {} + stuff_dataset_id_to_contiguous_id = {} + + for k in CITYSCAPES_CATEGORIES: + if k["isthing"] == 1: + thing_dataset_id_to_contiguous_id[k["id"]] = k["trainId"] + else: + stuff_dataset_id_to_contiguous_id[k["id"]] = k["trainId"] + + meta["thing_dataset_id_to_contiguous_id"] = thing_dataset_id_to_contiguous_id + meta["stuff_dataset_id_to_contiguous_id"] = stuff_dataset_id_to_contiguous_id + + for key, (image_dir, gt_dir, gt_json) in _RAW_CITYSCAPES_PANOPTIC_SPLITS.items(): + image_dir = os.path.join(root, image_dir) + gt_dir = os.path.join(root, gt_dir) + gt_json = os.path.join(root, gt_json) + + DatasetCatalog.register( + key, lambda x=image_dir, y=gt_dir, z=gt_json: load_cityscapes_panoptic(x, y, z, meta) + ) + MetadataCatalog.get(key).set( + panoptic_root=gt_dir, + image_root=image_dir, + panoptic_json=gt_json, + gt_dir=gt_dir.replace("cityscapes_panoptic_", ""), + evaluator_type="cityscapes_panoptic_seg", + ignore_label=255, + label_divisor=1000, + **meta, + ) diff --git a/custom_detectron2/data/datasets/coco.py b/custom_detectron2/data/datasets/coco.py new file mode 100644 index 0000000000000000000000000000000000000000..13d8bf49c454a0b29746e0c16794c12afb28ccf8 --- /dev/null +++ b/custom_detectron2/data/datasets/coco.py @@ -0,0 +1,539 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import datetime +import io +import json +import logging +import numpy as np +import os +import shutil +import custom_pycocotools.mask as mask_util +from fvcore.common.timer import Timer +from iopath.common.file_io import file_lock +from PIL import Image + +from custom_detectron2.structures import Boxes, BoxMode, PolygonMasks, RotatedBoxes +from custom_detectron2.utils.file_io import PathManager + +from .. import DatasetCatalog, MetadataCatalog + +""" +This file contains functions to parse COCO-format annotations into dicts in "Detectron2 format". +""" + + +logger = logging.getLogger(__name__) + +__all__ = ["load_coco_json", "load_sem_seg", "convert_to_coco_json", "register_coco_instances"] + + +def load_coco_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): + """ + Load a json file with COCO's instances annotation format. + Currently supports instance detection, instance segmentation, + and person keypoints annotations. + + Args: + json_file (str): full path to the json file in COCO instances annotation format. + image_root (str or path-like): the directory where the images in this json file exists. + dataset_name (str or None): the name of the dataset (e.g., coco_2017_train). + When provided, this function will also do the following: + + * Put "thing_classes" into the metadata associated with this dataset. + * Map the category ids into a contiguous range (needed by standard dataset format), + and add "thing_dataset_id_to_contiguous_id" to the metadata associated + with this dataset. + + This option should usually be provided, unless users need to load + the original json content and apply more processing manually. + extra_annotation_keys (list[str]): list of per-annotation keys that should also be + loaded into the dataset dict (besides "iscrowd", "bbox", "keypoints", + "category_id", "segmentation"). The values for these keys will be returned as-is. + For example, the densepose annotations are loaded in this way. + + Returns: + list[dict]: a list of dicts in Detectron2 standard dataset dicts format (See + `Using Custom Datasets `_ ) when `dataset_name` is not None. + If `dataset_name` is None, the returned `category_ids` may be + incontiguous and may not conform to the Detectron2 standard format. + + Notes: + 1. This function does not read the image files. + The results do not have the "image" field. + """ + from custom_pycocotools.coco import COCO + + timer = Timer() + json_file = PathManager.get_local_path(json_file) + with contextlib.redirect_stdout(io.StringIO()): + coco_api = COCO(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + + id_map = None + if dataset_name is not None: + meta = MetadataCatalog.get(dataset_name) + cat_ids = sorted(coco_api.getCatIds()) + cats = coco_api.loadCats(cat_ids) + # The categories in a custom json file may not be sorted. + thing_classes = [c["name"] for c in sorted(cats, key=lambda x: x["id"])] + meta.thing_classes = thing_classes + + # In COCO, certain category ids are artificially removed, + # and by convention they are always ignored. + # We deal with COCO's id issue and translate + # the category ids to contiguous ids in [0, 80). + + # It works by looking at the "categories" field in the json, therefore + # if users' own json also have incontiguous ids, we'll + # apply this mapping as well but print a warning. + if not (min(cat_ids) == 1 and max(cat_ids) == len(cat_ids)): + if "coco" not in dataset_name: + logger.warning( + """ +Category ids in annotations are not in [1, #categories]! We'll apply a mapping for you. +""" + ) + id_map = {v: i for i, v in enumerate(cat_ids)} + meta.thing_dataset_id_to_contiguous_id = id_map + + # sort indices for reproducible results + img_ids = sorted(coco_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = coco_api.loadImgs(img_ids) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. Example of anns[0]: + # [{'segmentation': [[192.81, + # 247.09, + # ... + # 219.03, + # 249.06]], + # 'area': 1035.749, + # 'iscrowd': 0, + # 'image_id': 1268, + # 'bbox': [192.81, 224.8, 74.73, 33.43], + # 'category_id': 16, + # 'id': 42986}, + # ...] + anns = [coco_api.imgToAnns[img_id] for img_id in img_ids] + total_num_valid_anns = sum([len(x) for x in anns]) + total_num_anns = len(coco_api.anns) + if total_num_valid_anns < total_num_anns: + logger.warning( + f"{json_file} contains {total_num_anns} annotations, but only " + f"{total_num_valid_anns} of them match to images in the file." + ) + + if "minival" not in json_file: + # The popular valminusminival & minival annotations for COCO2014 contain this bug. + # However the ratio of buggy annotations there is tiny and does not affect accuracy. + # Therefore we explicitly white-list them. + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique!".format( + json_file + ) + + imgs_anns = list(zip(imgs, anns)) + logger.info("Loaded {} images in COCO format from {}".format(len(imgs_anns), json_file)) + + dataset_dicts = [] + + ann_keys = ["iscrowd", "bbox", "keypoints", "category_id"] + (extra_annotation_keys or []) + + num_instances_without_valid_segmentation = 0 + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + record["file_name"] = os.path.join(image_root, img_dict["file_name"]) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + image_id = record["image_id"] = img_dict["id"] + + objs = [] + for anno in anno_dict_list: + # Check that the image_id in this annotation is the same as + # the image_id we're looking at. + # This fails only when the data parsing logic or the annotation file is buggy. + + # The original COCO valminusminival2014 & minival2014 annotation files + # actually contains bugs that, together with certain ways of using COCO API, + # can trigger this assertion. + assert anno["image_id"] == image_id + + assert anno.get("ignore", 0) == 0, '"ignore" in COCO json file is not supported.' + + obj = {key: anno[key] for key in ann_keys if key in anno} + if "bbox" in obj and len(obj["bbox"]) == 0: + raise ValueError( + f"One annotation of image {image_id} contains empty 'bbox' value! " + "This json does not have valid COCO format." + ) + + segm = anno.get("segmentation", None) + if segm: # either list[list[float]] or dict(RLE) + if isinstance(segm, dict): + if isinstance(segm["counts"], list): + # convert to compressed RLE + segm = mask_util.frPyObjects(segm, *segm["size"]) + else: + # filter out invalid polygons (< 3 points) + segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + if len(segm) == 0: + num_instances_without_valid_segmentation += 1 + continue # ignore this instance + obj["segmentation"] = segm + + keypts = anno.get("keypoints", None) + if keypts: # list[int] + for idx, v in enumerate(keypts): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # Therefore we assume the coordinates are "pixel indices" and + # add 0.5 to convert to floating point coordinates. + keypts[idx] = v + 0.5 + obj["keypoints"] = keypts + + obj["bbox_mode"] = BoxMode.XYWH_ABS + if id_map: + annotation_category_id = obj["category_id"] + try: + obj["category_id"] = id_map[annotation_category_id] + except KeyError as e: + raise KeyError( + f"Encountered category_id={annotation_category_id} " + "but this id does not exist in 'categories' of the json file." + ) from e + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + if num_instances_without_valid_segmentation > 0: + logger.warning( + "Filtered out {} instances without valid segmentation. ".format( + num_instances_without_valid_segmentation + ) + + "There might be issues in your dataset generation process. Please " + "check https://detectron2.readthedocs.io/en/latest/tutorials/datasets.html carefully" + ) + return dataset_dicts + + +def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg"): + """ + Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are + treated as ground truth annotations and all files under "image_root" with "image_ext" extension + as input images. Ground truth and input images are matched using file paths relative to + "gt_root" and "image_root" respectively without taking into account file extensions. + This works for COCO as well as some other datasets. + + Args: + gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation + annotations are stored as images with integer values in pixels that represent + corresponding semantic labels. + image_root (str): the directory where the input images are. + gt_ext (str): file extension for ground truth annotations. + image_ext (str): file extension for input images. + + Returns: + list[dict]: + a list of dicts in detectron2 standard format without instance-level + annotation. + + Notes: + 1. This function does not read the image and ground truth files. + The results do not have the "image" and "sem_seg" fields. + """ + + # We match input images with ground truth based on their relative filepaths (without file + # extensions) starting from 'image_root' and 'gt_root' respectively. + def file2id(folder_path, file_path): + # extract relative path starting from `folder_path` + image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path)) + # remove file extension + image_id = os.path.splitext(image_id)[0] + return image_id + + input_files = sorted( + (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)), + key=lambda file_path: file2id(image_root, file_path), + ) + gt_files = sorted( + (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)), + key=lambda file_path: file2id(gt_root, file_path), + ) + + assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root) + + # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images + if len(input_files) != len(gt_files): + logger.warn( + "Directory {} and {} has {} and {} files, respectively.".format( + image_root, gt_root, len(input_files), len(gt_files) + ) + ) + input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files] + gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files] + intersect = list(set(input_basenames) & set(gt_basenames)) + # sort, otherwise each worker may obtain a list[dict] in different order + intersect = sorted(intersect) + logger.warn("Will use their intersection of {} files.".format(len(intersect))) + input_files = [os.path.join(image_root, f + image_ext) for f in intersect] + gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect] + + logger.info( + "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root) + ) + + dataset_dicts = [] + for (img_path, gt_path) in zip(input_files, gt_files): + record = {} + record["file_name"] = img_path + record["sem_seg_file_name"] = gt_path + dataset_dicts.append(record) + + return dataset_dicts + + +def convert_to_coco_dict(dataset_name): + """ + Convert an instance detection/segmentation or keypoint detection dataset + in detectron2's standard format into COCO json format. + + Generic dataset description can be found here: + https://detectron2.readthedocs.io/tutorials/datasets.html#register-a-dataset + + COCO data format description can be found here: + http://cocodataset.org/#format-data + + Args: + dataset_name (str): + name of the source dataset + Must be registered in DatastCatalog and in detectron2's standard format. + Must have corresponding metadata "thing_classes" + Returns: + coco_dict: serializable dict in COCO json format + """ + + dataset_dicts = DatasetCatalog.get(dataset_name) + metadata = MetadataCatalog.get(dataset_name) + + # unmap the category mapping ids for COCO + if hasattr(metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = {v: k for k, v in metadata.thing_dataset_id_to_contiguous_id.items()} + reverse_id_mapper = lambda contiguous_id: reverse_id_mapping[contiguous_id] # noqa + else: + reverse_id_mapper = lambda contiguous_id: contiguous_id # noqa + + categories = [ + {"id": reverse_id_mapper(id), "name": name} + for id, name in enumerate(metadata.thing_classes) + ] + + logger.info("Converting dataset dicts into COCO format") + coco_images = [] + coco_annotations = [] + + for image_id, image_dict in enumerate(dataset_dicts): + coco_image = { + "id": image_dict.get("image_id", image_id), + "width": int(image_dict["width"]), + "height": int(image_dict["height"]), + "file_name": str(image_dict["file_name"]), + } + coco_images.append(coco_image) + + anns_per_image = image_dict.get("annotations", []) + for annotation in anns_per_image: + # create a new dict with only COCO fields + coco_annotation = {} + + # COCO requirement: XYWH box format for axis-align and XYWHA for rotated + bbox = annotation["bbox"] + if isinstance(bbox, np.ndarray): + if bbox.ndim != 1: + raise ValueError(f"bbox has to be 1-dimensional. Got shape={bbox.shape}.") + bbox = bbox.tolist() + if len(bbox) not in [4, 5]: + raise ValueError(f"bbox has to has length 4 or 5. Got {bbox}.") + from_bbox_mode = annotation["bbox_mode"] + to_bbox_mode = BoxMode.XYWH_ABS if len(bbox) == 4 else BoxMode.XYWHA_ABS + bbox = BoxMode.convert(bbox, from_bbox_mode, to_bbox_mode) + + # COCO requirement: instance area + if "segmentation" in annotation: + # Computing areas for instances by counting the pixels + segmentation = annotation["segmentation"] + # TODO: check segmentation type: RLE, BinaryMask or Polygon + if isinstance(segmentation, list): + polygons = PolygonMasks([segmentation]) + area = polygons.area()[0].item() + elif isinstance(segmentation, dict): # RLE + area = mask_util.area(segmentation).item() + else: + raise TypeError(f"Unknown segmentation type {type(segmentation)}!") + else: + # Computing areas using bounding boxes + if to_bbox_mode == BoxMode.XYWH_ABS: + bbox_xy = BoxMode.convert(bbox, to_bbox_mode, BoxMode.XYXY_ABS) + area = Boxes([bbox_xy]).area()[0].item() + else: + area = RotatedBoxes([bbox]).area()[0].item() + + if "keypoints" in annotation: + keypoints = annotation["keypoints"] # list[int] + for idx, v in enumerate(keypoints): + if idx % 3 != 2: + # COCO's segmentation coordinates are floating points in [0, H or W], + # but keypoint coordinates are integers in [0, H-1 or W-1] + # For COCO format consistency we substract 0.5 + # https://github.com/facebookresearch/detectron2/pull/175#issuecomment-551202163 + keypoints[idx] = v - 0.5 + if "num_keypoints" in annotation: + num_keypoints = annotation["num_keypoints"] + else: + num_keypoints = sum(kp > 0 for kp in keypoints[2::3]) + + # COCO requirement: + # linking annotations to images + # "id" field must start with 1 + coco_annotation["id"] = len(coco_annotations) + 1 + coco_annotation["image_id"] = coco_image["id"] + coco_annotation["bbox"] = [round(float(x), 3) for x in bbox] + coco_annotation["area"] = float(area) + coco_annotation["iscrowd"] = int(annotation.get("iscrowd", 0)) + coco_annotation["category_id"] = int(reverse_id_mapper(annotation["category_id"])) + + # Add optional fields + if "keypoints" in annotation: + coco_annotation["keypoints"] = keypoints + coco_annotation["num_keypoints"] = num_keypoints + + if "segmentation" in annotation: + seg = coco_annotation["segmentation"] = annotation["segmentation"] + if isinstance(seg, dict): # RLE + counts = seg["counts"] + if not isinstance(counts, str): + # make it json-serializable + seg["counts"] = counts.decode("ascii") + + coco_annotations.append(coco_annotation) + + logger.info( + "Conversion finished, " + f"#images: {len(coco_images)}, #annotations: {len(coco_annotations)}" + ) + + info = { + "date_created": str(datetime.datetime.now()), + "description": "Automatically generated COCO json file for Detectron2.", + } + coco_dict = {"info": info, "images": coco_images, "categories": categories, "licenses": None} + if len(coco_annotations) > 0: + coco_dict["annotations"] = coco_annotations + return coco_dict + + +def convert_to_coco_json(dataset_name, output_file, allow_cached=True): + """ + Converts dataset into COCO format and saves it to a json file. + dataset_name must be registered in DatasetCatalog and in detectron2's standard format. + + Args: + dataset_name: + reference from the config file to the catalogs + must be registered in DatasetCatalog and in detectron2's standard format + output_file: path of json file that will be saved to + allow_cached: if json file is already present then skip conversion + """ + + # TODO: The dataset or the conversion script *may* change, + # a checksum would be useful for validating the cached data + + PathManager.mkdirs(os.path.dirname(output_file)) + with file_lock(output_file): + if PathManager.exists(output_file) and allow_cached: + logger.warning( + f"Using previously cached COCO format annotations at '{output_file}'. " + "You need to clear the cache file if your dataset has been modified." + ) + else: + logger.info(f"Converting annotations of dataset '{dataset_name}' to COCO format ...)") + coco_dict = convert_to_coco_dict(dataset_name) + + logger.info(f"Caching COCO format annotations at '{output_file}' ...") + tmp_file = output_file + ".tmp" + with PathManager.open(tmp_file, "w") as f: + json.dump(coco_dict, f) + shutil.move(tmp_file, output_file) + + +def register_coco_instances(name, metadata, json_file, image_root): + """ + Register a dataset in COCO's json annotation format for + instance detection, instance segmentation and keypoint detection. + (i.e., Type 1 and 2 in http://cocodataset.org/#format-data. + `instances*.json` and `person_keypoints*.json` in the dataset). + + This is an example of how to register a new dataset. + You can do something similar to this function, to register new datasets. + + Args: + name (str): the name that identifies a dataset, e.g. "coco_2014_train". + metadata (dict): extra metadata associated with this dataset. You can + leave it as an empty dict. + json_file (str): path to the json instance annotation file. + image_root (str or path-like): directory which contains all the images. + """ + assert isinstance(name, str), name + assert isinstance(json_file, (str, os.PathLike)), json_file + assert isinstance(image_root, (str, os.PathLike)), image_root + # 1. register a function which returns dicts + DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name)) + + # 2. Optionally, add metadata about this dataset, + # since they might be useful in evaluation, visualization or logging + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="coco", **metadata + ) + + +if __name__ == "__main__": + """ + Test the COCO json dataset loader. + + Usage: + python -m detectron2.data.datasets.coco \ + path/to/json path/to/image_root dataset_name + + "dataset_name" can be "coco_2014_minival_100", or other + pre-registered ones + """ + from custom_detectron2.utils.logger import setup_logger + from custom_detectron2.utils.visualizer import Visualizer + import custom_detectron2.data.datasets # noqa # add pre-defined metadata + import sys + + logger = setup_logger(name=__name__) + assert sys.argv[3] in DatasetCatalog.list() + meta = MetadataCatalog.get(sys.argv[3]) + + dicts = load_coco_json(sys.argv[1], sys.argv[2], sys.argv[3]) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "coco-data-vis" + os.makedirs(dirname, exist_ok=True) + for d in dicts: + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/custom_detectron2/data/datasets/coco_panoptic.py b/custom_detectron2/data/datasets/coco_panoptic.py new file mode 100644 index 0000000000000000000000000000000000000000..e194be84d25c20fa83e1f93a6bf81bd6fd991970 --- /dev/null +++ b/custom_detectron2/data/datasets/coco_panoptic.py @@ -0,0 +1,228 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import json +import os + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog +from custom_detectron2.utils.file_io import PathManager + +from .coco import load_coco_json, load_sem_seg + +__all__ = ["register_coco_panoptic", "register_coco_panoptic_separated"] + + +def load_coco_panoptic_json(json_file, image_dir, gt_dir, meta): + """ + Args: + image_dir (str): path to the raw dataset. e.g., "~/coco/train2017". + gt_dir (str): path to the raw annotations. e.g., "~/coco/panoptic_train2017". + json_file (str): path to the json file. e.g., "~/coco/annotations/panoptic_train2017.json". + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + """ + + def _convert_category_id(segment_info, meta): + if segment_info["category_id"] in meta["thing_dataset_id_to_contiguous_id"]: + segment_info["category_id"] = meta["thing_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = True + else: + segment_info["category_id"] = meta["stuff_dataset_id_to_contiguous_id"][ + segment_info["category_id"] + ] + segment_info["isthing"] = False + return segment_info + + with PathManager.open(json_file) as f: + json_info = json.load(f) + + ret = [] + for ann in json_info["annotations"]: + image_id = int(ann["image_id"]) + # TODO: currently we assume image and label has the same filename but + # different extension, and images have extension ".jpg" for COCO. Need + # to make image extension a user-provided argument if we extend this + # function to support other COCO-like datasets. + image_file = os.path.join(image_dir, os.path.splitext(ann["file_name"])[0] + ".jpg") + label_file = os.path.join(gt_dir, ann["file_name"]) + segments_info = [_convert_category_id(x, meta) for x in ann["segments_info"]] + ret.append( + { + "file_name": image_file, + "image_id": image_id, + "pan_seg_file_name": label_file, + "segments_info": segments_info, + } + ) + assert len(ret), f"No images found in {image_dir}!" + assert PathManager.isfile(ret[0]["file_name"]), ret[0]["file_name"] + assert PathManager.isfile(ret[0]["pan_seg_file_name"]), ret[0]["pan_seg_file_name"] + return ret + + +def register_coco_panoptic( + name, metadata, image_root, panoptic_root, panoptic_json, instances_json=None +): + """ + Register a "standard" version of COCO panoptic segmentation dataset named `name`. + The dictionaries in this registered dataset follows detectron2's standard format. + Hence it's called "standard". + + Args: + name (str): the name that identifies a dataset, + e.g. "coco_2017_train_panoptic" + metadata (dict): extra metadata associated with this dataset. + image_root (str): directory which contains all the images + panoptic_root (str): directory which contains panoptic annotation images in COCO format + panoptic_json (str): path to the json panoptic annotation file in COCO format + sem_seg_root (none): not used, to be consistent with + `register_coco_panoptic_separated`. + instances_json (str): path to the json instance annotation file + """ + panoptic_name = name + DatasetCatalog.register( + panoptic_name, + lambda: load_coco_panoptic_json(panoptic_json, image_root, panoptic_root, metadata), + ) + MetadataCatalog.get(panoptic_name).set( + panoptic_root=panoptic_root, + image_root=image_root, + panoptic_json=panoptic_json, + json_file=instances_json, + evaluator_type="coco_panoptic_seg", + ignore_label=255, + label_divisor=1000, + **metadata, + ) + + +def register_coco_panoptic_separated( + name, metadata, image_root, panoptic_root, panoptic_json, sem_seg_root, instances_json +): + """ + Register a "separated" version of COCO panoptic segmentation dataset named `name`. + The annotations in this registered dataset will contain both instance annotations and + semantic annotations, each with its own contiguous ids. Hence it's called "separated". + + It follows the setting used by the PanopticFPN paper: + + 1. The instance annotations directly come from polygons in the COCO + instances annotation task, rather than from the masks in the COCO panoptic annotations. + + The two format have small differences: + Polygons in the instance annotations may have overlaps. + The mask annotations are produced by labeling the overlapped polygons + with depth ordering. + + 2. The semantic annotations are converted from panoptic annotations, where + all "things" are assigned a semantic id of 0. + All semantic categories will therefore have ids in contiguous + range [1, #stuff_categories]. + + This function will also register a pure semantic segmentation dataset + named ``name + '_stuffonly'``. + + Args: + name (str): the name that identifies a dataset, + e.g. "coco_2017_train_panoptic" + metadata (dict): extra metadata associated with this dataset. + image_root (str): directory which contains all the images + panoptic_root (str): directory which contains panoptic annotation images + panoptic_json (str): path to the json panoptic annotation file + sem_seg_root (str): directory which contains all the ground truth segmentation annotations. + instances_json (str): path to the json instance annotation file + """ + panoptic_name = name + "_separated" + DatasetCatalog.register( + panoptic_name, + lambda: merge_to_panoptic( + load_coco_json(instances_json, image_root, panoptic_name), + load_sem_seg(sem_seg_root, image_root), + ), + ) + MetadataCatalog.get(panoptic_name).set( + panoptic_root=panoptic_root, + image_root=image_root, + panoptic_json=panoptic_json, + sem_seg_root=sem_seg_root, + json_file=instances_json, # TODO rename + evaluator_type="coco_panoptic_seg", + ignore_label=255, + **metadata, + ) + + semantic_name = name + "_stuffonly" + DatasetCatalog.register(semantic_name, lambda: load_sem_seg(sem_seg_root, image_root)) + MetadataCatalog.get(semantic_name).set( + sem_seg_root=sem_seg_root, + image_root=image_root, + evaluator_type="sem_seg", + ignore_label=255, + **metadata, + ) + + +def merge_to_panoptic(detection_dicts, sem_seg_dicts): + """ + Create dataset dicts for panoptic segmentation, by + merging two dicts using "file_name" field to match their entries. + + Args: + detection_dicts (list[dict]): lists of dicts for object detection or instance segmentation. + sem_seg_dicts (list[dict]): lists of dicts for semantic segmentation. + + Returns: + list[dict] (one per input image): Each dict contains all (key, value) pairs from dicts in + both detection_dicts and sem_seg_dicts that correspond to the same image. + The function assumes that the same key in different dicts has the same value. + """ + results = [] + sem_seg_file_to_entry = {x["file_name"]: x for x in sem_seg_dicts} + assert len(sem_seg_file_to_entry) > 0 + + for det_dict in detection_dicts: + dic = copy.copy(det_dict) + dic.update(sem_seg_file_to_entry[dic["file_name"]]) + results.append(dic) + return results + + +if __name__ == "__main__": + """ + Test the COCO panoptic dataset loader. + + Usage: + python -m detectron2.data.datasets.coco_panoptic \ + path/to/image_root path/to/panoptic_root path/to/panoptic_json dataset_name 10 + + "dataset_name" can be "coco_2017_train_panoptic", or other + pre-registered ones + """ + from custom_detectron2.utils.logger import setup_logger + from custom_detectron2.utils.visualizer import Visualizer + import custom_detectron2.data.datasets # noqa # add pre-defined metadata + import sys + from PIL import Image + import numpy as np + + logger = setup_logger(name=__name__) + assert sys.argv[4] in DatasetCatalog.list() + meta = MetadataCatalog.get(sys.argv[4]) + + dicts = load_coco_panoptic_json(sys.argv[3], sys.argv[1], sys.argv[2], meta.as_dict()) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "coco-data-vis" + os.makedirs(dirname, exist_ok=True) + num_imgs_to_vis = int(sys.argv[5]) + for i, d in enumerate(dicts): + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) + if i + 1 >= num_imgs_to_vis: + break diff --git a/custom_detectron2/data/datasets/lvis.py b/custom_detectron2/data/datasets/lvis.py new file mode 100644 index 0000000000000000000000000000000000000000..337917c1e053c97163a300022702afb8ec2215c5 --- /dev/null +++ b/custom_detectron2/data/datasets/lvis.py @@ -0,0 +1,241 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import os +from fvcore.common.timer import Timer + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog +from custom_detectron2.structures import BoxMode +from custom_detectron2.utils.file_io import PathManager + +from .builtin_meta import _get_coco_instances_meta +from .lvis_v0_5_categories import LVIS_CATEGORIES as LVIS_V0_5_CATEGORIES +from .lvis_v1_categories import LVIS_CATEGORIES as LVIS_V1_CATEGORIES +from .lvis_v1_category_image_count import LVIS_CATEGORY_IMAGE_COUNT as LVIS_V1_CATEGORY_IMAGE_COUNT + +""" +This file contains functions to parse LVIS-format annotations into dicts in the +"Detectron2 format". +""" + +logger = logging.getLogger(__name__) + +__all__ = ["load_lvis_json", "register_lvis_instances", "get_lvis_instances_meta"] + + +def register_lvis_instances(name, metadata, json_file, image_root): + """ + Register a dataset in LVIS's json annotation format for instance detection and segmentation. + + Args: + name (str): a name that identifies the dataset, e.g. "lvis_v0.5_train". + metadata (dict): extra metadata associated with this dataset. It can be an empty dict. + json_file (str): path to the json instance annotation file. + image_root (str or path-like): directory which contains all the images. + """ + DatasetCatalog.register(name, lambda: load_lvis_json(json_file, image_root, name)) + MetadataCatalog.get(name).set( + json_file=json_file, image_root=image_root, evaluator_type="lvis", **metadata + ) + + +def load_lvis_json(json_file, image_root, dataset_name=None, extra_annotation_keys=None): + """ + Load a json file in LVIS's annotation format. + + Args: + json_file (str): full path to the LVIS json annotation file. + image_root (str): the directory where the images in this json file exists. + dataset_name (str): the name of the dataset (e.g., "lvis_v0.5_train"). + If provided, this function will put "thing_classes" into the metadata + associated with this dataset. + extra_annotation_keys (list[str]): list of per-annotation keys that should also be + loaded into the dataset dict (besides "bbox", "bbox_mode", "category_id", + "segmentation"). The values for these keys will be returned as-is. + + Returns: + list[dict]: a list of dicts in Detectron2 standard format. (See + `Using Custom Datasets `_ ) + + Notes: + 1. This function does not read the image files. + The results do not have the "image" field. + """ + from lvis import LVIS + + json_file = PathManager.get_local_path(json_file) + + timer = Timer() + lvis_api = LVIS(json_file) + if timer.seconds() > 1: + logger.info("Loading {} takes {:.2f} seconds.".format(json_file, timer.seconds())) + + if dataset_name is not None: + meta = get_lvis_instances_meta(dataset_name) + MetadataCatalog.get(dataset_name).set(**meta) + + # sort indices for reproducible results + img_ids = sorted(lvis_api.imgs.keys()) + # imgs is a list of dicts, each looks something like: + # {'license': 4, + # 'url': 'http://farm6.staticflickr.com/5454/9413846304_881d5e5c3b_z.jpg', + # 'file_name': 'COCO_val2014_000000001268.jpg', + # 'height': 427, + # 'width': 640, + # 'date_captured': '2013-11-17 05:57:24', + # 'id': 1268} + imgs = lvis_api.load_imgs(img_ids) + # anns is a list[list[dict]], where each dict is an annotation + # record for an object. The inner list enumerates the objects in an image + # and the outer list enumerates over images. Example of anns[0]: + # [{'segmentation': [[192.81, + # 247.09, + # ... + # 219.03, + # 249.06]], + # 'area': 1035.749, + # 'image_id': 1268, + # 'bbox': [192.81, 224.8, 74.73, 33.43], + # 'category_id': 16, + # 'id': 42986}, + # ...] + anns = [lvis_api.img_ann_map[img_id] for img_id in img_ids] + + # Sanity check that each annotation has a unique id + ann_ids = [ann["id"] for anns_per_image in anns for ann in anns_per_image] + assert len(set(ann_ids)) == len(ann_ids), "Annotation ids in '{}' are not unique".format( + json_file + ) + + imgs_anns = list(zip(imgs, anns)) + + logger.info("Loaded {} images in the LVIS format from {}".format(len(imgs_anns), json_file)) + + if extra_annotation_keys: + logger.info( + "The following extra annotation keys will be loaded: {} ".format(extra_annotation_keys) + ) + else: + extra_annotation_keys = [] + + def get_file_name(img_root, img_dict): + # Determine the path including the split folder ("train2017", "val2017", "test2017") from + # the coco_url field. Example: + # 'coco_url': 'http://images.cocodataset.org/train2017/000000155379.jpg' + split_folder, file_name = img_dict["coco_url"].split("/")[-2:] + return os.path.join(img_root + split_folder, file_name) + + dataset_dicts = [] + + for (img_dict, anno_dict_list) in imgs_anns: + record = {} + record["file_name"] = get_file_name(image_root, img_dict) + record["height"] = img_dict["height"] + record["width"] = img_dict["width"] + record["not_exhaustive_category_ids"] = img_dict.get("not_exhaustive_category_ids", []) + record["neg_category_ids"] = img_dict.get("neg_category_ids", []) + image_id = record["image_id"] = img_dict["id"] + + objs = [] + for anno in anno_dict_list: + # Check that the image_id in this annotation is the same as + # the image_id we're looking at. + # This fails only when the data parsing logic or the annotation file is buggy. + assert anno["image_id"] == image_id + obj = {"bbox": anno["bbox"], "bbox_mode": BoxMode.XYWH_ABS} + # LVIS data loader can be used to load COCO dataset categories. In this case `meta` + # variable will have a field with COCO-specific category mapping. + if dataset_name is not None and "thing_dataset_id_to_contiguous_id" in meta: + obj["category_id"] = meta["thing_dataset_id_to_contiguous_id"][anno["category_id"]] + else: + obj["category_id"] = anno["category_id"] - 1 # Convert 1-indexed to 0-indexed + segm = anno["segmentation"] # list[list[float]] + # filter out invalid polygons (< 3 points) + valid_segm = [poly for poly in segm if len(poly) % 2 == 0 and len(poly) >= 6] + assert len(segm) == len( + valid_segm + ), "Annotation contains an invalid polygon with < 3 points" + assert len(segm) > 0 + obj["segmentation"] = segm + for extra_ann_key in extra_annotation_keys: + obj[extra_ann_key] = anno[extra_ann_key] + objs.append(obj) + record["annotations"] = objs + dataset_dicts.append(record) + + return dataset_dicts + + +def get_lvis_instances_meta(dataset_name): + """ + Load LVIS metadata. + + Args: + dataset_name (str): LVIS dataset name without the split name (e.g., "lvis_v0.5"). + + Returns: + dict: LVIS metadata with keys: thing_classes + """ + if "cocofied" in dataset_name: + return _get_coco_instances_meta() + if "v0.5" in dataset_name: + return _get_lvis_instances_meta_v0_5() + elif "v1" in dataset_name: + return _get_lvis_instances_meta_v1() + raise ValueError("No built-in metadata for dataset {}".format(dataset_name)) + + +def _get_lvis_instances_meta_v0_5(): + assert len(LVIS_V0_5_CATEGORIES) == 1230 + cat_ids = [k["id"] for k in LVIS_V0_5_CATEGORIES] + assert min(cat_ids) == 1 and max(cat_ids) == len( + cat_ids + ), "Category ids are not in [1, #categories], as expected" + # Ensure that the category list is sorted by id + lvis_categories = sorted(LVIS_V0_5_CATEGORIES, key=lambda x: x["id"]) + thing_classes = [k["synonyms"][0] for k in lvis_categories] + meta = {"thing_classes": thing_classes} + return meta + + +def _get_lvis_instances_meta_v1(): + assert len(LVIS_V1_CATEGORIES) == 1203 + cat_ids = [k["id"] for k in LVIS_V1_CATEGORIES] + assert min(cat_ids) == 1 and max(cat_ids) == len( + cat_ids + ), "Category ids are not in [1, #categories], as expected" + # Ensure that the category list is sorted by id + lvis_categories = sorted(LVIS_V1_CATEGORIES, key=lambda x: x["id"]) + thing_classes = [k["synonyms"][0] for k in lvis_categories] + meta = {"thing_classes": thing_classes, "class_image_count": LVIS_V1_CATEGORY_IMAGE_COUNT} + return meta + + +if __name__ == "__main__": + """ + Test the LVIS json dataset loader. + + Usage: + python -m detectron2.data.datasets.lvis \ + path/to/json path/to/image_root dataset_name vis_limit + """ + import sys + import numpy as np + from custom_detectron2.utils.logger import setup_logger + from PIL import Image + import custom_detectron2.data.datasets # noqa # add pre-defined metadata + from custom_detectron2.utils.visualizer import Visualizer + + logger = setup_logger(name=__name__) + meta = MetadataCatalog.get(sys.argv[3]) + + dicts = load_lvis_json(sys.argv[1], sys.argv[2], sys.argv[3]) + logger.info("Done loading {} samples.".format(len(dicts))) + + dirname = "lvis-data-vis" + os.makedirs(dirname, exist_ok=True) + for d in dicts[: int(sys.argv[4])]: + img = np.array(Image.open(d["file_name"])) + visualizer = Visualizer(img, metadata=meta) + vis = visualizer.draw_dataset_dict(d) + fpath = os.path.join(dirname, os.path.basename(d["file_name"])) + vis.save(fpath) diff --git a/custom_detectron2/data/datasets/lvis_v0_5_categories.py b/custom_detectron2/data/datasets/lvis_v0_5_categories.py new file mode 100644 index 0000000000000000000000000000000000000000..d3dab6198da614937b08682f4c9edf52bdf1d236 --- /dev/null +++ b/custom_detectron2/data/datasets/lvis_v0_5_categories.py @@ -0,0 +1,13 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v0.5_val.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["image_count"] +# del x["instance_count"] +# LVIS_CATEGORIES = repr(c) + " # noqa" + +# fmt: off +LVIS_CATEGORIES = [{'frequency': 'r', 'id': 1, 'synset': 'acorn.n.01', 'synonyms': ['acorn'], 'def': 'nut from an oak tree', 'name': 'acorn'}, {'frequency': 'c', 'id': 2, 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'id': 3, 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'id': 4, 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'c', 'id': 5, 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'id': 6, 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'r', 'id': 7, 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'id': 8, 'synset': 'almond.n.02', 'synonyms': ['almond'], 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'id': 9, 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'r', 'id': 10, 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'id': 11, 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'id': 12, 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'id': 13, 'synset': 'apple.n.01', 'synonyms': ['apple'], 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'id': 14, 'synset': 'apple_juice.n.01', 'synonyms': ['apple_juice'], 'def': 'the juice of apples', 'name': 'apple_juice'}, {'frequency': 'r', 'id': 15, 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'id': 16, 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'id': 17, 'synset': 'apron.n.01', 'synonyms': ['apron'], 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'id': 18, 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'c', 'id': 19, 'synset': 'armband.n.02', 'synonyms': ['armband'], 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'id': 20, 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'id': 21, 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'id': 22, 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'id': 23, 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'id': 24, 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'id': 25, 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'id': 26, 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'id': 27, 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'c', 'id': 28, 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'id': 29, 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'id': 30, 'synset': 'awning.n.01', 'synonyms': ['awning'], 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'id': 31, 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'f', 'id': 32, 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'id': 33, 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'id': 34, 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'id': 35, 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'id': 36, 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'id': 37, 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'id': 38, 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'id': 39, 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'id': 40, 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'id': 41, 'synset': 'ball.n.06', 'synonyms': ['ball'], 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'id': 42, 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'id': 43, 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'id': 44, 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'id': 45, 'synset': 'banana.n.02', 'synonyms': ['banana'], 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'r', 'id': 46, 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'id': 47, 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'c', 'id': 48, 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'id': 49, 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'id': 50, 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'id': 51, 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'id': 52, 'synset': 'barge.n.01', 'synonyms': ['barge'], 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'id': 53, 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'id': 54, 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'id': 55, 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'id': 56, 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'id': 57, 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'id': 58, 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'id': 59, 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'id': 60, 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'id': 61, 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'id': 62, 'synset': 'basket.n.03', 'synonyms': ['basketball_hoop'], 'def': 'metal hoop supporting a net through which players try to throw the basketball', 'name': 'basketball_hoop'}, {'frequency': 'c', 'id': 63, 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'id': 64, 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'r', 'id': 65, 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'id': 66, 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'id': 67, 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'id': 68, 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'id': 69, 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'id': 70, 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'id': 71, 'synset': 'battery.n.02', 'synonyms': ['battery'], 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'id': 72, 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'id': 73, 'synset': 'bead.n.01', 'synonyms': ['bead'], 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'r', 'id': 74, 'synset': 'beaker.n.01', 'synonyms': ['beaker'], 'def': 'a flatbottomed jar made of glass or plastic; used for chemistry', 'name': 'beaker'}, {'frequency': 'c', 'id': 75, 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'id': 76, 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'id': 77, 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'id': 78, 'synset': 'bear.n.01', 'synonyms': ['bear'], 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'id': 79, 'synset': 'bed.n.01', 'synonyms': ['bed'], 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'c', 'id': 80, 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'id': 81, 'synset': 'beef.n.01', 'synonyms': ['cow'], 'def': 'cattle that are reared for their meat', 'name': 'cow'}, {'frequency': 'c', 'id': 82, 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'id': 83, 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'id': 84, 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'id': 85, 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'id': 86, 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'id': 87, 'synset': 'bell.n.01', 'synonyms': ['bell'], 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'id': 88, 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'id': 89, 'synset': 'belt.n.02', 'synonyms': ['belt'], 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'id': 90, 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'id': 91, 'synset': 'bench.n.01', 'synonyms': ['bench'], 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'id': 92, 'synset': 'beret.n.01', 'synonyms': ['beret'], 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'id': 93, 'synset': 'bib.n.02', 'synonyms': ['bib'], 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'id': 94, 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'id': 95, 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'id': 96, 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'c', 'id': 97, 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'id': 98, 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'id': 99, 'synset': 'bird.n.01', 'synonyms': ['bird'], 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'r', 'id': 100, 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'r', 'id': 101, 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'id': 102, 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'id': 103, 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'id': 104, 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'id': 105, 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'id': 106, 'synset': 'biscuit.n.01', 'synonyms': ['biscuit_(bread)'], 'def': 'small round bread leavened with baking-powder or soda', 'name': 'biscuit_(bread)'}, {'frequency': 'r', 'id': 107, 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'id': 108, 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'id': 109, 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'id': 110, 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'id': 111, 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'id': 112, 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'id': 113, 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'c', 'id': 114, 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'c', 'id': 115, 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'id': 116, 'synset': 'boar.n.02', 'synonyms': ['boar'], 'def': 'an uncastrated male hog', 'name': 'boar'}, {'frequency': 'r', 'id': 117, 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'id': 118, 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'c', 'id': 119, 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'r', 'id': 120, 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'id': 121, 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'id': 122, 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'id': 123, 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'id': 124, 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'id': 125, 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'id': 126, 'synset': 'book.n.01', 'synonyms': ['book'], 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'r', 'id': 127, 'synset': 'book_bag.n.01', 'synonyms': ['book_bag'], 'def': 'a bag in which students carry their books', 'name': 'book_bag'}, {'frequency': 'c', 'id': 128, 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'id': 129, 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'id': 130, 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'id': 131, 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'id': 132, 'synset': 'boot.n.01', 'synonyms': ['boot'], 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'id': 133, 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'id': 134, 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'id': 135, 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'id': 136, 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'id': 137, 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'id': 138, 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'id': 139, 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'id': 140, 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'id': 141, 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'id': 142, 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'r', 'id': 143, 'synset': 'bowling_pin.n.01', 'synonyms': ['bowling_pin'], 'def': 'a club-shaped wooden object used in bowling', 'name': 'bowling_pin'}, {'frequency': 'r', 'id': 144, 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'id': 145, 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'id': 146, 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'id': 147, 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'id': 148, 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'id': 149, 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'r', 'id': 150, 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'c', 'id': 151, 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'id': 152, 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'c', 'id': 153, 'synset': 'bristle_brush.n.01', 'synonyms': ['bristle_brush'], 'def': 'a brush that is made with the short stiff hairs of an animal or plant', 'name': 'bristle_brush'}, {'frequency': 'f', 'id': 154, 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'id': 155, 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'id': 156, 'synset': 'broom.n.01', 'synonyms': ['broom'], 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'id': 157, 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'id': 158, 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'id': 159, 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'id': 160, 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'id': 161, 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'id': 162, 'synset': 'bull.n.11', 'synonyms': ['bull'], 'def': 'mature male cow', 'name': 'bull'}, {'frequency': 'r', 'id': 163, 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'id': 164, 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'id': 165, 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'id': 166, 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'id': 167, 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'id': 168, 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'r', 'id': 169, 'synset': 'bully_beef.n.01', 'synonyms': ['corned_beef', 'corn_beef'], 'def': 'beef cured or pickled in brine', 'name': 'corned_beef'}, {'frequency': 'f', 'id': 170, 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'id': 171, 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'id': 172, 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'id': 173, 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'id': 174, 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'id': 175, 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'c', 'id': 176, 'synset': 'butcher_knife.n.01', 'synonyms': ['butcher_knife'], 'def': 'a large sharp knife for cutting or trimming meat', 'name': 'butcher_knife'}, {'frequency': 'c', 'id': 177, 'synset': 'butter.n.01', 'synonyms': ['butter'], 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'id': 178, 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'id': 179, 'synset': 'button.n.01', 'synonyms': ['button'], 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'id': 180, 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'id': 181, 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'r', 'id': 182, 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'id': 183, 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'id': 184, 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'id': 185, 'synset': 'cake.n.03', 'synonyms': ['cake'], 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'id': 186, 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'id': 187, 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'id': 188, 'synset': 'calf.n.01', 'synonyms': ['calf'], 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'id': 189, 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'id': 190, 'synset': 'camel.n.01', 'synonyms': ['camel'], 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'id': 191, 'synset': 'camera.n.01', 'synonyms': ['camera'], 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'id': 192, 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'id': 193, 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'id': 194, 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'id': 195, 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'r', 'id': 196, 'synset': 'candelabrum.n.01', 'synonyms': ['candelabrum', 'candelabra'], 'def': 'branched candlestick; ornamental; has several lights', 'name': 'candelabrum'}, {'frequency': 'f', 'id': 197, 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'id': 198, 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'id': 199, 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'id': 200, 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'id': 201, 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'id': 202, 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'r', 'id': 203, 'synset': 'cannon.n.02', 'synonyms': ['cannon'], 'def': 'heavy gun fired from a tank', 'name': 'cannon'}, {'frequency': 'c', 'id': 204, 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'r', 'id': 205, 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'id': 206, 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'c', 'id': 207, 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'id': 208, 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'r', 'id': 209, 'synset': 'cape.n.02', 'synonyms': ['cape'], 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'id': 210, 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'id': 211, 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'id': 212, 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'def': 'a wheeled vehicle adapted to the rails of railroad', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'id': 213, 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'id': 214, 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'id': 215, 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'id': 216, 'synset': 'card.n.03', 'synonyms': ['card'], 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'r', 'id': 217, 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'id': 218, 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'id': 219, 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'id': 220, 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'id': 221, 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'c', 'id': 222, 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'id': 223, 'synset': 'cart.n.01', 'synonyms': ['cart'], 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'id': 224, 'synset': 'carton.n.02', 'synonyms': ['carton'], 'def': 'a box made of cardboard; opens by flaps on top', 'name': 'carton'}, {'frequency': 'c', 'id': 225, 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'id': 226, 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'id': 227, 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'id': 228, 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'id': 229, 'synset': 'cat.n.01', 'synonyms': ['cat'], 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'c', 'id': 230, 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'r', 'id': 231, 'synset': 'caviar.n.01', 'synonyms': ['caviar', 'caviare'], 'def': "salted roe of sturgeon or other large fish; usually served as an hors d'oeuvre", 'name': 'caviar'}, {'frequency': 'c', 'id': 232, 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'id': 233, 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'c', 'id': 234, 'synset': 'celery.n.01', 'synonyms': ['celery'], 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'id': 235, 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'id': 236, 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'id': 237, 'synset': 'chair.n.01', 'synonyms': ['chair'], 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'id': 238, 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'id': 239, 'synset': 'champagne.n.01', 'synonyms': ['champagne'], 'def': 'a white sparkling wine produced in Champagne or resembling that produced there', 'name': 'champagne'}, {'frequency': 'f', 'id': 240, 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'id': 241, 'synset': 'chap.n.04', 'synonyms': ['chap'], 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'id': 242, 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'id': 243, 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'id': 244, 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'id': 245, 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'r', 'id': 246, 'synset': 'chest_of_drawers.n.01', 'synonyms': ['chest_of_drawers_(furniture)', 'bureau_(furniture)', 'chest_(furniture)'], 'def': 'furniture with drawers for keeping clothes', 'name': 'chest_of_drawers_(furniture)'}, {'frequency': 'c', 'id': 247, 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'id': 248, 'synset': 'chicken_wire.n.01', 'synonyms': ['chicken_wire'], 'def': 'a galvanized wire network with a hexagonal mesh; used to build fences', 'name': 'chicken_wire'}, {'frequency': 'r', 'id': 249, 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'r', 'id': 250, 'synset': 'chihuahua.n.03', 'synonyms': ['Chihuahua'], 'def': 'an old breed of tiny short-haired dog with protruding eyes from Mexico', 'name': 'Chihuahua'}, {'frequency': 'r', 'id': 251, 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'id': 252, 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'id': 253, 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'id': 254, 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'id': 255, 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'id': 256, 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'id': 257, 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'id': 258, 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'id': 259, 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'id': 260, 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'def': 'necklace that fits tightly around the neck', 'name': 'choker'}, {'frequency': 'f', 'id': 261, 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'c', 'id': 262, 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'id': 263, 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'id': 264, 'synset': 'chute.n.02', 'synonyms': ['slide'], 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'id': 265, 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'id': 266, 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'c', 'id': 267, 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'id': 268, 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'id': 269, 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'id': 270, 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'r', 'id': 271, 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'id': 272, 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'id': 273, 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'id': 274, 'synset': 'clip.n.03', 'synonyms': ['clip'], 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'id': 275, 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'f', 'id': 276, 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'id': 277, 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'id': 278, 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'id': 279, 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'id': 280, 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'id': 281, 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'id': 282, 'synset': 'coat.n.01', 'synonyms': ['coat'], 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'id': 283, 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'r', 'id': 284, 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'id': 285, 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'c', 'id': 286, 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'r', 'id': 287, 'synset': 'coffee_filter.n.01', 'synonyms': ['coffee_filter'], 'def': 'filter (usually of paper) that passes the coffee and retains the coffee grounds', 'name': 'coffee_filter'}, {'frequency': 'f', 'id': 288, 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'id': 289, 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'id': 290, 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'id': 291, 'synset': 'coil.n.05', 'synonyms': ['coil'], 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'id': 292, 'synset': 'coin.n.01', 'synonyms': ['coin'], 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'r', 'id': 293, 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'id': 294, 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'id': 295, 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'id': 296, 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'id': 297, 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'id': 298, 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'f', 'id': 299, 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'r', 'id': 300, 'synset': 'concrete_mixer.n.01', 'synonyms': ['concrete_mixer', 'cement_mixer'], 'def': 'a machine with a large revolving drum in which cement/concrete is mixed', 'name': 'concrete_mixer'}, {'frequency': 'f', 'id': 301, 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'id': 302, 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'id': 303, 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'id': 304, 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'c', 'id': 305, 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'id': 306, 'synset': 'cookie_jar.n.01', 'synonyms': ['cookie_jar', 'cooky_jar'], 'def': 'a jar in which cookies are kept (and sometimes money is hidden)', 'name': 'cookie_jar'}, {'frequency': 'r', 'id': 307, 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'id': 308, 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'c', 'id': 309, 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'id': 310, 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'r', 'id': 311, 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'c', 'id': 312, 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'def': 'ears of corn that can be prepared and served for human food', 'name': 'edible_corn'}, {'frequency': 'r', 'id': 313, 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'id': 314, 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'id': 315, 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'id': 316, 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'r', 'id': 317, 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'r', 'id': 318, 'synset': 'cos.n.02', 'synonyms': ['romaine_lettuce'], 'def': 'lettuce with long dark-green leaves in a loosely packed elongated head', 'name': 'romaine_lettuce'}, {'frequency': 'c', 'id': 319, 'synset': 'costume.n.04', 'synonyms': ['costume'], 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'id': 320, 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'id': 321, 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'r', 'id': 322, 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'id': 323, 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'r', 'id': 324, 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'c', 'id': 325, 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'id': 326, 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'id': 327, 'synset': 'crate.n.01', 'synonyms': ['crate'], 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'r', 'id': 328, 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'id': 329, 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'r', 'id': 330, 'synset': 'credit_card.n.01', 'synonyms': ['credit_card', 'charge_card', 'debit_card'], 'def': 'a card, usually plastic, used to pay for goods and services', 'name': 'credit_card'}, {'frequency': 'c', 'id': 331, 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'id': 332, 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'id': 333, 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'def': 'an earthen jar (made of baked clay)', 'name': 'crock_pot'}, {'frequency': 'f', 'id': 334, 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'id': 335, 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'r', 'id': 336, 'synset': 'crow.n.01', 'synonyms': ['crow'], 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'c', 'id': 337, 'synset': 'crown.n.04', 'synonyms': ['crown'], 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'id': 338, 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'id': 339, 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'id': 340, 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'c', 'id': 341, 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'r', 'id': 342, 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'id': 343, 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'r', 'id': 344, 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'id': 345, 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'id': 346, 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'id': 347, 'synset': 'cup.n.01', 'synonyms': ['cup'], 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'id': 348, 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'def': 'a metal vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'c', 'id': 349, 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'id': 350, 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'id': 351, 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'id': 352, 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'id': 353, 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'id': 354, 'synset': 'custard.n.01', 'synonyms': ['custard'], 'def': 'sweetened mixture of milk and eggs baked or boiled or frozen', 'name': 'custard'}, {'frequency': 'c', 'id': 355, 'synset': 'cutter.n.06', 'synonyms': ['cutting_tool'], 'def': 'a cutting implement; a tool for cutting', 'name': 'cutting_tool'}, {'frequency': 'r', 'id': 356, 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'id': 357, 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'id': 358, 'synset': 'dachshund.n.01', 'synonyms': ['dachshund', 'dachsie', 'badger_dog'], 'def': 'small long-bodied short-legged breed of dog having a short sleek coat and long drooping ears', 'name': 'dachshund'}, {'frequency': 'r', 'id': 359, 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'id': 360, 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'id': 361, 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'id': 362, 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'id': 363, 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'id': 364, 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'id': 365, 'synset': 'desk.n.01', 'synonyms': ['desk'], 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'id': 366, 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'id': 367, 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'id': 368, 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'def': 'a daily written record of (usually personal) experiences and observations', 'name': 'diary'}, {'frequency': 'r', 'id': 369, 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'id': 370, 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'id': 371, 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'id': 372, 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'c', 'id': 373, 'synset': 'dish.n.01', 'synonyms': ['dish'], 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'id': 374, 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'id': 375, 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'def': 'a cloth for washing dishes', 'name': 'dishrag'}, {'frequency': 'c', 'id': 376, 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'id': 377, 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'id': 378, 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid'], 'def': 'a low-sudsing detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'r', 'id': 379, 'synset': 'diskette.n.01', 'synonyms': ['diskette', 'floppy', 'floppy_disk'], 'def': 'a small plastic magnetic disk enclosed in a stiff envelope used to store data', 'name': 'diskette'}, {'frequency': 'c', 'id': 380, 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'c', 'id': 381, 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'id': 382, 'synset': 'dog.n.01', 'synonyms': ['dog'], 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'id': 383, 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'c', 'id': 384, 'synset': 'doll.n.01', 'synonyms': ['doll'], 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'id': 385, 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'id': 386, 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'id': 387, 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'r', 'id': 388, 'synset': 'domino.n.03', 'synonyms': ['eye_mask'], 'def': 'a mask covering the upper part of the face but with holes for the eyes', 'name': 'eye_mask'}, {'frequency': 'r', 'id': 389, 'synset': 'doorbell.n.01', 'synonyms': ['doorbell', 'buzzer'], 'def': 'a button at an outer door that gives a ringing or buzzing signal when pushed', 'name': 'doorbell'}, {'frequency': 'f', 'id': 390, 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'id': 391, 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'id': 392, 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'id': 393, 'synset': 'dove.n.01', 'synonyms': ['dove'], 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'id': 394, 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'id': 395, 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'id': 396, 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'id': 397, 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'id': 398, 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'c', 'id': 399, 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'c', 'id': 400, 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'id': 401, 'synset': 'drill.n.01', 'synonyms': ['drill'], 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'id': 402, 'synset': 'drinking_fountain.n.01', 'synonyms': ['drinking_fountain'], 'def': 'a public fountain to provide a jet of drinking water', 'name': 'drinking_fountain'}, {'frequency': 'r', 'id': 403, 'synset': 'drone.n.04', 'synonyms': ['drone'], 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'id': 404, 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'id': 405, 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'id': 406, 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'id': 407, 'synset': 'duck.n.01', 'synonyms': ['duck'], 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'r', 'id': 408, 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'id': 409, 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'id': 410, 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'def': 'a large cylindrical bag of heavy cloth', 'name': 'duffel_bag'}, {'frequency': 'r', 'id': 411, 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'id': 412, 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'id': 413, 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'r', 'id': 414, 'synset': 'dutch_oven.n.02', 'synonyms': ['Dutch_oven'], 'def': 'iron or earthenware cooking pot; used for stews', 'name': 'Dutch_oven'}, {'frequency': 'c', 'id': 415, 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'id': 416, 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'id': 417, 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'id': 418, 'synset': 'earring.n.01', 'synonyms': ['earring'], 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'id': 419, 'synset': 'easel.n.01', 'synonyms': ['easel'], 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'id': 420, 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'id': 421, 'synset': 'eel.n.01', 'synonyms': ['eel'], 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'id': 422, 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'id': 423, 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'id': 424, 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'id': 425, 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'id': 426, 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'id': 427, 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'id': 428, 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'id': 429, 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'r', 'id': 430, 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'id': 431, 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'id': 432, 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'id': 433, 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'id': 434, 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'id': 435, 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'id': 436, 'synset': 'fan.n.01', 'synonyms': ['fan'], 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'id': 437, 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'id': 438, 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'id': 439, 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'id': 440, 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'r', 'id': 441, 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'id': 442, 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'id': 443, 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'id': 444, 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'id': 445, 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'id': 446, 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'id': 447, 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'c', 'id': 448, 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'c', 'id': 449, 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'id': 450, 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'id': 451, 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'id': 452, 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'c', 'id': 453, 'synset': 'fish.n.01', 'synonyms': ['fish'], 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'r', 'id': 454, 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'id': 455, 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'r', 'id': 456, 'synset': 'fishing_boat.n.01', 'synonyms': ['fishing_boat', 'fishing_vessel'], 'def': 'a vessel for fishing', 'name': 'fishing_boat'}, {'frequency': 'c', 'id': 457, 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'id': 458, 'synset': 'flag.n.01', 'synonyms': ['flag'], 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'id': 459, 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'id': 460, 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'id': 461, 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'r', 'id': 462, 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'id': 463, 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'id': 464, 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'id': 465, 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'id': 466, 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'id': 467, 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'id': 468, 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'r', 'id': 469, 'synset': 'foal.n.01', 'synonyms': ['foal'], 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'id': 470, 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'id': 471, 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'id': 472, 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'id': 473, 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'id': 474, 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'id': 475, 'synset': 'fork.n.01', 'synonyms': ['fork'], 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'r', 'id': 476, 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'r', 'id': 477, 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'r', 'id': 478, 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'id': 479, 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'def': 'anything that freshens', 'name': 'freshener'}, {'frequency': 'f', 'id': 480, 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'id': 481, 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'id': 482, 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'r', 'id': 483, 'synset': 'fruit_salad.n.01', 'synonyms': ['fruit_salad'], 'def': 'salad composed of fruits', 'name': 'fruit_salad'}, {'frequency': 'c', 'id': 484, 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'id': 485, 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'id': 486, 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'c', 'id': 487, 'synset': 'futon.n.01', 'synonyms': ['futon'], 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'id': 488, 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'id': 489, 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'id': 490, 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'id': 491, 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'id': 492, 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'id': 493, 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'id': 494, 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'id': 495, 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'r', 'id': 496, 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'id': 497, 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'id': 498, 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'c', 'id': 499, 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'id': 500, 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'id': 501, 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'id': 502, 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'id': 503, 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'id': 504, 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'id': 505, 'synset': 'globe.n.03', 'synonyms': ['globe'], 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'id': 506, 'synset': 'glove.n.02', 'synonyms': ['glove'], 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'id': 507, 'synset': 'goat.n.01', 'synonyms': ['goat'], 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'id': 508, 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'id': 509, 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'r', 'id': 510, 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'id': 511, 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'id': 512, 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'id': 513, 'synset': 'goose.n.01', 'synonyms': ['goose'], 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'id': 514, 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'id': 515, 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'r', 'id': 516, 'synset': 'gown.n.04', 'synonyms': ['surgical_gown', 'scrubs_(surgical_clothing)'], 'def': 'protective garment worn by surgeons during operations', 'name': 'surgical_gown'}, {'frequency': 'f', 'id': 517, 'synset': 'grape.n.01', 'synonyms': ['grape'], 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'r', 'id': 518, 'synset': 'grasshopper.n.01', 'synonyms': ['grasshopper'], 'def': 'plant-eating insect with hind legs adapted for leaping', 'name': 'grasshopper'}, {'frequency': 'c', 'id': 519, 'synset': 'grater.n.01', 'synonyms': ['grater'], 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'id': 520, 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'id': 521, 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'c', 'id': 522, 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'c', 'id': 523, 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'id': 524, 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'r', 'id': 525, 'synset': 'grillroom.n.01', 'synonyms': ['grillroom', 'grill_(restaurant)'], 'def': 'a restaurant where food is cooked on a grill', 'name': 'grillroom'}, {'frequency': 'r', 'id': 526, 'synset': 'grinder.n.04', 'synonyms': ['grinder_(tool)'], 'def': 'a machine tool that polishes metal', 'name': 'grinder_(tool)'}, {'frequency': 'r', 'id': 527, 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'id': 528, 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'id': 529, 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'r', 'id': 530, 'synset': 'guacamole.n.01', 'synonyms': ['guacamole'], 'def': 'a dip made of mashed avocado mixed with chopped onions and other seasonings', 'name': 'guacamole'}, {'frequency': 'f', 'id': 531, 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'id': 532, 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'id': 533, 'synset': 'gun.n.01', 'synonyms': ['gun'], 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'r', 'id': 534, 'synset': 'hair_spray.n.01', 'synonyms': ['hair_spray'], 'def': 'substance sprayed on the hair to hold it in place', 'name': 'hair_spray'}, {'frequency': 'c', 'id': 535, 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'id': 536, 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'id': 537, 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'f', 'id': 538, 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'id': 539, 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'id': 540, 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'r', 'id': 541, 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'id': 542, 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'r', 'id': 543, 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'c', 'id': 544, 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'id': 545, 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'id': 546, 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'id': 547, 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'id': 548, 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'id': 549, 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'id': 550, 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'id': 551, 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'id': 552, 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'id': 553, 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'id': 554, 'synset': 'hat.n.01', 'synonyms': ['hat'], 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'id': 555, 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'r', 'id': 556, 'synset': 'hatch.n.03', 'synonyms': ['hatch'], 'def': 'a movable barrier covering a hatchway', 'name': 'hatch'}, {'frequency': 'c', 'id': 557, 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'def': 'a garment that covers the head and face', 'name': 'veil'}, {'frequency': 'f', 'id': 558, 'synset': 'headband.n.01', 'synonyms': ['headband'], 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'id': 559, 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'id': 560, 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'id': 561, 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'id': 562, 'synset': 'headset.n.01', 'synonyms': ['headset'], 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'id': 563, 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'r', 'id': 564, 'synset': 'hearing_aid.n.02', 'synonyms': ['hearing_aid'], 'def': 'an acoustic device used to direct sound to the ear of a hearing-impaired person', 'name': 'hearing_aid'}, {'frequency': 'c', 'id': 565, 'synset': 'heart.n.02', 'synonyms': ['heart'], 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'id': 566, 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'id': 567, 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'id': 568, 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'id': 569, 'synset': 'heron.n.02', 'synonyms': ['heron'], 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'id': 570, 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'id': 571, 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'id': 572, 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'id': 573, 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'id': 574, 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'id': 575, 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'id': 576, 'synset': 'honey.n.01', 'synonyms': ['honey'], 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'id': 577, 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'id': 578, 'synset': 'hook.n.05', 'synonyms': ['hook'], 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'f', 'id': 579, 'synset': 'horse.n.01', 'synonyms': ['horse'], 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'id': 580, 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'id': 581, 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'id': 582, 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'id': 583, 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'id': 584, 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'id': 585, 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'r', 'id': 586, 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'id': 587, 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'c', 'id': 588, 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'id': 589, 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'id': 590, 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'id': 591, 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'id': 592, 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'id': 593, 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'r', 'id': 594, 'synset': 'ice_tea.n.01', 'synonyms': ['ice_tea', 'iced_tea'], 'def': 'strong tea served over ice', 'name': 'ice_tea'}, {'frequency': 'c', 'id': 595, 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'id': 596, 'synset': 'incense.n.01', 'synonyms': ['incense'], 'def': 'a substance that produces a fragrant odor when burned', 'name': 'incense'}, {'frequency': 'r', 'id': 597, 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'c', 'id': 598, 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'id': 599, 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'r', 'id': 600, 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'id': 601, 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'r', 'id': 602, 'synset': 'jam.n.01', 'synonyms': ['jam'], 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'id': 603, 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'id': 604, 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'id': 605, 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'id': 606, 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'id': 607, 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'c', 'id': 608, 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'id': 609, 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'r', 'id': 610, 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'id': 611, 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'id': 612, 'synset': 'keg.n.02', 'synonyms': ['keg'], 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'id': 613, 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'id': 614, 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'id': 615, 'synset': 'key.n.01', 'synonyms': ['key'], 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'id': 616, 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'r', 'id': 617, 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'id': 618, 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'id': 619, 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'c', 'id': 620, 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'id': 621, 'synset': 'kite.n.03', 'synonyms': ['kite'], 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'id': 622, 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'id': 623, 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'id': 624, 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'id': 625, 'synset': 'knife.n.01', 'synonyms': ['knife'], 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'id': 626, 'synset': 'knight.n.02', 'synonyms': ['knight_(chess_piece)', 'horse_(chess_piece)'], 'def': 'a chess game piece shaped to resemble the head of a horse', 'name': 'knight_(chess_piece)'}, {'frequency': 'r', 'id': 627, 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'id': 628, 'synset': 'knob.n.02', 'synonyms': ['knob'], 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'id': 629, 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'id': 630, 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'id': 631, 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'id': 632, 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'id': 633, 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'r', 'id': 634, 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'c', 'id': 635, 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'id': 636, 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'id': 637, 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'id': 638, 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'id': 639, 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'id': 640, 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'id': 641, 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'id': 642, 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'id': 643, 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'c', 'id': 644, 'synset': 'latch.n.02', 'synonyms': ['latch'], 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'id': 645, 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'id': 646, 'synset': 'leather.n.01', 'synonyms': ['leather'], 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'id': 647, 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'id': 648, 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'f', 'id': 649, 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'id': 650, 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'id': 651, 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'id': 652, 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'id': 653, 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'id': 654, 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'id': 655, 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'def': 'glass bulb or tube shaped electric device that emits light (DO NOT MARK LAMPS AS A WHOLE)', 'name': 'lightbulb'}, {'frequency': 'r', 'id': 656, 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'c', 'id': 657, 'synset': 'lime.n.06', 'synonyms': ['lime'], 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'id': 658, 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'r', 'id': 659, 'synset': 'linen.n.02', 'synonyms': ['linen_paper'], 'def': 'a high-quality paper made of linen fibers or with a linen finish', 'name': 'linen_paper'}, {'frequency': 'c', 'id': 660, 'synset': 'lion.n.01', 'synonyms': ['lion'], 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'id': 661, 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'c', 'id': 662, 'synset': 'lipstick.n.01', 'synonyms': ['lipstick', 'lip_rouge'], 'def': 'makeup that is used to color the lips', 'name': 'lipstick'}, {'frequency': 'r', 'id': 663, 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'def': 'an alcoholic beverage that is distilled rather than fermented', 'name': 'liquor'}, {'frequency': 'r', 'id': 664, 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'r', 'id': 665, 'synset': 'loafer.n.02', 'synonyms': ['Loafer_(type_of_shoe)'], 'def': 'a low leather step-in shoe', 'name': 'Loafer_(type_of_shoe)'}, {'frequency': 'f', 'id': 666, 'synset': 'log.n.01', 'synonyms': ['log'], 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'id': 667, 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'c', 'id': 668, 'synset': 'lotion.n.01', 'synonyms': ['lotion'], 'def': 'any of various cosmetic preparations that are applied to the skin', 'name': 'lotion'}, {'frequency': 'f', 'id': 669, 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'id': 670, 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'id': 671, 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'id': 672, 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'id': 673, 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'r', 'id': 674, 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'c', 'id': 675, 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'id': 676, 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'id': 677, 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'c', 'id': 678, 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'id': 679, 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'id': 680, 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'c', 'id': 681, 'synset': 'map.n.01', 'synonyms': ['map'], 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'c', 'id': 682, 'synset': 'marker.n.03', 'synonyms': ['marker'], 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'id': 683, 'synset': 'martini.n.01', 'synonyms': ['martini'], 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'id': 684, 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'id': 685, 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'id': 686, 'synset': 'masher.n.02', 'synonyms': ['masher'], 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'id': 687, 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'id': 688, 'synset': 'mast.n.01', 'synonyms': ['mast'], 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'id': 689, 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'id': 690, 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'id': 691, 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'id': 692, 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'id': 693, 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'id': 694, 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'id': 695, 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'r', 'id': 696, 'synset': 'melon.n.01', 'synonyms': ['melon'], 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'id': 697, 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'id': 698, 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'id': 699, 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'id': 700, 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'c', 'id': 701, 'synset': 'milk.n.01', 'synonyms': ['milk'], 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'f', 'id': 702, 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'id': 703, 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'id': 704, 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'id': 705, 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'id': 706, 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'id': 707, 'synset': 'money.n.03', 'synonyms': ['money'], 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'id': 708, 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'id': 709, 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'id': 710, 'synset': 'motor.n.01', 'synonyms': ['motor'], 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'id': 711, 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'id': 712, 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'r', 'id': 713, 'synset': 'motorboat.n.01', 'synonyms': ['motorboat', 'powerboat'], 'def': 'a boat propelled by an internal-combustion engine', 'name': 'motorboat'}, {'frequency': 'f', 'id': 714, 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'id': 715, 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'r', 'id': 716, 'synset': 'mouse.n.01', 'synonyms': ['mouse_(animal_rodent)'], 'def': 'a small rodent with pointed snouts and small ears on elongated bodies with slender usually hairless tails', 'name': 'mouse_(animal_rodent)'}, {'frequency': 'f', 'id': 717, 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'def': 'a computer input device that controls an on-screen pointer', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'id': 718, 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'id': 719, 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'id': 720, 'synset': 'mug.n.04', 'synonyms': ['mug'], 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'id': 721, 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'id': 722, 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'r', 'id': 723, 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'id': 724, 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'r', 'id': 725, 'synset': 'nameplate.n.01', 'synonyms': ['nameplate'], 'def': 'a plate bearing a name', 'name': 'nameplate'}, {'frequency': 'f', 'id': 726, 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'id': 727, 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'id': 728, 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'id': 729, 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'r', 'id': 730, 'synset': 'needle.n.03', 'synonyms': ['needle'], 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'id': 731, 'synset': 'nest.n.01', 'synonyms': ['nest'], 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'r', 'id': 732, 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'id': 733, 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'id': 734, 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'r', 'id': 735, 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'id': 736, 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'id': 737, 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'c', 'id': 738, 'synset': 'nut.n.03', 'synonyms': ['nut'], 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'id': 739, 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'c', 'id': 740, 'synset': 'oar.n.01', 'synonyms': ['oar'], 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'id': 741, 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'id': 742, 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'id': 743, 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'id': 744, 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'id': 745, 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'id': 746, 'synset': 'onion.n.01', 'synonyms': ['onion'], 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'id': 747, 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'id': 748, 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'r', 'id': 749, 'synset': 'oregano.n.01', 'synonyms': ['oregano', 'marjoram'], 'def': 'aromatic Eurasian perennial herb used in cooking and baking', 'name': 'oregano'}, {'frequency': 'c', 'id': 750, 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'c', 'id': 751, 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'def': 'thick cushion used as a seat', 'name': 'ottoman'}, {'frequency': 'c', 'id': 752, 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'id': 753, 'synset': 'owl.n.01', 'synonyms': ['owl'], 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'id': 754, 'synset': 'packet.n.03', 'synonyms': ['packet'], 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'id': 755, 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'id': 756, 'synset': 'pad.n.04', 'synonyms': ['pad'], 'def': 'a flat mass of soft material used for protection, stuffing, or comfort', 'name': 'pad'}, {'frequency': 'c', 'id': 757, 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'id': 758, 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'r', 'id': 759, 'synset': 'paintbox.n.01', 'synonyms': ['paintbox'], 'def': "a box containing a collection of cubes or tubes of artists' paint", 'name': 'paintbox'}, {'frequency': 'c', 'id': 760, 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'id': 761, 'synset': 'painting.n.01', 'synonyms': ['painting'], 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'c', 'id': 762, 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'id': 763, 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'id': 764, 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'id': 765, 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'id': 766, 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'id': 767, 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'id': 768, 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'r', 'id': 769, 'synset': 'paper_clip.n.01', 'synonyms': ['paperclip'], 'def': 'a wire or plastic clip for holding sheets of paper together', 'name': 'paperclip'}, {'frequency': 'f', 'id': 770, 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'id': 771, 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'id': 772, 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'id': 773, 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'id': 774, 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'r', 'id': 775, 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'id': 776, 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'r', 'id': 777, 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'r', 'id': 778, 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'id': 779, 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'id': 780, 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'id': 781, 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'id': 782, 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'r', 'id': 783, 'synset': 'passport.n.02', 'synonyms': ['passport'], 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'id': 784, 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'id': 785, 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'id': 786, 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'id': 787, 'synset': 'peach.n.03', 'synonyms': ['peach'], 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'id': 788, 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'c', 'id': 789, 'synset': 'pear.n.01', 'synonyms': ['pear'], 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'r', 'id': 790, 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'id': 791, 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'id': 792, 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'id': 793, 'synset': 'pen.n.01', 'synonyms': ['pen'], 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'c', 'id': 794, 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'id': 795, 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'id': 796, 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'id': 797, 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'id': 798, 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'id': 799, 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'id': 800, 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'c', 'id': 801, 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'id': 802, 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'id': 803, 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'id': 804, 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'id': 805, 'synset': 'person.n.01', 'synonyms': ['baby', 'child', 'boy', 'girl', 'man', 'woman', 'person', 'human'], 'def': 'a human being', 'name': 'baby'}, {'frequency': 'r', 'id': 806, 'synset': 'pet.n.01', 'synonyms': ['pet'], 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'r', 'id': 807, 'synset': 'petfood.n.01', 'synonyms': ['petfood', 'pet-food'], 'def': 'food prepared for animal pets', 'name': 'petfood'}, {'frequency': 'r', 'id': 808, 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'id': 809, 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'id': 810, 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'c', 'id': 811, 'synset': 'piano.n.01', 'synonyms': ['piano'], 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'id': 812, 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'id': 813, 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'id': 814, 'synset': 'pie.n.01', 'synonyms': ['pie'], 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'id': 815, 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'id': 816, 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'id': 817, 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'id': 818, 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'id': 819, 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'id': 820, 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'id': 821, 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'id': 822, 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'id': 823, 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'id': 824, 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'id': 825, 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'r', 'id': 826, 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'id': 827, 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'id': 828, 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'id': 829, 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'id': 830, 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'id': 831, 'synset': 'plate.n.04', 'synonyms': ['plate'], 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'id': 832, 'synset': 'platter.n.01', 'synonyms': ['platter'], 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'id': 833, 'synset': 'playing_card.n.01', 'synonyms': ['playing_card'], 'def': 'one of a pack of cards that are used to play card games', 'name': 'playing_card'}, {'frequency': 'r', 'id': 834, 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'id': 835, 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'id': 836, 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'id': 837, 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'id': 838, 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'id': 839, 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'id': 840, 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'r', 'id': 841, 'synset': 'police_van.n.01', 'synonyms': ['police_van', 'police_wagon', 'paddy_wagon', 'patrol_wagon'], 'def': 'van used by police to transport prisoners', 'name': 'police_van'}, {'frequency': 'f', 'id': 842, 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'id': 843, 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'id': 844, 'synset': 'pony.n.05', 'synonyms': ['pony'], 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'id': 845, 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'id': 846, 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'r', 'id': 847, 'synset': 'portrait.n.02', 'synonyms': ['portrait', 'portrayal'], 'def': 'any likeness of a person, in any medium', 'name': 'portrait'}, {'frequency': 'c', 'id': 848, 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'id': 849, 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'id': 850, 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'id': 851, 'synset': 'pot.n.01', 'synonyms': ['pot'], 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'id': 852, 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'id': 853, 'synset': 'potato.n.01', 'synonyms': ['potato'], 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'id': 854, 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'id': 855, 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'id': 856, 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'r', 'id': 857, 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'id': 858, 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'f', 'id': 859, 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'id': 860, 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'id': 861, 'synset': 'projector.n.02', 'synonyms': ['projector'], 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'id': 862, 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'id': 863, 'synset': 'prune.n.01', 'synonyms': ['prune'], 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'id': 864, 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'id': 865, 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'id': 866, 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'id': 867, 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'id': 868, 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'id': 869, 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'id': 870, 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'r', 'id': 871, 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'id': 872, 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'id': 873, 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'id': 874, 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'id': 875, 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'id': 876, 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'id': 877, 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'id': 878, 'synset': 'radar.n.01', 'synonyms': ['radar'], 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'c', 'id': 879, 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'id': 880, 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'id': 881, 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'id': 882, 'synset': 'raft.n.01', 'synonyms': ['raft'], 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'id': 883, 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'id': 884, 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'id': 885, 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'id': 886, 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'id': 887, 'synset': 'rat.n.01', 'synonyms': ['rat'], 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'id': 888, 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'id': 889, 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'id': 890, 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'def': 'car mirror that reflects the view out of the rear window', 'name': 'rearview_mirror'}, {'frequency': 'c', 'id': 891, 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'id': 892, 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'r', 'id': 893, 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'r', 'id': 894, 'synset': 'red_cabbage.n.02', 'synonyms': ['red_cabbage'], 'def': 'compact head of purplish-red leaves', 'name': 'red_cabbage'}, {'frequency': 'f', 'id': 895, 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'id': 896, 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'id': 897, 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'id': 898, 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'r', 'id': 899, 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'id': 900, 'synset': 'ring.n.08', 'synonyms': ['ring'], 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'id': 901, 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'id': 902, 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'id': 903, 'synset': 'robe.n.01', 'synonyms': ['robe'], 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'id': 904, 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'id': 905, 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'id': 906, 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'id': 907, 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'id': 908, 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'id': 909, 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'id': 910, 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'id': 911, 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'id': 912, 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'id': 913, 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'id': 914, 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'id': 915, 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'id': 916, 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'c', 'id': 917, 'synset': 'sail.n.01', 'synonyms': ['sail'], 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'c', 'id': 918, 'synset': 'salad.n.01', 'synonyms': ['salad'], 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'id': 919, 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'r', 'id': 920, 'synset': 'salami.n.01', 'synonyms': ['salami'], 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'r', 'id': 921, 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'id': 922, 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'r', 'id': 923, 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'id': 924, 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'id': 925, 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'id': 926, 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'id': 927, 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'id': 928, 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'id': 929, 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'id': 930, 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'id': 931, 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'id': 932, 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'id': 933, 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'id': 934, 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'id': 935, 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'id': 936, 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'id': 937, 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'c', 'id': 938, 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'c', 'id': 939, 'synset': 'scrambled_eggs.n.01', 'synonyms': ['scrambled_eggs'], 'def': 'eggs beaten and cooked to a soft firm consistency while stirring', 'name': 'scrambled_eggs'}, {'frequency': 'r', 'id': 940, 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'r', 'id': 941, 'synset': 'scratcher.n.03', 'synonyms': ['scratcher'], 'def': 'a device used for scratching', 'name': 'scratcher'}, {'frequency': 'c', 'id': 942, 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'c', 'id': 943, 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'id': 944, 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'r', 'id': 945, 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'r', 'id': 946, 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'id': 947, 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'id': 948, 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'r', 'id': 949, 'synset': 'seedling.n.01', 'synonyms': ['seedling'], 'def': 'young plant or tree grown from a seed', 'name': 'seedling'}, {'frequency': 'c', 'id': 950, 'synset': 'serving_dish.n.01', 'synonyms': ['serving_dish'], 'def': 'a dish used for serving food', 'name': 'serving_dish'}, {'frequency': 'r', 'id': 951, 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'r', 'id': 952, 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'id': 953, 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'r', 'id': 954, 'synset': 'shark.n.01', 'synonyms': ['shark'], 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'id': 955, 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'id': 956, 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'id': 957, 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'id': 958, 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'id': 959, 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'id': 960, 'synset': 'shears.n.01', 'synonyms': ['shears'], 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'id': 961, 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'id': 962, 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'id': 963, 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'r', 'id': 964, 'synset': 'shield.n.02', 'synonyms': ['shield'], 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'id': 965, 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'id': 966, 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'c', 'id': 967, 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'id': 968, 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'id': 969, 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'id': 970, 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'c', 'id': 971, 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'id': 972, 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'id': 973, 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'f', 'id': 974, 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'id': 975, 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'r', 'id': 976, 'synset': 'sieve.n.01', 'synonyms': ['sieve', 'screen_(sieve)'], 'def': 'a strainer for separating lumps from powdered material or grading particles', 'name': 'sieve'}, {'frequency': 'f', 'id': 977, 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'id': 978, 'synset': 'silo.n.01', 'synonyms': ['silo'], 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'id': 979, 'synset': 'sink.n.01', 'synonyms': ['sink'], 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'id': 980, 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'id': 981, 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'id': 982, 'synset': 'ski.n.01', 'synonyms': ['ski'], 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'id': 983, 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'id': 984, 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'id': 985, 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'id': 986, 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'c', 'id': 987, 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'id': 988, 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'id': 989, 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'id': 990, 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'id': 991, 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'id': 992, 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'id': 993, 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'id': 994, 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'id': 995, 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'id': 996, 'synset': 'soap.n.01', 'synonyms': ['soap'], 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'id': 997, 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'id': 998, 'synset': 'sock.n.01', 'synonyms': ['sock'], 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'r', 'id': 999, 'synset': 'soda_fountain.n.02', 'synonyms': ['soda_fountain'], 'def': 'an apparatus for dispensing soda water', 'name': 'soda_fountain'}, {'frequency': 'r', 'id': 1000, 'synset': 'soda_water.n.01', 'synonyms': ['carbonated_water', 'club_soda', 'seltzer', 'sparkling_water'], 'def': 'effervescent beverage artificially charged with carbon dioxide', 'name': 'carbonated_water'}, {'frequency': 'f', 'id': 1001, 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'id': 1002, 'synset': 'softball.n.01', 'synonyms': ['softball'], 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'id': 1003, 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'id': 1004, 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'c', 'id': 1005, 'synset': 'soup.n.01', 'synonyms': ['soup'], 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'id': 1006, 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'id': 1007, 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'id': 1008, 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'id': 1009, 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'id': 1010, 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'id': 1011, 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'id': 1012, 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'id': 1013, 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'id': 1014, 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'id': 1015, 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'r', 'id': 1016, 'synset': 'spider.n.01', 'synonyms': ['spider'], 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'c', 'id': 1017, 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'id': 1018, 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'id': 1019, 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'id': 1020, 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'id': 1021, 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'c', 'id': 1022, 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'r', 'id': 1023, 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'id': 1024, 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'id': 1025, 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'id': 1026, 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'r', 'id': 1027, 'synset': 'steamer.n.02', 'synonyms': ['steamer_(kitchen_appliance)'], 'def': 'a cooking utensil that can be used to cook food by steaming it', 'name': 'steamer_(kitchen_appliance)'}, {'frequency': 'f', 'id': 1028, 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'id': 1029, 'synset': 'stencil.n.01', 'synonyms': ['stencil'], 'def': 'a sheet of material (metal, plastic, etc.) that has been perforated with a pattern; ink or paint can pass through the perforations to create the printed pattern on the surface below', 'name': 'stencil'}, {'frequency': 'r', 'id': 1030, 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'id': 1031, 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'id': 1032, 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'id': 1033, 'synset': 'stew.n.02', 'synonyms': ['stew'], 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'id': 1034, 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'id': 1035, 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'c', 'id': 1036, 'synset': 'stocking.n.01', 'synonyms': ['stockings_(leg_wear)'], 'def': 'close-fitting hosiery to cover the foot and leg; come in matched pairs', 'name': 'stockings_(leg_wear)'}, {'frequency': 'f', 'id': 1037, 'synset': 'stool.n.01', 'synonyms': ['stool'], 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'id': 1038, 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'id': 1039, 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'id': 1040, 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'id': 1041, 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'id': 1042, 'synset': 'strap.n.01', 'synonyms': ['strap'], 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'id': 1043, 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'id': 1044, 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'id': 1045, 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'id': 1046, 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'id': 1047, 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'id': 1048, 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'def': 'a pointed tool for writing or drawing or engraving', 'name': 'stylus'}, {'frequency': 'r', 'id': 1049, 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'id': 1050, 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'id': 1051, 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'c', 'id': 1052, 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'id': 1053, 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'id': 1054, 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'id': 1055, 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'r', 'id': 1056, 'synset': 'sunscreen.n.01', 'synonyms': ['sunscreen', 'sunblock'], 'def': 'a cream spread on the skin; contains a chemical to filter out ultraviolet light and so protect from sunburn', 'name': 'sunscreen'}, {'frequency': 'f', 'id': 1057, 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'id': 1058, 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'id': 1059, 'synset': 'swab.n.02', 'synonyms': ['mop'], 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'id': 1060, 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'id': 1061, 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'id': 1062, 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'id': 1063, 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'id': 1064, 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'id': 1065, 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'id': 1066, 'synset': 'sword.n.01', 'synonyms': ['sword'], 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'id': 1067, 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'id': 1068, 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'id': 1069, 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'id': 1070, 'synset': 'table.n.02', 'synonyms': ['table'], 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'id': 1071, 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'id': 1072, 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'id': 1073, 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'id': 1074, 'synset': 'taco.n.02', 'synonyms': ['taco'], 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'id': 1075, 'synset': 'tag.n.02', 'synonyms': ['tag'], 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'id': 1076, 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'id': 1077, 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'id': 1078, 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'c', 'id': 1079, 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'id': 1080, 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'c', 'id': 1081, 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'id': 1082, 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'id': 1083, 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'id': 1084, 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'id': 1085, 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'id': 1086, 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'r', 'id': 1087, 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'id': 1088, 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'id': 1089, 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'c', 'id': 1090, 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'id': 1091, 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'id': 1092, 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'def': 'electronic device for communicating by voice over long distances', 'name': 'telephone'}, {'frequency': 'c', 'id': 1093, 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'id': 1094, 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'id': 1095, 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'id': 1096, 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'id': 1097, 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'id': 1098, 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'id': 1099, 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'id': 1100, 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'id': 1101, 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'id': 1102, 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'c', 'id': 1103, 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'id': 1104, 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'id': 1105, 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'id': 1106, 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'id': 1107, 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'id': 1108, 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'id': 1109, 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'id': 1110, 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'id': 1111, 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'r', 'id': 1112, 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'id': 1113, 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'id': 1114, 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'id': 1115, 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'c', 'id': 1116, 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'id': 1117, 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'id': 1118, 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'id': 1119, 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'c', 'id': 1120, 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'id': 1121, 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'id': 1122, 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'id': 1123, 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'c', 'id': 1124, 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'c', 'id': 1125, 'synset': 'top.n.09', 'synonyms': ['cover'], 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'id': 1126, 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'id': 1127, 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'id': 1128, 'synset': 'towel.n.01', 'synonyms': ['towel'], 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'id': 1129, 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'id': 1130, 'synset': 'toy.n.03', 'synonyms': ['toy'], 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'id': 1131, 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'id': 1132, 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'r', 'id': 1133, 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'c', 'id': 1134, 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'id': 1135, 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'id': 1136, 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'id': 1137, 'synset': 'tray.n.01', 'synonyms': ['tray'], 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'id': 1138, 'synset': 'tree_house.n.01', 'synonyms': ['tree_house'], 'def': '(NOT A TREE) a PLAYHOUSE built in the branches of a tree', 'name': 'tree_house'}, {'frequency': 'r', 'id': 1139, 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'id': 1140, 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'r', 'id': 1141, 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'c', 'id': 1142, 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'id': 1143, 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'id': 1144, 'synset': 'truck.n.01', 'synonyms': ['truck'], 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'id': 1145, 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'id': 1146, 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'id': 1147, 'synset': 'tub.n.02', 'synonyms': ['vat'], 'def': 'a large open vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'id': 1148, 'synset': 'turban.n.01', 'synonyms': ['turban'], 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'r', 'id': 1149, 'synset': 'turkey.n.01', 'synonyms': ['turkey_(bird)'], 'def': 'large gallinaceous bird with fan-shaped tail; widely domesticated for food', 'name': 'turkey_(bird)'}, {'frequency': 'c', 'id': 1150, 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'id': 1151, 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'id': 1152, 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'r', 'id': 1153, 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'r', 'id': 1154, 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'id': 1155, 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'c', 'id': 1156, 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'id': 1157, 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'c', 'id': 1158, 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'r', 'id': 1159, 'synset': 'urn.n.01', 'synonyms': ['urn'], 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'id': 1160, 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'c', 'id': 1161, 'synset': 'valve.n.03', 'synonyms': ['valve'], 'def': 'control consisting of a mechanical device for controlling the flow of a fluid', 'name': 'valve'}, {'frequency': 'f', 'id': 1162, 'synset': 'vase.n.01', 'synonyms': ['vase'], 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'id': 1163, 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'id': 1164, 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'c', 'id': 1165, 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'id': 1166, 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'id': 1167, 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'id': 1168, 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'r', 'id': 1169, 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'id': 1170, 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'id': 1171, 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'id': 1172, 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'id': 1173, 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'id': 1174, 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'id': 1175, 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'id': 1176, 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'id': 1177, 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'c', 'id': 1178, 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'id': 1179, 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'id': 1180, 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'id': 1181, 'synset': 'wasabi.n.02', 'synonyms': ['wasabi'], 'def': 'the thick green root of the wasabi plant that the Japanese use in cooking and that tastes like strong horseradish', 'name': 'wasabi'}, {'frequency': 'c', 'id': 1182, 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'id': 1183, 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'id': 1184, 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'id': 1185, 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'id': 1186, 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'id': 1187, 'synset': 'water_filter.n.01', 'synonyms': ['water_filter'], 'def': 'a filter to remove impurities from the water supply', 'name': 'water_filter'}, {'frequency': 'r', 'id': 1188, 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'r', 'id': 1189, 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'id': 1190, 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'id': 1191, 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'id': 1192, 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'id': 1193, 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'id': 1194, 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'c', 'id': 1195, 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'id': 1196, 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'id': 1197, 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'id': 1198, 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'id': 1199, 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'id': 1200, 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'id': 1201, 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'id': 1202, 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'id': 1203, 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'r', 'id': 1204, 'synset': 'whiskey.n.01', 'synonyms': ['whiskey'], 'def': 'a liquor made from fermented mash of grain', 'name': 'whiskey'}, {'frequency': 'r', 'id': 1205, 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'r', 'id': 1206, 'synset': 'wick.n.02', 'synonyms': ['wick'], 'def': 'a loosely woven cord in a candle or oil lamp that is lit on fire', 'name': 'wick'}, {'frequency': 'c', 'id': 1207, 'synset': 'wig.n.01', 'synonyms': ['wig'], 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'id': 1208, 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'id': 1209, 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'def': 'a mill that is powered by the wind', 'name': 'windmill'}, {'frequency': 'c', 'id': 1210, 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'id': 1211, 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'id': 1212, 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'id': 1213, 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'r', 'id': 1214, 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'id': 1215, 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'r', 'id': 1216, 'synset': 'wing_chair.n.01', 'synonyms': ['wing_chair'], 'def': 'easy chair having wings on each side of a high back', 'name': 'wing_chair'}, {'frequency': 'c', 'id': 1217, 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'id': 1218, 'synset': 'wok.n.01', 'synonyms': ['wok'], 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'id': 1219, 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'id': 1220, 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'id': 1221, 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'id': 1222, 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'c', 'id': 1223, 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'id': 1224, 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'r', 'id': 1225, 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'r', 'id': 1226, 'synset': 'yak.n.02', 'synonyms': ['yak'], 'def': 'large long-haired wild ox of Tibet often domesticated', 'name': 'yak'}, {'frequency': 'c', 'id': 1227, 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'r', 'id': 1228, 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'id': 1229, 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'id': 1230, 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa +# fmt: on diff --git a/custom_detectron2/data/datasets/lvis_v1_categories.py b/custom_detectron2/data/datasets/lvis_v1_categories.py new file mode 100644 index 0000000000000000000000000000000000000000..7374e6968bb006f5d8c49e75d9d3b31ea3d77d05 --- /dev/null +++ b/custom_detectron2/data/datasets/lvis_v1_categories.py @@ -0,0 +1,16 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v1_val.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["image_count"] +# del x["instance_count"] +# LVIS_CATEGORIES = repr(c) + " # noqa" +# with open("/tmp/lvis_categories.py", "wt") as f: +# f.write(f"LVIS_CATEGORIES = {LVIS_CATEGORIES}") +# Then paste the contents of that file below + +# fmt: off +LVIS_CATEGORIES = [{'frequency': 'c', 'synset': 'aerosol.n.02', 'synonyms': ['aerosol_can', 'spray_can'], 'id': 1, 'def': 'a dispenser that holds a substance under pressure', 'name': 'aerosol_can'}, {'frequency': 'f', 'synset': 'air_conditioner.n.01', 'synonyms': ['air_conditioner'], 'id': 2, 'def': 'a machine that keeps air cool and dry', 'name': 'air_conditioner'}, {'frequency': 'f', 'synset': 'airplane.n.01', 'synonyms': ['airplane', 'aeroplane'], 'id': 3, 'def': 'an aircraft that has a fixed wing and is powered by propellers or jets', 'name': 'airplane'}, {'frequency': 'f', 'synset': 'alarm_clock.n.01', 'synonyms': ['alarm_clock'], 'id': 4, 'def': 'a clock that wakes a sleeper at some preset time', 'name': 'alarm_clock'}, {'frequency': 'c', 'synset': 'alcohol.n.01', 'synonyms': ['alcohol', 'alcoholic_beverage'], 'id': 5, 'def': 'a liquor or brew containing alcohol as the active agent', 'name': 'alcohol'}, {'frequency': 'c', 'synset': 'alligator.n.02', 'synonyms': ['alligator', 'gator'], 'id': 6, 'def': 'amphibious reptiles related to crocodiles but with shorter broader snouts', 'name': 'alligator'}, {'frequency': 'c', 'synset': 'almond.n.02', 'synonyms': ['almond'], 'id': 7, 'def': 'oval-shaped edible seed of the almond tree', 'name': 'almond'}, {'frequency': 'c', 'synset': 'ambulance.n.01', 'synonyms': ['ambulance'], 'id': 8, 'def': 'a vehicle that takes people to and from hospitals', 'name': 'ambulance'}, {'frequency': 'c', 'synset': 'amplifier.n.01', 'synonyms': ['amplifier'], 'id': 9, 'def': 'electronic equipment that increases strength of signals', 'name': 'amplifier'}, {'frequency': 'c', 'synset': 'anklet.n.03', 'synonyms': ['anklet', 'ankle_bracelet'], 'id': 10, 'def': 'an ornament worn around the ankle', 'name': 'anklet'}, {'frequency': 'f', 'synset': 'antenna.n.01', 'synonyms': ['antenna', 'aerial', 'transmitting_aerial'], 'id': 11, 'def': 'an electrical device that sends or receives radio or television signals', 'name': 'antenna'}, {'frequency': 'f', 'synset': 'apple.n.01', 'synonyms': ['apple'], 'id': 12, 'def': 'fruit with red or yellow or green skin and sweet to tart crisp whitish flesh', 'name': 'apple'}, {'frequency': 'r', 'synset': 'applesauce.n.01', 'synonyms': ['applesauce'], 'id': 13, 'def': 'puree of stewed apples usually sweetened and spiced', 'name': 'applesauce'}, {'frequency': 'r', 'synset': 'apricot.n.02', 'synonyms': ['apricot'], 'id': 14, 'def': 'downy yellow to rosy-colored fruit resembling a small peach', 'name': 'apricot'}, {'frequency': 'f', 'synset': 'apron.n.01', 'synonyms': ['apron'], 'id': 15, 'def': 'a garment of cloth that is tied about the waist and worn to protect clothing', 'name': 'apron'}, {'frequency': 'c', 'synset': 'aquarium.n.01', 'synonyms': ['aquarium', 'fish_tank'], 'id': 16, 'def': 'a tank/pool/bowl filled with water for keeping live fish and underwater animals', 'name': 'aquarium'}, {'frequency': 'r', 'synset': 'arctic.n.02', 'synonyms': ['arctic_(type_of_shoe)', 'galosh', 'golosh', 'rubber_(type_of_shoe)', 'gumshoe'], 'id': 17, 'def': 'a waterproof overshoe that protects shoes from water or snow', 'name': 'arctic_(type_of_shoe)'}, {'frequency': 'c', 'synset': 'armband.n.02', 'synonyms': ['armband'], 'id': 18, 'def': 'a band worn around the upper arm', 'name': 'armband'}, {'frequency': 'f', 'synset': 'armchair.n.01', 'synonyms': ['armchair'], 'id': 19, 'def': 'chair with a support on each side for arms', 'name': 'armchair'}, {'frequency': 'r', 'synset': 'armoire.n.01', 'synonyms': ['armoire'], 'id': 20, 'def': 'a large wardrobe or cabinet', 'name': 'armoire'}, {'frequency': 'r', 'synset': 'armor.n.01', 'synonyms': ['armor', 'armour'], 'id': 21, 'def': 'protective covering made of metal and used in combat', 'name': 'armor'}, {'frequency': 'c', 'synset': 'artichoke.n.02', 'synonyms': ['artichoke'], 'id': 22, 'def': 'a thistlelike flower head with edible fleshy leaves and heart', 'name': 'artichoke'}, {'frequency': 'f', 'synset': 'ashcan.n.01', 'synonyms': ['trash_can', 'garbage_can', 'wastebin', 'dustbin', 'trash_barrel', 'trash_bin'], 'id': 23, 'def': 'a bin that holds rubbish until it is collected', 'name': 'trash_can'}, {'frequency': 'c', 'synset': 'ashtray.n.01', 'synonyms': ['ashtray'], 'id': 24, 'def': "a receptacle for the ash from smokers' cigars or cigarettes", 'name': 'ashtray'}, {'frequency': 'c', 'synset': 'asparagus.n.02', 'synonyms': ['asparagus'], 'id': 25, 'def': 'edible young shoots of the asparagus plant', 'name': 'asparagus'}, {'frequency': 'c', 'synset': 'atomizer.n.01', 'synonyms': ['atomizer', 'atomiser', 'spray', 'sprayer', 'nebulizer', 'nebuliser'], 'id': 26, 'def': 'a dispenser that turns a liquid (such as perfume) into a fine mist', 'name': 'atomizer'}, {'frequency': 'f', 'synset': 'avocado.n.01', 'synonyms': ['avocado'], 'id': 27, 'def': 'a pear-shaped fruit with green or blackish skin and rich yellowish pulp enclosing a single large seed', 'name': 'avocado'}, {'frequency': 'c', 'synset': 'award.n.02', 'synonyms': ['award', 'accolade'], 'id': 28, 'def': 'a tangible symbol signifying approval or distinction', 'name': 'award'}, {'frequency': 'f', 'synset': 'awning.n.01', 'synonyms': ['awning'], 'id': 29, 'def': 'a canopy made of canvas to shelter people or things from rain or sun', 'name': 'awning'}, {'frequency': 'r', 'synset': 'ax.n.01', 'synonyms': ['ax', 'axe'], 'id': 30, 'def': 'an edge tool with a heavy bladed head mounted across a handle', 'name': 'ax'}, {'frequency': 'r', 'synset': 'baboon.n.01', 'synonyms': ['baboon'], 'id': 31, 'def': 'large terrestrial monkeys having doglike muzzles', 'name': 'baboon'}, {'frequency': 'f', 'synset': 'baby_buggy.n.01', 'synonyms': ['baby_buggy', 'baby_carriage', 'perambulator', 'pram', 'stroller'], 'id': 32, 'def': 'a small vehicle with four wheels in which a baby or child is pushed around', 'name': 'baby_buggy'}, {'frequency': 'c', 'synset': 'backboard.n.01', 'synonyms': ['basketball_backboard'], 'id': 33, 'def': 'a raised vertical board with basket attached; used to play basketball', 'name': 'basketball_backboard'}, {'frequency': 'f', 'synset': 'backpack.n.01', 'synonyms': ['backpack', 'knapsack', 'packsack', 'rucksack', 'haversack'], 'id': 34, 'def': 'a bag carried by a strap on your back or shoulder', 'name': 'backpack'}, {'frequency': 'f', 'synset': 'bag.n.04', 'synonyms': ['handbag', 'purse', 'pocketbook'], 'id': 35, 'def': 'a container used for carrying money and small personal items or accessories', 'name': 'handbag'}, {'frequency': 'f', 'synset': 'bag.n.06', 'synonyms': ['suitcase', 'baggage', 'luggage'], 'id': 36, 'def': 'cases used to carry belongings when traveling', 'name': 'suitcase'}, {'frequency': 'c', 'synset': 'bagel.n.01', 'synonyms': ['bagel', 'beigel'], 'id': 37, 'def': 'glazed yeast-raised doughnut-shaped roll with hard crust', 'name': 'bagel'}, {'frequency': 'r', 'synset': 'bagpipe.n.01', 'synonyms': ['bagpipe'], 'id': 38, 'def': 'a tubular wind instrument; the player blows air into a bag and squeezes it out', 'name': 'bagpipe'}, {'frequency': 'r', 'synset': 'baguet.n.01', 'synonyms': ['baguet', 'baguette'], 'id': 39, 'def': 'narrow French stick loaf', 'name': 'baguet'}, {'frequency': 'r', 'synset': 'bait.n.02', 'synonyms': ['bait', 'lure'], 'id': 40, 'def': 'something used to lure fish or other animals into danger so they can be trapped or killed', 'name': 'bait'}, {'frequency': 'f', 'synset': 'ball.n.06', 'synonyms': ['ball'], 'id': 41, 'def': 'a spherical object used as a plaything', 'name': 'ball'}, {'frequency': 'r', 'synset': 'ballet_skirt.n.01', 'synonyms': ['ballet_skirt', 'tutu'], 'id': 42, 'def': 'very short skirt worn by ballerinas', 'name': 'ballet_skirt'}, {'frequency': 'f', 'synset': 'balloon.n.01', 'synonyms': ['balloon'], 'id': 43, 'def': 'large tough nonrigid bag filled with gas or heated air', 'name': 'balloon'}, {'frequency': 'c', 'synset': 'bamboo.n.02', 'synonyms': ['bamboo'], 'id': 44, 'def': 'woody tropical grass having hollow woody stems', 'name': 'bamboo'}, {'frequency': 'f', 'synset': 'banana.n.02', 'synonyms': ['banana'], 'id': 45, 'def': 'elongated crescent-shaped yellow fruit with soft sweet flesh', 'name': 'banana'}, {'frequency': 'c', 'synset': 'band_aid.n.01', 'synonyms': ['Band_Aid'], 'id': 46, 'def': 'trade name for an adhesive bandage to cover small cuts or blisters', 'name': 'Band_Aid'}, {'frequency': 'c', 'synset': 'bandage.n.01', 'synonyms': ['bandage'], 'id': 47, 'def': 'a piece of soft material that covers and protects an injured part of the body', 'name': 'bandage'}, {'frequency': 'f', 'synset': 'bandanna.n.01', 'synonyms': ['bandanna', 'bandana'], 'id': 48, 'def': 'large and brightly colored handkerchief; often used as a neckerchief', 'name': 'bandanna'}, {'frequency': 'r', 'synset': 'banjo.n.01', 'synonyms': ['banjo'], 'id': 49, 'def': 'a stringed instrument of the guitar family with a long neck and circular body', 'name': 'banjo'}, {'frequency': 'f', 'synset': 'banner.n.01', 'synonyms': ['banner', 'streamer'], 'id': 50, 'def': 'long strip of cloth or paper used for decoration or advertising', 'name': 'banner'}, {'frequency': 'r', 'synset': 'barbell.n.01', 'synonyms': ['barbell'], 'id': 51, 'def': 'a bar to which heavy discs are attached at each end; used in weightlifting', 'name': 'barbell'}, {'frequency': 'r', 'synset': 'barge.n.01', 'synonyms': ['barge'], 'id': 52, 'def': 'a flatbottom boat for carrying heavy loads (especially on canals)', 'name': 'barge'}, {'frequency': 'f', 'synset': 'barrel.n.02', 'synonyms': ['barrel', 'cask'], 'id': 53, 'def': 'a cylindrical container that holds liquids', 'name': 'barrel'}, {'frequency': 'c', 'synset': 'barrette.n.01', 'synonyms': ['barrette'], 'id': 54, 'def': "a pin for holding women's hair in place", 'name': 'barrette'}, {'frequency': 'c', 'synset': 'barrow.n.03', 'synonyms': ['barrow', 'garden_cart', 'lawn_cart', 'wheelbarrow'], 'id': 55, 'def': 'a cart for carrying small loads; has handles and one or more wheels', 'name': 'barrow'}, {'frequency': 'f', 'synset': 'base.n.03', 'synonyms': ['baseball_base'], 'id': 56, 'def': 'a place that the runner must touch before scoring', 'name': 'baseball_base'}, {'frequency': 'f', 'synset': 'baseball.n.02', 'synonyms': ['baseball'], 'id': 57, 'def': 'a ball used in playing baseball', 'name': 'baseball'}, {'frequency': 'f', 'synset': 'baseball_bat.n.01', 'synonyms': ['baseball_bat'], 'id': 58, 'def': 'an implement used in baseball by the batter', 'name': 'baseball_bat'}, {'frequency': 'f', 'synset': 'baseball_cap.n.01', 'synonyms': ['baseball_cap', 'jockey_cap', 'golf_cap'], 'id': 59, 'def': 'a cap with a bill', 'name': 'baseball_cap'}, {'frequency': 'f', 'synset': 'baseball_glove.n.01', 'synonyms': ['baseball_glove', 'baseball_mitt'], 'id': 60, 'def': 'the handwear used by fielders in playing baseball', 'name': 'baseball_glove'}, {'frequency': 'f', 'synset': 'basket.n.01', 'synonyms': ['basket', 'handbasket'], 'id': 61, 'def': 'a container that is usually woven and has handles', 'name': 'basket'}, {'frequency': 'c', 'synset': 'basketball.n.02', 'synonyms': ['basketball'], 'id': 62, 'def': 'an inflated ball used in playing basketball', 'name': 'basketball'}, {'frequency': 'r', 'synset': 'bass_horn.n.01', 'synonyms': ['bass_horn', 'sousaphone', 'tuba'], 'id': 63, 'def': 'the lowest brass wind instrument', 'name': 'bass_horn'}, {'frequency': 'c', 'synset': 'bat.n.01', 'synonyms': ['bat_(animal)'], 'id': 64, 'def': 'nocturnal mouselike mammal with forelimbs modified to form membranous wings', 'name': 'bat_(animal)'}, {'frequency': 'f', 'synset': 'bath_mat.n.01', 'synonyms': ['bath_mat'], 'id': 65, 'def': 'a heavy towel or mat to stand on while drying yourself after a bath', 'name': 'bath_mat'}, {'frequency': 'f', 'synset': 'bath_towel.n.01', 'synonyms': ['bath_towel'], 'id': 66, 'def': 'a large towel; to dry yourself after a bath', 'name': 'bath_towel'}, {'frequency': 'c', 'synset': 'bathrobe.n.01', 'synonyms': ['bathrobe'], 'id': 67, 'def': 'a loose-fitting robe of towelling; worn after a bath or swim', 'name': 'bathrobe'}, {'frequency': 'f', 'synset': 'bathtub.n.01', 'synonyms': ['bathtub', 'bathing_tub'], 'id': 68, 'def': 'a large open container that you fill with water and use to wash the body', 'name': 'bathtub'}, {'frequency': 'r', 'synset': 'batter.n.02', 'synonyms': ['batter_(food)'], 'id': 69, 'def': 'a liquid or semiliquid mixture, as of flour, eggs, and milk, used in cooking', 'name': 'batter_(food)'}, {'frequency': 'c', 'synset': 'battery.n.02', 'synonyms': ['battery'], 'id': 70, 'def': 'a portable device that produces electricity', 'name': 'battery'}, {'frequency': 'r', 'synset': 'beach_ball.n.01', 'synonyms': ['beachball'], 'id': 71, 'def': 'large and light ball; for play at the seaside', 'name': 'beachball'}, {'frequency': 'c', 'synset': 'bead.n.01', 'synonyms': ['bead'], 'id': 72, 'def': 'a small ball with a hole through the middle used for ornamentation, jewellery, etc.', 'name': 'bead'}, {'frequency': 'c', 'synset': 'bean_curd.n.01', 'synonyms': ['bean_curd', 'tofu'], 'id': 73, 'def': 'cheeselike food made of curdled soybean milk', 'name': 'bean_curd'}, {'frequency': 'c', 'synset': 'beanbag.n.01', 'synonyms': ['beanbag'], 'id': 74, 'def': 'a bag filled with dried beans or similar items; used in games or to sit on', 'name': 'beanbag'}, {'frequency': 'f', 'synset': 'beanie.n.01', 'synonyms': ['beanie', 'beany'], 'id': 75, 'def': 'a small skullcap; formerly worn by schoolboys and college freshmen', 'name': 'beanie'}, {'frequency': 'f', 'synset': 'bear.n.01', 'synonyms': ['bear'], 'id': 76, 'def': 'large carnivorous or omnivorous mammals with shaggy coats and claws', 'name': 'bear'}, {'frequency': 'f', 'synset': 'bed.n.01', 'synonyms': ['bed'], 'id': 77, 'def': 'a piece of furniture that provides a place to sleep', 'name': 'bed'}, {'frequency': 'r', 'synset': 'bedpan.n.01', 'synonyms': ['bedpan'], 'id': 78, 'def': 'a shallow vessel used by a bedridden patient for defecation and urination', 'name': 'bedpan'}, {'frequency': 'f', 'synset': 'bedspread.n.01', 'synonyms': ['bedspread', 'bedcover', 'bed_covering', 'counterpane', 'spread'], 'id': 79, 'def': 'decorative cover for a bed', 'name': 'bedspread'}, {'frequency': 'f', 'synset': 'beef.n.01', 'synonyms': ['cow'], 'id': 80, 'def': 'cattle/cow', 'name': 'cow'}, {'frequency': 'f', 'synset': 'beef.n.02', 'synonyms': ['beef_(food)', 'boeuf_(food)'], 'id': 81, 'def': 'meat from an adult domestic bovine', 'name': 'beef_(food)'}, {'frequency': 'r', 'synset': 'beeper.n.01', 'synonyms': ['beeper', 'pager'], 'id': 82, 'def': 'an device that beeps when the person carrying it is being paged', 'name': 'beeper'}, {'frequency': 'f', 'synset': 'beer_bottle.n.01', 'synonyms': ['beer_bottle'], 'id': 83, 'def': 'a bottle that holds beer', 'name': 'beer_bottle'}, {'frequency': 'c', 'synset': 'beer_can.n.01', 'synonyms': ['beer_can'], 'id': 84, 'def': 'a can that holds beer', 'name': 'beer_can'}, {'frequency': 'r', 'synset': 'beetle.n.01', 'synonyms': ['beetle'], 'id': 85, 'def': 'insect with hard wing covers', 'name': 'beetle'}, {'frequency': 'f', 'synset': 'bell.n.01', 'synonyms': ['bell'], 'id': 86, 'def': 'a hollow device made of metal that makes a ringing sound when struck', 'name': 'bell'}, {'frequency': 'f', 'synset': 'bell_pepper.n.02', 'synonyms': ['bell_pepper', 'capsicum'], 'id': 87, 'def': 'large bell-shaped sweet pepper in green or red or yellow or orange or black varieties', 'name': 'bell_pepper'}, {'frequency': 'f', 'synset': 'belt.n.02', 'synonyms': ['belt'], 'id': 88, 'def': 'a band to tie or buckle around the body (usually at the waist)', 'name': 'belt'}, {'frequency': 'f', 'synset': 'belt_buckle.n.01', 'synonyms': ['belt_buckle'], 'id': 89, 'def': 'the buckle used to fasten a belt', 'name': 'belt_buckle'}, {'frequency': 'f', 'synset': 'bench.n.01', 'synonyms': ['bench'], 'id': 90, 'def': 'a long seat for more than one person', 'name': 'bench'}, {'frequency': 'c', 'synset': 'beret.n.01', 'synonyms': ['beret'], 'id': 91, 'def': 'a cap with no brim or bill; made of soft cloth', 'name': 'beret'}, {'frequency': 'c', 'synset': 'bib.n.02', 'synonyms': ['bib'], 'id': 92, 'def': 'a napkin tied under the chin of a child while eating', 'name': 'bib'}, {'frequency': 'r', 'synset': 'bible.n.01', 'synonyms': ['Bible'], 'id': 93, 'def': 'the sacred writings of the Christian religions', 'name': 'Bible'}, {'frequency': 'f', 'synset': 'bicycle.n.01', 'synonyms': ['bicycle', 'bike_(bicycle)'], 'id': 94, 'def': 'a wheeled vehicle that has two wheels and is moved by foot pedals', 'name': 'bicycle'}, {'frequency': 'f', 'synset': 'bill.n.09', 'synonyms': ['visor', 'vizor'], 'id': 95, 'def': 'a brim that projects to the front to shade the eyes', 'name': 'visor'}, {'frequency': 'f', 'synset': 'billboard.n.01', 'synonyms': ['billboard'], 'id': 96, 'def': 'large outdoor signboard', 'name': 'billboard'}, {'frequency': 'c', 'synset': 'binder.n.03', 'synonyms': ['binder', 'ring-binder'], 'id': 97, 'def': 'holds loose papers or magazines', 'name': 'binder'}, {'frequency': 'c', 'synset': 'binoculars.n.01', 'synonyms': ['binoculars', 'field_glasses', 'opera_glasses'], 'id': 98, 'def': 'an optical instrument designed for simultaneous use by both eyes', 'name': 'binoculars'}, {'frequency': 'f', 'synset': 'bird.n.01', 'synonyms': ['bird'], 'id': 99, 'def': 'animal characterized by feathers and wings', 'name': 'bird'}, {'frequency': 'c', 'synset': 'bird_feeder.n.01', 'synonyms': ['birdfeeder'], 'id': 100, 'def': 'an outdoor device that supplies food for wild birds', 'name': 'birdfeeder'}, {'frequency': 'c', 'synset': 'birdbath.n.01', 'synonyms': ['birdbath'], 'id': 101, 'def': 'an ornamental basin (usually in a garden) for birds to bathe in', 'name': 'birdbath'}, {'frequency': 'c', 'synset': 'birdcage.n.01', 'synonyms': ['birdcage'], 'id': 102, 'def': 'a cage in which a bird can be kept', 'name': 'birdcage'}, {'frequency': 'c', 'synset': 'birdhouse.n.01', 'synonyms': ['birdhouse'], 'id': 103, 'def': 'a shelter for birds', 'name': 'birdhouse'}, {'frequency': 'f', 'synset': 'birthday_cake.n.01', 'synonyms': ['birthday_cake'], 'id': 104, 'def': 'decorated cake served at a birthday party', 'name': 'birthday_cake'}, {'frequency': 'r', 'synset': 'birthday_card.n.01', 'synonyms': ['birthday_card'], 'id': 105, 'def': 'a card expressing a birthday greeting', 'name': 'birthday_card'}, {'frequency': 'r', 'synset': 'black_flag.n.01', 'synonyms': ['pirate_flag'], 'id': 106, 'def': 'a flag usually bearing a white skull and crossbones on a black background', 'name': 'pirate_flag'}, {'frequency': 'c', 'synset': 'black_sheep.n.02', 'synonyms': ['black_sheep'], 'id': 107, 'def': 'sheep with a black coat', 'name': 'black_sheep'}, {'frequency': 'c', 'synset': 'blackberry.n.01', 'synonyms': ['blackberry'], 'id': 108, 'def': 'large sweet black or very dark purple edible aggregate fruit', 'name': 'blackberry'}, {'frequency': 'f', 'synset': 'blackboard.n.01', 'synonyms': ['blackboard', 'chalkboard'], 'id': 109, 'def': 'sheet of slate; for writing with chalk', 'name': 'blackboard'}, {'frequency': 'f', 'synset': 'blanket.n.01', 'synonyms': ['blanket'], 'id': 110, 'def': 'bedding that keeps a person warm in bed', 'name': 'blanket'}, {'frequency': 'c', 'synset': 'blazer.n.01', 'synonyms': ['blazer', 'sport_jacket', 'sport_coat', 'sports_jacket', 'sports_coat'], 'id': 111, 'def': 'lightweight jacket; often striped in the colors of a club or school', 'name': 'blazer'}, {'frequency': 'f', 'synset': 'blender.n.01', 'synonyms': ['blender', 'liquidizer', 'liquidiser'], 'id': 112, 'def': 'an electrically powered mixer that mix or chop or liquefy foods', 'name': 'blender'}, {'frequency': 'r', 'synset': 'blimp.n.02', 'synonyms': ['blimp'], 'id': 113, 'def': 'a small nonrigid airship used for observation or as a barrage balloon', 'name': 'blimp'}, {'frequency': 'f', 'synset': 'blinker.n.01', 'synonyms': ['blinker', 'flasher'], 'id': 114, 'def': 'a light that flashes on and off; used as a signal or to send messages', 'name': 'blinker'}, {'frequency': 'f', 'synset': 'blouse.n.01', 'synonyms': ['blouse'], 'id': 115, 'def': 'a top worn by women', 'name': 'blouse'}, {'frequency': 'f', 'synset': 'blueberry.n.02', 'synonyms': ['blueberry'], 'id': 116, 'def': 'sweet edible dark-blue berries of blueberry plants', 'name': 'blueberry'}, {'frequency': 'r', 'synset': 'board.n.09', 'synonyms': ['gameboard'], 'id': 117, 'def': 'a flat portable surface (usually rectangular) designed for board games', 'name': 'gameboard'}, {'frequency': 'f', 'synset': 'boat.n.01', 'synonyms': ['boat', 'ship_(boat)'], 'id': 118, 'def': 'a vessel for travel on water', 'name': 'boat'}, {'frequency': 'r', 'synset': 'bob.n.05', 'synonyms': ['bob', 'bobber', 'bobfloat'], 'id': 119, 'def': 'a small float usually made of cork; attached to a fishing line', 'name': 'bob'}, {'frequency': 'c', 'synset': 'bobbin.n.01', 'synonyms': ['bobbin', 'spool', 'reel'], 'id': 120, 'def': 'a thing around which thread/tape/film or other flexible materials can be wound', 'name': 'bobbin'}, {'frequency': 'c', 'synset': 'bobby_pin.n.01', 'synonyms': ['bobby_pin', 'hairgrip'], 'id': 121, 'def': 'a flat wire hairpin used to hold bobbed hair in place', 'name': 'bobby_pin'}, {'frequency': 'c', 'synset': 'boiled_egg.n.01', 'synonyms': ['boiled_egg', 'coddled_egg'], 'id': 122, 'def': 'egg cooked briefly in the shell in gently boiling water', 'name': 'boiled_egg'}, {'frequency': 'r', 'synset': 'bolo_tie.n.01', 'synonyms': ['bolo_tie', 'bolo', 'bola_tie', 'bola'], 'id': 123, 'def': 'a cord fastened around the neck with an ornamental clasp and worn as a necktie', 'name': 'bolo_tie'}, {'frequency': 'c', 'synset': 'bolt.n.03', 'synonyms': ['deadbolt'], 'id': 124, 'def': 'the part of a lock that is engaged or withdrawn with a key', 'name': 'deadbolt'}, {'frequency': 'f', 'synset': 'bolt.n.06', 'synonyms': ['bolt'], 'id': 125, 'def': 'a screw that screws into a nut to form a fastener', 'name': 'bolt'}, {'frequency': 'r', 'synset': 'bonnet.n.01', 'synonyms': ['bonnet'], 'id': 126, 'def': 'a hat tied under the chin', 'name': 'bonnet'}, {'frequency': 'f', 'synset': 'book.n.01', 'synonyms': ['book'], 'id': 127, 'def': 'a written work or composition that has been published', 'name': 'book'}, {'frequency': 'c', 'synset': 'bookcase.n.01', 'synonyms': ['bookcase'], 'id': 128, 'def': 'a piece of furniture with shelves for storing books', 'name': 'bookcase'}, {'frequency': 'c', 'synset': 'booklet.n.01', 'synonyms': ['booklet', 'brochure', 'leaflet', 'pamphlet'], 'id': 129, 'def': 'a small book usually having a paper cover', 'name': 'booklet'}, {'frequency': 'r', 'synset': 'bookmark.n.01', 'synonyms': ['bookmark', 'bookmarker'], 'id': 130, 'def': 'a marker (a piece of paper or ribbon) placed between the pages of a book', 'name': 'bookmark'}, {'frequency': 'r', 'synset': 'boom.n.04', 'synonyms': ['boom_microphone', 'microphone_boom'], 'id': 131, 'def': 'a pole carrying an overhead microphone projected over a film or tv set', 'name': 'boom_microphone'}, {'frequency': 'f', 'synset': 'boot.n.01', 'synonyms': ['boot'], 'id': 132, 'def': 'footwear that covers the whole foot and lower leg', 'name': 'boot'}, {'frequency': 'f', 'synset': 'bottle.n.01', 'synonyms': ['bottle'], 'id': 133, 'def': 'a glass or plastic vessel used for storing drinks or other liquids', 'name': 'bottle'}, {'frequency': 'c', 'synset': 'bottle_opener.n.01', 'synonyms': ['bottle_opener'], 'id': 134, 'def': 'an opener for removing caps or corks from bottles', 'name': 'bottle_opener'}, {'frequency': 'c', 'synset': 'bouquet.n.01', 'synonyms': ['bouquet'], 'id': 135, 'def': 'an arrangement of flowers that is usually given as a present', 'name': 'bouquet'}, {'frequency': 'r', 'synset': 'bow.n.04', 'synonyms': ['bow_(weapon)'], 'id': 136, 'def': 'a weapon for shooting arrows', 'name': 'bow_(weapon)'}, {'frequency': 'f', 'synset': 'bow.n.08', 'synonyms': ['bow_(decorative_ribbons)'], 'id': 137, 'def': 'a decorative interlacing of ribbons', 'name': 'bow_(decorative_ribbons)'}, {'frequency': 'f', 'synset': 'bow_tie.n.01', 'synonyms': ['bow-tie', 'bowtie'], 'id': 138, 'def': "a man's tie that ties in a bow", 'name': 'bow-tie'}, {'frequency': 'f', 'synset': 'bowl.n.03', 'synonyms': ['bowl'], 'id': 139, 'def': 'a dish that is round and open at the top for serving foods', 'name': 'bowl'}, {'frequency': 'r', 'synset': 'bowl.n.08', 'synonyms': ['pipe_bowl'], 'id': 140, 'def': 'a small round container that is open at the top for holding tobacco', 'name': 'pipe_bowl'}, {'frequency': 'c', 'synset': 'bowler_hat.n.01', 'synonyms': ['bowler_hat', 'bowler', 'derby_hat', 'derby', 'plug_hat'], 'id': 141, 'def': 'a felt hat that is round and hard with a narrow brim', 'name': 'bowler_hat'}, {'frequency': 'r', 'synset': 'bowling_ball.n.01', 'synonyms': ['bowling_ball'], 'id': 142, 'def': 'a large ball with finger holes used in the sport of bowling', 'name': 'bowling_ball'}, {'frequency': 'f', 'synset': 'box.n.01', 'synonyms': ['box'], 'id': 143, 'def': 'a (usually rectangular) container; may have a lid', 'name': 'box'}, {'frequency': 'r', 'synset': 'boxing_glove.n.01', 'synonyms': ['boxing_glove'], 'id': 144, 'def': 'large glove coverings the fists of a fighter worn for the sport of boxing', 'name': 'boxing_glove'}, {'frequency': 'c', 'synset': 'brace.n.06', 'synonyms': ['suspenders'], 'id': 145, 'def': 'elastic straps that hold trousers up (usually used in the plural)', 'name': 'suspenders'}, {'frequency': 'f', 'synset': 'bracelet.n.02', 'synonyms': ['bracelet', 'bangle'], 'id': 146, 'def': 'jewelry worn around the wrist for decoration', 'name': 'bracelet'}, {'frequency': 'r', 'synset': 'brass.n.07', 'synonyms': ['brass_plaque'], 'id': 147, 'def': 'a memorial made of brass', 'name': 'brass_plaque'}, {'frequency': 'c', 'synset': 'brassiere.n.01', 'synonyms': ['brassiere', 'bra', 'bandeau'], 'id': 148, 'def': 'an undergarment worn by women to support their breasts', 'name': 'brassiere'}, {'frequency': 'c', 'synset': 'bread-bin.n.01', 'synonyms': ['bread-bin', 'breadbox'], 'id': 149, 'def': 'a container used to keep bread or cake in', 'name': 'bread-bin'}, {'frequency': 'f', 'synset': 'bread.n.01', 'synonyms': ['bread'], 'id': 150, 'def': 'food made from dough of flour or meal and usually raised with yeast or baking powder and then baked', 'name': 'bread'}, {'frequency': 'r', 'synset': 'breechcloth.n.01', 'synonyms': ['breechcloth', 'breechclout', 'loincloth'], 'id': 151, 'def': 'a garment that provides covering for the loins', 'name': 'breechcloth'}, {'frequency': 'f', 'synset': 'bridal_gown.n.01', 'synonyms': ['bridal_gown', 'wedding_gown', 'wedding_dress'], 'id': 152, 'def': 'a gown worn by the bride at a wedding', 'name': 'bridal_gown'}, {'frequency': 'c', 'synset': 'briefcase.n.01', 'synonyms': ['briefcase'], 'id': 153, 'def': 'a case with a handle; for carrying papers or files or books', 'name': 'briefcase'}, {'frequency': 'f', 'synset': 'broccoli.n.01', 'synonyms': ['broccoli'], 'id': 154, 'def': 'plant with dense clusters of tight green flower buds', 'name': 'broccoli'}, {'frequency': 'r', 'synset': 'brooch.n.01', 'synonyms': ['broach'], 'id': 155, 'def': 'a decorative pin worn by women', 'name': 'broach'}, {'frequency': 'c', 'synset': 'broom.n.01', 'synonyms': ['broom'], 'id': 156, 'def': 'bundle of straws or twigs attached to a long handle; used for cleaning', 'name': 'broom'}, {'frequency': 'c', 'synset': 'brownie.n.03', 'synonyms': ['brownie'], 'id': 157, 'def': 'square or bar of very rich chocolate cake usually with nuts', 'name': 'brownie'}, {'frequency': 'c', 'synset': 'brussels_sprouts.n.01', 'synonyms': ['brussels_sprouts'], 'id': 158, 'def': 'the small edible cabbage-like buds growing along a stalk', 'name': 'brussels_sprouts'}, {'frequency': 'r', 'synset': 'bubble_gum.n.01', 'synonyms': ['bubble_gum'], 'id': 159, 'def': 'a kind of chewing gum that can be blown into bubbles', 'name': 'bubble_gum'}, {'frequency': 'f', 'synset': 'bucket.n.01', 'synonyms': ['bucket', 'pail'], 'id': 160, 'def': 'a roughly cylindrical vessel that is open at the top', 'name': 'bucket'}, {'frequency': 'r', 'synset': 'buggy.n.01', 'synonyms': ['horse_buggy'], 'id': 161, 'def': 'a small lightweight carriage; drawn by a single horse', 'name': 'horse_buggy'}, {'frequency': 'c', 'synset': 'bull.n.11', 'synonyms': ['horned_cow'], 'id': 162, 'def': 'a cow with horns', 'name': 'bull'}, {'frequency': 'c', 'synset': 'bulldog.n.01', 'synonyms': ['bulldog'], 'id': 163, 'def': 'a thickset short-haired dog with a large head and strong undershot lower jaw', 'name': 'bulldog'}, {'frequency': 'r', 'synset': 'bulldozer.n.01', 'synonyms': ['bulldozer', 'dozer'], 'id': 164, 'def': 'large powerful tractor; a large blade in front flattens areas of ground', 'name': 'bulldozer'}, {'frequency': 'c', 'synset': 'bullet_train.n.01', 'synonyms': ['bullet_train'], 'id': 165, 'def': 'a high-speed passenger train', 'name': 'bullet_train'}, {'frequency': 'c', 'synset': 'bulletin_board.n.02', 'synonyms': ['bulletin_board', 'notice_board'], 'id': 166, 'def': 'a board that hangs on a wall; displays announcements', 'name': 'bulletin_board'}, {'frequency': 'r', 'synset': 'bulletproof_vest.n.01', 'synonyms': ['bulletproof_vest'], 'id': 167, 'def': 'a vest capable of resisting the impact of a bullet', 'name': 'bulletproof_vest'}, {'frequency': 'c', 'synset': 'bullhorn.n.01', 'synonyms': ['bullhorn', 'megaphone'], 'id': 168, 'def': 'a portable loudspeaker with built-in microphone and amplifier', 'name': 'bullhorn'}, {'frequency': 'f', 'synset': 'bun.n.01', 'synonyms': ['bun', 'roll'], 'id': 169, 'def': 'small rounded bread either plain or sweet', 'name': 'bun'}, {'frequency': 'c', 'synset': 'bunk_bed.n.01', 'synonyms': ['bunk_bed'], 'id': 170, 'def': 'beds built one above the other', 'name': 'bunk_bed'}, {'frequency': 'f', 'synset': 'buoy.n.01', 'synonyms': ['buoy'], 'id': 171, 'def': 'a float attached by rope to the seabed to mark channels in a harbor or underwater hazards', 'name': 'buoy'}, {'frequency': 'r', 'synset': 'burrito.n.01', 'synonyms': ['burrito'], 'id': 172, 'def': 'a flour tortilla folded around a filling', 'name': 'burrito'}, {'frequency': 'f', 'synset': 'bus.n.01', 'synonyms': ['bus_(vehicle)', 'autobus', 'charabanc', 'double-decker', 'motorbus', 'motorcoach'], 'id': 173, 'def': 'a vehicle carrying many passengers; used for public transport', 'name': 'bus_(vehicle)'}, {'frequency': 'c', 'synset': 'business_card.n.01', 'synonyms': ['business_card'], 'id': 174, 'def': "a card on which are printed the person's name and business affiliation", 'name': 'business_card'}, {'frequency': 'f', 'synset': 'butter.n.01', 'synonyms': ['butter'], 'id': 175, 'def': 'an edible emulsion of fat globules made by churning milk or cream; for cooking and table use', 'name': 'butter'}, {'frequency': 'c', 'synset': 'butterfly.n.01', 'synonyms': ['butterfly'], 'id': 176, 'def': 'insect typically having a slender body with knobbed antennae and broad colorful wings', 'name': 'butterfly'}, {'frequency': 'f', 'synset': 'button.n.01', 'synonyms': ['button'], 'id': 177, 'def': 'a round fastener sewn to shirts and coats etc to fit through buttonholes', 'name': 'button'}, {'frequency': 'f', 'synset': 'cab.n.03', 'synonyms': ['cab_(taxi)', 'taxi', 'taxicab'], 'id': 178, 'def': 'a car that takes passengers where they want to go in exchange for money', 'name': 'cab_(taxi)'}, {'frequency': 'r', 'synset': 'cabana.n.01', 'synonyms': ['cabana'], 'id': 179, 'def': 'a small tent used as a dressing room beside the sea or a swimming pool', 'name': 'cabana'}, {'frequency': 'c', 'synset': 'cabin_car.n.01', 'synonyms': ['cabin_car', 'caboose'], 'id': 180, 'def': 'a car on a freight train for use of the train crew; usually the last car on the train', 'name': 'cabin_car'}, {'frequency': 'f', 'synset': 'cabinet.n.01', 'synonyms': ['cabinet'], 'id': 181, 'def': 'a piece of furniture resembling a cupboard with doors and shelves and drawers', 'name': 'cabinet'}, {'frequency': 'r', 'synset': 'cabinet.n.03', 'synonyms': ['locker', 'storage_locker'], 'id': 182, 'def': 'a storage compartment for clothes and valuables; usually it has a lock', 'name': 'locker'}, {'frequency': 'f', 'synset': 'cake.n.03', 'synonyms': ['cake'], 'id': 183, 'def': 'baked goods made from or based on a mixture of flour, sugar, eggs, and fat', 'name': 'cake'}, {'frequency': 'c', 'synset': 'calculator.n.02', 'synonyms': ['calculator'], 'id': 184, 'def': 'a small machine that is used for mathematical calculations', 'name': 'calculator'}, {'frequency': 'f', 'synset': 'calendar.n.02', 'synonyms': ['calendar'], 'id': 185, 'def': 'a list or register of events (appointments/social events/court cases, etc)', 'name': 'calendar'}, {'frequency': 'c', 'synset': 'calf.n.01', 'synonyms': ['calf'], 'id': 186, 'def': 'young of domestic cattle', 'name': 'calf'}, {'frequency': 'c', 'synset': 'camcorder.n.01', 'synonyms': ['camcorder'], 'id': 187, 'def': 'a portable television camera and videocassette recorder', 'name': 'camcorder'}, {'frequency': 'c', 'synset': 'camel.n.01', 'synonyms': ['camel'], 'id': 188, 'def': 'cud-chewing mammal used as a draft or saddle animal in desert regions', 'name': 'camel'}, {'frequency': 'f', 'synset': 'camera.n.01', 'synonyms': ['camera'], 'id': 189, 'def': 'equipment for taking photographs', 'name': 'camera'}, {'frequency': 'c', 'synset': 'camera_lens.n.01', 'synonyms': ['camera_lens'], 'id': 190, 'def': 'a lens that focuses the image in a camera', 'name': 'camera_lens'}, {'frequency': 'c', 'synset': 'camper.n.02', 'synonyms': ['camper_(vehicle)', 'camping_bus', 'motor_home'], 'id': 191, 'def': 'a recreational vehicle equipped for camping out while traveling', 'name': 'camper_(vehicle)'}, {'frequency': 'f', 'synset': 'can.n.01', 'synonyms': ['can', 'tin_can'], 'id': 192, 'def': 'airtight sealed metal container for food or drink or paint etc.', 'name': 'can'}, {'frequency': 'c', 'synset': 'can_opener.n.01', 'synonyms': ['can_opener', 'tin_opener'], 'id': 193, 'def': 'a device for cutting cans open', 'name': 'can_opener'}, {'frequency': 'f', 'synset': 'candle.n.01', 'synonyms': ['candle', 'candlestick'], 'id': 194, 'def': 'stick of wax with a wick in the middle', 'name': 'candle'}, {'frequency': 'f', 'synset': 'candlestick.n.01', 'synonyms': ['candle_holder'], 'id': 195, 'def': 'a holder with sockets for candles', 'name': 'candle_holder'}, {'frequency': 'r', 'synset': 'candy_bar.n.01', 'synonyms': ['candy_bar'], 'id': 196, 'def': 'a candy shaped as a bar', 'name': 'candy_bar'}, {'frequency': 'c', 'synset': 'candy_cane.n.01', 'synonyms': ['candy_cane'], 'id': 197, 'def': 'a hard candy in the shape of a rod (usually with stripes)', 'name': 'candy_cane'}, {'frequency': 'c', 'synset': 'cane.n.01', 'synonyms': ['walking_cane'], 'id': 198, 'def': 'a stick that people can lean on to help them walk', 'name': 'walking_cane'}, {'frequency': 'c', 'synset': 'canister.n.02', 'synonyms': ['canister', 'cannister'], 'id': 199, 'def': 'metal container for storing dry foods such as tea or flour', 'name': 'canister'}, {'frequency': 'c', 'synset': 'canoe.n.01', 'synonyms': ['canoe'], 'id': 200, 'def': 'small and light boat; pointed at both ends; propelled with a paddle', 'name': 'canoe'}, {'frequency': 'c', 'synset': 'cantaloup.n.02', 'synonyms': ['cantaloup', 'cantaloupe'], 'id': 201, 'def': 'the fruit of a cantaloup vine; small to medium-sized melon with yellowish flesh', 'name': 'cantaloup'}, {'frequency': 'r', 'synset': 'canteen.n.01', 'synonyms': ['canteen'], 'id': 202, 'def': 'a flask for carrying water; used by soldiers or travelers', 'name': 'canteen'}, {'frequency': 'f', 'synset': 'cap.n.01', 'synonyms': ['cap_(headwear)'], 'id': 203, 'def': 'a tight-fitting headwear', 'name': 'cap_(headwear)'}, {'frequency': 'f', 'synset': 'cap.n.02', 'synonyms': ['bottle_cap', 'cap_(container_lid)'], 'id': 204, 'def': 'a top (as for a bottle)', 'name': 'bottle_cap'}, {'frequency': 'c', 'synset': 'cape.n.02', 'synonyms': ['cape'], 'id': 205, 'def': 'a sleeveless garment like a cloak but shorter', 'name': 'cape'}, {'frequency': 'c', 'synset': 'cappuccino.n.01', 'synonyms': ['cappuccino', 'coffee_cappuccino'], 'id': 206, 'def': 'equal parts of espresso and steamed milk', 'name': 'cappuccino'}, {'frequency': 'f', 'synset': 'car.n.01', 'synonyms': ['car_(automobile)', 'auto_(automobile)', 'automobile'], 'id': 207, 'def': 'a motor vehicle with four wheels', 'name': 'car_(automobile)'}, {'frequency': 'f', 'synset': 'car.n.02', 'synonyms': ['railcar_(part_of_a_train)', 'railway_car_(part_of_a_train)', 'railroad_car_(part_of_a_train)'], 'id': 208, 'def': 'a wheeled vehicle adapted to the rails of railroad (mark each individual railcar separately)', 'name': 'railcar_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'car.n.04', 'synonyms': ['elevator_car'], 'id': 209, 'def': 'where passengers ride up and down', 'name': 'elevator_car'}, {'frequency': 'r', 'synset': 'car_battery.n.01', 'synonyms': ['car_battery', 'automobile_battery'], 'id': 210, 'def': 'a battery in a motor vehicle', 'name': 'car_battery'}, {'frequency': 'c', 'synset': 'card.n.02', 'synonyms': ['identity_card'], 'id': 211, 'def': 'a card certifying the identity of the bearer', 'name': 'identity_card'}, {'frequency': 'c', 'synset': 'card.n.03', 'synonyms': ['card'], 'id': 212, 'def': 'a rectangular piece of paper used to send messages (e.g. greetings or pictures)', 'name': 'card'}, {'frequency': 'c', 'synset': 'cardigan.n.01', 'synonyms': ['cardigan'], 'id': 213, 'def': 'knitted jacket that is fastened up the front with buttons or a zipper', 'name': 'cardigan'}, {'frequency': 'r', 'synset': 'cargo_ship.n.01', 'synonyms': ['cargo_ship', 'cargo_vessel'], 'id': 214, 'def': 'a ship designed to carry cargo', 'name': 'cargo_ship'}, {'frequency': 'r', 'synset': 'carnation.n.01', 'synonyms': ['carnation'], 'id': 215, 'def': 'plant with pink to purple-red spice-scented usually double flowers', 'name': 'carnation'}, {'frequency': 'c', 'synset': 'carriage.n.02', 'synonyms': ['horse_carriage'], 'id': 216, 'def': 'a vehicle with wheels drawn by one or more horses', 'name': 'horse_carriage'}, {'frequency': 'f', 'synset': 'carrot.n.01', 'synonyms': ['carrot'], 'id': 217, 'def': 'deep orange edible root of the cultivated carrot plant', 'name': 'carrot'}, {'frequency': 'f', 'synset': 'carryall.n.01', 'synonyms': ['tote_bag'], 'id': 218, 'def': 'a capacious bag or basket', 'name': 'tote_bag'}, {'frequency': 'c', 'synset': 'cart.n.01', 'synonyms': ['cart'], 'id': 219, 'def': 'a heavy open wagon usually having two wheels and drawn by an animal', 'name': 'cart'}, {'frequency': 'c', 'synset': 'carton.n.02', 'synonyms': ['carton'], 'id': 220, 'def': 'a container made of cardboard for holding food or drink', 'name': 'carton'}, {'frequency': 'c', 'synset': 'cash_register.n.01', 'synonyms': ['cash_register', 'register_(for_cash_transactions)'], 'id': 221, 'def': 'a cashbox with an adding machine to register transactions', 'name': 'cash_register'}, {'frequency': 'r', 'synset': 'casserole.n.01', 'synonyms': ['casserole'], 'id': 222, 'def': 'food cooked and served in a casserole', 'name': 'casserole'}, {'frequency': 'r', 'synset': 'cassette.n.01', 'synonyms': ['cassette'], 'id': 223, 'def': 'a container that holds a magnetic tape used for recording or playing sound or video', 'name': 'cassette'}, {'frequency': 'c', 'synset': 'cast.n.05', 'synonyms': ['cast', 'plaster_cast', 'plaster_bandage'], 'id': 224, 'def': 'bandage consisting of a firm covering that immobilizes broken bones while they heal', 'name': 'cast'}, {'frequency': 'f', 'synset': 'cat.n.01', 'synonyms': ['cat'], 'id': 225, 'def': 'a domestic house cat', 'name': 'cat'}, {'frequency': 'f', 'synset': 'cauliflower.n.02', 'synonyms': ['cauliflower'], 'id': 226, 'def': 'edible compact head of white undeveloped flowers', 'name': 'cauliflower'}, {'frequency': 'c', 'synset': 'cayenne.n.02', 'synonyms': ['cayenne_(spice)', 'cayenne_pepper_(spice)', 'red_pepper_(spice)'], 'id': 227, 'def': 'ground pods and seeds of pungent red peppers of the genus Capsicum', 'name': 'cayenne_(spice)'}, {'frequency': 'c', 'synset': 'cd_player.n.01', 'synonyms': ['CD_player'], 'id': 228, 'def': 'electronic equipment for playing compact discs (CDs)', 'name': 'CD_player'}, {'frequency': 'f', 'synset': 'celery.n.01', 'synonyms': ['celery'], 'id': 229, 'def': 'widely cultivated herb with aromatic leaf stalks that are eaten raw or cooked', 'name': 'celery'}, {'frequency': 'f', 'synset': 'cellular_telephone.n.01', 'synonyms': ['cellular_telephone', 'cellular_phone', 'cellphone', 'mobile_phone', 'smart_phone'], 'id': 230, 'def': 'a hand-held mobile telephone', 'name': 'cellular_telephone'}, {'frequency': 'r', 'synset': 'chain_mail.n.01', 'synonyms': ['chain_mail', 'ring_mail', 'chain_armor', 'chain_armour', 'ring_armor', 'ring_armour'], 'id': 231, 'def': '(Middle Ages) flexible armor made of interlinked metal rings', 'name': 'chain_mail'}, {'frequency': 'f', 'synset': 'chair.n.01', 'synonyms': ['chair'], 'id': 232, 'def': 'a seat for one person, with a support for the back', 'name': 'chair'}, {'frequency': 'r', 'synset': 'chaise_longue.n.01', 'synonyms': ['chaise_longue', 'chaise', 'daybed'], 'id': 233, 'def': 'a long chair; for reclining', 'name': 'chaise_longue'}, {'frequency': 'r', 'synset': 'chalice.n.01', 'synonyms': ['chalice'], 'id': 234, 'def': 'a bowl-shaped drinking vessel; especially the Eucharistic cup', 'name': 'chalice'}, {'frequency': 'f', 'synset': 'chandelier.n.01', 'synonyms': ['chandelier'], 'id': 235, 'def': 'branched lighting fixture; often ornate; hangs from the ceiling', 'name': 'chandelier'}, {'frequency': 'r', 'synset': 'chap.n.04', 'synonyms': ['chap'], 'id': 236, 'def': 'leather leggings without a seat; worn over trousers by cowboys to protect their legs', 'name': 'chap'}, {'frequency': 'r', 'synset': 'checkbook.n.01', 'synonyms': ['checkbook', 'chequebook'], 'id': 237, 'def': 'a book issued to holders of checking accounts', 'name': 'checkbook'}, {'frequency': 'r', 'synset': 'checkerboard.n.01', 'synonyms': ['checkerboard'], 'id': 238, 'def': 'a board having 64 squares of two alternating colors', 'name': 'checkerboard'}, {'frequency': 'c', 'synset': 'cherry.n.03', 'synonyms': ['cherry'], 'id': 239, 'def': 'a red fruit with a single hard stone', 'name': 'cherry'}, {'frequency': 'r', 'synset': 'chessboard.n.01', 'synonyms': ['chessboard'], 'id': 240, 'def': 'a checkerboard used to play chess', 'name': 'chessboard'}, {'frequency': 'c', 'synset': 'chicken.n.02', 'synonyms': ['chicken_(animal)'], 'id': 241, 'def': 'a domestic fowl bred for flesh or eggs', 'name': 'chicken_(animal)'}, {'frequency': 'c', 'synset': 'chickpea.n.01', 'synonyms': ['chickpea', 'garbanzo'], 'id': 242, 'def': 'the seed of the chickpea plant; usually dried', 'name': 'chickpea'}, {'frequency': 'c', 'synset': 'chili.n.02', 'synonyms': ['chili_(vegetable)', 'chili_pepper_(vegetable)', 'chilli_(vegetable)', 'chilly_(vegetable)', 'chile_(vegetable)'], 'id': 243, 'def': 'very hot and finely tapering pepper of special pungency', 'name': 'chili_(vegetable)'}, {'frequency': 'r', 'synset': 'chime.n.01', 'synonyms': ['chime', 'gong'], 'id': 244, 'def': 'an instrument consisting of a set of bells that are struck with a hammer', 'name': 'chime'}, {'frequency': 'r', 'synset': 'chinaware.n.01', 'synonyms': ['chinaware'], 'id': 245, 'def': 'dishware made of high quality porcelain', 'name': 'chinaware'}, {'frequency': 'c', 'synset': 'chip.n.04', 'synonyms': ['crisp_(potato_chip)', 'potato_chip'], 'id': 246, 'def': 'a thin crisp slice of potato fried in deep fat', 'name': 'crisp_(potato_chip)'}, {'frequency': 'r', 'synset': 'chip.n.06', 'synonyms': ['poker_chip'], 'id': 247, 'def': 'a small disk-shaped counter used to represent money when gambling', 'name': 'poker_chip'}, {'frequency': 'c', 'synset': 'chocolate_bar.n.01', 'synonyms': ['chocolate_bar'], 'id': 248, 'def': 'a bar of chocolate candy', 'name': 'chocolate_bar'}, {'frequency': 'c', 'synset': 'chocolate_cake.n.01', 'synonyms': ['chocolate_cake'], 'id': 249, 'def': 'cake containing chocolate', 'name': 'chocolate_cake'}, {'frequency': 'r', 'synset': 'chocolate_milk.n.01', 'synonyms': ['chocolate_milk'], 'id': 250, 'def': 'milk flavored with chocolate syrup', 'name': 'chocolate_milk'}, {'frequency': 'r', 'synset': 'chocolate_mousse.n.01', 'synonyms': ['chocolate_mousse'], 'id': 251, 'def': 'dessert mousse made with chocolate', 'name': 'chocolate_mousse'}, {'frequency': 'f', 'synset': 'choker.n.03', 'synonyms': ['choker', 'collar', 'neckband'], 'id': 252, 'def': 'shirt collar, animal collar, or tight-fitting necklace', 'name': 'choker'}, {'frequency': 'f', 'synset': 'chopping_board.n.01', 'synonyms': ['chopping_board', 'cutting_board', 'chopping_block'], 'id': 253, 'def': 'a wooden board where meats or vegetables can be cut', 'name': 'chopping_board'}, {'frequency': 'f', 'synset': 'chopstick.n.01', 'synonyms': ['chopstick'], 'id': 254, 'def': 'one of a pair of slender sticks used as oriental tableware to eat food with', 'name': 'chopstick'}, {'frequency': 'f', 'synset': 'christmas_tree.n.05', 'synonyms': ['Christmas_tree'], 'id': 255, 'def': 'an ornamented evergreen used as a Christmas decoration', 'name': 'Christmas_tree'}, {'frequency': 'c', 'synset': 'chute.n.02', 'synonyms': ['slide'], 'id': 256, 'def': 'sloping channel through which things can descend', 'name': 'slide'}, {'frequency': 'r', 'synset': 'cider.n.01', 'synonyms': ['cider', 'cyder'], 'id': 257, 'def': 'a beverage made from juice pressed from apples', 'name': 'cider'}, {'frequency': 'r', 'synset': 'cigar_box.n.01', 'synonyms': ['cigar_box'], 'id': 258, 'def': 'a box for holding cigars', 'name': 'cigar_box'}, {'frequency': 'f', 'synset': 'cigarette.n.01', 'synonyms': ['cigarette'], 'id': 259, 'def': 'finely ground tobacco wrapped in paper; for smoking', 'name': 'cigarette'}, {'frequency': 'c', 'synset': 'cigarette_case.n.01', 'synonyms': ['cigarette_case', 'cigarette_pack'], 'id': 260, 'def': 'a small flat case for holding cigarettes', 'name': 'cigarette_case'}, {'frequency': 'f', 'synset': 'cistern.n.02', 'synonyms': ['cistern', 'water_tank'], 'id': 261, 'def': 'a tank that holds the water used to flush a toilet', 'name': 'cistern'}, {'frequency': 'r', 'synset': 'clarinet.n.01', 'synonyms': ['clarinet'], 'id': 262, 'def': 'a single-reed instrument with a straight tube', 'name': 'clarinet'}, {'frequency': 'c', 'synset': 'clasp.n.01', 'synonyms': ['clasp'], 'id': 263, 'def': 'a fastener (as a buckle or hook) that is used to hold two things together', 'name': 'clasp'}, {'frequency': 'c', 'synset': 'cleansing_agent.n.01', 'synonyms': ['cleansing_agent', 'cleanser', 'cleaner'], 'id': 264, 'def': 'a preparation used in cleaning something', 'name': 'cleansing_agent'}, {'frequency': 'r', 'synset': 'cleat.n.02', 'synonyms': ['cleat_(for_securing_rope)'], 'id': 265, 'def': 'a fastener (usually with two projecting horns) around which a rope can be secured', 'name': 'cleat_(for_securing_rope)'}, {'frequency': 'r', 'synset': 'clementine.n.01', 'synonyms': ['clementine'], 'id': 266, 'def': 'a variety of mandarin orange', 'name': 'clementine'}, {'frequency': 'c', 'synset': 'clip.n.03', 'synonyms': ['clip'], 'id': 267, 'def': 'any of various small fasteners used to hold loose articles together', 'name': 'clip'}, {'frequency': 'c', 'synset': 'clipboard.n.01', 'synonyms': ['clipboard'], 'id': 268, 'def': 'a small writing board with a clip at the top for holding papers', 'name': 'clipboard'}, {'frequency': 'r', 'synset': 'clipper.n.03', 'synonyms': ['clippers_(for_plants)'], 'id': 269, 'def': 'shears for cutting grass or shrubbery (often used in the plural)', 'name': 'clippers_(for_plants)'}, {'frequency': 'r', 'synset': 'cloak.n.02', 'synonyms': ['cloak'], 'id': 270, 'def': 'a loose outer garment', 'name': 'cloak'}, {'frequency': 'f', 'synset': 'clock.n.01', 'synonyms': ['clock', 'timepiece', 'timekeeper'], 'id': 271, 'def': 'a timepiece that shows the time of day', 'name': 'clock'}, {'frequency': 'f', 'synset': 'clock_tower.n.01', 'synonyms': ['clock_tower'], 'id': 272, 'def': 'a tower with a large clock visible high up on an outside face', 'name': 'clock_tower'}, {'frequency': 'c', 'synset': 'clothes_hamper.n.01', 'synonyms': ['clothes_hamper', 'laundry_basket', 'clothes_basket'], 'id': 273, 'def': 'a hamper that holds dirty clothes to be washed or wet clothes to be dried', 'name': 'clothes_hamper'}, {'frequency': 'c', 'synset': 'clothespin.n.01', 'synonyms': ['clothespin', 'clothes_peg'], 'id': 274, 'def': 'wood or plastic fastener; for holding clothes on a clothesline', 'name': 'clothespin'}, {'frequency': 'r', 'synset': 'clutch_bag.n.01', 'synonyms': ['clutch_bag'], 'id': 275, 'def': "a woman's strapless purse that is carried in the hand", 'name': 'clutch_bag'}, {'frequency': 'f', 'synset': 'coaster.n.03', 'synonyms': ['coaster'], 'id': 276, 'def': 'a covering (plate or mat) that protects the surface of a table', 'name': 'coaster'}, {'frequency': 'f', 'synset': 'coat.n.01', 'synonyms': ['coat'], 'id': 277, 'def': 'an outer garment that has sleeves and covers the body from shoulder down', 'name': 'coat'}, {'frequency': 'c', 'synset': 'coat_hanger.n.01', 'synonyms': ['coat_hanger', 'clothes_hanger', 'dress_hanger'], 'id': 278, 'def': "a hanger that is shaped like a person's shoulders", 'name': 'coat_hanger'}, {'frequency': 'c', 'synset': 'coatrack.n.01', 'synonyms': ['coatrack', 'hatrack'], 'id': 279, 'def': 'a rack with hooks for temporarily holding coats and hats', 'name': 'coatrack'}, {'frequency': 'c', 'synset': 'cock.n.04', 'synonyms': ['cock', 'rooster'], 'id': 280, 'def': 'adult male chicken', 'name': 'cock'}, {'frequency': 'r', 'synset': 'cockroach.n.01', 'synonyms': ['cockroach'], 'id': 281, 'def': 'any of numerous chiefly nocturnal insects; some are domestic pests', 'name': 'cockroach'}, {'frequency': 'r', 'synset': 'cocoa.n.01', 'synonyms': ['cocoa_(beverage)', 'hot_chocolate_(beverage)', 'drinking_chocolate'], 'id': 282, 'def': 'a beverage made from cocoa powder and milk and sugar; usually drunk hot', 'name': 'cocoa_(beverage)'}, {'frequency': 'c', 'synset': 'coconut.n.02', 'synonyms': ['coconut', 'cocoanut'], 'id': 283, 'def': 'large hard-shelled brown oval nut with a fibrous husk', 'name': 'coconut'}, {'frequency': 'f', 'synset': 'coffee_maker.n.01', 'synonyms': ['coffee_maker', 'coffee_machine'], 'id': 284, 'def': 'a kitchen appliance for brewing coffee automatically', 'name': 'coffee_maker'}, {'frequency': 'f', 'synset': 'coffee_table.n.01', 'synonyms': ['coffee_table', 'cocktail_table'], 'id': 285, 'def': 'low table where magazines can be placed and coffee or cocktails are served', 'name': 'coffee_table'}, {'frequency': 'c', 'synset': 'coffeepot.n.01', 'synonyms': ['coffeepot'], 'id': 286, 'def': 'tall pot in which coffee is brewed', 'name': 'coffeepot'}, {'frequency': 'r', 'synset': 'coil.n.05', 'synonyms': ['coil'], 'id': 287, 'def': 'tubing that is wound in a spiral', 'name': 'coil'}, {'frequency': 'c', 'synset': 'coin.n.01', 'synonyms': ['coin'], 'id': 288, 'def': 'a flat metal piece (usually a disc) used as money', 'name': 'coin'}, {'frequency': 'c', 'synset': 'colander.n.01', 'synonyms': ['colander', 'cullender'], 'id': 289, 'def': 'bowl-shaped strainer; used to wash or drain foods', 'name': 'colander'}, {'frequency': 'c', 'synset': 'coleslaw.n.01', 'synonyms': ['coleslaw', 'slaw'], 'id': 290, 'def': 'basically shredded cabbage', 'name': 'coleslaw'}, {'frequency': 'r', 'synset': 'coloring_material.n.01', 'synonyms': ['coloring_material', 'colouring_material'], 'id': 291, 'def': 'any material used for its color', 'name': 'coloring_material'}, {'frequency': 'r', 'synset': 'combination_lock.n.01', 'synonyms': ['combination_lock'], 'id': 292, 'def': 'lock that can be opened only by turning dials in a special sequence', 'name': 'combination_lock'}, {'frequency': 'c', 'synset': 'comforter.n.04', 'synonyms': ['pacifier', 'teething_ring'], 'id': 293, 'def': 'device used for an infant to suck or bite on', 'name': 'pacifier'}, {'frequency': 'r', 'synset': 'comic_book.n.01', 'synonyms': ['comic_book'], 'id': 294, 'def': 'a magazine devoted to comic strips', 'name': 'comic_book'}, {'frequency': 'r', 'synset': 'compass.n.01', 'synonyms': ['compass'], 'id': 295, 'def': 'navigational instrument for finding directions', 'name': 'compass'}, {'frequency': 'f', 'synset': 'computer_keyboard.n.01', 'synonyms': ['computer_keyboard', 'keyboard_(computer)'], 'id': 296, 'def': 'a keyboard that is a data input device for computers', 'name': 'computer_keyboard'}, {'frequency': 'f', 'synset': 'condiment.n.01', 'synonyms': ['condiment'], 'id': 297, 'def': 'a preparation (a sauce or relish or spice) to enhance flavor or enjoyment', 'name': 'condiment'}, {'frequency': 'f', 'synset': 'cone.n.01', 'synonyms': ['cone', 'traffic_cone'], 'id': 298, 'def': 'a cone-shaped object used to direct traffic', 'name': 'cone'}, {'frequency': 'f', 'synset': 'control.n.09', 'synonyms': ['control', 'controller'], 'id': 299, 'def': 'a mechanism that controls the operation of a machine', 'name': 'control'}, {'frequency': 'r', 'synset': 'convertible.n.01', 'synonyms': ['convertible_(automobile)'], 'id': 300, 'def': 'a car that has top that can be folded or removed', 'name': 'convertible_(automobile)'}, {'frequency': 'r', 'synset': 'convertible.n.03', 'synonyms': ['sofa_bed'], 'id': 301, 'def': 'a sofa that can be converted into a bed', 'name': 'sofa_bed'}, {'frequency': 'r', 'synset': 'cooker.n.01', 'synonyms': ['cooker'], 'id': 302, 'def': 'a utensil for cooking', 'name': 'cooker'}, {'frequency': 'f', 'synset': 'cookie.n.01', 'synonyms': ['cookie', 'cooky', 'biscuit_(cookie)'], 'id': 303, 'def': "any of various small flat sweet cakes (`biscuit' is the British term)", 'name': 'cookie'}, {'frequency': 'r', 'synset': 'cooking_utensil.n.01', 'synonyms': ['cooking_utensil'], 'id': 304, 'def': 'a kitchen utensil made of material that does not melt easily; used for cooking', 'name': 'cooking_utensil'}, {'frequency': 'f', 'synset': 'cooler.n.01', 'synonyms': ['cooler_(for_food)', 'ice_chest'], 'id': 305, 'def': 'an insulated box for storing food often with ice', 'name': 'cooler_(for_food)'}, {'frequency': 'f', 'synset': 'cork.n.04', 'synonyms': ['cork_(bottle_plug)', 'bottle_cork'], 'id': 306, 'def': 'the plug in the mouth of a bottle (especially a wine bottle)', 'name': 'cork_(bottle_plug)'}, {'frequency': 'r', 'synset': 'corkboard.n.01', 'synonyms': ['corkboard'], 'id': 307, 'def': 'a sheet consisting of cork granules', 'name': 'corkboard'}, {'frequency': 'c', 'synset': 'corkscrew.n.01', 'synonyms': ['corkscrew', 'bottle_screw'], 'id': 308, 'def': 'a bottle opener that pulls corks', 'name': 'corkscrew'}, {'frequency': 'f', 'synset': 'corn.n.03', 'synonyms': ['edible_corn', 'corn', 'maize'], 'id': 309, 'def': 'ears or kernels of corn that can be prepared and served for human food (only mark individual ears or kernels)', 'name': 'edible_corn'}, {'frequency': 'r', 'synset': 'cornbread.n.01', 'synonyms': ['cornbread'], 'id': 310, 'def': 'bread made primarily of cornmeal', 'name': 'cornbread'}, {'frequency': 'c', 'synset': 'cornet.n.01', 'synonyms': ['cornet', 'horn', 'trumpet'], 'id': 311, 'def': 'a brass musical instrument with a narrow tube and a flared bell and many valves', 'name': 'cornet'}, {'frequency': 'c', 'synset': 'cornice.n.01', 'synonyms': ['cornice', 'valance', 'valance_board', 'pelmet'], 'id': 312, 'def': 'a decorative framework to conceal curtain fixtures at the top of a window casing', 'name': 'cornice'}, {'frequency': 'r', 'synset': 'cornmeal.n.01', 'synonyms': ['cornmeal'], 'id': 313, 'def': 'coarsely ground corn', 'name': 'cornmeal'}, {'frequency': 'c', 'synset': 'corset.n.01', 'synonyms': ['corset', 'girdle'], 'id': 314, 'def': "a woman's close-fitting foundation garment", 'name': 'corset'}, {'frequency': 'c', 'synset': 'costume.n.04', 'synonyms': ['costume'], 'id': 315, 'def': 'the attire characteristic of a country or a time or a social class', 'name': 'costume'}, {'frequency': 'r', 'synset': 'cougar.n.01', 'synonyms': ['cougar', 'puma', 'catamount', 'mountain_lion', 'panther'], 'id': 316, 'def': 'large American feline resembling a lion', 'name': 'cougar'}, {'frequency': 'r', 'synset': 'coverall.n.01', 'synonyms': ['coverall'], 'id': 317, 'def': 'a loose-fitting protective garment that is worn over other clothing', 'name': 'coverall'}, {'frequency': 'c', 'synset': 'cowbell.n.01', 'synonyms': ['cowbell'], 'id': 318, 'def': 'a bell hung around the neck of cow so that the cow can be easily located', 'name': 'cowbell'}, {'frequency': 'f', 'synset': 'cowboy_hat.n.01', 'synonyms': ['cowboy_hat', 'ten-gallon_hat'], 'id': 319, 'def': 'a hat with a wide brim and a soft crown; worn by American ranch hands', 'name': 'cowboy_hat'}, {'frequency': 'c', 'synset': 'crab.n.01', 'synonyms': ['crab_(animal)'], 'id': 320, 'def': 'decapod having eyes on short stalks and a broad flattened shell and pincers', 'name': 'crab_(animal)'}, {'frequency': 'r', 'synset': 'crab.n.05', 'synonyms': ['crabmeat'], 'id': 321, 'def': 'the edible flesh of any of various crabs', 'name': 'crabmeat'}, {'frequency': 'c', 'synset': 'cracker.n.01', 'synonyms': ['cracker'], 'id': 322, 'def': 'a thin crisp wafer', 'name': 'cracker'}, {'frequency': 'r', 'synset': 'crape.n.01', 'synonyms': ['crape', 'crepe', 'French_pancake'], 'id': 323, 'def': 'small very thin pancake', 'name': 'crape'}, {'frequency': 'f', 'synset': 'crate.n.01', 'synonyms': ['crate'], 'id': 324, 'def': 'a rugged box (usually made of wood); used for shipping', 'name': 'crate'}, {'frequency': 'c', 'synset': 'crayon.n.01', 'synonyms': ['crayon', 'wax_crayon'], 'id': 325, 'def': 'writing or drawing implement made of a colored stick of composition wax', 'name': 'crayon'}, {'frequency': 'r', 'synset': 'cream_pitcher.n.01', 'synonyms': ['cream_pitcher'], 'id': 326, 'def': 'a small pitcher for serving cream', 'name': 'cream_pitcher'}, {'frequency': 'c', 'synset': 'crescent_roll.n.01', 'synonyms': ['crescent_roll', 'croissant'], 'id': 327, 'def': 'very rich flaky crescent-shaped roll', 'name': 'crescent_roll'}, {'frequency': 'c', 'synset': 'crib.n.01', 'synonyms': ['crib', 'cot'], 'id': 328, 'def': 'baby bed with high sides made of slats', 'name': 'crib'}, {'frequency': 'c', 'synset': 'crock.n.03', 'synonyms': ['crock_pot', 'earthenware_jar'], 'id': 329, 'def': 'an earthen jar (made of baked clay) or a modern electric crockpot', 'name': 'crock_pot'}, {'frequency': 'f', 'synset': 'crossbar.n.01', 'synonyms': ['crossbar'], 'id': 330, 'def': 'a horizontal bar that goes across something', 'name': 'crossbar'}, {'frequency': 'r', 'synset': 'crouton.n.01', 'synonyms': ['crouton'], 'id': 331, 'def': 'a small piece of toasted or fried bread; served in soup or salads', 'name': 'crouton'}, {'frequency': 'c', 'synset': 'crow.n.01', 'synonyms': ['crow'], 'id': 332, 'def': 'black birds having a raucous call', 'name': 'crow'}, {'frequency': 'r', 'synset': 'crowbar.n.01', 'synonyms': ['crowbar', 'wrecking_bar', 'pry_bar'], 'id': 333, 'def': 'a heavy iron lever with one end forged into a wedge', 'name': 'crowbar'}, {'frequency': 'c', 'synset': 'crown.n.04', 'synonyms': ['crown'], 'id': 334, 'def': 'an ornamental jeweled headdress signifying sovereignty', 'name': 'crown'}, {'frequency': 'c', 'synset': 'crucifix.n.01', 'synonyms': ['crucifix'], 'id': 335, 'def': 'representation of the cross on which Jesus died', 'name': 'crucifix'}, {'frequency': 'c', 'synset': 'cruise_ship.n.01', 'synonyms': ['cruise_ship', 'cruise_liner'], 'id': 336, 'def': 'a passenger ship used commercially for pleasure cruises', 'name': 'cruise_ship'}, {'frequency': 'c', 'synset': 'cruiser.n.01', 'synonyms': ['police_cruiser', 'patrol_car', 'police_car', 'squad_car'], 'id': 337, 'def': 'a car in which policemen cruise the streets', 'name': 'police_cruiser'}, {'frequency': 'f', 'synset': 'crumb.n.03', 'synonyms': ['crumb'], 'id': 338, 'def': 'small piece of e.g. bread or cake', 'name': 'crumb'}, {'frequency': 'c', 'synset': 'crutch.n.01', 'synonyms': ['crutch'], 'id': 339, 'def': 'a wooden or metal staff that fits under the armpit and reaches to the ground', 'name': 'crutch'}, {'frequency': 'c', 'synset': 'cub.n.03', 'synonyms': ['cub_(animal)'], 'id': 340, 'def': 'the young of certain carnivorous mammals such as the bear or wolf or lion', 'name': 'cub_(animal)'}, {'frequency': 'c', 'synset': 'cube.n.05', 'synonyms': ['cube', 'square_block'], 'id': 341, 'def': 'a block in the (approximate) shape of a cube', 'name': 'cube'}, {'frequency': 'f', 'synset': 'cucumber.n.02', 'synonyms': ['cucumber', 'cuke'], 'id': 342, 'def': 'cylindrical green fruit with thin green rind and white flesh eaten as a vegetable', 'name': 'cucumber'}, {'frequency': 'c', 'synset': 'cufflink.n.01', 'synonyms': ['cufflink'], 'id': 343, 'def': 'jewelry consisting of linked buttons used to fasten the cuffs of a shirt', 'name': 'cufflink'}, {'frequency': 'f', 'synset': 'cup.n.01', 'synonyms': ['cup'], 'id': 344, 'def': 'a small open container usually used for drinking; usually has a handle', 'name': 'cup'}, {'frequency': 'c', 'synset': 'cup.n.08', 'synonyms': ['trophy_cup'], 'id': 345, 'def': 'a metal award or cup-shaped vessel with handles that is awarded as a trophy to a competition winner', 'name': 'trophy_cup'}, {'frequency': 'f', 'synset': 'cupboard.n.01', 'synonyms': ['cupboard', 'closet'], 'id': 346, 'def': 'a small room (or recess) or cabinet used for storage space', 'name': 'cupboard'}, {'frequency': 'f', 'synset': 'cupcake.n.01', 'synonyms': ['cupcake'], 'id': 347, 'def': 'small cake baked in a muffin tin', 'name': 'cupcake'}, {'frequency': 'r', 'synset': 'curler.n.01', 'synonyms': ['hair_curler', 'hair_roller', 'hair_crimper'], 'id': 348, 'def': 'a cylindrical tube around which the hair is wound to curl it', 'name': 'hair_curler'}, {'frequency': 'r', 'synset': 'curling_iron.n.01', 'synonyms': ['curling_iron'], 'id': 349, 'def': 'a cylindrical home appliance that heats hair that has been curled around it', 'name': 'curling_iron'}, {'frequency': 'f', 'synset': 'curtain.n.01', 'synonyms': ['curtain', 'drapery'], 'id': 350, 'def': 'hanging cloth used as a blind (especially for a window)', 'name': 'curtain'}, {'frequency': 'f', 'synset': 'cushion.n.03', 'synonyms': ['cushion'], 'id': 351, 'def': 'a soft bag filled with air or padding such as feathers or foam rubber', 'name': 'cushion'}, {'frequency': 'r', 'synset': 'cylinder.n.04', 'synonyms': ['cylinder'], 'id': 352, 'def': 'a cylindrical container', 'name': 'cylinder'}, {'frequency': 'r', 'synset': 'cymbal.n.01', 'synonyms': ['cymbal'], 'id': 353, 'def': 'a percussion instrument consisting of a concave brass disk', 'name': 'cymbal'}, {'frequency': 'r', 'synset': 'dagger.n.01', 'synonyms': ['dagger'], 'id': 354, 'def': 'a short knife with a pointed blade used for piercing or stabbing', 'name': 'dagger'}, {'frequency': 'r', 'synset': 'dalmatian.n.02', 'synonyms': ['dalmatian'], 'id': 355, 'def': 'a large breed having a smooth white coat with black or brown spots', 'name': 'dalmatian'}, {'frequency': 'c', 'synset': 'dartboard.n.01', 'synonyms': ['dartboard'], 'id': 356, 'def': 'a circular board of wood or cork used as the target in the game of darts', 'name': 'dartboard'}, {'frequency': 'r', 'synset': 'date.n.08', 'synonyms': ['date_(fruit)'], 'id': 357, 'def': 'sweet edible fruit of the date palm with a single long woody seed', 'name': 'date_(fruit)'}, {'frequency': 'f', 'synset': 'deck_chair.n.01', 'synonyms': ['deck_chair', 'beach_chair'], 'id': 358, 'def': 'a folding chair for use outdoors; a wooden frame supports a length of canvas', 'name': 'deck_chair'}, {'frequency': 'c', 'synset': 'deer.n.01', 'synonyms': ['deer', 'cervid'], 'id': 359, 'def': "distinguished from Bovidae by the male's having solid deciduous antlers", 'name': 'deer'}, {'frequency': 'c', 'synset': 'dental_floss.n.01', 'synonyms': ['dental_floss', 'floss'], 'id': 360, 'def': 'a soft thread for cleaning the spaces between the teeth', 'name': 'dental_floss'}, {'frequency': 'f', 'synset': 'desk.n.01', 'synonyms': ['desk'], 'id': 361, 'def': 'a piece of furniture with a writing surface and usually drawers or other compartments', 'name': 'desk'}, {'frequency': 'r', 'synset': 'detergent.n.01', 'synonyms': ['detergent'], 'id': 362, 'def': 'a surface-active chemical widely used in industry and laundering', 'name': 'detergent'}, {'frequency': 'c', 'synset': 'diaper.n.01', 'synonyms': ['diaper'], 'id': 363, 'def': 'garment consisting of a folded cloth drawn up between the legs and fastened at the waist', 'name': 'diaper'}, {'frequency': 'r', 'synset': 'diary.n.01', 'synonyms': ['diary', 'journal'], 'id': 364, 'def': 'yearly planner book', 'name': 'diary'}, {'frequency': 'r', 'synset': 'die.n.01', 'synonyms': ['die', 'dice'], 'id': 365, 'def': 'a small cube with 1 to 6 spots on the six faces; used in gambling', 'name': 'die'}, {'frequency': 'r', 'synset': 'dinghy.n.01', 'synonyms': ['dinghy', 'dory', 'rowboat'], 'id': 366, 'def': 'a small boat of shallow draft with seats and oars with which it is propelled', 'name': 'dinghy'}, {'frequency': 'f', 'synset': 'dining_table.n.01', 'synonyms': ['dining_table'], 'id': 367, 'def': 'a table at which meals are served', 'name': 'dining_table'}, {'frequency': 'r', 'synset': 'dinner_jacket.n.01', 'synonyms': ['tux', 'tuxedo'], 'id': 368, 'def': 'semiformal evening dress for men', 'name': 'tux'}, {'frequency': 'f', 'synset': 'dish.n.01', 'synonyms': ['dish'], 'id': 369, 'def': 'a piece of dishware normally used as a container for holding or serving food', 'name': 'dish'}, {'frequency': 'c', 'synset': 'dish.n.05', 'synonyms': ['dish_antenna'], 'id': 370, 'def': 'directional antenna consisting of a parabolic reflector', 'name': 'dish_antenna'}, {'frequency': 'c', 'synset': 'dishrag.n.01', 'synonyms': ['dishrag', 'dishcloth'], 'id': 371, 'def': 'a cloth for washing dishes or cleaning in general', 'name': 'dishrag'}, {'frequency': 'f', 'synset': 'dishtowel.n.01', 'synonyms': ['dishtowel', 'tea_towel'], 'id': 372, 'def': 'a towel for drying dishes', 'name': 'dishtowel'}, {'frequency': 'f', 'synset': 'dishwasher.n.01', 'synonyms': ['dishwasher', 'dishwashing_machine'], 'id': 373, 'def': 'a machine for washing dishes', 'name': 'dishwasher'}, {'frequency': 'r', 'synset': 'dishwasher_detergent.n.01', 'synonyms': ['dishwasher_detergent', 'dishwashing_detergent', 'dishwashing_liquid', 'dishsoap'], 'id': 374, 'def': 'dishsoap or dish detergent designed for use in dishwashers', 'name': 'dishwasher_detergent'}, {'frequency': 'f', 'synset': 'dispenser.n.01', 'synonyms': ['dispenser'], 'id': 375, 'def': 'a container so designed that the contents can be used in prescribed amounts', 'name': 'dispenser'}, {'frequency': 'r', 'synset': 'diving_board.n.01', 'synonyms': ['diving_board'], 'id': 376, 'def': 'a springboard from which swimmers can dive', 'name': 'diving_board'}, {'frequency': 'f', 'synset': 'dixie_cup.n.01', 'synonyms': ['Dixie_cup', 'paper_cup'], 'id': 377, 'def': 'a disposable cup made of paper; for holding drinks', 'name': 'Dixie_cup'}, {'frequency': 'f', 'synset': 'dog.n.01', 'synonyms': ['dog'], 'id': 378, 'def': 'a common domesticated dog', 'name': 'dog'}, {'frequency': 'f', 'synset': 'dog_collar.n.01', 'synonyms': ['dog_collar'], 'id': 379, 'def': 'a collar for a dog', 'name': 'dog_collar'}, {'frequency': 'f', 'synset': 'doll.n.01', 'synonyms': ['doll'], 'id': 380, 'def': 'a toy replica of a HUMAN (NOT AN ANIMAL)', 'name': 'doll'}, {'frequency': 'r', 'synset': 'dollar.n.02', 'synonyms': ['dollar', 'dollar_bill', 'one_dollar_bill'], 'id': 381, 'def': 'a piece of paper money worth one dollar', 'name': 'dollar'}, {'frequency': 'r', 'synset': 'dollhouse.n.01', 'synonyms': ['dollhouse', "doll's_house"], 'id': 382, 'def': "a house so small that it is likened to a child's plaything", 'name': 'dollhouse'}, {'frequency': 'c', 'synset': 'dolphin.n.02', 'synonyms': ['dolphin'], 'id': 383, 'def': 'any of various small toothed whales with a beaklike snout; larger than porpoises', 'name': 'dolphin'}, {'frequency': 'c', 'synset': 'domestic_ass.n.01', 'synonyms': ['domestic_ass', 'donkey'], 'id': 384, 'def': 'domestic beast of burden descended from the African wild ass; patient but stubborn', 'name': 'domestic_ass'}, {'frequency': 'f', 'synset': 'doorknob.n.01', 'synonyms': ['doorknob', 'doorhandle'], 'id': 385, 'def': "a knob used to open a door (often called `doorhandle' in Great Britain)", 'name': 'doorknob'}, {'frequency': 'c', 'synset': 'doormat.n.02', 'synonyms': ['doormat', 'welcome_mat'], 'id': 386, 'def': 'a mat placed outside an exterior door for wiping the shoes before entering', 'name': 'doormat'}, {'frequency': 'f', 'synset': 'doughnut.n.02', 'synonyms': ['doughnut', 'donut'], 'id': 387, 'def': 'a small ring-shaped friedcake', 'name': 'doughnut'}, {'frequency': 'r', 'synset': 'dove.n.01', 'synonyms': ['dove'], 'id': 388, 'def': 'any of numerous small pigeons', 'name': 'dove'}, {'frequency': 'r', 'synset': 'dragonfly.n.01', 'synonyms': ['dragonfly'], 'id': 389, 'def': 'slender-bodied non-stinging insect having iridescent wings that are outspread at rest', 'name': 'dragonfly'}, {'frequency': 'f', 'synset': 'drawer.n.01', 'synonyms': ['drawer'], 'id': 390, 'def': 'a boxlike container in a piece of furniture; made so as to slide in and out', 'name': 'drawer'}, {'frequency': 'c', 'synset': 'drawers.n.01', 'synonyms': ['underdrawers', 'boxers', 'boxershorts'], 'id': 391, 'def': 'underpants worn by men', 'name': 'underdrawers'}, {'frequency': 'f', 'synset': 'dress.n.01', 'synonyms': ['dress', 'frock'], 'id': 392, 'def': 'a one-piece garment for a woman; has skirt and bodice', 'name': 'dress'}, {'frequency': 'c', 'synset': 'dress_hat.n.01', 'synonyms': ['dress_hat', 'high_hat', 'opera_hat', 'silk_hat', 'top_hat'], 'id': 393, 'def': "a man's hat with a tall crown; usually covered with silk or with beaver fur", 'name': 'dress_hat'}, {'frequency': 'f', 'synset': 'dress_suit.n.01', 'synonyms': ['dress_suit'], 'id': 394, 'def': 'formalwear consisting of full evening dress for men', 'name': 'dress_suit'}, {'frequency': 'f', 'synset': 'dresser.n.05', 'synonyms': ['dresser'], 'id': 395, 'def': 'a cabinet with shelves', 'name': 'dresser'}, {'frequency': 'c', 'synset': 'drill.n.01', 'synonyms': ['drill'], 'id': 396, 'def': 'a tool with a sharp rotating point for making holes in hard materials', 'name': 'drill'}, {'frequency': 'r', 'synset': 'drone.n.04', 'synonyms': ['drone'], 'id': 397, 'def': 'an aircraft without a pilot that is operated by remote control', 'name': 'drone'}, {'frequency': 'r', 'synset': 'dropper.n.01', 'synonyms': ['dropper', 'eye_dropper'], 'id': 398, 'def': 'pipet consisting of a small tube with a vacuum bulb at one end for drawing liquid in and releasing it a drop at a time', 'name': 'dropper'}, {'frequency': 'c', 'synset': 'drum.n.01', 'synonyms': ['drum_(musical_instrument)'], 'id': 399, 'def': 'a musical percussion instrument; usually consists of a hollow cylinder with a membrane stretched across each end', 'name': 'drum_(musical_instrument)'}, {'frequency': 'r', 'synset': 'drumstick.n.02', 'synonyms': ['drumstick'], 'id': 400, 'def': 'a stick used for playing a drum', 'name': 'drumstick'}, {'frequency': 'f', 'synset': 'duck.n.01', 'synonyms': ['duck'], 'id': 401, 'def': 'small web-footed broad-billed swimming bird', 'name': 'duck'}, {'frequency': 'c', 'synset': 'duckling.n.02', 'synonyms': ['duckling'], 'id': 402, 'def': 'young duck', 'name': 'duckling'}, {'frequency': 'c', 'synset': 'duct_tape.n.01', 'synonyms': ['duct_tape'], 'id': 403, 'def': 'a wide silvery adhesive tape', 'name': 'duct_tape'}, {'frequency': 'f', 'synset': 'duffel_bag.n.01', 'synonyms': ['duffel_bag', 'duffle_bag', 'duffel', 'duffle'], 'id': 404, 'def': 'a large cylindrical bag of heavy cloth (does not include suitcases)', 'name': 'duffel_bag'}, {'frequency': 'r', 'synset': 'dumbbell.n.01', 'synonyms': ['dumbbell'], 'id': 405, 'def': 'an exercising weight with two ball-like ends connected by a short handle', 'name': 'dumbbell'}, {'frequency': 'c', 'synset': 'dumpster.n.01', 'synonyms': ['dumpster'], 'id': 406, 'def': 'a container designed to receive and transport and dump waste', 'name': 'dumpster'}, {'frequency': 'r', 'synset': 'dustpan.n.02', 'synonyms': ['dustpan'], 'id': 407, 'def': 'a short-handled receptacle into which dust can be swept', 'name': 'dustpan'}, {'frequency': 'c', 'synset': 'eagle.n.01', 'synonyms': ['eagle'], 'id': 408, 'def': 'large birds of prey noted for their broad wings and strong soaring flight', 'name': 'eagle'}, {'frequency': 'f', 'synset': 'earphone.n.01', 'synonyms': ['earphone', 'earpiece', 'headphone'], 'id': 409, 'def': 'device for listening to audio that is held over or inserted into the ear', 'name': 'earphone'}, {'frequency': 'r', 'synset': 'earplug.n.01', 'synonyms': ['earplug'], 'id': 410, 'def': 'a soft plug that is inserted into the ear canal to block sound', 'name': 'earplug'}, {'frequency': 'f', 'synset': 'earring.n.01', 'synonyms': ['earring'], 'id': 411, 'def': 'jewelry to ornament the ear', 'name': 'earring'}, {'frequency': 'c', 'synset': 'easel.n.01', 'synonyms': ['easel'], 'id': 412, 'def': "an upright tripod for displaying something (usually an artist's canvas)", 'name': 'easel'}, {'frequency': 'r', 'synset': 'eclair.n.01', 'synonyms': ['eclair'], 'id': 413, 'def': 'oblong cream puff', 'name': 'eclair'}, {'frequency': 'r', 'synset': 'eel.n.01', 'synonyms': ['eel'], 'id': 414, 'def': 'an elongate fish with fatty flesh', 'name': 'eel'}, {'frequency': 'f', 'synset': 'egg.n.02', 'synonyms': ['egg', 'eggs'], 'id': 415, 'def': 'oval reproductive body of a fowl (especially a hen) used as food', 'name': 'egg'}, {'frequency': 'r', 'synset': 'egg_roll.n.01', 'synonyms': ['egg_roll', 'spring_roll'], 'id': 416, 'def': 'minced vegetables and meat wrapped in a pancake and fried', 'name': 'egg_roll'}, {'frequency': 'c', 'synset': 'egg_yolk.n.01', 'synonyms': ['egg_yolk', 'yolk_(egg)'], 'id': 417, 'def': 'the yellow spherical part of an egg', 'name': 'egg_yolk'}, {'frequency': 'c', 'synset': 'eggbeater.n.02', 'synonyms': ['eggbeater', 'eggwhisk'], 'id': 418, 'def': 'a mixer for beating eggs or whipping cream', 'name': 'eggbeater'}, {'frequency': 'c', 'synset': 'eggplant.n.01', 'synonyms': ['eggplant', 'aubergine'], 'id': 419, 'def': 'egg-shaped vegetable having a shiny skin typically dark purple', 'name': 'eggplant'}, {'frequency': 'r', 'synset': 'electric_chair.n.01', 'synonyms': ['electric_chair'], 'id': 420, 'def': 'a chair-shaped instrument of execution by electrocution', 'name': 'electric_chair'}, {'frequency': 'f', 'synset': 'electric_refrigerator.n.01', 'synonyms': ['refrigerator'], 'id': 421, 'def': 'a refrigerator in which the coolant is pumped around by an electric motor', 'name': 'refrigerator'}, {'frequency': 'f', 'synset': 'elephant.n.01', 'synonyms': ['elephant'], 'id': 422, 'def': 'a common elephant', 'name': 'elephant'}, {'frequency': 'c', 'synset': 'elk.n.01', 'synonyms': ['elk', 'moose'], 'id': 423, 'def': 'large northern deer with enormous flattened antlers in the male', 'name': 'elk'}, {'frequency': 'c', 'synset': 'envelope.n.01', 'synonyms': ['envelope'], 'id': 424, 'def': 'a flat (usually rectangular) container for a letter, thin package, etc.', 'name': 'envelope'}, {'frequency': 'c', 'synset': 'eraser.n.01', 'synonyms': ['eraser'], 'id': 425, 'def': 'an implement used to erase something', 'name': 'eraser'}, {'frequency': 'r', 'synset': 'escargot.n.01', 'synonyms': ['escargot'], 'id': 426, 'def': 'edible snail usually served in the shell with a sauce of melted butter and garlic', 'name': 'escargot'}, {'frequency': 'r', 'synset': 'eyepatch.n.01', 'synonyms': ['eyepatch'], 'id': 427, 'def': 'a protective cloth covering for an injured eye', 'name': 'eyepatch'}, {'frequency': 'r', 'synset': 'falcon.n.01', 'synonyms': ['falcon'], 'id': 428, 'def': 'birds of prey having long pointed powerful wings adapted for swift flight', 'name': 'falcon'}, {'frequency': 'f', 'synset': 'fan.n.01', 'synonyms': ['fan'], 'id': 429, 'def': 'a device for creating a current of air by movement of a surface or surfaces', 'name': 'fan'}, {'frequency': 'f', 'synset': 'faucet.n.01', 'synonyms': ['faucet', 'spigot', 'tap'], 'id': 430, 'def': 'a regulator for controlling the flow of a liquid from a reservoir', 'name': 'faucet'}, {'frequency': 'r', 'synset': 'fedora.n.01', 'synonyms': ['fedora'], 'id': 431, 'def': 'a hat made of felt with a creased crown', 'name': 'fedora'}, {'frequency': 'r', 'synset': 'ferret.n.02', 'synonyms': ['ferret'], 'id': 432, 'def': 'domesticated albino variety of the European polecat bred for hunting rats and rabbits', 'name': 'ferret'}, {'frequency': 'c', 'synset': 'ferris_wheel.n.01', 'synonyms': ['Ferris_wheel'], 'id': 433, 'def': 'a large wheel with suspended seats that remain upright as the wheel rotates', 'name': 'Ferris_wheel'}, {'frequency': 'c', 'synset': 'ferry.n.01', 'synonyms': ['ferry', 'ferryboat'], 'id': 434, 'def': 'a boat that transports people or vehicles across a body of water and operates on a regular schedule', 'name': 'ferry'}, {'frequency': 'r', 'synset': 'fig.n.04', 'synonyms': ['fig_(fruit)'], 'id': 435, 'def': 'fleshy sweet pear-shaped yellowish or purple fruit eaten fresh or preserved or dried', 'name': 'fig_(fruit)'}, {'frequency': 'c', 'synset': 'fighter.n.02', 'synonyms': ['fighter_jet', 'fighter_aircraft', 'attack_aircraft'], 'id': 436, 'def': 'a high-speed military or naval airplane designed to destroy enemy targets', 'name': 'fighter_jet'}, {'frequency': 'f', 'synset': 'figurine.n.01', 'synonyms': ['figurine'], 'id': 437, 'def': 'a small carved or molded figure', 'name': 'figurine'}, {'frequency': 'c', 'synset': 'file.n.03', 'synonyms': ['file_cabinet', 'filing_cabinet'], 'id': 438, 'def': 'office furniture consisting of a container for keeping papers in order', 'name': 'file_cabinet'}, {'frequency': 'r', 'synset': 'file.n.04', 'synonyms': ['file_(tool)'], 'id': 439, 'def': 'a steel hand tool with small sharp teeth on some or all of its surfaces; used for smoothing wood or metal', 'name': 'file_(tool)'}, {'frequency': 'f', 'synset': 'fire_alarm.n.02', 'synonyms': ['fire_alarm', 'smoke_alarm'], 'id': 440, 'def': 'an alarm that is tripped off by fire or smoke', 'name': 'fire_alarm'}, {'frequency': 'f', 'synset': 'fire_engine.n.01', 'synonyms': ['fire_engine', 'fire_truck'], 'id': 441, 'def': 'large trucks that carry firefighters and equipment to the site of a fire', 'name': 'fire_engine'}, {'frequency': 'f', 'synset': 'fire_extinguisher.n.01', 'synonyms': ['fire_extinguisher', 'extinguisher'], 'id': 442, 'def': 'a manually operated device for extinguishing small fires', 'name': 'fire_extinguisher'}, {'frequency': 'c', 'synset': 'fire_hose.n.01', 'synonyms': ['fire_hose'], 'id': 443, 'def': 'a large hose that carries water from a fire hydrant to the site of the fire', 'name': 'fire_hose'}, {'frequency': 'f', 'synset': 'fireplace.n.01', 'synonyms': ['fireplace'], 'id': 444, 'def': 'an open recess in a wall at the base of a chimney where a fire can be built', 'name': 'fireplace'}, {'frequency': 'f', 'synset': 'fireplug.n.01', 'synonyms': ['fireplug', 'fire_hydrant', 'hydrant'], 'id': 445, 'def': 'an upright hydrant for drawing water to use in fighting a fire', 'name': 'fireplug'}, {'frequency': 'r', 'synset': 'first-aid_kit.n.01', 'synonyms': ['first-aid_kit'], 'id': 446, 'def': 'kit consisting of a set of bandages and medicines for giving first aid', 'name': 'first-aid_kit'}, {'frequency': 'f', 'synset': 'fish.n.01', 'synonyms': ['fish'], 'id': 447, 'def': 'any of various mostly cold-blooded aquatic vertebrates usually having scales and breathing through gills', 'name': 'fish'}, {'frequency': 'c', 'synset': 'fish.n.02', 'synonyms': ['fish_(food)'], 'id': 448, 'def': 'the flesh of fish used as food', 'name': 'fish_(food)'}, {'frequency': 'r', 'synset': 'fishbowl.n.02', 'synonyms': ['fishbowl', 'goldfish_bowl'], 'id': 449, 'def': 'a transparent bowl in which small fish are kept', 'name': 'fishbowl'}, {'frequency': 'c', 'synset': 'fishing_rod.n.01', 'synonyms': ['fishing_rod', 'fishing_pole'], 'id': 450, 'def': 'a rod that is used in fishing to extend the fishing line', 'name': 'fishing_rod'}, {'frequency': 'f', 'synset': 'flag.n.01', 'synonyms': ['flag'], 'id': 451, 'def': 'emblem usually consisting of a rectangular piece of cloth of distinctive design (do not include pole)', 'name': 'flag'}, {'frequency': 'f', 'synset': 'flagpole.n.02', 'synonyms': ['flagpole', 'flagstaff'], 'id': 452, 'def': 'a tall staff or pole on which a flag is raised', 'name': 'flagpole'}, {'frequency': 'c', 'synset': 'flamingo.n.01', 'synonyms': ['flamingo'], 'id': 453, 'def': 'large pink web-footed bird with down-bent bill', 'name': 'flamingo'}, {'frequency': 'c', 'synset': 'flannel.n.01', 'synonyms': ['flannel'], 'id': 454, 'def': 'a soft light woolen fabric; used for clothing', 'name': 'flannel'}, {'frequency': 'c', 'synset': 'flap.n.01', 'synonyms': ['flap'], 'id': 455, 'def': 'any broad thin covering attached at one edge, such as a mud flap next to a wheel or a flap on an airplane wing', 'name': 'flap'}, {'frequency': 'r', 'synset': 'flash.n.10', 'synonyms': ['flash', 'flashbulb'], 'id': 456, 'def': 'a lamp for providing momentary light to take a photograph', 'name': 'flash'}, {'frequency': 'c', 'synset': 'flashlight.n.01', 'synonyms': ['flashlight', 'torch'], 'id': 457, 'def': 'a small portable battery-powered electric lamp', 'name': 'flashlight'}, {'frequency': 'r', 'synset': 'fleece.n.03', 'synonyms': ['fleece'], 'id': 458, 'def': 'a soft bulky fabric with deep pile; used chiefly for clothing', 'name': 'fleece'}, {'frequency': 'f', 'synset': 'flip-flop.n.02', 'synonyms': ['flip-flop_(sandal)'], 'id': 459, 'def': 'a backless sandal held to the foot by a thong between two toes', 'name': 'flip-flop_(sandal)'}, {'frequency': 'c', 'synset': 'flipper.n.01', 'synonyms': ['flipper_(footwear)', 'fin_(footwear)'], 'id': 460, 'def': 'a shoe to aid a person in swimming', 'name': 'flipper_(footwear)'}, {'frequency': 'f', 'synset': 'flower_arrangement.n.01', 'synonyms': ['flower_arrangement', 'floral_arrangement'], 'id': 461, 'def': 'a decorative arrangement of flowers', 'name': 'flower_arrangement'}, {'frequency': 'c', 'synset': 'flute.n.02', 'synonyms': ['flute_glass', 'champagne_flute'], 'id': 462, 'def': 'a tall narrow wineglass', 'name': 'flute_glass'}, {'frequency': 'c', 'synset': 'foal.n.01', 'synonyms': ['foal'], 'id': 463, 'def': 'a young horse', 'name': 'foal'}, {'frequency': 'c', 'synset': 'folding_chair.n.01', 'synonyms': ['folding_chair'], 'id': 464, 'def': 'a chair that can be folded flat for storage', 'name': 'folding_chair'}, {'frequency': 'c', 'synset': 'food_processor.n.01', 'synonyms': ['food_processor'], 'id': 465, 'def': 'a kitchen appliance for shredding, blending, chopping, or slicing food', 'name': 'food_processor'}, {'frequency': 'c', 'synset': 'football.n.02', 'synonyms': ['football_(American)'], 'id': 466, 'def': 'the inflated oblong ball used in playing American football', 'name': 'football_(American)'}, {'frequency': 'r', 'synset': 'football_helmet.n.01', 'synonyms': ['football_helmet'], 'id': 467, 'def': 'a padded helmet with a face mask to protect the head of football players', 'name': 'football_helmet'}, {'frequency': 'c', 'synset': 'footstool.n.01', 'synonyms': ['footstool', 'footrest'], 'id': 468, 'def': 'a low seat or a stool to rest the feet of a seated person', 'name': 'footstool'}, {'frequency': 'f', 'synset': 'fork.n.01', 'synonyms': ['fork'], 'id': 469, 'def': 'cutlery used for serving and eating food', 'name': 'fork'}, {'frequency': 'c', 'synset': 'forklift.n.01', 'synonyms': ['forklift'], 'id': 470, 'def': 'an industrial vehicle with a power operated fork in front that can be inserted under loads to lift and move them', 'name': 'forklift'}, {'frequency': 'c', 'synset': 'freight_car.n.01', 'synonyms': ['freight_car'], 'id': 471, 'def': 'a railway car that carries freight', 'name': 'freight_car'}, {'frequency': 'c', 'synset': 'french_toast.n.01', 'synonyms': ['French_toast'], 'id': 472, 'def': 'bread slice dipped in egg and milk and fried', 'name': 'French_toast'}, {'frequency': 'c', 'synset': 'freshener.n.01', 'synonyms': ['freshener', 'air_freshener'], 'id': 473, 'def': 'anything that freshens air by removing or covering odor', 'name': 'freshener'}, {'frequency': 'f', 'synset': 'frisbee.n.01', 'synonyms': ['frisbee'], 'id': 474, 'def': 'a light, plastic disk propelled with a flip of the wrist for recreation or competition', 'name': 'frisbee'}, {'frequency': 'c', 'synset': 'frog.n.01', 'synonyms': ['frog', 'toad', 'toad_frog'], 'id': 475, 'def': 'a tailless stout-bodied amphibians with long hind limbs for leaping', 'name': 'frog'}, {'frequency': 'c', 'synset': 'fruit_juice.n.01', 'synonyms': ['fruit_juice'], 'id': 476, 'def': 'drink produced by squeezing or crushing fruit', 'name': 'fruit_juice'}, {'frequency': 'f', 'synset': 'frying_pan.n.01', 'synonyms': ['frying_pan', 'frypan', 'skillet'], 'id': 477, 'def': 'a pan used for frying foods', 'name': 'frying_pan'}, {'frequency': 'r', 'synset': 'fudge.n.01', 'synonyms': ['fudge'], 'id': 478, 'def': 'soft creamy candy', 'name': 'fudge'}, {'frequency': 'r', 'synset': 'funnel.n.02', 'synonyms': ['funnel'], 'id': 479, 'def': 'a cone-shaped utensil used to channel a substance into a container with a small mouth', 'name': 'funnel'}, {'frequency': 'r', 'synset': 'futon.n.01', 'synonyms': ['futon'], 'id': 480, 'def': 'a pad that is used for sleeping on the floor or on a raised frame', 'name': 'futon'}, {'frequency': 'r', 'synset': 'gag.n.02', 'synonyms': ['gag', 'muzzle'], 'id': 481, 'def': "restraint put into a person's mouth to prevent speaking or shouting", 'name': 'gag'}, {'frequency': 'r', 'synset': 'garbage.n.03', 'synonyms': ['garbage'], 'id': 482, 'def': 'a receptacle where waste can be discarded', 'name': 'garbage'}, {'frequency': 'c', 'synset': 'garbage_truck.n.01', 'synonyms': ['garbage_truck'], 'id': 483, 'def': 'a truck for collecting domestic refuse', 'name': 'garbage_truck'}, {'frequency': 'c', 'synset': 'garden_hose.n.01', 'synonyms': ['garden_hose'], 'id': 484, 'def': 'a hose used for watering a lawn or garden', 'name': 'garden_hose'}, {'frequency': 'c', 'synset': 'gargle.n.01', 'synonyms': ['gargle', 'mouthwash'], 'id': 485, 'def': 'a medicated solution used for gargling and rinsing the mouth', 'name': 'gargle'}, {'frequency': 'r', 'synset': 'gargoyle.n.02', 'synonyms': ['gargoyle'], 'id': 486, 'def': 'an ornament consisting of a grotesquely carved figure of a person or animal', 'name': 'gargoyle'}, {'frequency': 'c', 'synset': 'garlic.n.02', 'synonyms': ['garlic', 'ail'], 'id': 487, 'def': 'aromatic bulb used as seasoning', 'name': 'garlic'}, {'frequency': 'r', 'synset': 'gasmask.n.01', 'synonyms': ['gasmask', 'respirator', 'gas_helmet'], 'id': 488, 'def': 'a protective face mask with a filter', 'name': 'gasmask'}, {'frequency': 'c', 'synset': 'gazelle.n.01', 'synonyms': ['gazelle'], 'id': 489, 'def': 'small swift graceful antelope of Africa and Asia having lustrous eyes', 'name': 'gazelle'}, {'frequency': 'c', 'synset': 'gelatin.n.02', 'synonyms': ['gelatin', 'jelly'], 'id': 490, 'def': 'an edible jelly made with gelatin and used as a dessert or salad base or a coating for foods', 'name': 'gelatin'}, {'frequency': 'r', 'synset': 'gem.n.02', 'synonyms': ['gemstone'], 'id': 491, 'def': 'a crystalline rock that can be cut and polished for jewelry', 'name': 'gemstone'}, {'frequency': 'r', 'synset': 'generator.n.02', 'synonyms': ['generator'], 'id': 492, 'def': 'engine that converts mechanical energy into electrical energy by electromagnetic induction', 'name': 'generator'}, {'frequency': 'c', 'synset': 'giant_panda.n.01', 'synonyms': ['giant_panda', 'panda', 'panda_bear'], 'id': 493, 'def': 'large black-and-white herbivorous mammal of bamboo forests of China and Tibet', 'name': 'giant_panda'}, {'frequency': 'c', 'synset': 'gift_wrap.n.01', 'synonyms': ['gift_wrap'], 'id': 494, 'def': 'attractive wrapping paper suitable for wrapping gifts', 'name': 'gift_wrap'}, {'frequency': 'c', 'synset': 'ginger.n.03', 'synonyms': ['ginger', 'gingerroot'], 'id': 495, 'def': 'the root of the common ginger plant; used fresh as a seasoning', 'name': 'ginger'}, {'frequency': 'f', 'synset': 'giraffe.n.01', 'synonyms': ['giraffe'], 'id': 496, 'def': 'tall animal having a spotted coat and small horns and very long neck and legs', 'name': 'giraffe'}, {'frequency': 'c', 'synset': 'girdle.n.02', 'synonyms': ['cincture', 'sash', 'waistband', 'waistcloth'], 'id': 497, 'def': 'a band of material around the waist that strengthens a skirt or trousers', 'name': 'cincture'}, {'frequency': 'f', 'synset': 'glass.n.02', 'synonyms': ['glass_(drink_container)', 'drinking_glass'], 'id': 498, 'def': 'a container for holding liquids while drinking', 'name': 'glass_(drink_container)'}, {'frequency': 'c', 'synset': 'globe.n.03', 'synonyms': ['globe'], 'id': 499, 'def': 'a sphere on which a map (especially of the earth) is represented', 'name': 'globe'}, {'frequency': 'f', 'synset': 'glove.n.02', 'synonyms': ['glove'], 'id': 500, 'def': 'handwear covering the hand', 'name': 'glove'}, {'frequency': 'c', 'synset': 'goat.n.01', 'synonyms': ['goat'], 'id': 501, 'def': 'a common goat', 'name': 'goat'}, {'frequency': 'f', 'synset': 'goggles.n.01', 'synonyms': ['goggles'], 'id': 502, 'def': 'tight-fitting spectacles worn to protect the eyes', 'name': 'goggles'}, {'frequency': 'r', 'synset': 'goldfish.n.01', 'synonyms': ['goldfish'], 'id': 503, 'def': 'small golden or orange-red freshwater fishes used as pond or aquarium pets', 'name': 'goldfish'}, {'frequency': 'c', 'synset': 'golf_club.n.02', 'synonyms': ['golf_club', 'golf-club'], 'id': 504, 'def': 'golf equipment used by a golfer to hit a golf ball', 'name': 'golf_club'}, {'frequency': 'c', 'synset': 'golfcart.n.01', 'synonyms': ['golfcart'], 'id': 505, 'def': 'a small motor vehicle in which golfers can ride between shots', 'name': 'golfcart'}, {'frequency': 'r', 'synset': 'gondola.n.02', 'synonyms': ['gondola_(boat)'], 'id': 506, 'def': 'long narrow flat-bottomed boat propelled by sculling; traditionally used on canals of Venice', 'name': 'gondola_(boat)'}, {'frequency': 'c', 'synset': 'goose.n.01', 'synonyms': ['goose'], 'id': 507, 'def': 'loud, web-footed long-necked aquatic birds usually larger than ducks', 'name': 'goose'}, {'frequency': 'r', 'synset': 'gorilla.n.01', 'synonyms': ['gorilla'], 'id': 508, 'def': 'largest ape', 'name': 'gorilla'}, {'frequency': 'r', 'synset': 'gourd.n.02', 'synonyms': ['gourd'], 'id': 509, 'def': 'any of numerous inedible fruits with hard rinds', 'name': 'gourd'}, {'frequency': 'f', 'synset': 'grape.n.01', 'synonyms': ['grape'], 'id': 510, 'def': 'any of various juicy fruit with green or purple skins; grow in clusters', 'name': 'grape'}, {'frequency': 'c', 'synset': 'grater.n.01', 'synonyms': ['grater'], 'id': 511, 'def': 'utensil with sharp perforations for shredding foods (as vegetables or cheese)', 'name': 'grater'}, {'frequency': 'c', 'synset': 'gravestone.n.01', 'synonyms': ['gravestone', 'headstone', 'tombstone'], 'id': 512, 'def': 'a stone that is used to mark a grave', 'name': 'gravestone'}, {'frequency': 'r', 'synset': 'gravy_boat.n.01', 'synonyms': ['gravy_boat', 'gravy_holder'], 'id': 513, 'def': 'a dish (often boat-shaped) for serving gravy or sauce', 'name': 'gravy_boat'}, {'frequency': 'f', 'synset': 'green_bean.n.02', 'synonyms': ['green_bean'], 'id': 514, 'def': 'a common bean plant cultivated for its slender green edible pods', 'name': 'green_bean'}, {'frequency': 'f', 'synset': 'green_onion.n.01', 'synonyms': ['green_onion', 'spring_onion', 'scallion'], 'id': 515, 'def': 'a young onion before the bulb has enlarged', 'name': 'green_onion'}, {'frequency': 'r', 'synset': 'griddle.n.01', 'synonyms': ['griddle'], 'id': 516, 'def': 'cooking utensil consisting of a flat heated surface on which food is cooked', 'name': 'griddle'}, {'frequency': 'f', 'synset': 'grill.n.02', 'synonyms': ['grill', 'grille', 'grillwork', 'radiator_grille'], 'id': 517, 'def': 'a framework of metal bars used as a partition or a grate', 'name': 'grill'}, {'frequency': 'r', 'synset': 'grits.n.01', 'synonyms': ['grits', 'hominy_grits'], 'id': 518, 'def': 'coarsely ground corn boiled as a breakfast dish', 'name': 'grits'}, {'frequency': 'c', 'synset': 'grizzly.n.01', 'synonyms': ['grizzly', 'grizzly_bear'], 'id': 519, 'def': 'powerful brownish-yellow bear of the uplands of western North America', 'name': 'grizzly'}, {'frequency': 'c', 'synset': 'grocery_bag.n.01', 'synonyms': ['grocery_bag'], 'id': 520, 'def': "a sack for holding customer's groceries", 'name': 'grocery_bag'}, {'frequency': 'f', 'synset': 'guitar.n.01', 'synonyms': ['guitar'], 'id': 521, 'def': 'a stringed instrument usually having six strings; played by strumming or plucking', 'name': 'guitar'}, {'frequency': 'c', 'synset': 'gull.n.02', 'synonyms': ['gull', 'seagull'], 'id': 522, 'def': 'mostly white aquatic bird having long pointed wings and short legs', 'name': 'gull'}, {'frequency': 'c', 'synset': 'gun.n.01', 'synonyms': ['gun'], 'id': 523, 'def': 'a weapon that discharges a bullet at high velocity from a metal tube', 'name': 'gun'}, {'frequency': 'f', 'synset': 'hairbrush.n.01', 'synonyms': ['hairbrush'], 'id': 524, 'def': "a brush used to groom a person's hair", 'name': 'hairbrush'}, {'frequency': 'c', 'synset': 'hairnet.n.01', 'synonyms': ['hairnet'], 'id': 525, 'def': 'a small net that someone wears over their hair to keep it in place', 'name': 'hairnet'}, {'frequency': 'c', 'synset': 'hairpin.n.01', 'synonyms': ['hairpin'], 'id': 526, 'def': "a double pronged pin used to hold women's hair in place", 'name': 'hairpin'}, {'frequency': 'r', 'synset': 'halter.n.03', 'synonyms': ['halter_top'], 'id': 527, 'def': "a woman's top that fastens behind the back and neck leaving the back and arms uncovered", 'name': 'halter_top'}, {'frequency': 'f', 'synset': 'ham.n.01', 'synonyms': ['ham', 'jambon', 'gammon'], 'id': 528, 'def': 'meat cut from the thigh of a hog (usually smoked)', 'name': 'ham'}, {'frequency': 'c', 'synset': 'hamburger.n.01', 'synonyms': ['hamburger', 'beefburger', 'burger'], 'id': 529, 'def': 'a sandwich consisting of a patty of minced beef served on a bun', 'name': 'hamburger'}, {'frequency': 'c', 'synset': 'hammer.n.02', 'synonyms': ['hammer'], 'id': 530, 'def': 'a hand tool with a heavy head and a handle; used to deliver an impulsive force by striking', 'name': 'hammer'}, {'frequency': 'c', 'synset': 'hammock.n.02', 'synonyms': ['hammock'], 'id': 531, 'def': 'a hanging bed of canvas or rope netting (usually suspended between two trees)', 'name': 'hammock'}, {'frequency': 'r', 'synset': 'hamper.n.02', 'synonyms': ['hamper'], 'id': 532, 'def': 'a basket usually with a cover', 'name': 'hamper'}, {'frequency': 'c', 'synset': 'hamster.n.01', 'synonyms': ['hamster'], 'id': 533, 'def': 'short-tailed burrowing rodent with large cheek pouches', 'name': 'hamster'}, {'frequency': 'f', 'synset': 'hand_blower.n.01', 'synonyms': ['hair_dryer'], 'id': 534, 'def': 'a hand-held electric blower that can blow warm air onto the hair', 'name': 'hair_dryer'}, {'frequency': 'r', 'synset': 'hand_glass.n.01', 'synonyms': ['hand_glass', 'hand_mirror'], 'id': 535, 'def': 'a mirror intended to be held in the hand', 'name': 'hand_glass'}, {'frequency': 'f', 'synset': 'hand_towel.n.01', 'synonyms': ['hand_towel', 'face_towel'], 'id': 536, 'def': 'a small towel used to dry the hands or face', 'name': 'hand_towel'}, {'frequency': 'c', 'synset': 'handcart.n.01', 'synonyms': ['handcart', 'pushcart', 'hand_truck'], 'id': 537, 'def': 'wheeled vehicle that can be pushed by a person', 'name': 'handcart'}, {'frequency': 'r', 'synset': 'handcuff.n.01', 'synonyms': ['handcuff'], 'id': 538, 'def': 'shackle that consists of a metal loop that can be locked around the wrist', 'name': 'handcuff'}, {'frequency': 'c', 'synset': 'handkerchief.n.01', 'synonyms': ['handkerchief'], 'id': 539, 'def': 'a square piece of cloth used for wiping the eyes or nose or as a costume accessory', 'name': 'handkerchief'}, {'frequency': 'f', 'synset': 'handle.n.01', 'synonyms': ['handle', 'grip', 'handgrip'], 'id': 540, 'def': 'the appendage to an object that is designed to be held in order to use or move it', 'name': 'handle'}, {'frequency': 'r', 'synset': 'handsaw.n.01', 'synonyms': ['handsaw', "carpenter's_saw"], 'id': 541, 'def': 'a saw used with one hand for cutting wood', 'name': 'handsaw'}, {'frequency': 'r', 'synset': 'hardback.n.01', 'synonyms': ['hardback_book', 'hardcover_book'], 'id': 542, 'def': 'a book with cardboard or cloth or leather covers', 'name': 'hardback_book'}, {'frequency': 'r', 'synset': 'harmonium.n.01', 'synonyms': ['harmonium', 'organ_(musical_instrument)', 'reed_organ_(musical_instrument)'], 'id': 543, 'def': 'a free-reed instrument in which air is forced through the reeds by bellows', 'name': 'harmonium'}, {'frequency': 'f', 'synset': 'hat.n.01', 'synonyms': ['hat'], 'id': 544, 'def': 'headwear that protects the head from bad weather, sun, or worn for fashion', 'name': 'hat'}, {'frequency': 'r', 'synset': 'hatbox.n.01', 'synonyms': ['hatbox'], 'id': 545, 'def': 'a round piece of luggage for carrying hats', 'name': 'hatbox'}, {'frequency': 'c', 'synset': 'head_covering.n.01', 'synonyms': ['veil'], 'id': 546, 'def': 'a garment that covers the head OR face', 'name': 'veil'}, {'frequency': 'f', 'synset': 'headband.n.01', 'synonyms': ['headband'], 'id': 547, 'def': 'a band worn around or over the head', 'name': 'headband'}, {'frequency': 'f', 'synset': 'headboard.n.01', 'synonyms': ['headboard'], 'id': 548, 'def': 'a vertical board or panel forming the head of a bedstead', 'name': 'headboard'}, {'frequency': 'f', 'synset': 'headlight.n.01', 'synonyms': ['headlight', 'headlamp'], 'id': 549, 'def': 'a powerful light with reflector; attached to the front of an automobile or locomotive', 'name': 'headlight'}, {'frequency': 'c', 'synset': 'headscarf.n.01', 'synonyms': ['headscarf'], 'id': 550, 'def': 'a kerchief worn over the head and tied under the chin', 'name': 'headscarf'}, {'frequency': 'r', 'synset': 'headset.n.01', 'synonyms': ['headset'], 'id': 551, 'def': 'receiver consisting of a pair of headphones', 'name': 'headset'}, {'frequency': 'c', 'synset': 'headstall.n.01', 'synonyms': ['headstall_(for_horses)', 'headpiece_(for_horses)'], 'id': 552, 'def': "the band that is the part of a bridle that fits around a horse's head", 'name': 'headstall_(for_horses)'}, {'frequency': 'c', 'synset': 'heart.n.02', 'synonyms': ['heart'], 'id': 553, 'def': 'a muscular organ; its contractions move the blood through the body', 'name': 'heart'}, {'frequency': 'c', 'synset': 'heater.n.01', 'synonyms': ['heater', 'warmer'], 'id': 554, 'def': 'device that heats water or supplies warmth to a room', 'name': 'heater'}, {'frequency': 'c', 'synset': 'helicopter.n.01', 'synonyms': ['helicopter'], 'id': 555, 'def': 'an aircraft without wings that obtains its lift from the rotation of overhead blades', 'name': 'helicopter'}, {'frequency': 'f', 'synset': 'helmet.n.02', 'synonyms': ['helmet'], 'id': 556, 'def': 'a protective headgear made of hard material to resist blows', 'name': 'helmet'}, {'frequency': 'r', 'synset': 'heron.n.02', 'synonyms': ['heron'], 'id': 557, 'def': 'grey or white wading bird with long neck and long legs and (usually) long bill', 'name': 'heron'}, {'frequency': 'c', 'synset': 'highchair.n.01', 'synonyms': ['highchair', 'feeding_chair'], 'id': 558, 'def': 'a chair for feeding a very young child', 'name': 'highchair'}, {'frequency': 'f', 'synset': 'hinge.n.01', 'synonyms': ['hinge'], 'id': 559, 'def': 'a joint that holds two parts together so that one can swing relative to the other', 'name': 'hinge'}, {'frequency': 'r', 'synset': 'hippopotamus.n.01', 'synonyms': ['hippopotamus'], 'id': 560, 'def': 'massive thick-skinned animal living in or around rivers of tropical Africa', 'name': 'hippopotamus'}, {'frequency': 'r', 'synset': 'hockey_stick.n.01', 'synonyms': ['hockey_stick'], 'id': 561, 'def': 'sports implement consisting of a stick used by hockey players to move the puck', 'name': 'hockey_stick'}, {'frequency': 'c', 'synset': 'hog.n.03', 'synonyms': ['hog', 'pig'], 'id': 562, 'def': 'domestic swine', 'name': 'hog'}, {'frequency': 'f', 'synset': 'home_plate.n.01', 'synonyms': ['home_plate_(baseball)', 'home_base_(baseball)'], 'id': 563, 'def': '(baseball) a rubber slab where the batter stands; it must be touched by a base runner in order to score', 'name': 'home_plate_(baseball)'}, {'frequency': 'c', 'synset': 'honey.n.01', 'synonyms': ['honey'], 'id': 564, 'def': 'a sweet yellow liquid produced by bees', 'name': 'honey'}, {'frequency': 'f', 'synset': 'hood.n.06', 'synonyms': ['fume_hood', 'exhaust_hood'], 'id': 565, 'def': 'metal covering leading to a vent that exhausts smoke or fumes', 'name': 'fume_hood'}, {'frequency': 'f', 'synset': 'hook.n.05', 'synonyms': ['hook'], 'id': 566, 'def': 'a curved or bent implement for suspending or pulling something', 'name': 'hook'}, {'frequency': 'r', 'synset': 'hookah.n.01', 'synonyms': ['hookah', 'narghile', 'nargileh', 'sheesha', 'shisha', 'water_pipe'], 'id': 567, 'def': 'a tobacco pipe with a long flexible tube connected to a container where the smoke is cooled by passing through water', 'name': 'hookah'}, {'frequency': 'r', 'synset': 'hornet.n.01', 'synonyms': ['hornet'], 'id': 568, 'def': 'large stinging wasp', 'name': 'hornet'}, {'frequency': 'f', 'synset': 'horse.n.01', 'synonyms': ['horse'], 'id': 569, 'def': 'a common horse', 'name': 'horse'}, {'frequency': 'f', 'synset': 'hose.n.03', 'synonyms': ['hose', 'hosepipe'], 'id': 570, 'def': 'a flexible pipe for conveying a liquid or gas', 'name': 'hose'}, {'frequency': 'r', 'synset': 'hot-air_balloon.n.01', 'synonyms': ['hot-air_balloon'], 'id': 571, 'def': 'balloon for travel through the air in a basket suspended below a large bag of heated air', 'name': 'hot-air_balloon'}, {'frequency': 'r', 'synset': 'hot_plate.n.01', 'synonyms': ['hotplate'], 'id': 572, 'def': 'a portable electric appliance for heating or cooking or keeping food warm', 'name': 'hotplate'}, {'frequency': 'c', 'synset': 'hot_sauce.n.01', 'synonyms': ['hot_sauce'], 'id': 573, 'def': 'a pungent peppery sauce', 'name': 'hot_sauce'}, {'frequency': 'r', 'synset': 'hourglass.n.01', 'synonyms': ['hourglass'], 'id': 574, 'def': 'a sandglass timer that runs for sixty minutes', 'name': 'hourglass'}, {'frequency': 'r', 'synset': 'houseboat.n.01', 'synonyms': ['houseboat'], 'id': 575, 'def': 'a barge that is designed and equipped for use as a dwelling', 'name': 'houseboat'}, {'frequency': 'c', 'synset': 'hummingbird.n.01', 'synonyms': ['hummingbird'], 'id': 576, 'def': 'tiny American bird having brilliant iridescent plumage and long slender bills', 'name': 'hummingbird'}, {'frequency': 'r', 'synset': 'hummus.n.01', 'synonyms': ['hummus', 'humus', 'hommos', 'hoummos', 'humous'], 'id': 577, 'def': 'a thick spread made from mashed chickpeas', 'name': 'hummus'}, {'frequency': 'f', 'synset': 'ice_bear.n.01', 'synonyms': ['polar_bear'], 'id': 578, 'def': 'white bear of Arctic regions', 'name': 'polar_bear'}, {'frequency': 'c', 'synset': 'ice_cream.n.01', 'synonyms': ['icecream'], 'id': 579, 'def': 'frozen dessert containing cream and sugar and flavoring', 'name': 'icecream'}, {'frequency': 'r', 'synset': 'ice_lolly.n.01', 'synonyms': ['popsicle'], 'id': 580, 'def': 'ice cream or water ice on a small wooden stick', 'name': 'popsicle'}, {'frequency': 'c', 'synset': 'ice_maker.n.01', 'synonyms': ['ice_maker'], 'id': 581, 'def': 'an appliance included in some electric refrigerators for making ice cubes', 'name': 'ice_maker'}, {'frequency': 'r', 'synset': 'ice_pack.n.01', 'synonyms': ['ice_pack', 'ice_bag'], 'id': 582, 'def': 'a waterproof bag filled with ice: applied to the body (especially the head) to cool or reduce swelling', 'name': 'ice_pack'}, {'frequency': 'r', 'synset': 'ice_skate.n.01', 'synonyms': ['ice_skate'], 'id': 583, 'def': 'skate consisting of a boot with a steel blade fitted to the sole', 'name': 'ice_skate'}, {'frequency': 'c', 'synset': 'igniter.n.01', 'synonyms': ['igniter', 'ignitor', 'lighter'], 'id': 584, 'def': 'a substance or device used to start a fire', 'name': 'igniter'}, {'frequency': 'r', 'synset': 'inhaler.n.01', 'synonyms': ['inhaler', 'inhalator'], 'id': 585, 'def': 'a dispenser that produces a chemical vapor to be inhaled through mouth or nose', 'name': 'inhaler'}, {'frequency': 'f', 'synset': 'ipod.n.01', 'synonyms': ['iPod'], 'id': 586, 'def': 'a pocket-sized device used to play music files', 'name': 'iPod'}, {'frequency': 'c', 'synset': 'iron.n.04', 'synonyms': ['iron_(for_clothing)', 'smoothing_iron_(for_clothing)'], 'id': 587, 'def': 'home appliance consisting of a flat metal base that is heated and used to smooth cloth', 'name': 'iron_(for_clothing)'}, {'frequency': 'c', 'synset': 'ironing_board.n.01', 'synonyms': ['ironing_board'], 'id': 588, 'def': 'narrow padded board on collapsible supports; used for ironing clothes', 'name': 'ironing_board'}, {'frequency': 'f', 'synset': 'jacket.n.01', 'synonyms': ['jacket'], 'id': 589, 'def': 'a waist-length coat', 'name': 'jacket'}, {'frequency': 'c', 'synset': 'jam.n.01', 'synonyms': ['jam'], 'id': 590, 'def': 'preserve of crushed fruit', 'name': 'jam'}, {'frequency': 'f', 'synset': 'jar.n.01', 'synonyms': ['jar'], 'id': 591, 'def': 'a vessel (usually cylindrical) with a wide mouth and without handles', 'name': 'jar'}, {'frequency': 'f', 'synset': 'jean.n.01', 'synonyms': ['jean', 'blue_jean', 'denim'], 'id': 592, 'def': '(usually plural) close-fitting trousers of heavy denim for manual work or casual wear', 'name': 'jean'}, {'frequency': 'c', 'synset': 'jeep.n.01', 'synonyms': ['jeep', 'landrover'], 'id': 593, 'def': 'a car suitable for traveling over rough terrain', 'name': 'jeep'}, {'frequency': 'r', 'synset': 'jelly_bean.n.01', 'synonyms': ['jelly_bean', 'jelly_egg'], 'id': 594, 'def': 'sugar-glazed jellied candy', 'name': 'jelly_bean'}, {'frequency': 'f', 'synset': 'jersey.n.03', 'synonyms': ['jersey', 'T-shirt', 'tee_shirt'], 'id': 595, 'def': 'a close-fitting pullover shirt', 'name': 'jersey'}, {'frequency': 'c', 'synset': 'jet.n.01', 'synonyms': ['jet_plane', 'jet-propelled_plane'], 'id': 596, 'def': 'an airplane powered by one or more jet engines', 'name': 'jet_plane'}, {'frequency': 'r', 'synset': 'jewel.n.01', 'synonyms': ['jewel', 'gem', 'precious_stone'], 'id': 597, 'def': 'a precious or semiprecious stone incorporated into a piece of jewelry', 'name': 'jewel'}, {'frequency': 'c', 'synset': 'jewelry.n.01', 'synonyms': ['jewelry', 'jewellery'], 'id': 598, 'def': 'an adornment (as a bracelet or ring or necklace) made of precious metals and set with gems (or imitation gems)', 'name': 'jewelry'}, {'frequency': 'r', 'synset': 'joystick.n.02', 'synonyms': ['joystick'], 'id': 599, 'def': 'a control device for computers consisting of a vertical handle that can move freely in two directions', 'name': 'joystick'}, {'frequency': 'c', 'synset': 'jump_suit.n.01', 'synonyms': ['jumpsuit'], 'id': 600, 'def': "one-piece garment fashioned after a parachutist's uniform", 'name': 'jumpsuit'}, {'frequency': 'c', 'synset': 'kayak.n.01', 'synonyms': ['kayak'], 'id': 601, 'def': 'a small canoe consisting of a light frame made watertight with animal skins', 'name': 'kayak'}, {'frequency': 'r', 'synset': 'keg.n.02', 'synonyms': ['keg'], 'id': 602, 'def': 'small cask or barrel', 'name': 'keg'}, {'frequency': 'r', 'synset': 'kennel.n.01', 'synonyms': ['kennel', 'doghouse'], 'id': 603, 'def': 'outbuilding that serves as a shelter for a dog', 'name': 'kennel'}, {'frequency': 'c', 'synset': 'kettle.n.01', 'synonyms': ['kettle', 'boiler'], 'id': 604, 'def': 'a metal pot for stewing or boiling; usually has a lid', 'name': 'kettle'}, {'frequency': 'f', 'synset': 'key.n.01', 'synonyms': ['key'], 'id': 605, 'def': 'metal instrument used to unlock a lock', 'name': 'key'}, {'frequency': 'r', 'synset': 'keycard.n.01', 'synonyms': ['keycard'], 'id': 606, 'def': 'a plastic card used to gain access typically to a door', 'name': 'keycard'}, {'frequency': 'c', 'synset': 'kilt.n.01', 'synonyms': ['kilt'], 'id': 607, 'def': 'a knee-length pleated tartan skirt worn by men as part of the traditional dress in the Highlands of northern Scotland', 'name': 'kilt'}, {'frequency': 'c', 'synset': 'kimono.n.01', 'synonyms': ['kimono'], 'id': 608, 'def': 'a loose robe; imitated from robes originally worn by Japanese', 'name': 'kimono'}, {'frequency': 'f', 'synset': 'kitchen_sink.n.01', 'synonyms': ['kitchen_sink'], 'id': 609, 'def': 'a sink in a kitchen', 'name': 'kitchen_sink'}, {'frequency': 'r', 'synset': 'kitchen_table.n.01', 'synonyms': ['kitchen_table'], 'id': 610, 'def': 'a table in the kitchen', 'name': 'kitchen_table'}, {'frequency': 'f', 'synset': 'kite.n.03', 'synonyms': ['kite'], 'id': 611, 'def': 'plaything consisting of a light frame covered with tissue paper; flown in wind at end of a string', 'name': 'kite'}, {'frequency': 'c', 'synset': 'kitten.n.01', 'synonyms': ['kitten', 'kitty'], 'id': 612, 'def': 'young domestic cat', 'name': 'kitten'}, {'frequency': 'c', 'synset': 'kiwi.n.03', 'synonyms': ['kiwi_fruit'], 'id': 613, 'def': 'fuzzy brown egg-shaped fruit with slightly tart green flesh', 'name': 'kiwi_fruit'}, {'frequency': 'f', 'synset': 'knee_pad.n.01', 'synonyms': ['knee_pad'], 'id': 614, 'def': 'protective garment consisting of a pad worn by football or baseball or hockey players', 'name': 'knee_pad'}, {'frequency': 'f', 'synset': 'knife.n.01', 'synonyms': ['knife'], 'id': 615, 'def': 'tool with a blade and point used as a cutting instrument', 'name': 'knife'}, {'frequency': 'r', 'synset': 'knitting_needle.n.01', 'synonyms': ['knitting_needle'], 'id': 616, 'def': 'needle consisting of a slender rod with pointed ends; usually used in pairs', 'name': 'knitting_needle'}, {'frequency': 'f', 'synset': 'knob.n.02', 'synonyms': ['knob'], 'id': 617, 'def': 'a round handle often found on a door', 'name': 'knob'}, {'frequency': 'r', 'synset': 'knocker.n.05', 'synonyms': ['knocker_(on_a_door)', 'doorknocker'], 'id': 618, 'def': 'a device (usually metal and ornamental) attached by a hinge to a door', 'name': 'knocker_(on_a_door)'}, {'frequency': 'r', 'synset': 'koala.n.01', 'synonyms': ['koala', 'koala_bear'], 'id': 619, 'def': 'sluggish tailless Australian marsupial with grey furry ears and coat', 'name': 'koala'}, {'frequency': 'r', 'synset': 'lab_coat.n.01', 'synonyms': ['lab_coat', 'laboratory_coat'], 'id': 620, 'def': 'a light coat worn to protect clothing from substances used while working in a laboratory', 'name': 'lab_coat'}, {'frequency': 'f', 'synset': 'ladder.n.01', 'synonyms': ['ladder'], 'id': 621, 'def': 'steps consisting of two parallel members connected by rungs', 'name': 'ladder'}, {'frequency': 'c', 'synset': 'ladle.n.01', 'synonyms': ['ladle'], 'id': 622, 'def': 'a spoon-shaped vessel with a long handle frequently used to transfer liquids', 'name': 'ladle'}, {'frequency': 'c', 'synset': 'ladybug.n.01', 'synonyms': ['ladybug', 'ladybeetle', 'ladybird_beetle'], 'id': 623, 'def': 'small round bright-colored and spotted beetle, typically red and black', 'name': 'ladybug'}, {'frequency': 'f', 'synset': 'lamb.n.01', 'synonyms': ['lamb_(animal)'], 'id': 624, 'def': 'young sheep', 'name': 'lamb_(animal)'}, {'frequency': 'r', 'synset': 'lamb_chop.n.01', 'synonyms': ['lamb-chop', 'lambchop'], 'id': 625, 'def': 'chop cut from a lamb', 'name': 'lamb-chop'}, {'frequency': 'f', 'synset': 'lamp.n.02', 'synonyms': ['lamp'], 'id': 626, 'def': 'a piece of furniture holding one or more electric light bulbs', 'name': 'lamp'}, {'frequency': 'f', 'synset': 'lamppost.n.01', 'synonyms': ['lamppost'], 'id': 627, 'def': 'a metal post supporting an outdoor lamp (such as a streetlight)', 'name': 'lamppost'}, {'frequency': 'f', 'synset': 'lampshade.n.01', 'synonyms': ['lampshade'], 'id': 628, 'def': 'a protective ornamental shade used to screen a light bulb from direct view', 'name': 'lampshade'}, {'frequency': 'c', 'synset': 'lantern.n.01', 'synonyms': ['lantern'], 'id': 629, 'def': 'light in a transparent protective case', 'name': 'lantern'}, {'frequency': 'f', 'synset': 'lanyard.n.02', 'synonyms': ['lanyard', 'laniard'], 'id': 630, 'def': 'a cord worn around the neck to hold a knife or whistle, etc.', 'name': 'lanyard'}, {'frequency': 'f', 'synset': 'laptop.n.01', 'synonyms': ['laptop_computer', 'notebook_computer'], 'id': 631, 'def': 'a portable computer small enough to use in your lap', 'name': 'laptop_computer'}, {'frequency': 'r', 'synset': 'lasagna.n.01', 'synonyms': ['lasagna', 'lasagne'], 'id': 632, 'def': 'baked dish of layers of lasagna pasta with sauce and cheese and meat or vegetables', 'name': 'lasagna'}, {'frequency': 'f', 'synset': 'latch.n.02', 'synonyms': ['latch'], 'id': 633, 'def': 'a bar that can be lowered or slid into a groove to fasten a door or gate', 'name': 'latch'}, {'frequency': 'r', 'synset': 'lawn_mower.n.01', 'synonyms': ['lawn_mower'], 'id': 634, 'def': 'garden tool for mowing grass on lawns', 'name': 'lawn_mower'}, {'frequency': 'r', 'synset': 'leather.n.01', 'synonyms': ['leather'], 'id': 635, 'def': 'an animal skin made smooth and flexible by removing the hair and then tanning', 'name': 'leather'}, {'frequency': 'c', 'synset': 'legging.n.01', 'synonyms': ['legging_(clothing)', 'leging_(clothing)', 'leg_covering'], 'id': 636, 'def': 'a garment covering the leg (usually extending from the knee to the ankle)', 'name': 'legging_(clothing)'}, {'frequency': 'c', 'synset': 'lego.n.01', 'synonyms': ['Lego', 'Lego_set'], 'id': 637, 'def': "a child's plastic construction set for making models from blocks", 'name': 'Lego'}, {'frequency': 'r', 'synset': 'legume.n.02', 'synonyms': ['legume'], 'id': 638, 'def': 'the fruit or seed of bean or pea plants', 'name': 'legume'}, {'frequency': 'f', 'synset': 'lemon.n.01', 'synonyms': ['lemon'], 'id': 639, 'def': 'yellow oval fruit with juicy acidic flesh', 'name': 'lemon'}, {'frequency': 'r', 'synset': 'lemonade.n.01', 'synonyms': ['lemonade'], 'id': 640, 'def': 'sweetened beverage of diluted lemon juice', 'name': 'lemonade'}, {'frequency': 'f', 'synset': 'lettuce.n.02', 'synonyms': ['lettuce'], 'id': 641, 'def': 'leafy plant commonly eaten in salad or on sandwiches', 'name': 'lettuce'}, {'frequency': 'f', 'synset': 'license_plate.n.01', 'synonyms': ['license_plate', 'numberplate'], 'id': 642, 'def': "a plate mounted on the front and back of car and bearing the car's registration number", 'name': 'license_plate'}, {'frequency': 'f', 'synset': 'life_buoy.n.01', 'synonyms': ['life_buoy', 'lifesaver', 'life_belt', 'life_ring'], 'id': 643, 'def': 'a ring-shaped life preserver used to prevent drowning (NOT a life-jacket or vest)', 'name': 'life_buoy'}, {'frequency': 'f', 'synset': 'life_jacket.n.01', 'synonyms': ['life_jacket', 'life_vest'], 'id': 644, 'def': 'life preserver consisting of a sleeveless jacket of buoyant or inflatable design', 'name': 'life_jacket'}, {'frequency': 'f', 'synset': 'light_bulb.n.01', 'synonyms': ['lightbulb'], 'id': 645, 'def': 'lightblub/source of light', 'name': 'lightbulb'}, {'frequency': 'r', 'synset': 'lightning_rod.n.02', 'synonyms': ['lightning_rod', 'lightning_conductor'], 'id': 646, 'def': 'a metallic conductor that is attached to a high point and leads to the ground', 'name': 'lightning_rod'}, {'frequency': 'f', 'synset': 'lime.n.06', 'synonyms': ['lime'], 'id': 647, 'def': 'the green acidic fruit of any of various lime trees', 'name': 'lime'}, {'frequency': 'r', 'synset': 'limousine.n.01', 'synonyms': ['limousine'], 'id': 648, 'def': 'long luxurious car; usually driven by a chauffeur', 'name': 'limousine'}, {'frequency': 'c', 'synset': 'lion.n.01', 'synonyms': ['lion'], 'id': 649, 'def': 'large gregarious predatory cat of Africa and India', 'name': 'lion'}, {'frequency': 'c', 'synset': 'lip_balm.n.01', 'synonyms': ['lip_balm'], 'id': 650, 'def': 'a balm applied to the lips', 'name': 'lip_balm'}, {'frequency': 'r', 'synset': 'liquor.n.01', 'synonyms': ['liquor', 'spirits', 'hard_liquor', 'liqueur', 'cordial'], 'id': 651, 'def': 'liquor or beer', 'name': 'liquor'}, {'frequency': 'c', 'synset': 'lizard.n.01', 'synonyms': ['lizard'], 'id': 652, 'def': 'a reptile with usually two pairs of legs and a tapering tail', 'name': 'lizard'}, {'frequency': 'f', 'synset': 'log.n.01', 'synonyms': ['log'], 'id': 653, 'def': 'a segment of the trunk of a tree when stripped of branches', 'name': 'log'}, {'frequency': 'c', 'synset': 'lollipop.n.02', 'synonyms': ['lollipop'], 'id': 654, 'def': 'hard candy on a stick', 'name': 'lollipop'}, {'frequency': 'f', 'synset': 'loudspeaker.n.01', 'synonyms': ['speaker_(stero_equipment)'], 'id': 655, 'def': 'electronic device that produces sound often as part of a stereo system', 'name': 'speaker_(stero_equipment)'}, {'frequency': 'c', 'synset': 'love_seat.n.01', 'synonyms': ['loveseat'], 'id': 656, 'def': 'small sofa that seats two people', 'name': 'loveseat'}, {'frequency': 'r', 'synset': 'machine_gun.n.01', 'synonyms': ['machine_gun'], 'id': 657, 'def': 'a rapidly firing automatic gun', 'name': 'machine_gun'}, {'frequency': 'f', 'synset': 'magazine.n.02', 'synonyms': ['magazine'], 'id': 658, 'def': 'a paperback periodic publication', 'name': 'magazine'}, {'frequency': 'f', 'synset': 'magnet.n.01', 'synonyms': ['magnet'], 'id': 659, 'def': 'a device that attracts iron and produces a magnetic field', 'name': 'magnet'}, {'frequency': 'c', 'synset': 'mail_slot.n.01', 'synonyms': ['mail_slot'], 'id': 660, 'def': 'a slot (usually in a door) through which mail can be delivered', 'name': 'mail_slot'}, {'frequency': 'f', 'synset': 'mailbox.n.01', 'synonyms': ['mailbox_(at_home)', 'letter_box_(at_home)'], 'id': 661, 'def': 'a private box for delivery of mail', 'name': 'mailbox_(at_home)'}, {'frequency': 'r', 'synset': 'mallard.n.01', 'synonyms': ['mallard'], 'id': 662, 'def': 'wild dabbling duck from which domestic ducks are descended', 'name': 'mallard'}, {'frequency': 'r', 'synset': 'mallet.n.01', 'synonyms': ['mallet'], 'id': 663, 'def': 'a sports implement with a long handle and a hammer-like head used to hit a ball', 'name': 'mallet'}, {'frequency': 'r', 'synset': 'mammoth.n.01', 'synonyms': ['mammoth'], 'id': 664, 'def': 'any of numerous extinct elephants widely distributed in the Pleistocene', 'name': 'mammoth'}, {'frequency': 'r', 'synset': 'manatee.n.01', 'synonyms': ['manatee'], 'id': 665, 'def': 'sirenian mammal of tropical coastal waters of America', 'name': 'manatee'}, {'frequency': 'c', 'synset': 'mandarin.n.05', 'synonyms': ['mandarin_orange'], 'id': 666, 'def': 'a somewhat flat reddish-orange loose skinned citrus of China', 'name': 'mandarin_orange'}, {'frequency': 'c', 'synset': 'manger.n.01', 'synonyms': ['manger', 'trough'], 'id': 667, 'def': 'a container (usually in a barn or stable) from which cattle or horses feed', 'name': 'manger'}, {'frequency': 'f', 'synset': 'manhole.n.01', 'synonyms': ['manhole'], 'id': 668, 'def': 'a hole (usually with a flush cover) through which a person can gain access to an underground structure', 'name': 'manhole'}, {'frequency': 'f', 'synset': 'map.n.01', 'synonyms': ['map'], 'id': 669, 'def': "a diagrammatic representation of the earth's surface (or part of it)", 'name': 'map'}, {'frequency': 'f', 'synset': 'marker.n.03', 'synonyms': ['marker'], 'id': 670, 'def': 'a writing implement for making a mark', 'name': 'marker'}, {'frequency': 'r', 'synset': 'martini.n.01', 'synonyms': ['martini'], 'id': 671, 'def': 'a cocktail made of gin (or vodka) with dry vermouth', 'name': 'martini'}, {'frequency': 'r', 'synset': 'mascot.n.01', 'synonyms': ['mascot'], 'id': 672, 'def': 'a person or animal that is adopted by a team or other group as a symbolic figure', 'name': 'mascot'}, {'frequency': 'c', 'synset': 'mashed_potato.n.01', 'synonyms': ['mashed_potato'], 'id': 673, 'def': 'potato that has been peeled and boiled and then mashed', 'name': 'mashed_potato'}, {'frequency': 'r', 'synset': 'masher.n.02', 'synonyms': ['masher'], 'id': 674, 'def': 'a kitchen utensil used for mashing (e.g. potatoes)', 'name': 'masher'}, {'frequency': 'f', 'synset': 'mask.n.04', 'synonyms': ['mask', 'facemask'], 'id': 675, 'def': 'a protective covering worn over the face', 'name': 'mask'}, {'frequency': 'f', 'synset': 'mast.n.01', 'synonyms': ['mast'], 'id': 676, 'def': 'a vertical spar for supporting sails', 'name': 'mast'}, {'frequency': 'c', 'synset': 'mat.n.03', 'synonyms': ['mat_(gym_equipment)', 'gym_mat'], 'id': 677, 'def': 'sports equipment consisting of a piece of thick padding on the floor for gymnastics', 'name': 'mat_(gym_equipment)'}, {'frequency': 'r', 'synset': 'matchbox.n.01', 'synonyms': ['matchbox'], 'id': 678, 'def': 'a box for holding matches', 'name': 'matchbox'}, {'frequency': 'f', 'synset': 'mattress.n.01', 'synonyms': ['mattress'], 'id': 679, 'def': 'a thick pad filled with resilient material used as a bed or part of a bed', 'name': 'mattress'}, {'frequency': 'c', 'synset': 'measuring_cup.n.01', 'synonyms': ['measuring_cup'], 'id': 680, 'def': 'graduated cup used to measure liquid or granular ingredients', 'name': 'measuring_cup'}, {'frequency': 'c', 'synset': 'measuring_stick.n.01', 'synonyms': ['measuring_stick', 'ruler_(measuring_stick)', 'measuring_rod'], 'id': 681, 'def': 'measuring instrument having a sequence of marks at regular intervals', 'name': 'measuring_stick'}, {'frequency': 'c', 'synset': 'meatball.n.01', 'synonyms': ['meatball'], 'id': 682, 'def': 'ground meat formed into a ball and fried or simmered in broth', 'name': 'meatball'}, {'frequency': 'c', 'synset': 'medicine.n.02', 'synonyms': ['medicine'], 'id': 683, 'def': 'something that treats or prevents or alleviates the symptoms of disease', 'name': 'medicine'}, {'frequency': 'c', 'synset': 'melon.n.01', 'synonyms': ['melon'], 'id': 684, 'def': 'fruit of the gourd family having a hard rind and sweet juicy flesh', 'name': 'melon'}, {'frequency': 'f', 'synset': 'microphone.n.01', 'synonyms': ['microphone'], 'id': 685, 'def': 'device for converting sound waves into electrical energy', 'name': 'microphone'}, {'frequency': 'r', 'synset': 'microscope.n.01', 'synonyms': ['microscope'], 'id': 686, 'def': 'magnifier of the image of small objects', 'name': 'microscope'}, {'frequency': 'f', 'synset': 'microwave.n.02', 'synonyms': ['microwave_oven'], 'id': 687, 'def': 'kitchen appliance that cooks food by passing an electromagnetic wave through it', 'name': 'microwave_oven'}, {'frequency': 'r', 'synset': 'milestone.n.01', 'synonyms': ['milestone', 'milepost'], 'id': 688, 'def': 'stone post at side of a road to show distances', 'name': 'milestone'}, {'frequency': 'f', 'synset': 'milk.n.01', 'synonyms': ['milk'], 'id': 689, 'def': 'a white nutritious liquid secreted by mammals and used as food by human beings', 'name': 'milk'}, {'frequency': 'r', 'synset': 'milk_can.n.01', 'synonyms': ['milk_can'], 'id': 690, 'def': 'can for transporting milk', 'name': 'milk_can'}, {'frequency': 'r', 'synset': 'milkshake.n.01', 'synonyms': ['milkshake'], 'id': 691, 'def': 'frothy drink of milk and flavoring and sometimes fruit or ice cream', 'name': 'milkshake'}, {'frequency': 'f', 'synset': 'minivan.n.01', 'synonyms': ['minivan'], 'id': 692, 'def': 'a small box-shaped passenger van', 'name': 'minivan'}, {'frequency': 'r', 'synset': 'mint.n.05', 'synonyms': ['mint_candy'], 'id': 693, 'def': 'a candy that is flavored with a mint oil', 'name': 'mint_candy'}, {'frequency': 'f', 'synset': 'mirror.n.01', 'synonyms': ['mirror'], 'id': 694, 'def': 'polished surface that forms images by reflecting light', 'name': 'mirror'}, {'frequency': 'c', 'synset': 'mitten.n.01', 'synonyms': ['mitten'], 'id': 695, 'def': 'glove that encases the thumb separately and the other four fingers together', 'name': 'mitten'}, {'frequency': 'c', 'synset': 'mixer.n.04', 'synonyms': ['mixer_(kitchen_tool)', 'stand_mixer'], 'id': 696, 'def': 'a kitchen utensil that is used for mixing foods', 'name': 'mixer_(kitchen_tool)'}, {'frequency': 'c', 'synset': 'money.n.03', 'synonyms': ['money'], 'id': 697, 'def': 'the official currency issued by a government or national bank', 'name': 'money'}, {'frequency': 'f', 'synset': 'monitor.n.04', 'synonyms': ['monitor_(computer_equipment) computer_monitor'], 'id': 698, 'def': 'a computer monitor', 'name': 'monitor_(computer_equipment) computer_monitor'}, {'frequency': 'c', 'synset': 'monkey.n.01', 'synonyms': ['monkey'], 'id': 699, 'def': 'any of various long-tailed primates', 'name': 'monkey'}, {'frequency': 'f', 'synset': 'motor.n.01', 'synonyms': ['motor'], 'id': 700, 'def': 'machine that converts other forms of energy into mechanical energy and so imparts motion', 'name': 'motor'}, {'frequency': 'f', 'synset': 'motor_scooter.n.01', 'synonyms': ['motor_scooter', 'scooter'], 'id': 701, 'def': 'a wheeled vehicle with small wheels and a low-powered engine', 'name': 'motor_scooter'}, {'frequency': 'r', 'synset': 'motor_vehicle.n.01', 'synonyms': ['motor_vehicle', 'automotive_vehicle'], 'id': 702, 'def': 'a self-propelled wheeled vehicle that does not run on rails', 'name': 'motor_vehicle'}, {'frequency': 'f', 'synset': 'motorcycle.n.01', 'synonyms': ['motorcycle'], 'id': 703, 'def': 'a motor vehicle with two wheels and a strong frame', 'name': 'motorcycle'}, {'frequency': 'f', 'synset': 'mound.n.01', 'synonyms': ['mound_(baseball)', "pitcher's_mound"], 'id': 704, 'def': '(baseball) the slight elevation on which the pitcher stands', 'name': 'mound_(baseball)'}, {'frequency': 'f', 'synset': 'mouse.n.04', 'synonyms': ['mouse_(computer_equipment)', 'computer_mouse'], 'id': 705, 'def': 'a computer input device that controls an on-screen pointer (does not include trackpads / touchpads)', 'name': 'mouse_(computer_equipment)'}, {'frequency': 'f', 'synset': 'mousepad.n.01', 'synonyms': ['mousepad'], 'id': 706, 'def': 'a small portable pad that provides an operating surface for a computer mouse', 'name': 'mousepad'}, {'frequency': 'c', 'synset': 'muffin.n.01', 'synonyms': ['muffin'], 'id': 707, 'def': 'a sweet quick bread baked in a cup-shaped pan', 'name': 'muffin'}, {'frequency': 'f', 'synset': 'mug.n.04', 'synonyms': ['mug'], 'id': 708, 'def': 'with handle and usually cylindrical', 'name': 'mug'}, {'frequency': 'f', 'synset': 'mushroom.n.02', 'synonyms': ['mushroom'], 'id': 709, 'def': 'a common mushroom', 'name': 'mushroom'}, {'frequency': 'r', 'synset': 'music_stool.n.01', 'synonyms': ['music_stool', 'piano_stool'], 'id': 710, 'def': 'a stool for piano players; usually adjustable in height', 'name': 'music_stool'}, {'frequency': 'c', 'synset': 'musical_instrument.n.01', 'synonyms': ['musical_instrument', 'instrument_(musical)'], 'id': 711, 'def': 'any of various devices or contrivances that can be used to produce musical tones or sounds', 'name': 'musical_instrument'}, {'frequency': 'r', 'synset': 'nailfile.n.01', 'synonyms': ['nailfile'], 'id': 712, 'def': 'a small flat file for shaping the nails', 'name': 'nailfile'}, {'frequency': 'f', 'synset': 'napkin.n.01', 'synonyms': ['napkin', 'table_napkin', 'serviette'], 'id': 713, 'def': 'a small piece of table linen or paper that is used to wipe the mouth and to cover the lap in order to protect clothing', 'name': 'napkin'}, {'frequency': 'r', 'synset': 'neckerchief.n.01', 'synonyms': ['neckerchief'], 'id': 714, 'def': 'a kerchief worn around the neck', 'name': 'neckerchief'}, {'frequency': 'f', 'synset': 'necklace.n.01', 'synonyms': ['necklace'], 'id': 715, 'def': 'jewelry consisting of a cord or chain (often bearing gems) worn about the neck as an ornament', 'name': 'necklace'}, {'frequency': 'f', 'synset': 'necktie.n.01', 'synonyms': ['necktie', 'tie_(necktie)'], 'id': 716, 'def': 'neckwear consisting of a long narrow piece of material worn under a collar and tied in knot at the front', 'name': 'necktie'}, {'frequency': 'c', 'synset': 'needle.n.03', 'synonyms': ['needle'], 'id': 717, 'def': 'a sharp pointed implement (usually metal)', 'name': 'needle'}, {'frequency': 'c', 'synset': 'nest.n.01', 'synonyms': ['nest'], 'id': 718, 'def': 'a structure in which animals lay eggs or give birth to their young', 'name': 'nest'}, {'frequency': 'f', 'synset': 'newspaper.n.01', 'synonyms': ['newspaper', 'paper_(newspaper)'], 'id': 719, 'def': 'a daily or weekly publication on folded sheets containing news, articles, and advertisements', 'name': 'newspaper'}, {'frequency': 'c', 'synset': 'newsstand.n.01', 'synonyms': ['newsstand'], 'id': 720, 'def': 'a stall where newspapers and other periodicals are sold', 'name': 'newsstand'}, {'frequency': 'c', 'synset': 'nightwear.n.01', 'synonyms': ['nightshirt', 'nightwear', 'sleepwear', 'nightclothes'], 'id': 721, 'def': 'garments designed to be worn in bed', 'name': 'nightshirt'}, {'frequency': 'r', 'synset': 'nosebag.n.01', 'synonyms': ['nosebag_(for_animals)', 'feedbag'], 'id': 722, 'def': 'a canvas bag that is used to feed an animal (such as a horse); covers the muzzle and fastens at the top of the head', 'name': 'nosebag_(for_animals)'}, {'frequency': 'c', 'synset': 'noseband.n.01', 'synonyms': ['noseband_(for_animals)', 'nosepiece_(for_animals)'], 'id': 723, 'def': "a strap that is the part of a bridle that goes over the animal's nose", 'name': 'noseband_(for_animals)'}, {'frequency': 'f', 'synset': 'notebook.n.01', 'synonyms': ['notebook'], 'id': 724, 'def': 'a book with blank pages for recording notes or memoranda', 'name': 'notebook'}, {'frequency': 'c', 'synset': 'notepad.n.01', 'synonyms': ['notepad'], 'id': 725, 'def': 'a pad of paper for keeping notes', 'name': 'notepad'}, {'frequency': 'f', 'synset': 'nut.n.03', 'synonyms': ['nut'], 'id': 726, 'def': 'a small metal block (usually square or hexagonal) with internal screw thread to be fitted onto a bolt', 'name': 'nut'}, {'frequency': 'r', 'synset': 'nutcracker.n.01', 'synonyms': ['nutcracker'], 'id': 727, 'def': 'a hand tool used to crack nuts open', 'name': 'nutcracker'}, {'frequency': 'f', 'synset': 'oar.n.01', 'synonyms': ['oar'], 'id': 728, 'def': 'an implement used to propel or steer a boat', 'name': 'oar'}, {'frequency': 'r', 'synset': 'octopus.n.01', 'synonyms': ['octopus_(food)'], 'id': 729, 'def': 'tentacles of octopus prepared as food', 'name': 'octopus_(food)'}, {'frequency': 'r', 'synset': 'octopus.n.02', 'synonyms': ['octopus_(animal)'], 'id': 730, 'def': 'bottom-living cephalopod having a soft oval body with eight long tentacles', 'name': 'octopus_(animal)'}, {'frequency': 'c', 'synset': 'oil_lamp.n.01', 'synonyms': ['oil_lamp', 'kerosene_lamp', 'kerosine_lamp'], 'id': 731, 'def': 'a lamp that burns oil (as kerosine) for light', 'name': 'oil_lamp'}, {'frequency': 'c', 'synset': 'olive_oil.n.01', 'synonyms': ['olive_oil'], 'id': 732, 'def': 'oil from olives', 'name': 'olive_oil'}, {'frequency': 'r', 'synset': 'omelet.n.01', 'synonyms': ['omelet', 'omelette'], 'id': 733, 'def': 'beaten eggs cooked until just set; may be folded around e.g. ham or cheese or jelly', 'name': 'omelet'}, {'frequency': 'f', 'synset': 'onion.n.01', 'synonyms': ['onion'], 'id': 734, 'def': 'the bulb of an onion plant', 'name': 'onion'}, {'frequency': 'f', 'synset': 'orange.n.01', 'synonyms': ['orange_(fruit)'], 'id': 735, 'def': 'orange (FRUIT of an orange tree)', 'name': 'orange_(fruit)'}, {'frequency': 'c', 'synset': 'orange_juice.n.01', 'synonyms': ['orange_juice'], 'id': 736, 'def': 'bottled or freshly squeezed juice of oranges', 'name': 'orange_juice'}, {'frequency': 'c', 'synset': 'ostrich.n.02', 'synonyms': ['ostrich'], 'id': 737, 'def': 'fast-running African flightless bird with two-toed feet; largest living bird', 'name': 'ostrich'}, {'frequency': 'f', 'synset': 'ottoman.n.03', 'synonyms': ['ottoman', 'pouf', 'pouffe', 'hassock'], 'id': 738, 'def': 'a thick standalone cushion used as a seat or footrest, often next to a chair', 'name': 'ottoman'}, {'frequency': 'f', 'synset': 'oven.n.01', 'synonyms': ['oven'], 'id': 739, 'def': 'kitchen appliance used for baking or roasting', 'name': 'oven'}, {'frequency': 'c', 'synset': 'overall.n.01', 'synonyms': ['overalls_(clothing)'], 'id': 740, 'def': 'work clothing consisting of denim trousers usually with a bib and shoulder straps', 'name': 'overalls_(clothing)'}, {'frequency': 'c', 'synset': 'owl.n.01', 'synonyms': ['owl'], 'id': 741, 'def': 'nocturnal bird of prey with hawk-like beak and claws and large head with front-facing eyes', 'name': 'owl'}, {'frequency': 'c', 'synset': 'packet.n.03', 'synonyms': ['packet'], 'id': 742, 'def': 'a small package or bundle', 'name': 'packet'}, {'frequency': 'r', 'synset': 'pad.n.03', 'synonyms': ['inkpad', 'inking_pad', 'stamp_pad'], 'id': 743, 'def': 'absorbent material saturated with ink used to transfer ink evenly to a rubber stamp', 'name': 'inkpad'}, {'frequency': 'c', 'synset': 'pad.n.04', 'synonyms': ['pad'], 'id': 744, 'def': 'mostly arm/knee pads labeled', 'name': 'pad'}, {'frequency': 'f', 'synset': 'paddle.n.04', 'synonyms': ['paddle', 'boat_paddle'], 'id': 745, 'def': 'a short light oar used without an oarlock to propel a canoe or small boat', 'name': 'paddle'}, {'frequency': 'c', 'synset': 'padlock.n.01', 'synonyms': ['padlock'], 'id': 746, 'def': 'a detachable, portable lock', 'name': 'padlock'}, {'frequency': 'c', 'synset': 'paintbrush.n.01', 'synonyms': ['paintbrush'], 'id': 747, 'def': 'a brush used as an applicator to apply paint', 'name': 'paintbrush'}, {'frequency': 'f', 'synset': 'painting.n.01', 'synonyms': ['painting'], 'id': 748, 'def': 'graphic art consisting of an artistic composition made by applying paints to a surface', 'name': 'painting'}, {'frequency': 'f', 'synset': 'pajama.n.02', 'synonyms': ['pajamas', 'pyjamas'], 'id': 749, 'def': 'loose-fitting nightclothes worn for sleeping or lounging', 'name': 'pajamas'}, {'frequency': 'c', 'synset': 'palette.n.02', 'synonyms': ['palette', 'pallet'], 'id': 750, 'def': 'board that provides a flat surface on which artists mix paints and the range of colors used', 'name': 'palette'}, {'frequency': 'f', 'synset': 'pan.n.01', 'synonyms': ['pan_(for_cooking)', 'cooking_pan'], 'id': 751, 'def': 'cooking utensil consisting of a wide metal vessel', 'name': 'pan_(for_cooking)'}, {'frequency': 'r', 'synset': 'pan.n.03', 'synonyms': ['pan_(metal_container)'], 'id': 752, 'def': 'shallow container made of metal', 'name': 'pan_(metal_container)'}, {'frequency': 'c', 'synset': 'pancake.n.01', 'synonyms': ['pancake'], 'id': 753, 'def': 'a flat cake of thin batter fried on both sides on a griddle', 'name': 'pancake'}, {'frequency': 'r', 'synset': 'pantyhose.n.01', 'synonyms': ['pantyhose'], 'id': 754, 'def': "a woman's tights consisting of underpants and stockings", 'name': 'pantyhose'}, {'frequency': 'r', 'synset': 'papaya.n.02', 'synonyms': ['papaya'], 'id': 755, 'def': 'large oval melon-like tropical fruit with yellowish flesh', 'name': 'papaya'}, {'frequency': 'f', 'synset': 'paper_plate.n.01', 'synonyms': ['paper_plate'], 'id': 756, 'def': 'a disposable plate made of cardboard', 'name': 'paper_plate'}, {'frequency': 'f', 'synset': 'paper_towel.n.01', 'synonyms': ['paper_towel'], 'id': 757, 'def': 'a disposable towel made of absorbent paper', 'name': 'paper_towel'}, {'frequency': 'r', 'synset': 'paperback_book.n.01', 'synonyms': ['paperback_book', 'paper-back_book', 'softback_book', 'soft-cover_book'], 'id': 758, 'def': 'a book with paper covers', 'name': 'paperback_book'}, {'frequency': 'r', 'synset': 'paperweight.n.01', 'synonyms': ['paperweight'], 'id': 759, 'def': 'a weight used to hold down a stack of papers', 'name': 'paperweight'}, {'frequency': 'c', 'synset': 'parachute.n.01', 'synonyms': ['parachute'], 'id': 760, 'def': 'rescue equipment consisting of a device that fills with air and retards your fall', 'name': 'parachute'}, {'frequency': 'c', 'synset': 'parakeet.n.01', 'synonyms': ['parakeet', 'parrakeet', 'parroket', 'paraquet', 'paroquet', 'parroquet'], 'id': 761, 'def': 'any of numerous small slender long-tailed parrots', 'name': 'parakeet'}, {'frequency': 'c', 'synset': 'parasail.n.01', 'synonyms': ['parasail_(sports)'], 'id': 762, 'def': 'parachute that will lift a person up into the air when it is towed by a motorboat or a car', 'name': 'parasail_(sports)'}, {'frequency': 'c', 'synset': 'parasol.n.01', 'synonyms': ['parasol', 'sunshade'], 'id': 763, 'def': 'a handheld collapsible source of shade', 'name': 'parasol'}, {'frequency': 'r', 'synset': 'parchment.n.01', 'synonyms': ['parchment'], 'id': 764, 'def': 'a superior paper resembling sheepskin', 'name': 'parchment'}, {'frequency': 'c', 'synset': 'parka.n.01', 'synonyms': ['parka', 'anorak'], 'id': 765, 'def': "a kind of heavy jacket (`windcheater' is a British term)", 'name': 'parka'}, {'frequency': 'f', 'synset': 'parking_meter.n.01', 'synonyms': ['parking_meter'], 'id': 766, 'def': 'a coin-operated timer located next to a parking space', 'name': 'parking_meter'}, {'frequency': 'c', 'synset': 'parrot.n.01', 'synonyms': ['parrot'], 'id': 767, 'def': 'usually brightly colored tropical birds with short hooked beaks and the ability to mimic sounds', 'name': 'parrot'}, {'frequency': 'c', 'synset': 'passenger_car.n.01', 'synonyms': ['passenger_car_(part_of_a_train)', 'coach_(part_of_a_train)'], 'id': 768, 'def': 'a railcar where passengers ride', 'name': 'passenger_car_(part_of_a_train)'}, {'frequency': 'r', 'synset': 'passenger_ship.n.01', 'synonyms': ['passenger_ship'], 'id': 769, 'def': 'a ship built to carry passengers', 'name': 'passenger_ship'}, {'frequency': 'c', 'synset': 'passport.n.02', 'synonyms': ['passport'], 'id': 770, 'def': 'a document issued by a country to a citizen allowing that person to travel abroad and re-enter the home country', 'name': 'passport'}, {'frequency': 'f', 'synset': 'pastry.n.02', 'synonyms': ['pastry'], 'id': 771, 'def': 'any of various baked foods made of dough or batter', 'name': 'pastry'}, {'frequency': 'r', 'synset': 'patty.n.01', 'synonyms': ['patty_(food)'], 'id': 772, 'def': 'small flat mass of chopped food', 'name': 'patty_(food)'}, {'frequency': 'c', 'synset': 'pea.n.01', 'synonyms': ['pea_(food)'], 'id': 773, 'def': 'seed of a pea plant used for food', 'name': 'pea_(food)'}, {'frequency': 'c', 'synset': 'peach.n.03', 'synonyms': ['peach'], 'id': 774, 'def': 'downy juicy fruit with sweet yellowish or whitish flesh', 'name': 'peach'}, {'frequency': 'c', 'synset': 'peanut_butter.n.01', 'synonyms': ['peanut_butter'], 'id': 775, 'def': 'a spread made from ground peanuts', 'name': 'peanut_butter'}, {'frequency': 'f', 'synset': 'pear.n.01', 'synonyms': ['pear'], 'id': 776, 'def': 'sweet juicy gritty-textured fruit available in many varieties', 'name': 'pear'}, {'frequency': 'c', 'synset': 'peeler.n.03', 'synonyms': ['peeler_(tool_for_fruit_and_vegetables)'], 'id': 777, 'def': 'a device for peeling vegetables or fruits', 'name': 'peeler_(tool_for_fruit_and_vegetables)'}, {'frequency': 'r', 'synset': 'peg.n.04', 'synonyms': ['wooden_leg', 'pegleg'], 'id': 778, 'def': 'a prosthesis that replaces a missing leg', 'name': 'wooden_leg'}, {'frequency': 'r', 'synset': 'pegboard.n.01', 'synonyms': ['pegboard'], 'id': 779, 'def': 'a board perforated with regularly spaced holes into which pegs can be fitted', 'name': 'pegboard'}, {'frequency': 'c', 'synset': 'pelican.n.01', 'synonyms': ['pelican'], 'id': 780, 'def': 'large long-winged warm-water seabird having a large bill with a distensible pouch for fish', 'name': 'pelican'}, {'frequency': 'f', 'synset': 'pen.n.01', 'synonyms': ['pen'], 'id': 781, 'def': 'a writing implement with a point from which ink flows', 'name': 'pen'}, {'frequency': 'f', 'synset': 'pencil.n.01', 'synonyms': ['pencil'], 'id': 782, 'def': 'a thin cylindrical pointed writing implement made of wood and graphite', 'name': 'pencil'}, {'frequency': 'r', 'synset': 'pencil_box.n.01', 'synonyms': ['pencil_box', 'pencil_case'], 'id': 783, 'def': 'a box for holding pencils', 'name': 'pencil_box'}, {'frequency': 'r', 'synset': 'pencil_sharpener.n.01', 'synonyms': ['pencil_sharpener'], 'id': 784, 'def': 'a rotary implement for sharpening the point on pencils', 'name': 'pencil_sharpener'}, {'frequency': 'r', 'synset': 'pendulum.n.01', 'synonyms': ['pendulum'], 'id': 785, 'def': 'an apparatus consisting of an object mounted so that it swings freely under the influence of gravity', 'name': 'pendulum'}, {'frequency': 'c', 'synset': 'penguin.n.01', 'synonyms': ['penguin'], 'id': 786, 'def': 'short-legged flightless birds of cold southern regions having webbed feet and wings modified as flippers', 'name': 'penguin'}, {'frequency': 'r', 'synset': 'pennant.n.02', 'synonyms': ['pennant'], 'id': 787, 'def': 'a flag longer than it is wide (and often tapering)', 'name': 'pennant'}, {'frequency': 'r', 'synset': 'penny.n.02', 'synonyms': ['penny_(coin)'], 'id': 788, 'def': 'a coin worth one-hundredth of the value of the basic unit', 'name': 'penny_(coin)'}, {'frequency': 'f', 'synset': 'pepper.n.03', 'synonyms': ['pepper', 'peppercorn'], 'id': 789, 'def': 'pungent seasoning from the berry of the common pepper plant; whole or ground', 'name': 'pepper'}, {'frequency': 'c', 'synset': 'pepper_mill.n.01', 'synonyms': ['pepper_mill', 'pepper_grinder'], 'id': 790, 'def': 'a mill for grinding pepper', 'name': 'pepper_mill'}, {'frequency': 'c', 'synset': 'perfume.n.02', 'synonyms': ['perfume'], 'id': 791, 'def': 'a toiletry that emits and diffuses a fragrant odor', 'name': 'perfume'}, {'frequency': 'r', 'synset': 'persimmon.n.02', 'synonyms': ['persimmon'], 'id': 792, 'def': 'orange fruit resembling a plum; edible when fully ripe', 'name': 'persimmon'}, {'frequency': 'f', 'synset': 'person.n.01', 'synonyms': ['person', 'baby', 'child', 'boy', 'girl', 'man', 'woman', 'human'], 'id': 793, 'def': 'a human being', 'name': 'person'}, {'frequency': 'c', 'synset': 'pet.n.01', 'synonyms': ['pet'], 'id': 794, 'def': 'a domesticated animal kept for companionship or amusement', 'name': 'pet'}, {'frequency': 'c', 'synset': 'pew.n.01', 'synonyms': ['pew_(church_bench)', 'church_bench'], 'id': 795, 'def': 'long bench with backs; used in church by the congregation', 'name': 'pew_(church_bench)'}, {'frequency': 'r', 'synset': 'phonebook.n.01', 'synonyms': ['phonebook', 'telephone_book', 'telephone_directory'], 'id': 796, 'def': 'a directory containing an alphabetical list of telephone subscribers and their telephone numbers', 'name': 'phonebook'}, {'frequency': 'c', 'synset': 'phonograph_record.n.01', 'synonyms': ['phonograph_record', 'phonograph_recording', 'record_(phonograph_recording)'], 'id': 797, 'def': 'sound recording consisting of a typically black disk with a continuous groove', 'name': 'phonograph_record'}, {'frequency': 'f', 'synset': 'piano.n.01', 'synonyms': ['piano'], 'id': 798, 'def': 'a keyboard instrument that is played by depressing keys that cause hammers to strike tuned strings and produce sounds', 'name': 'piano'}, {'frequency': 'f', 'synset': 'pickle.n.01', 'synonyms': ['pickle'], 'id': 799, 'def': 'vegetables (especially cucumbers) preserved in brine or vinegar', 'name': 'pickle'}, {'frequency': 'f', 'synset': 'pickup.n.01', 'synonyms': ['pickup_truck'], 'id': 800, 'def': 'a light truck with an open body and low sides and a tailboard', 'name': 'pickup_truck'}, {'frequency': 'c', 'synset': 'pie.n.01', 'synonyms': ['pie'], 'id': 801, 'def': 'dish baked in pastry-lined pan often with a pastry top', 'name': 'pie'}, {'frequency': 'c', 'synset': 'pigeon.n.01', 'synonyms': ['pigeon'], 'id': 802, 'def': 'wild and domesticated birds having a heavy body and short legs', 'name': 'pigeon'}, {'frequency': 'r', 'synset': 'piggy_bank.n.01', 'synonyms': ['piggy_bank', 'penny_bank'], 'id': 803, 'def': "a child's coin bank (often shaped like a pig)", 'name': 'piggy_bank'}, {'frequency': 'f', 'synset': 'pillow.n.01', 'synonyms': ['pillow'], 'id': 804, 'def': 'a cushion to support the head of a sleeping person', 'name': 'pillow'}, {'frequency': 'r', 'synset': 'pin.n.09', 'synonyms': ['pin_(non_jewelry)'], 'id': 805, 'def': 'a small slender (often pointed) piece of wood or metal used to support or fasten or attach things', 'name': 'pin_(non_jewelry)'}, {'frequency': 'f', 'synset': 'pineapple.n.02', 'synonyms': ['pineapple'], 'id': 806, 'def': 'large sweet fleshy tropical fruit with a tuft of stiff leaves', 'name': 'pineapple'}, {'frequency': 'c', 'synset': 'pinecone.n.01', 'synonyms': ['pinecone'], 'id': 807, 'def': 'the seed-producing cone of a pine tree', 'name': 'pinecone'}, {'frequency': 'r', 'synset': 'ping-pong_ball.n.01', 'synonyms': ['ping-pong_ball'], 'id': 808, 'def': 'light hollow ball used in playing table tennis', 'name': 'ping-pong_ball'}, {'frequency': 'r', 'synset': 'pinwheel.n.03', 'synonyms': ['pinwheel'], 'id': 809, 'def': 'a toy consisting of vanes of colored paper or plastic that is pinned to a stick and spins when it is pointed into the wind', 'name': 'pinwheel'}, {'frequency': 'r', 'synset': 'pipe.n.01', 'synonyms': ['tobacco_pipe'], 'id': 810, 'def': 'a tube with a small bowl at one end; used for smoking tobacco', 'name': 'tobacco_pipe'}, {'frequency': 'f', 'synset': 'pipe.n.02', 'synonyms': ['pipe', 'piping'], 'id': 811, 'def': 'a long tube made of metal or plastic that is used to carry water or oil or gas etc.', 'name': 'pipe'}, {'frequency': 'r', 'synset': 'pistol.n.01', 'synonyms': ['pistol', 'handgun'], 'id': 812, 'def': 'a firearm that is held and fired with one hand', 'name': 'pistol'}, {'frequency': 'c', 'synset': 'pita.n.01', 'synonyms': ['pita_(bread)', 'pocket_bread'], 'id': 813, 'def': 'usually small round bread that can open into a pocket for filling', 'name': 'pita_(bread)'}, {'frequency': 'f', 'synset': 'pitcher.n.02', 'synonyms': ['pitcher_(vessel_for_liquid)', 'ewer'], 'id': 814, 'def': 'an open vessel with a handle and a spout for pouring', 'name': 'pitcher_(vessel_for_liquid)'}, {'frequency': 'r', 'synset': 'pitchfork.n.01', 'synonyms': ['pitchfork'], 'id': 815, 'def': 'a long-handled hand tool with sharp widely spaced prongs for lifting and pitching hay', 'name': 'pitchfork'}, {'frequency': 'f', 'synset': 'pizza.n.01', 'synonyms': ['pizza'], 'id': 816, 'def': 'Italian open pie made of thin bread dough spread with a spiced mixture of e.g. tomato sauce and cheese', 'name': 'pizza'}, {'frequency': 'f', 'synset': 'place_mat.n.01', 'synonyms': ['place_mat'], 'id': 817, 'def': 'a mat placed on a table for an individual place setting', 'name': 'place_mat'}, {'frequency': 'f', 'synset': 'plate.n.04', 'synonyms': ['plate'], 'id': 818, 'def': 'dish on which food is served or from which food is eaten', 'name': 'plate'}, {'frequency': 'c', 'synset': 'platter.n.01', 'synonyms': ['platter'], 'id': 819, 'def': 'a large shallow dish used for serving food', 'name': 'platter'}, {'frequency': 'r', 'synset': 'playpen.n.01', 'synonyms': ['playpen'], 'id': 820, 'def': 'a portable enclosure in which babies may be left to play', 'name': 'playpen'}, {'frequency': 'c', 'synset': 'pliers.n.01', 'synonyms': ['pliers', 'plyers'], 'id': 821, 'def': 'a gripping hand tool with two hinged arms and (usually) serrated jaws', 'name': 'pliers'}, {'frequency': 'r', 'synset': 'plow.n.01', 'synonyms': ['plow_(farm_equipment)', 'plough_(farm_equipment)'], 'id': 822, 'def': 'a farm tool having one or more heavy blades to break the soil and cut a furrow prior to sowing', 'name': 'plow_(farm_equipment)'}, {'frequency': 'r', 'synset': 'plume.n.02', 'synonyms': ['plume'], 'id': 823, 'def': 'a feather or cluster of feathers worn as an ornament', 'name': 'plume'}, {'frequency': 'r', 'synset': 'pocket_watch.n.01', 'synonyms': ['pocket_watch'], 'id': 824, 'def': 'a watch that is carried in a small watch pocket', 'name': 'pocket_watch'}, {'frequency': 'c', 'synset': 'pocketknife.n.01', 'synonyms': ['pocketknife'], 'id': 825, 'def': 'a knife with a blade that folds into the handle; suitable for carrying in the pocket', 'name': 'pocketknife'}, {'frequency': 'c', 'synset': 'poker.n.01', 'synonyms': ['poker_(fire_stirring_tool)', 'stove_poker', 'fire_hook'], 'id': 826, 'def': 'fire iron consisting of a metal rod with a handle; used to stir a fire', 'name': 'poker_(fire_stirring_tool)'}, {'frequency': 'f', 'synset': 'pole.n.01', 'synonyms': ['pole', 'post'], 'id': 827, 'def': 'a long (usually round) rod of wood or metal or plastic', 'name': 'pole'}, {'frequency': 'f', 'synset': 'polo_shirt.n.01', 'synonyms': ['polo_shirt', 'sport_shirt'], 'id': 828, 'def': 'a shirt with short sleeves designed for comfort and casual wear', 'name': 'polo_shirt'}, {'frequency': 'r', 'synset': 'poncho.n.01', 'synonyms': ['poncho'], 'id': 829, 'def': 'a blanket-like cloak with a hole in the center for the head', 'name': 'poncho'}, {'frequency': 'c', 'synset': 'pony.n.05', 'synonyms': ['pony'], 'id': 830, 'def': 'any of various breeds of small gentle horses usually less than five feet high at the shoulder', 'name': 'pony'}, {'frequency': 'r', 'synset': 'pool_table.n.01', 'synonyms': ['pool_table', 'billiard_table', 'snooker_table'], 'id': 831, 'def': 'game equipment consisting of a heavy table on which pool is played', 'name': 'pool_table'}, {'frequency': 'f', 'synset': 'pop.n.02', 'synonyms': ['pop_(soda)', 'soda_(pop)', 'tonic', 'soft_drink'], 'id': 832, 'def': 'a sweet drink containing carbonated water and flavoring', 'name': 'pop_(soda)'}, {'frequency': 'c', 'synset': 'postbox.n.01', 'synonyms': ['postbox_(public)', 'mailbox_(public)'], 'id': 833, 'def': 'public box for deposit of mail', 'name': 'postbox_(public)'}, {'frequency': 'c', 'synset': 'postcard.n.01', 'synonyms': ['postcard', 'postal_card', 'mailing-card'], 'id': 834, 'def': 'a card for sending messages by post without an envelope', 'name': 'postcard'}, {'frequency': 'f', 'synset': 'poster.n.01', 'synonyms': ['poster', 'placard'], 'id': 835, 'def': 'a sign posted in a public place as an advertisement', 'name': 'poster'}, {'frequency': 'f', 'synset': 'pot.n.01', 'synonyms': ['pot'], 'id': 836, 'def': 'metal or earthenware cooking vessel that is usually round and deep; often has a handle and lid', 'name': 'pot'}, {'frequency': 'f', 'synset': 'pot.n.04', 'synonyms': ['flowerpot'], 'id': 837, 'def': 'a container in which plants are cultivated', 'name': 'flowerpot'}, {'frequency': 'f', 'synset': 'potato.n.01', 'synonyms': ['potato'], 'id': 838, 'def': 'an edible tuber native to South America', 'name': 'potato'}, {'frequency': 'c', 'synset': 'potholder.n.01', 'synonyms': ['potholder'], 'id': 839, 'def': 'an insulated pad for holding hot pots', 'name': 'potholder'}, {'frequency': 'c', 'synset': 'pottery.n.01', 'synonyms': ['pottery', 'clayware'], 'id': 840, 'def': 'ceramic ware made from clay and baked in a kiln', 'name': 'pottery'}, {'frequency': 'c', 'synset': 'pouch.n.01', 'synonyms': ['pouch'], 'id': 841, 'def': 'a small or medium size container for holding or carrying things', 'name': 'pouch'}, {'frequency': 'c', 'synset': 'power_shovel.n.01', 'synonyms': ['power_shovel', 'excavator', 'digger'], 'id': 842, 'def': 'a machine for excavating', 'name': 'power_shovel'}, {'frequency': 'c', 'synset': 'prawn.n.01', 'synonyms': ['prawn', 'shrimp'], 'id': 843, 'def': 'any of various edible decapod crustaceans', 'name': 'prawn'}, {'frequency': 'c', 'synset': 'pretzel.n.01', 'synonyms': ['pretzel'], 'id': 844, 'def': 'glazed and salted cracker typically in the shape of a loose knot', 'name': 'pretzel'}, {'frequency': 'f', 'synset': 'printer.n.03', 'synonyms': ['printer', 'printing_machine'], 'id': 845, 'def': 'a machine that prints', 'name': 'printer'}, {'frequency': 'c', 'synset': 'projectile.n.01', 'synonyms': ['projectile_(weapon)', 'missile'], 'id': 846, 'def': 'a weapon that is forcibly thrown or projected at a targets', 'name': 'projectile_(weapon)'}, {'frequency': 'c', 'synset': 'projector.n.02', 'synonyms': ['projector'], 'id': 847, 'def': 'an optical instrument that projects an enlarged image onto a screen', 'name': 'projector'}, {'frequency': 'f', 'synset': 'propeller.n.01', 'synonyms': ['propeller', 'propellor'], 'id': 848, 'def': 'a mechanical device that rotates to push against air or water', 'name': 'propeller'}, {'frequency': 'r', 'synset': 'prune.n.01', 'synonyms': ['prune'], 'id': 849, 'def': 'dried plum', 'name': 'prune'}, {'frequency': 'r', 'synset': 'pudding.n.01', 'synonyms': ['pudding'], 'id': 850, 'def': 'any of various soft thick unsweetened baked dishes', 'name': 'pudding'}, {'frequency': 'r', 'synset': 'puffer.n.02', 'synonyms': ['puffer_(fish)', 'pufferfish', 'blowfish', 'globefish'], 'id': 851, 'def': 'fishes whose elongated spiny body can inflate itself with water or air to form a globe', 'name': 'puffer_(fish)'}, {'frequency': 'r', 'synset': 'puffin.n.01', 'synonyms': ['puffin'], 'id': 852, 'def': 'seabirds having short necks and brightly colored compressed bills', 'name': 'puffin'}, {'frequency': 'r', 'synset': 'pug.n.01', 'synonyms': ['pug-dog'], 'id': 853, 'def': 'small compact smooth-coated breed of Asiatic origin having a tightly curled tail and broad flat wrinkled muzzle', 'name': 'pug-dog'}, {'frequency': 'c', 'synset': 'pumpkin.n.02', 'synonyms': ['pumpkin'], 'id': 854, 'def': 'usually large pulpy deep-yellow round fruit of the squash family maturing in late summer or early autumn', 'name': 'pumpkin'}, {'frequency': 'r', 'synset': 'punch.n.03', 'synonyms': ['puncher'], 'id': 855, 'def': 'a tool for making holes or indentations', 'name': 'puncher'}, {'frequency': 'r', 'synset': 'puppet.n.01', 'synonyms': ['puppet', 'marionette'], 'id': 856, 'def': 'a small figure of a person operated from above with strings by a puppeteer', 'name': 'puppet'}, {'frequency': 'c', 'synset': 'puppy.n.01', 'synonyms': ['puppy'], 'id': 857, 'def': 'a young dog', 'name': 'puppy'}, {'frequency': 'r', 'synset': 'quesadilla.n.01', 'synonyms': ['quesadilla'], 'id': 858, 'def': 'a tortilla that is filled with cheese and heated', 'name': 'quesadilla'}, {'frequency': 'r', 'synset': 'quiche.n.02', 'synonyms': ['quiche'], 'id': 859, 'def': 'a tart filled with rich unsweetened custard; often contains other ingredients (as cheese or ham or seafood or vegetables)', 'name': 'quiche'}, {'frequency': 'f', 'synset': 'quilt.n.01', 'synonyms': ['quilt', 'comforter'], 'id': 860, 'def': 'bedding made of two layers of cloth filled with stuffing and stitched together', 'name': 'quilt'}, {'frequency': 'c', 'synset': 'rabbit.n.01', 'synonyms': ['rabbit'], 'id': 861, 'def': 'any of various burrowing animals of the family Leporidae having long ears and short tails', 'name': 'rabbit'}, {'frequency': 'r', 'synset': 'racer.n.02', 'synonyms': ['race_car', 'racing_car'], 'id': 862, 'def': 'a fast car that competes in races', 'name': 'race_car'}, {'frequency': 'c', 'synset': 'racket.n.04', 'synonyms': ['racket', 'racquet'], 'id': 863, 'def': 'a sports implement used to strike a ball in various games', 'name': 'racket'}, {'frequency': 'r', 'synset': 'radar.n.01', 'synonyms': ['radar'], 'id': 864, 'def': 'measuring instrument in which the echo of a pulse of microwave radiation is used to detect and locate distant objects', 'name': 'radar'}, {'frequency': 'f', 'synset': 'radiator.n.03', 'synonyms': ['radiator'], 'id': 865, 'def': 'a mechanism consisting of a metal honeycomb through which hot fluids circulate', 'name': 'radiator'}, {'frequency': 'c', 'synset': 'radio_receiver.n.01', 'synonyms': ['radio_receiver', 'radio_set', 'radio', 'tuner_(radio)'], 'id': 866, 'def': 'an electronic receiver that detects and demodulates and amplifies transmitted radio signals', 'name': 'radio_receiver'}, {'frequency': 'c', 'synset': 'radish.n.03', 'synonyms': ['radish', 'daikon'], 'id': 867, 'def': 'pungent edible root of any of various cultivated radish plants', 'name': 'radish'}, {'frequency': 'c', 'synset': 'raft.n.01', 'synonyms': ['raft'], 'id': 868, 'def': 'a flat float (usually made of logs or planks) that can be used for transport or as a platform for swimmers', 'name': 'raft'}, {'frequency': 'r', 'synset': 'rag_doll.n.01', 'synonyms': ['rag_doll'], 'id': 869, 'def': 'a cloth doll that is stuffed and (usually) painted', 'name': 'rag_doll'}, {'frequency': 'c', 'synset': 'raincoat.n.01', 'synonyms': ['raincoat', 'waterproof_jacket'], 'id': 870, 'def': 'a water-resistant coat', 'name': 'raincoat'}, {'frequency': 'c', 'synset': 'ram.n.05', 'synonyms': ['ram_(animal)'], 'id': 871, 'def': 'uncastrated adult male sheep', 'name': 'ram_(animal)'}, {'frequency': 'c', 'synset': 'raspberry.n.02', 'synonyms': ['raspberry'], 'id': 872, 'def': 'red or black edible aggregate berries usually smaller than the related blackberries', 'name': 'raspberry'}, {'frequency': 'r', 'synset': 'rat.n.01', 'synonyms': ['rat'], 'id': 873, 'def': 'any of various long-tailed rodents similar to but larger than a mouse', 'name': 'rat'}, {'frequency': 'c', 'synset': 'razorblade.n.01', 'synonyms': ['razorblade'], 'id': 874, 'def': 'a blade that has very sharp edge', 'name': 'razorblade'}, {'frequency': 'c', 'synset': 'reamer.n.01', 'synonyms': ['reamer_(juicer)', 'juicer', 'juice_reamer'], 'id': 875, 'def': 'a squeezer with a conical ridged center that is used for squeezing juice from citrus fruit', 'name': 'reamer_(juicer)'}, {'frequency': 'f', 'synset': 'rearview_mirror.n.01', 'synonyms': ['rearview_mirror'], 'id': 876, 'def': 'vehicle mirror (side or rearview)', 'name': 'rearview_mirror'}, {'frequency': 'c', 'synset': 'receipt.n.02', 'synonyms': ['receipt'], 'id': 877, 'def': 'an acknowledgment (usually tangible) that payment has been made', 'name': 'receipt'}, {'frequency': 'c', 'synset': 'recliner.n.01', 'synonyms': ['recliner', 'reclining_chair', 'lounger_(chair)'], 'id': 878, 'def': 'an armchair whose back can be lowered and foot can be raised to allow the sitter to recline in it', 'name': 'recliner'}, {'frequency': 'c', 'synset': 'record_player.n.01', 'synonyms': ['record_player', 'phonograph_(record_player)', 'turntable'], 'id': 879, 'def': 'machine in which rotating records cause a stylus to vibrate and the vibrations are amplified acoustically or electronically', 'name': 'record_player'}, {'frequency': 'f', 'synset': 'reflector.n.01', 'synonyms': ['reflector'], 'id': 880, 'def': 'device that reflects light, radiation, etc.', 'name': 'reflector'}, {'frequency': 'f', 'synset': 'remote_control.n.01', 'synonyms': ['remote_control'], 'id': 881, 'def': 'a device that can be used to control a machine or apparatus from a distance', 'name': 'remote_control'}, {'frequency': 'c', 'synset': 'rhinoceros.n.01', 'synonyms': ['rhinoceros'], 'id': 882, 'def': 'massive powerful herbivorous odd-toed ungulate of southeast Asia and Africa having very thick skin and one or two horns on the snout', 'name': 'rhinoceros'}, {'frequency': 'r', 'synset': 'rib.n.03', 'synonyms': ['rib_(food)'], 'id': 883, 'def': 'cut of meat including one or more ribs', 'name': 'rib_(food)'}, {'frequency': 'c', 'synset': 'rifle.n.01', 'synonyms': ['rifle'], 'id': 884, 'def': 'a shoulder firearm with a long barrel', 'name': 'rifle'}, {'frequency': 'f', 'synset': 'ring.n.08', 'synonyms': ['ring'], 'id': 885, 'def': 'jewelry consisting of a circlet of precious metal (often set with jewels) worn on the finger', 'name': 'ring'}, {'frequency': 'r', 'synset': 'river_boat.n.01', 'synonyms': ['river_boat'], 'id': 886, 'def': 'a boat used on rivers or to ply a river', 'name': 'river_boat'}, {'frequency': 'r', 'synset': 'road_map.n.02', 'synonyms': ['road_map'], 'id': 887, 'def': '(NOT A ROAD) a MAP showing roads (for automobile travel)', 'name': 'road_map'}, {'frequency': 'c', 'synset': 'robe.n.01', 'synonyms': ['robe'], 'id': 888, 'def': 'any loose flowing garment', 'name': 'robe'}, {'frequency': 'c', 'synset': 'rocking_chair.n.01', 'synonyms': ['rocking_chair'], 'id': 889, 'def': 'a chair mounted on rockers', 'name': 'rocking_chair'}, {'frequency': 'r', 'synset': 'rodent.n.01', 'synonyms': ['rodent'], 'id': 890, 'def': 'relatively small placental mammals having a single pair of constantly growing incisor teeth specialized for gnawing', 'name': 'rodent'}, {'frequency': 'r', 'synset': 'roller_skate.n.01', 'synonyms': ['roller_skate'], 'id': 891, 'def': 'a shoe with pairs of rollers (small hard wheels) fixed to the sole', 'name': 'roller_skate'}, {'frequency': 'r', 'synset': 'rollerblade.n.01', 'synonyms': ['Rollerblade'], 'id': 892, 'def': 'an in-line variant of a roller skate', 'name': 'Rollerblade'}, {'frequency': 'c', 'synset': 'rolling_pin.n.01', 'synonyms': ['rolling_pin'], 'id': 893, 'def': 'utensil consisting of a cylinder (usually of wood) with a handle at each end; used to roll out dough', 'name': 'rolling_pin'}, {'frequency': 'r', 'synset': 'root_beer.n.01', 'synonyms': ['root_beer'], 'id': 894, 'def': 'carbonated drink containing extracts of roots and herbs', 'name': 'root_beer'}, {'frequency': 'c', 'synset': 'router.n.02', 'synonyms': ['router_(computer_equipment)'], 'id': 895, 'def': 'a device that forwards data packets between computer networks', 'name': 'router_(computer_equipment)'}, {'frequency': 'f', 'synset': 'rubber_band.n.01', 'synonyms': ['rubber_band', 'elastic_band'], 'id': 896, 'def': 'a narrow band of elastic rubber used to hold things (such as papers) together', 'name': 'rubber_band'}, {'frequency': 'c', 'synset': 'runner.n.08', 'synonyms': ['runner_(carpet)'], 'id': 897, 'def': 'a long narrow carpet', 'name': 'runner_(carpet)'}, {'frequency': 'f', 'synset': 'sack.n.01', 'synonyms': ['plastic_bag', 'paper_bag'], 'id': 898, 'def': "a bag made of paper or plastic for holding customer's purchases", 'name': 'plastic_bag'}, {'frequency': 'f', 'synset': 'saddle.n.01', 'synonyms': ['saddle_(on_an_animal)'], 'id': 899, 'def': 'a seat for the rider of a horse or camel', 'name': 'saddle_(on_an_animal)'}, {'frequency': 'f', 'synset': 'saddle_blanket.n.01', 'synonyms': ['saddle_blanket', 'saddlecloth', 'horse_blanket'], 'id': 900, 'def': 'stable gear consisting of a blanket placed under the saddle', 'name': 'saddle_blanket'}, {'frequency': 'c', 'synset': 'saddlebag.n.01', 'synonyms': ['saddlebag'], 'id': 901, 'def': 'a large bag (or pair of bags) hung over a saddle', 'name': 'saddlebag'}, {'frequency': 'r', 'synset': 'safety_pin.n.01', 'synonyms': ['safety_pin'], 'id': 902, 'def': 'a pin in the form of a clasp; has a guard so the point of the pin will not stick the user', 'name': 'safety_pin'}, {'frequency': 'f', 'synset': 'sail.n.01', 'synonyms': ['sail'], 'id': 903, 'def': 'a large piece of fabric by means of which wind is used to propel a sailing vessel', 'name': 'sail'}, {'frequency': 'f', 'synset': 'salad.n.01', 'synonyms': ['salad'], 'id': 904, 'def': 'food mixtures either arranged on a plate or tossed and served with a moist dressing; usually consisting of or including greens', 'name': 'salad'}, {'frequency': 'r', 'synset': 'salad_plate.n.01', 'synonyms': ['salad_plate', 'salad_bowl'], 'id': 905, 'def': 'a plate or bowl for individual servings of salad', 'name': 'salad_plate'}, {'frequency': 'c', 'synset': 'salami.n.01', 'synonyms': ['salami'], 'id': 906, 'def': 'highly seasoned fatty sausage of pork and beef usually dried', 'name': 'salami'}, {'frequency': 'c', 'synset': 'salmon.n.01', 'synonyms': ['salmon_(fish)'], 'id': 907, 'def': 'any of various large food and game fishes of northern waters', 'name': 'salmon_(fish)'}, {'frequency': 'r', 'synset': 'salmon.n.03', 'synonyms': ['salmon_(food)'], 'id': 908, 'def': 'flesh of any of various marine or freshwater fish of the family Salmonidae', 'name': 'salmon_(food)'}, {'frequency': 'c', 'synset': 'salsa.n.01', 'synonyms': ['salsa'], 'id': 909, 'def': 'spicy sauce of tomatoes and onions and chili peppers to accompany Mexican foods', 'name': 'salsa'}, {'frequency': 'f', 'synset': 'saltshaker.n.01', 'synonyms': ['saltshaker'], 'id': 910, 'def': 'a shaker with a perforated top for sprinkling salt', 'name': 'saltshaker'}, {'frequency': 'f', 'synset': 'sandal.n.01', 'synonyms': ['sandal_(type_of_shoe)'], 'id': 911, 'def': 'a shoe consisting of a sole fastened by straps to the foot', 'name': 'sandal_(type_of_shoe)'}, {'frequency': 'f', 'synset': 'sandwich.n.01', 'synonyms': ['sandwich'], 'id': 912, 'def': 'two (or more) slices of bread with a filling between them', 'name': 'sandwich'}, {'frequency': 'r', 'synset': 'satchel.n.01', 'synonyms': ['satchel'], 'id': 913, 'def': 'luggage consisting of a small case with a flat bottom and (usually) a shoulder strap', 'name': 'satchel'}, {'frequency': 'r', 'synset': 'saucepan.n.01', 'synonyms': ['saucepan'], 'id': 914, 'def': 'a deep pan with a handle; used for stewing or boiling', 'name': 'saucepan'}, {'frequency': 'f', 'synset': 'saucer.n.02', 'synonyms': ['saucer'], 'id': 915, 'def': 'a small shallow dish for holding a cup at the table', 'name': 'saucer'}, {'frequency': 'f', 'synset': 'sausage.n.01', 'synonyms': ['sausage'], 'id': 916, 'def': 'highly seasoned minced meat stuffed in casings', 'name': 'sausage'}, {'frequency': 'r', 'synset': 'sawhorse.n.01', 'synonyms': ['sawhorse', 'sawbuck'], 'id': 917, 'def': 'a framework for holding wood that is being sawed', 'name': 'sawhorse'}, {'frequency': 'r', 'synset': 'sax.n.02', 'synonyms': ['saxophone'], 'id': 918, 'def': "a wind instrument with a `J'-shaped form typically made of brass", 'name': 'saxophone'}, {'frequency': 'f', 'synset': 'scale.n.07', 'synonyms': ['scale_(measuring_instrument)'], 'id': 919, 'def': 'a measuring instrument for weighing; shows amount of mass', 'name': 'scale_(measuring_instrument)'}, {'frequency': 'r', 'synset': 'scarecrow.n.01', 'synonyms': ['scarecrow', 'strawman'], 'id': 920, 'def': 'an effigy in the shape of a man to frighten birds away from seeds', 'name': 'scarecrow'}, {'frequency': 'f', 'synset': 'scarf.n.01', 'synonyms': ['scarf'], 'id': 921, 'def': 'a garment worn around the head or neck or shoulders for warmth or decoration', 'name': 'scarf'}, {'frequency': 'c', 'synset': 'school_bus.n.01', 'synonyms': ['school_bus'], 'id': 922, 'def': 'a bus used to transport children to or from school', 'name': 'school_bus'}, {'frequency': 'f', 'synset': 'scissors.n.01', 'synonyms': ['scissors'], 'id': 923, 'def': 'a tool having two crossed pivoting blades with looped handles', 'name': 'scissors'}, {'frequency': 'f', 'synset': 'scoreboard.n.01', 'synonyms': ['scoreboard'], 'id': 924, 'def': 'a large board for displaying the score of a contest (and some other information)', 'name': 'scoreboard'}, {'frequency': 'r', 'synset': 'scraper.n.01', 'synonyms': ['scraper'], 'id': 925, 'def': 'any of various hand tools for scraping', 'name': 'scraper'}, {'frequency': 'c', 'synset': 'screwdriver.n.01', 'synonyms': ['screwdriver'], 'id': 926, 'def': 'a hand tool for driving screws; has a tip that fits into the head of a screw', 'name': 'screwdriver'}, {'frequency': 'f', 'synset': 'scrub_brush.n.01', 'synonyms': ['scrubbing_brush'], 'id': 927, 'def': 'a brush with short stiff bristles for heavy cleaning', 'name': 'scrubbing_brush'}, {'frequency': 'c', 'synset': 'sculpture.n.01', 'synonyms': ['sculpture'], 'id': 928, 'def': 'a three-dimensional work of art', 'name': 'sculpture'}, {'frequency': 'c', 'synset': 'seabird.n.01', 'synonyms': ['seabird', 'seafowl'], 'id': 929, 'def': 'a bird that frequents coastal waters and the open ocean: gulls; pelicans; gannets; cormorants; albatrosses; petrels; etc.', 'name': 'seabird'}, {'frequency': 'c', 'synset': 'seahorse.n.02', 'synonyms': ['seahorse'], 'id': 930, 'def': 'small fish with horse-like heads bent sharply downward and curled tails', 'name': 'seahorse'}, {'frequency': 'r', 'synset': 'seaplane.n.01', 'synonyms': ['seaplane', 'hydroplane'], 'id': 931, 'def': 'an airplane that can land on or take off from water', 'name': 'seaplane'}, {'frequency': 'c', 'synset': 'seashell.n.01', 'synonyms': ['seashell'], 'id': 932, 'def': 'the shell of a marine organism', 'name': 'seashell'}, {'frequency': 'c', 'synset': 'sewing_machine.n.01', 'synonyms': ['sewing_machine'], 'id': 933, 'def': 'a textile machine used as a home appliance for sewing', 'name': 'sewing_machine'}, {'frequency': 'c', 'synset': 'shaker.n.03', 'synonyms': ['shaker'], 'id': 934, 'def': 'a container in which something can be shaken', 'name': 'shaker'}, {'frequency': 'c', 'synset': 'shampoo.n.01', 'synonyms': ['shampoo'], 'id': 935, 'def': 'cleansing agent consisting of soaps or detergents used for washing the hair', 'name': 'shampoo'}, {'frequency': 'c', 'synset': 'shark.n.01', 'synonyms': ['shark'], 'id': 936, 'def': 'typically large carnivorous fishes with sharpe teeth', 'name': 'shark'}, {'frequency': 'r', 'synset': 'sharpener.n.01', 'synonyms': ['sharpener'], 'id': 937, 'def': 'any implement that is used to make something (an edge or a point) sharper', 'name': 'sharpener'}, {'frequency': 'r', 'synset': 'sharpie.n.03', 'synonyms': ['Sharpie'], 'id': 938, 'def': 'a pen with indelible ink that will write on any surface', 'name': 'Sharpie'}, {'frequency': 'r', 'synset': 'shaver.n.03', 'synonyms': ['shaver_(electric)', 'electric_shaver', 'electric_razor'], 'id': 939, 'def': 'a razor powered by an electric motor', 'name': 'shaver_(electric)'}, {'frequency': 'c', 'synset': 'shaving_cream.n.01', 'synonyms': ['shaving_cream', 'shaving_soap'], 'id': 940, 'def': 'toiletry consisting that forms a rich lather for softening the beard before shaving', 'name': 'shaving_cream'}, {'frequency': 'r', 'synset': 'shawl.n.01', 'synonyms': ['shawl'], 'id': 941, 'def': 'cloak consisting of an oblong piece of cloth used to cover the head and shoulders', 'name': 'shawl'}, {'frequency': 'r', 'synset': 'shears.n.01', 'synonyms': ['shears'], 'id': 942, 'def': 'large scissors with strong blades', 'name': 'shears'}, {'frequency': 'f', 'synset': 'sheep.n.01', 'synonyms': ['sheep'], 'id': 943, 'def': 'woolly usually horned ruminant mammal related to the goat', 'name': 'sheep'}, {'frequency': 'r', 'synset': 'shepherd_dog.n.01', 'synonyms': ['shepherd_dog', 'sheepdog'], 'id': 944, 'def': 'any of various usually long-haired breeds of dog reared to herd and guard sheep', 'name': 'shepherd_dog'}, {'frequency': 'r', 'synset': 'sherbert.n.01', 'synonyms': ['sherbert', 'sherbet'], 'id': 945, 'def': 'a frozen dessert made primarily of fruit juice and sugar', 'name': 'sherbert'}, {'frequency': 'c', 'synset': 'shield.n.02', 'synonyms': ['shield'], 'id': 946, 'def': 'armor carried on the arm to intercept blows', 'name': 'shield'}, {'frequency': 'f', 'synset': 'shirt.n.01', 'synonyms': ['shirt'], 'id': 947, 'def': 'a garment worn on the upper half of the body', 'name': 'shirt'}, {'frequency': 'f', 'synset': 'shoe.n.01', 'synonyms': ['shoe', 'sneaker_(type_of_shoe)', 'tennis_shoe'], 'id': 948, 'def': 'common footwear covering the foot', 'name': 'shoe'}, {'frequency': 'f', 'synset': 'shopping_bag.n.01', 'synonyms': ['shopping_bag'], 'id': 949, 'def': 'a bag made of plastic or strong paper (often with handles); used to transport goods after shopping', 'name': 'shopping_bag'}, {'frequency': 'c', 'synset': 'shopping_cart.n.01', 'synonyms': ['shopping_cart'], 'id': 950, 'def': 'a handcart that holds groceries or other goods while shopping', 'name': 'shopping_cart'}, {'frequency': 'f', 'synset': 'short_pants.n.01', 'synonyms': ['short_pants', 'shorts_(clothing)', 'trunks_(clothing)'], 'id': 951, 'def': 'trousers that end at or above the knee', 'name': 'short_pants'}, {'frequency': 'r', 'synset': 'shot_glass.n.01', 'synonyms': ['shot_glass'], 'id': 952, 'def': 'a small glass adequate to hold a single swallow of whiskey', 'name': 'shot_glass'}, {'frequency': 'f', 'synset': 'shoulder_bag.n.01', 'synonyms': ['shoulder_bag'], 'id': 953, 'def': 'a large handbag that can be carried by a strap looped over the shoulder', 'name': 'shoulder_bag'}, {'frequency': 'c', 'synset': 'shovel.n.01', 'synonyms': ['shovel'], 'id': 954, 'def': 'a hand tool for lifting loose material such as snow, dirt, etc.', 'name': 'shovel'}, {'frequency': 'f', 'synset': 'shower.n.01', 'synonyms': ['shower_head'], 'id': 955, 'def': 'a plumbing fixture that sprays water over you', 'name': 'shower_head'}, {'frequency': 'r', 'synset': 'shower_cap.n.01', 'synonyms': ['shower_cap'], 'id': 956, 'def': 'a tight cap worn to keep hair dry while showering', 'name': 'shower_cap'}, {'frequency': 'f', 'synset': 'shower_curtain.n.01', 'synonyms': ['shower_curtain'], 'id': 957, 'def': 'a curtain that keeps water from splashing out of the shower area', 'name': 'shower_curtain'}, {'frequency': 'r', 'synset': 'shredder.n.01', 'synonyms': ['shredder_(for_paper)'], 'id': 958, 'def': 'a device that shreds documents', 'name': 'shredder_(for_paper)'}, {'frequency': 'f', 'synset': 'signboard.n.01', 'synonyms': ['signboard'], 'id': 959, 'def': 'structure displaying a board on which advertisements can be posted', 'name': 'signboard'}, {'frequency': 'c', 'synset': 'silo.n.01', 'synonyms': ['silo'], 'id': 960, 'def': 'a cylindrical tower used for storing goods', 'name': 'silo'}, {'frequency': 'f', 'synset': 'sink.n.01', 'synonyms': ['sink'], 'id': 961, 'def': 'plumbing fixture consisting of a water basin fixed to a wall or floor and having a drainpipe', 'name': 'sink'}, {'frequency': 'f', 'synset': 'skateboard.n.01', 'synonyms': ['skateboard'], 'id': 962, 'def': 'a board with wheels that is ridden in a standing or crouching position and propelled by foot', 'name': 'skateboard'}, {'frequency': 'c', 'synset': 'skewer.n.01', 'synonyms': ['skewer'], 'id': 963, 'def': 'a long pin for holding meat in position while it is being roasted', 'name': 'skewer'}, {'frequency': 'f', 'synset': 'ski.n.01', 'synonyms': ['ski'], 'id': 964, 'def': 'sports equipment for skiing on snow', 'name': 'ski'}, {'frequency': 'f', 'synset': 'ski_boot.n.01', 'synonyms': ['ski_boot'], 'id': 965, 'def': 'a stiff boot that is fastened to a ski with a ski binding', 'name': 'ski_boot'}, {'frequency': 'f', 'synset': 'ski_parka.n.01', 'synonyms': ['ski_parka', 'ski_jacket'], 'id': 966, 'def': 'a parka to be worn while skiing', 'name': 'ski_parka'}, {'frequency': 'f', 'synset': 'ski_pole.n.01', 'synonyms': ['ski_pole'], 'id': 967, 'def': 'a pole with metal points used as an aid in skiing', 'name': 'ski_pole'}, {'frequency': 'f', 'synset': 'skirt.n.02', 'synonyms': ['skirt'], 'id': 968, 'def': 'a garment hanging from the waist; worn mainly by girls and women', 'name': 'skirt'}, {'frequency': 'r', 'synset': 'skullcap.n.01', 'synonyms': ['skullcap'], 'id': 969, 'def': 'rounded brimless cap fitting the crown of the head', 'name': 'skullcap'}, {'frequency': 'c', 'synset': 'sled.n.01', 'synonyms': ['sled', 'sledge', 'sleigh'], 'id': 970, 'def': 'a vehicle or flat object for transportation over snow by sliding or pulled by dogs, etc.', 'name': 'sled'}, {'frequency': 'c', 'synset': 'sleeping_bag.n.01', 'synonyms': ['sleeping_bag'], 'id': 971, 'def': 'large padded bag designed to be slept in outdoors', 'name': 'sleeping_bag'}, {'frequency': 'r', 'synset': 'sling.n.05', 'synonyms': ['sling_(bandage)', 'triangular_bandage'], 'id': 972, 'def': 'bandage to support an injured forearm; slung over the shoulder or neck', 'name': 'sling_(bandage)'}, {'frequency': 'c', 'synset': 'slipper.n.01', 'synonyms': ['slipper_(footwear)', 'carpet_slipper_(footwear)'], 'id': 973, 'def': 'low footwear that can be slipped on and off easily; usually worn indoors', 'name': 'slipper_(footwear)'}, {'frequency': 'r', 'synset': 'smoothie.n.02', 'synonyms': ['smoothie'], 'id': 974, 'def': 'a thick smooth drink consisting of fresh fruit pureed with ice cream or yoghurt or milk', 'name': 'smoothie'}, {'frequency': 'r', 'synset': 'snake.n.01', 'synonyms': ['snake', 'serpent'], 'id': 975, 'def': 'limbless scaly elongate reptile; some are venomous', 'name': 'snake'}, {'frequency': 'f', 'synset': 'snowboard.n.01', 'synonyms': ['snowboard'], 'id': 976, 'def': 'a board that resembles a broad ski or a small surfboard; used in a standing position to slide down snow-covered slopes', 'name': 'snowboard'}, {'frequency': 'c', 'synset': 'snowman.n.01', 'synonyms': ['snowman'], 'id': 977, 'def': 'a figure of a person made of packed snow', 'name': 'snowman'}, {'frequency': 'c', 'synset': 'snowmobile.n.01', 'synonyms': ['snowmobile'], 'id': 978, 'def': 'tracked vehicle for travel on snow having skis in front', 'name': 'snowmobile'}, {'frequency': 'f', 'synset': 'soap.n.01', 'synonyms': ['soap'], 'id': 979, 'def': 'a cleansing agent made from the salts of vegetable or animal fats', 'name': 'soap'}, {'frequency': 'f', 'synset': 'soccer_ball.n.01', 'synonyms': ['soccer_ball'], 'id': 980, 'def': "an inflated ball used in playing soccer (called `football' outside of the United States)", 'name': 'soccer_ball'}, {'frequency': 'f', 'synset': 'sock.n.01', 'synonyms': ['sock'], 'id': 981, 'def': 'cloth covering for the foot; worn inside the shoe; reaches to between the ankle and the knee', 'name': 'sock'}, {'frequency': 'f', 'synset': 'sofa.n.01', 'synonyms': ['sofa', 'couch', 'lounge'], 'id': 982, 'def': 'an upholstered seat for more than one person', 'name': 'sofa'}, {'frequency': 'r', 'synset': 'softball.n.01', 'synonyms': ['softball'], 'id': 983, 'def': 'ball used in playing softball', 'name': 'softball'}, {'frequency': 'c', 'synset': 'solar_array.n.01', 'synonyms': ['solar_array', 'solar_battery', 'solar_panel'], 'id': 984, 'def': 'electrical device consisting of a large array of connected solar cells', 'name': 'solar_array'}, {'frequency': 'r', 'synset': 'sombrero.n.02', 'synonyms': ['sombrero'], 'id': 985, 'def': 'a straw hat with a tall crown and broad brim; worn in American southwest and in Mexico', 'name': 'sombrero'}, {'frequency': 'f', 'synset': 'soup.n.01', 'synonyms': ['soup'], 'id': 986, 'def': 'liquid food especially of meat or fish or vegetable stock often containing pieces of solid food', 'name': 'soup'}, {'frequency': 'r', 'synset': 'soup_bowl.n.01', 'synonyms': ['soup_bowl'], 'id': 987, 'def': 'a bowl for serving soup', 'name': 'soup_bowl'}, {'frequency': 'c', 'synset': 'soupspoon.n.01', 'synonyms': ['soupspoon'], 'id': 988, 'def': 'a spoon with a rounded bowl for eating soup', 'name': 'soupspoon'}, {'frequency': 'c', 'synset': 'sour_cream.n.01', 'synonyms': ['sour_cream', 'soured_cream'], 'id': 989, 'def': 'soured light cream', 'name': 'sour_cream'}, {'frequency': 'r', 'synset': 'soya_milk.n.01', 'synonyms': ['soya_milk', 'soybean_milk', 'soymilk'], 'id': 990, 'def': 'a milk substitute containing soybean flour and water; used in some infant formulas and in making tofu', 'name': 'soya_milk'}, {'frequency': 'r', 'synset': 'space_shuttle.n.01', 'synonyms': ['space_shuttle'], 'id': 991, 'def': "a reusable spacecraft with wings for a controlled descent through the Earth's atmosphere", 'name': 'space_shuttle'}, {'frequency': 'r', 'synset': 'sparkler.n.02', 'synonyms': ['sparkler_(fireworks)'], 'id': 992, 'def': 'a firework that burns slowly and throws out a shower of sparks', 'name': 'sparkler_(fireworks)'}, {'frequency': 'f', 'synset': 'spatula.n.02', 'synonyms': ['spatula'], 'id': 993, 'def': 'a hand tool with a thin flexible blade used to mix or spread soft substances', 'name': 'spatula'}, {'frequency': 'r', 'synset': 'spear.n.01', 'synonyms': ['spear', 'lance'], 'id': 994, 'def': 'a long pointed rod used as a tool or weapon', 'name': 'spear'}, {'frequency': 'f', 'synset': 'spectacles.n.01', 'synonyms': ['spectacles', 'specs', 'eyeglasses', 'glasses'], 'id': 995, 'def': 'optical instrument consisting of a frame that holds a pair of lenses for correcting defective vision', 'name': 'spectacles'}, {'frequency': 'c', 'synset': 'spice_rack.n.01', 'synonyms': ['spice_rack'], 'id': 996, 'def': 'a rack for displaying containers filled with spices', 'name': 'spice_rack'}, {'frequency': 'c', 'synset': 'spider.n.01', 'synonyms': ['spider'], 'id': 997, 'def': 'predatory arachnid with eight legs, two poison fangs, two feelers, and usually two silk-spinning organs at the back end of the body', 'name': 'spider'}, {'frequency': 'r', 'synset': 'spiny_lobster.n.02', 'synonyms': ['crawfish', 'crayfish'], 'id': 998, 'def': 'large edible marine crustacean having a spiny carapace but lacking the large pincers of true lobsters', 'name': 'crawfish'}, {'frequency': 'c', 'synset': 'sponge.n.01', 'synonyms': ['sponge'], 'id': 999, 'def': 'a porous mass usable to absorb water typically used for cleaning', 'name': 'sponge'}, {'frequency': 'f', 'synset': 'spoon.n.01', 'synonyms': ['spoon'], 'id': 1000, 'def': 'a piece of cutlery with a shallow bowl-shaped container and a handle', 'name': 'spoon'}, {'frequency': 'c', 'synset': 'sportswear.n.01', 'synonyms': ['sportswear', 'athletic_wear', 'activewear'], 'id': 1001, 'def': 'attire worn for sport or for casual wear', 'name': 'sportswear'}, {'frequency': 'c', 'synset': 'spotlight.n.02', 'synonyms': ['spotlight'], 'id': 1002, 'def': 'a lamp that produces a strong beam of light to illuminate a restricted area; used to focus attention of a stage performer', 'name': 'spotlight'}, {'frequency': 'r', 'synset': 'squid.n.01', 'synonyms': ['squid_(food)', 'calamari', 'calamary'], 'id': 1003, 'def': '(Italian cuisine) squid prepared as food', 'name': 'squid_(food)'}, {'frequency': 'c', 'synset': 'squirrel.n.01', 'synonyms': ['squirrel'], 'id': 1004, 'def': 'a kind of arboreal rodent having a long bushy tail', 'name': 'squirrel'}, {'frequency': 'r', 'synset': 'stagecoach.n.01', 'synonyms': ['stagecoach'], 'id': 1005, 'def': 'a large coach-and-four formerly used to carry passengers and mail on regular routes between towns', 'name': 'stagecoach'}, {'frequency': 'c', 'synset': 'stapler.n.01', 'synonyms': ['stapler_(stapling_machine)'], 'id': 1006, 'def': 'a machine that inserts staples into sheets of paper in order to fasten them together', 'name': 'stapler_(stapling_machine)'}, {'frequency': 'c', 'synset': 'starfish.n.01', 'synonyms': ['starfish', 'sea_star'], 'id': 1007, 'def': 'echinoderms characterized by five arms extending from a central disk', 'name': 'starfish'}, {'frequency': 'f', 'synset': 'statue.n.01', 'synonyms': ['statue_(sculpture)'], 'id': 1008, 'def': 'a sculpture representing a human or animal', 'name': 'statue_(sculpture)'}, {'frequency': 'c', 'synset': 'steak.n.01', 'synonyms': ['steak_(food)'], 'id': 1009, 'def': 'a slice of meat cut from the fleshy part of an animal or large fish', 'name': 'steak_(food)'}, {'frequency': 'r', 'synset': 'steak_knife.n.01', 'synonyms': ['steak_knife'], 'id': 1010, 'def': 'a sharp table knife used in eating steak', 'name': 'steak_knife'}, {'frequency': 'f', 'synset': 'steering_wheel.n.01', 'synonyms': ['steering_wheel'], 'id': 1011, 'def': 'a handwheel that is used for steering', 'name': 'steering_wheel'}, {'frequency': 'r', 'synset': 'step_ladder.n.01', 'synonyms': ['stepladder'], 'id': 1012, 'def': 'a folding portable ladder hinged at the top', 'name': 'stepladder'}, {'frequency': 'c', 'synset': 'step_stool.n.01', 'synonyms': ['step_stool'], 'id': 1013, 'def': 'a stool that has one or two steps that fold under the seat', 'name': 'step_stool'}, {'frequency': 'c', 'synset': 'stereo.n.01', 'synonyms': ['stereo_(sound_system)'], 'id': 1014, 'def': 'electronic device for playing audio', 'name': 'stereo_(sound_system)'}, {'frequency': 'r', 'synset': 'stew.n.02', 'synonyms': ['stew'], 'id': 1015, 'def': 'food prepared by stewing especially meat or fish with vegetables', 'name': 'stew'}, {'frequency': 'r', 'synset': 'stirrer.n.02', 'synonyms': ['stirrer'], 'id': 1016, 'def': 'an implement used for stirring', 'name': 'stirrer'}, {'frequency': 'f', 'synset': 'stirrup.n.01', 'synonyms': ['stirrup'], 'id': 1017, 'def': "support consisting of metal loops into which rider's feet go", 'name': 'stirrup'}, {'frequency': 'f', 'synset': 'stool.n.01', 'synonyms': ['stool'], 'id': 1018, 'def': 'a simple seat without a back or arms', 'name': 'stool'}, {'frequency': 'f', 'synset': 'stop_sign.n.01', 'synonyms': ['stop_sign'], 'id': 1019, 'def': 'a traffic sign to notify drivers that they must come to a complete stop', 'name': 'stop_sign'}, {'frequency': 'f', 'synset': 'stoplight.n.01', 'synonyms': ['brake_light'], 'id': 1020, 'def': 'a red light on the rear of a motor vehicle that signals when the brakes are applied', 'name': 'brake_light'}, {'frequency': 'f', 'synset': 'stove.n.01', 'synonyms': ['stove', 'kitchen_stove', 'range_(kitchen_appliance)', 'kitchen_range', 'cooking_stove'], 'id': 1021, 'def': 'a kitchen appliance used for cooking food', 'name': 'stove'}, {'frequency': 'c', 'synset': 'strainer.n.01', 'synonyms': ['strainer'], 'id': 1022, 'def': 'a filter to retain larger pieces while smaller pieces and liquids pass through', 'name': 'strainer'}, {'frequency': 'f', 'synset': 'strap.n.01', 'synonyms': ['strap'], 'id': 1023, 'def': 'an elongated strip of material for binding things together or holding', 'name': 'strap'}, {'frequency': 'f', 'synset': 'straw.n.04', 'synonyms': ['straw_(for_drinking)', 'drinking_straw'], 'id': 1024, 'def': 'a thin paper or plastic tube used to suck liquids into the mouth', 'name': 'straw_(for_drinking)'}, {'frequency': 'f', 'synset': 'strawberry.n.01', 'synonyms': ['strawberry'], 'id': 1025, 'def': 'sweet fleshy red fruit', 'name': 'strawberry'}, {'frequency': 'f', 'synset': 'street_sign.n.01', 'synonyms': ['street_sign'], 'id': 1026, 'def': 'a sign visible from the street', 'name': 'street_sign'}, {'frequency': 'f', 'synset': 'streetlight.n.01', 'synonyms': ['streetlight', 'street_lamp'], 'id': 1027, 'def': 'a lamp supported on a lamppost; for illuminating a street', 'name': 'streetlight'}, {'frequency': 'r', 'synset': 'string_cheese.n.01', 'synonyms': ['string_cheese'], 'id': 1028, 'def': 'cheese formed in long strings twisted together', 'name': 'string_cheese'}, {'frequency': 'r', 'synset': 'stylus.n.02', 'synonyms': ['stylus'], 'id': 1029, 'def': 'a pointed tool for writing or drawing or engraving, including pens', 'name': 'stylus'}, {'frequency': 'r', 'synset': 'subwoofer.n.01', 'synonyms': ['subwoofer'], 'id': 1030, 'def': 'a loudspeaker that is designed to reproduce very low bass frequencies', 'name': 'subwoofer'}, {'frequency': 'r', 'synset': 'sugar_bowl.n.01', 'synonyms': ['sugar_bowl'], 'id': 1031, 'def': 'a dish in which sugar is served', 'name': 'sugar_bowl'}, {'frequency': 'r', 'synset': 'sugarcane.n.01', 'synonyms': ['sugarcane_(plant)'], 'id': 1032, 'def': 'juicy canes whose sap is a source of molasses and commercial sugar; fresh canes are sometimes chewed for the juice', 'name': 'sugarcane_(plant)'}, {'frequency': 'f', 'synset': 'suit.n.01', 'synonyms': ['suit_(clothing)'], 'id': 1033, 'def': 'a set of garments (usually including a jacket and trousers or skirt) for outerwear all of the same fabric and color', 'name': 'suit_(clothing)'}, {'frequency': 'c', 'synset': 'sunflower.n.01', 'synonyms': ['sunflower'], 'id': 1034, 'def': 'any plant of the genus Helianthus having large flower heads with dark disk florets and showy yellow rays', 'name': 'sunflower'}, {'frequency': 'f', 'synset': 'sunglasses.n.01', 'synonyms': ['sunglasses'], 'id': 1035, 'def': 'spectacles that are darkened or polarized to protect the eyes from the glare of the sun', 'name': 'sunglasses'}, {'frequency': 'c', 'synset': 'sunhat.n.01', 'synonyms': ['sunhat'], 'id': 1036, 'def': 'a hat with a broad brim that protects the face from direct exposure to the sun', 'name': 'sunhat'}, {'frequency': 'f', 'synset': 'surfboard.n.01', 'synonyms': ['surfboard'], 'id': 1037, 'def': 'a narrow buoyant board for riding surf', 'name': 'surfboard'}, {'frequency': 'c', 'synset': 'sushi.n.01', 'synonyms': ['sushi'], 'id': 1038, 'def': 'rice (with raw fish) wrapped in seaweed', 'name': 'sushi'}, {'frequency': 'c', 'synset': 'swab.n.02', 'synonyms': ['mop'], 'id': 1039, 'def': 'cleaning implement consisting of absorbent material fastened to a handle; for cleaning floors', 'name': 'mop'}, {'frequency': 'c', 'synset': 'sweat_pants.n.01', 'synonyms': ['sweat_pants'], 'id': 1040, 'def': 'loose-fitting trousers with elastic cuffs; worn by athletes', 'name': 'sweat_pants'}, {'frequency': 'c', 'synset': 'sweatband.n.02', 'synonyms': ['sweatband'], 'id': 1041, 'def': 'a band of material tied around the forehead or wrist to absorb sweat', 'name': 'sweatband'}, {'frequency': 'f', 'synset': 'sweater.n.01', 'synonyms': ['sweater'], 'id': 1042, 'def': 'a crocheted or knitted garment covering the upper part of the body', 'name': 'sweater'}, {'frequency': 'f', 'synset': 'sweatshirt.n.01', 'synonyms': ['sweatshirt'], 'id': 1043, 'def': 'cotton knit pullover with long sleeves worn during athletic activity', 'name': 'sweatshirt'}, {'frequency': 'c', 'synset': 'sweet_potato.n.02', 'synonyms': ['sweet_potato'], 'id': 1044, 'def': 'the edible tuberous root of the sweet potato vine', 'name': 'sweet_potato'}, {'frequency': 'f', 'synset': 'swimsuit.n.01', 'synonyms': ['swimsuit', 'swimwear', 'bathing_suit', 'swimming_costume', 'bathing_costume', 'swimming_trunks', 'bathing_trunks'], 'id': 1045, 'def': 'garment worn for swimming', 'name': 'swimsuit'}, {'frequency': 'c', 'synset': 'sword.n.01', 'synonyms': ['sword'], 'id': 1046, 'def': 'a cutting or thrusting weapon that has a long metal blade', 'name': 'sword'}, {'frequency': 'r', 'synset': 'syringe.n.01', 'synonyms': ['syringe'], 'id': 1047, 'def': 'a medical instrument used to inject or withdraw fluids', 'name': 'syringe'}, {'frequency': 'r', 'synset': 'tabasco.n.02', 'synonyms': ['Tabasco_sauce'], 'id': 1048, 'def': 'very spicy sauce (trade name Tabasco) made from fully-aged red peppers', 'name': 'Tabasco_sauce'}, {'frequency': 'r', 'synset': 'table-tennis_table.n.01', 'synonyms': ['table-tennis_table', 'ping-pong_table'], 'id': 1049, 'def': 'a table used for playing table tennis', 'name': 'table-tennis_table'}, {'frequency': 'f', 'synset': 'table.n.02', 'synonyms': ['table'], 'id': 1050, 'def': 'a piece of furniture having a smooth flat top that is usually supported by one or more vertical legs', 'name': 'table'}, {'frequency': 'c', 'synset': 'table_lamp.n.01', 'synonyms': ['table_lamp'], 'id': 1051, 'def': 'a lamp that sits on a table', 'name': 'table_lamp'}, {'frequency': 'f', 'synset': 'tablecloth.n.01', 'synonyms': ['tablecloth'], 'id': 1052, 'def': 'a covering spread over a dining table', 'name': 'tablecloth'}, {'frequency': 'r', 'synset': 'tachometer.n.01', 'synonyms': ['tachometer'], 'id': 1053, 'def': 'measuring instrument for indicating speed of rotation', 'name': 'tachometer'}, {'frequency': 'r', 'synset': 'taco.n.02', 'synonyms': ['taco'], 'id': 1054, 'def': 'a small tortilla cupped around a filling', 'name': 'taco'}, {'frequency': 'f', 'synset': 'tag.n.02', 'synonyms': ['tag'], 'id': 1055, 'def': 'a label associated with something for the purpose of identification or information', 'name': 'tag'}, {'frequency': 'f', 'synset': 'taillight.n.01', 'synonyms': ['taillight', 'rear_light'], 'id': 1056, 'def': 'lamp (usually red) mounted at the rear of a motor vehicle', 'name': 'taillight'}, {'frequency': 'r', 'synset': 'tambourine.n.01', 'synonyms': ['tambourine'], 'id': 1057, 'def': 'a shallow drum with a single drumhead and with metallic disks in the sides', 'name': 'tambourine'}, {'frequency': 'r', 'synset': 'tank.n.01', 'synonyms': ['army_tank', 'armored_combat_vehicle', 'armoured_combat_vehicle'], 'id': 1058, 'def': 'an enclosed armored military vehicle; has a cannon and moves on caterpillar treads', 'name': 'army_tank'}, {'frequency': 'f', 'synset': 'tank.n.02', 'synonyms': ['tank_(storage_vessel)', 'storage_tank'], 'id': 1059, 'def': 'a large (usually metallic) vessel for holding gases or liquids', 'name': 'tank_(storage_vessel)'}, {'frequency': 'f', 'synset': 'tank_top.n.01', 'synonyms': ['tank_top_(clothing)'], 'id': 1060, 'def': 'a tight-fitting sleeveless shirt with wide shoulder straps and low neck and no front opening', 'name': 'tank_top_(clothing)'}, {'frequency': 'f', 'synset': 'tape.n.01', 'synonyms': ['tape_(sticky_cloth_or_paper)'], 'id': 1061, 'def': 'a long thin piece of cloth or paper as used for binding or fastening', 'name': 'tape_(sticky_cloth_or_paper)'}, {'frequency': 'c', 'synset': 'tape.n.04', 'synonyms': ['tape_measure', 'measuring_tape'], 'id': 1062, 'def': 'measuring instrument consisting of a narrow strip (cloth or metal) marked in inches or centimeters and used for measuring lengths', 'name': 'tape_measure'}, {'frequency': 'c', 'synset': 'tapestry.n.02', 'synonyms': ['tapestry'], 'id': 1063, 'def': 'a heavy textile with a woven design; used for curtains and upholstery', 'name': 'tapestry'}, {'frequency': 'f', 'synset': 'tarpaulin.n.01', 'synonyms': ['tarp'], 'id': 1064, 'def': 'waterproofed canvas', 'name': 'tarp'}, {'frequency': 'c', 'synset': 'tartan.n.01', 'synonyms': ['tartan', 'plaid'], 'id': 1065, 'def': 'a cloth having a crisscross design', 'name': 'tartan'}, {'frequency': 'c', 'synset': 'tassel.n.01', 'synonyms': ['tassel'], 'id': 1066, 'def': 'adornment consisting of a bunch of cords fastened at one end', 'name': 'tassel'}, {'frequency': 'c', 'synset': 'tea_bag.n.01', 'synonyms': ['tea_bag'], 'id': 1067, 'def': 'a measured amount of tea in a bag for an individual serving of tea', 'name': 'tea_bag'}, {'frequency': 'c', 'synset': 'teacup.n.02', 'synonyms': ['teacup'], 'id': 1068, 'def': 'a cup from which tea is drunk', 'name': 'teacup'}, {'frequency': 'c', 'synset': 'teakettle.n.01', 'synonyms': ['teakettle'], 'id': 1069, 'def': 'kettle for boiling water to make tea', 'name': 'teakettle'}, {'frequency': 'f', 'synset': 'teapot.n.01', 'synonyms': ['teapot'], 'id': 1070, 'def': 'pot for brewing tea; usually has a spout and handle', 'name': 'teapot'}, {'frequency': 'f', 'synset': 'teddy.n.01', 'synonyms': ['teddy_bear'], 'id': 1071, 'def': "plaything consisting of a child's toy bear (usually plush and stuffed with soft materials)", 'name': 'teddy_bear'}, {'frequency': 'f', 'synset': 'telephone.n.01', 'synonyms': ['telephone', 'phone', 'telephone_set'], 'id': 1072, 'def': 'electronic device for communicating by voice over long distances (includes wired and wireless/cell phones)', 'name': 'telephone'}, {'frequency': 'c', 'synset': 'telephone_booth.n.01', 'synonyms': ['telephone_booth', 'phone_booth', 'call_box', 'telephone_box', 'telephone_kiosk'], 'id': 1073, 'def': 'booth for using a telephone', 'name': 'telephone_booth'}, {'frequency': 'f', 'synset': 'telephone_pole.n.01', 'synonyms': ['telephone_pole', 'telegraph_pole', 'telegraph_post'], 'id': 1074, 'def': 'tall pole supporting telephone wires', 'name': 'telephone_pole'}, {'frequency': 'r', 'synset': 'telephoto_lens.n.01', 'synonyms': ['telephoto_lens', 'zoom_lens'], 'id': 1075, 'def': 'a camera lens that magnifies the image', 'name': 'telephoto_lens'}, {'frequency': 'c', 'synset': 'television_camera.n.01', 'synonyms': ['television_camera', 'tv_camera'], 'id': 1076, 'def': 'television equipment for capturing and recording video', 'name': 'television_camera'}, {'frequency': 'f', 'synset': 'television_receiver.n.01', 'synonyms': ['television_set', 'tv', 'tv_set'], 'id': 1077, 'def': 'an electronic device that receives television signals and displays them on a screen', 'name': 'television_set'}, {'frequency': 'f', 'synset': 'tennis_ball.n.01', 'synonyms': ['tennis_ball'], 'id': 1078, 'def': 'ball about the size of a fist used in playing tennis', 'name': 'tennis_ball'}, {'frequency': 'f', 'synset': 'tennis_racket.n.01', 'synonyms': ['tennis_racket'], 'id': 1079, 'def': 'a racket used to play tennis', 'name': 'tennis_racket'}, {'frequency': 'r', 'synset': 'tequila.n.01', 'synonyms': ['tequila'], 'id': 1080, 'def': 'Mexican liquor made from fermented juices of an agave plant', 'name': 'tequila'}, {'frequency': 'c', 'synset': 'thermometer.n.01', 'synonyms': ['thermometer'], 'id': 1081, 'def': 'measuring instrument for measuring temperature', 'name': 'thermometer'}, {'frequency': 'c', 'synset': 'thermos.n.01', 'synonyms': ['thermos_bottle'], 'id': 1082, 'def': 'vacuum flask that preserves temperature of hot or cold drinks', 'name': 'thermos_bottle'}, {'frequency': 'f', 'synset': 'thermostat.n.01', 'synonyms': ['thermostat'], 'id': 1083, 'def': 'a regulator for automatically regulating temperature by starting or stopping the supply of heat', 'name': 'thermostat'}, {'frequency': 'r', 'synset': 'thimble.n.02', 'synonyms': ['thimble'], 'id': 1084, 'def': 'a small metal cap to protect the finger while sewing; can be used as a small container', 'name': 'thimble'}, {'frequency': 'c', 'synset': 'thread.n.01', 'synonyms': ['thread', 'yarn'], 'id': 1085, 'def': 'a fine cord of twisted fibers (of cotton or silk or wool or nylon etc.) used in sewing and weaving', 'name': 'thread'}, {'frequency': 'c', 'synset': 'thumbtack.n.01', 'synonyms': ['thumbtack', 'drawing_pin', 'pushpin'], 'id': 1086, 'def': 'a tack for attaching papers to a bulletin board or drawing board', 'name': 'thumbtack'}, {'frequency': 'c', 'synset': 'tiara.n.01', 'synonyms': ['tiara'], 'id': 1087, 'def': 'a jeweled headdress worn by women on formal occasions', 'name': 'tiara'}, {'frequency': 'c', 'synset': 'tiger.n.02', 'synonyms': ['tiger'], 'id': 1088, 'def': 'large feline of forests in most of Asia having a tawny coat with black stripes', 'name': 'tiger'}, {'frequency': 'c', 'synset': 'tights.n.01', 'synonyms': ['tights_(clothing)', 'leotards'], 'id': 1089, 'def': 'skintight knit hose covering the body from the waist to the feet worn by acrobats and dancers and as stockings by women and girls', 'name': 'tights_(clothing)'}, {'frequency': 'c', 'synset': 'timer.n.01', 'synonyms': ['timer', 'stopwatch'], 'id': 1090, 'def': 'a timepiece that measures a time interval and signals its end', 'name': 'timer'}, {'frequency': 'f', 'synset': 'tinfoil.n.01', 'synonyms': ['tinfoil'], 'id': 1091, 'def': 'foil made of tin or an alloy of tin and lead', 'name': 'tinfoil'}, {'frequency': 'c', 'synset': 'tinsel.n.01', 'synonyms': ['tinsel'], 'id': 1092, 'def': 'a showy decoration that is basically valueless', 'name': 'tinsel'}, {'frequency': 'f', 'synset': 'tissue.n.02', 'synonyms': ['tissue_paper'], 'id': 1093, 'def': 'a soft thin (usually translucent) paper', 'name': 'tissue_paper'}, {'frequency': 'c', 'synset': 'toast.n.01', 'synonyms': ['toast_(food)'], 'id': 1094, 'def': 'slice of bread that has been toasted', 'name': 'toast_(food)'}, {'frequency': 'f', 'synset': 'toaster.n.02', 'synonyms': ['toaster'], 'id': 1095, 'def': 'a kitchen appliance (usually electric) for toasting bread', 'name': 'toaster'}, {'frequency': 'f', 'synset': 'toaster_oven.n.01', 'synonyms': ['toaster_oven'], 'id': 1096, 'def': 'kitchen appliance consisting of a small electric oven for toasting or warming food', 'name': 'toaster_oven'}, {'frequency': 'f', 'synset': 'toilet.n.02', 'synonyms': ['toilet'], 'id': 1097, 'def': 'a plumbing fixture for defecation and urination', 'name': 'toilet'}, {'frequency': 'f', 'synset': 'toilet_tissue.n.01', 'synonyms': ['toilet_tissue', 'toilet_paper', 'bathroom_tissue'], 'id': 1098, 'def': 'a soft thin absorbent paper for use in toilets', 'name': 'toilet_tissue'}, {'frequency': 'f', 'synset': 'tomato.n.01', 'synonyms': ['tomato'], 'id': 1099, 'def': 'mildly acid red or yellow pulpy fruit eaten as a vegetable', 'name': 'tomato'}, {'frequency': 'f', 'synset': 'tongs.n.01', 'synonyms': ['tongs'], 'id': 1100, 'def': 'any of various devices for taking hold of objects; usually have two hinged legs with handles above and pointed hooks below', 'name': 'tongs'}, {'frequency': 'c', 'synset': 'toolbox.n.01', 'synonyms': ['toolbox'], 'id': 1101, 'def': 'a box or chest or cabinet for holding hand tools', 'name': 'toolbox'}, {'frequency': 'f', 'synset': 'toothbrush.n.01', 'synonyms': ['toothbrush'], 'id': 1102, 'def': 'small brush; has long handle; used to clean teeth', 'name': 'toothbrush'}, {'frequency': 'f', 'synset': 'toothpaste.n.01', 'synonyms': ['toothpaste'], 'id': 1103, 'def': 'a dentifrice in the form of a paste', 'name': 'toothpaste'}, {'frequency': 'f', 'synset': 'toothpick.n.01', 'synonyms': ['toothpick'], 'id': 1104, 'def': 'pick consisting of a small strip of wood or plastic; used to pick food from between the teeth', 'name': 'toothpick'}, {'frequency': 'f', 'synset': 'top.n.09', 'synonyms': ['cover'], 'id': 1105, 'def': 'covering for a hole (especially a hole in the top of a container)', 'name': 'cover'}, {'frequency': 'c', 'synset': 'tortilla.n.01', 'synonyms': ['tortilla'], 'id': 1106, 'def': 'thin unleavened pancake made from cornmeal or wheat flour', 'name': 'tortilla'}, {'frequency': 'c', 'synset': 'tow_truck.n.01', 'synonyms': ['tow_truck'], 'id': 1107, 'def': 'a truck equipped to hoist and pull wrecked cars (or to remove cars from no-parking zones)', 'name': 'tow_truck'}, {'frequency': 'f', 'synset': 'towel.n.01', 'synonyms': ['towel'], 'id': 1108, 'def': 'a rectangular piece of absorbent cloth (or paper) for drying or wiping', 'name': 'towel'}, {'frequency': 'f', 'synset': 'towel_rack.n.01', 'synonyms': ['towel_rack', 'towel_rail', 'towel_bar'], 'id': 1109, 'def': 'a rack consisting of one or more bars on which towels can be hung', 'name': 'towel_rack'}, {'frequency': 'f', 'synset': 'toy.n.03', 'synonyms': ['toy'], 'id': 1110, 'def': 'a device regarded as providing amusement', 'name': 'toy'}, {'frequency': 'c', 'synset': 'tractor.n.01', 'synonyms': ['tractor_(farm_equipment)'], 'id': 1111, 'def': 'a wheeled vehicle with large wheels; used in farming and other applications', 'name': 'tractor_(farm_equipment)'}, {'frequency': 'f', 'synset': 'traffic_light.n.01', 'synonyms': ['traffic_light'], 'id': 1112, 'def': 'a device to control vehicle traffic often consisting of three or more lights', 'name': 'traffic_light'}, {'frequency': 'c', 'synset': 'trail_bike.n.01', 'synonyms': ['dirt_bike'], 'id': 1113, 'def': 'a lightweight motorcycle equipped with rugged tires and suspension for off-road use', 'name': 'dirt_bike'}, {'frequency': 'f', 'synset': 'trailer_truck.n.01', 'synonyms': ['trailer_truck', 'tractor_trailer', 'trucking_rig', 'articulated_lorry', 'semi_truck'], 'id': 1114, 'def': 'a truck consisting of a tractor and trailer together', 'name': 'trailer_truck'}, {'frequency': 'f', 'synset': 'train.n.01', 'synonyms': ['train_(railroad_vehicle)', 'railroad_train'], 'id': 1115, 'def': 'public or private transport provided by a line of railway cars coupled together and drawn by a locomotive', 'name': 'train_(railroad_vehicle)'}, {'frequency': 'r', 'synset': 'trampoline.n.01', 'synonyms': ['trampoline'], 'id': 1116, 'def': 'gymnastic apparatus consisting of a strong canvas sheet attached with springs to a metal frame', 'name': 'trampoline'}, {'frequency': 'f', 'synset': 'tray.n.01', 'synonyms': ['tray'], 'id': 1117, 'def': 'an open receptacle for holding or displaying or serving articles or food', 'name': 'tray'}, {'frequency': 'r', 'synset': 'trench_coat.n.01', 'synonyms': ['trench_coat'], 'id': 1118, 'def': 'a military style raincoat; belted with deep pockets', 'name': 'trench_coat'}, {'frequency': 'r', 'synset': 'triangle.n.05', 'synonyms': ['triangle_(musical_instrument)'], 'id': 1119, 'def': 'a percussion instrument consisting of a metal bar bent in the shape of an open triangle', 'name': 'triangle_(musical_instrument)'}, {'frequency': 'c', 'synset': 'tricycle.n.01', 'synonyms': ['tricycle'], 'id': 1120, 'def': 'a vehicle with three wheels that is moved by foot pedals', 'name': 'tricycle'}, {'frequency': 'f', 'synset': 'tripod.n.01', 'synonyms': ['tripod'], 'id': 1121, 'def': 'a three-legged rack used for support', 'name': 'tripod'}, {'frequency': 'f', 'synset': 'trouser.n.01', 'synonyms': ['trousers', 'pants_(clothing)'], 'id': 1122, 'def': 'a garment extending from the waist to the knee or ankle, covering each leg separately', 'name': 'trousers'}, {'frequency': 'f', 'synset': 'truck.n.01', 'synonyms': ['truck'], 'id': 1123, 'def': 'an automotive vehicle suitable for hauling', 'name': 'truck'}, {'frequency': 'r', 'synset': 'truffle.n.03', 'synonyms': ['truffle_(chocolate)', 'chocolate_truffle'], 'id': 1124, 'def': 'creamy chocolate candy', 'name': 'truffle_(chocolate)'}, {'frequency': 'c', 'synset': 'trunk.n.02', 'synonyms': ['trunk'], 'id': 1125, 'def': 'luggage consisting of a large strong case used when traveling or for storage', 'name': 'trunk'}, {'frequency': 'r', 'synset': 'tub.n.02', 'synonyms': ['vat'], 'id': 1126, 'def': 'a large vessel for holding or storing liquids', 'name': 'vat'}, {'frequency': 'c', 'synset': 'turban.n.01', 'synonyms': ['turban'], 'id': 1127, 'def': 'a traditional headdress consisting of a long scarf wrapped around the head', 'name': 'turban'}, {'frequency': 'c', 'synset': 'turkey.n.04', 'synonyms': ['turkey_(food)'], 'id': 1128, 'def': 'flesh of large domesticated fowl usually roasted', 'name': 'turkey_(food)'}, {'frequency': 'r', 'synset': 'turnip.n.01', 'synonyms': ['turnip'], 'id': 1129, 'def': 'widely cultivated plant having a large fleshy edible white or yellow root', 'name': 'turnip'}, {'frequency': 'c', 'synset': 'turtle.n.02', 'synonyms': ['turtle'], 'id': 1130, 'def': 'any of various aquatic and land reptiles having a bony shell and flipper-like limbs for swimming', 'name': 'turtle'}, {'frequency': 'c', 'synset': 'turtleneck.n.01', 'synonyms': ['turtleneck_(clothing)', 'polo-neck'], 'id': 1131, 'def': 'a sweater or jersey with a high close-fitting collar', 'name': 'turtleneck_(clothing)'}, {'frequency': 'c', 'synset': 'typewriter.n.01', 'synonyms': ['typewriter'], 'id': 1132, 'def': 'hand-operated character printer for printing written messages one character at a time', 'name': 'typewriter'}, {'frequency': 'f', 'synset': 'umbrella.n.01', 'synonyms': ['umbrella'], 'id': 1133, 'def': 'a lightweight handheld collapsible canopy', 'name': 'umbrella'}, {'frequency': 'f', 'synset': 'underwear.n.01', 'synonyms': ['underwear', 'underclothes', 'underclothing', 'underpants'], 'id': 1134, 'def': 'undergarment worn next to the skin and under the outer garments', 'name': 'underwear'}, {'frequency': 'r', 'synset': 'unicycle.n.01', 'synonyms': ['unicycle'], 'id': 1135, 'def': 'a vehicle with a single wheel that is driven by pedals', 'name': 'unicycle'}, {'frequency': 'f', 'synset': 'urinal.n.01', 'synonyms': ['urinal'], 'id': 1136, 'def': 'a plumbing fixture (usually attached to the wall) used by men to urinate', 'name': 'urinal'}, {'frequency': 'c', 'synset': 'urn.n.01', 'synonyms': ['urn'], 'id': 1137, 'def': 'a large vase that usually has a pedestal or feet', 'name': 'urn'}, {'frequency': 'c', 'synset': 'vacuum.n.04', 'synonyms': ['vacuum_cleaner'], 'id': 1138, 'def': 'an electrical home appliance that cleans by suction', 'name': 'vacuum_cleaner'}, {'frequency': 'f', 'synset': 'vase.n.01', 'synonyms': ['vase'], 'id': 1139, 'def': 'an open jar of glass or porcelain used as an ornament or to hold flowers', 'name': 'vase'}, {'frequency': 'c', 'synset': 'vending_machine.n.01', 'synonyms': ['vending_machine'], 'id': 1140, 'def': 'a slot machine for selling goods', 'name': 'vending_machine'}, {'frequency': 'f', 'synset': 'vent.n.01', 'synonyms': ['vent', 'blowhole', 'air_vent'], 'id': 1141, 'def': 'a hole for the escape of gas or air', 'name': 'vent'}, {'frequency': 'f', 'synset': 'vest.n.01', 'synonyms': ['vest', 'waistcoat'], 'id': 1142, 'def': "a man's sleeveless garment worn underneath a coat", 'name': 'vest'}, {'frequency': 'c', 'synset': 'videotape.n.01', 'synonyms': ['videotape'], 'id': 1143, 'def': 'a video recording made on magnetic tape', 'name': 'videotape'}, {'frequency': 'r', 'synset': 'vinegar.n.01', 'synonyms': ['vinegar'], 'id': 1144, 'def': 'sour-tasting liquid produced usually by oxidation of the alcohol in wine or cider and used as a condiment or food preservative', 'name': 'vinegar'}, {'frequency': 'r', 'synset': 'violin.n.01', 'synonyms': ['violin', 'fiddle'], 'id': 1145, 'def': 'bowed stringed instrument that is the highest member of the violin family', 'name': 'violin'}, {'frequency': 'r', 'synset': 'vodka.n.01', 'synonyms': ['vodka'], 'id': 1146, 'def': 'unaged colorless liquor originating in Russia', 'name': 'vodka'}, {'frequency': 'c', 'synset': 'volleyball.n.02', 'synonyms': ['volleyball'], 'id': 1147, 'def': 'an inflated ball used in playing volleyball', 'name': 'volleyball'}, {'frequency': 'r', 'synset': 'vulture.n.01', 'synonyms': ['vulture'], 'id': 1148, 'def': 'any of various large birds of prey having naked heads and weak claws and feeding chiefly on carrion', 'name': 'vulture'}, {'frequency': 'c', 'synset': 'waffle.n.01', 'synonyms': ['waffle'], 'id': 1149, 'def': 'pancake batter baked in a waffle iron', 'name': 'waffle'}, {'frequency': 'r', 'synset': 'waffle_iron.n.01', 'synonyms': ['waffle_iron'], 'id': 1150, 'def': 'a kitchen appliance for baking waffles', 'name': 'waffle_iron'}, {'frequency': 'c', 'synset': 'wagon.n.01', 'synonyms': ['wagon'], 'id': 1151, 'def': 'any of various kinds of wheeled vehicles drawn by an animal or a tractor', 'name': 'wagon'}, {'frequency': 'c', 'synset': 'wagon_wheel.n.01', 'synonyms': ['wagon_wheel'], 'id': 1152, 'def': 'a wheel of a wagon', 'name': 'wagon_wheel'}, {'frequency': 'c', 'synset': 'walking_stick.n.01', 'synonyms': ['walking_stick'], 'id': 1153, 'def': 'a stick carried in the hand for support in walking', 'name': 'walking_stick'}, {'frequency': 'c', 'synset': 'wall_clock.n.01', 'synonyms': ['wall_clock'], 'id': 1154, 'def': 'a clock mounted on a wall', 'name': 'wall_clock'}, {'frequency': 'f', 'synset': 'wall_socket.n.01', 'synonyms': ['wall_socket', 'wall_plug', 'electric_outlet', 'electrical_outlet', 'outlet', 'electric_receptacle'], 'id': 1155, 'def': 'receptacle providing a place in a wiring system where current can be taken to run electrical devices', 'name': 'wall_socket'}, {'frequency': 'f', 'synset': 'wallet.n.01', 'synonyms': ['wallet', 'billfold'], 'id': 1156, 'def': 'a pocket-size case for holding papers and paper money', 'name': 'wallet'}, {'frequency': 'r', 'synset': 'walrus.n.01', 'synonyms': ['walrus'], 'id': 1157, 'def': 'either of two large northern marine mammals having ivory tusks and tough hide over thick blubber', 'name': 'walrus'}, {'frequency': 'r', 'synset': 'wardrobe.n.01', 'synonyms': ['wardrobe'], 'id': 1158, 'def': 'a tall piece of furniture that provides storage space for clothes; has a door and rails or hooks for hanging clothes', 'name': 'wardrobe'}, {'frequency': 'r', 'synset': 'washbasin.n.01', 'synonyms': ['washbasin', 'basin_(for_washing)', 'washbowl', 'washstand', 'handbasin'], 'id': 1159, 'def': 'a bathroom sink that is permanently installed and connected to a water supply and drainpipe; where you can wash your hands and face', 'name': 'washbasin'}, {'frequency': 'c', 'synset': 'washer.n.03', 'synonyms': ['automatic_washer', 'washing_machine'], 'id': 1160, 'def': 'a home appliance for washing clothes and linens automatically', 'name': 'automatic_washer'}, {'frequency': 'f', 'synset': 'watch.n.01', 'synonyms': ['watch', 'wristwatch'], 'id': 1161, 'def': 'a small, portable timepiece', 'name': 'watch'}, {'frequency': 'f', 'synset': 'water_bottle.n.01', 'synonyms': ['water_bottle'], 'id': 1162, 'def': 'a bottle for holding water', 'name': 'water_bottle'}, {'frequency': 'c', 'synset': 'water_cooler.n.01', 'synonyms': ['water_cooler'], 'id': 1163, 'def': 'a device for cooling and dispensing drinking water', 'name': 'water_cooler'}, {'frequency': 'c', 'synset': 'water_faucet.n.01', 'synonyms': ['water_faucet', 'water_tap', 'tap_(water_faucet)'], 'id': 1164, 'def': 'a faucet for drawing water from a pipe or cask', 'name': 'water_faucet'}, {'frequency': 'r', 'synset': 'water_heater.n.01', 'synonyms': ['water_heater', 'hot-water_heater'], 'id': 1165, 'def': 'a heater and storage tank to supply heated water', 'name': 'water_heater'}, {'frequency': 'c', 'synset': 'water_jug.n.01', 'synonyms': ['water_jug'], 'id': 1166, 'def': 'a jug that holds water', 'name': 'water_jug'}, {'frequency': 'r', 'synset': 'water_pistol.n.01', 'synonyms': ['water_gun', 'squirt_gun'], 'id': 1167, 'def': 'plaything consisting of a toy pistol that squirts water', 'name': 'water_gun'}, {'frequency': 'c', 'synset': 'water_scooter.n.01', 'synonyms': ['water_scooter', 'sea_scooter', 'jet_ski'], 'id': 1168, 'def': 'a motorboat resembling a motor scooter (NOT A SURFBOARD OR WATER SKI)', 'name': 'water_scooter'}, {'frequency': 'c', 'synset': 'water_ski.n.01', 'synonyms': ['water_ski'], 'id': 1169, 'def': 'broad ski for skimming over water towed by a speedboat (DO NOT MARK WATER)', 'name': 'water_ski'}, {'frequency': 'c', 'synset': 'water_tower.n.01', 'synonyms': ['water_tower'], 'id': 1170, 'def': 'a large reservoir for water', 'name': 'water_tower'}, {'frequency': 'c', 'synset': 'watering_can.n.01', 'synonyms': ['watering_can'], 'id': 1171, 'def': 'a container with a handle and a spout with a perforated nozzle; used to sprinkle water over plants', 'name': 'watering_can'}, {'frequency': 'f', 'synset': 'watermelon.n.02', 'synonyms': ['watermelon'], 'id': 1172, 'def': 'large oblong or roundish melon with a hard green rind and sweet watery red or occasionally yellowish pulp', 'name': 'watermelon'}, {'frequency': 'f', 'synset': 'weathervane.n.01', 'synonyms': ['weathervane', 'vane_(weathervane)', 'wind_vane'], 'id': 1173, 'def': 'mechanical device attached to an elevated structure; rotates freely to show the direction of the wind', 'name': 'weathervane'}, {'frequency': 'c', 'synset': 'webcam.n.01', 'synonyms': ['webcam'], 'id': 1174, 'def': 'a digital camera designed to take digital photographs and transmit them over the internet', 'name': 'webcam'}, {'frequency': 'c', 'synset': 'wedding_cake.n.01', 'synonyms': ['wedding_cake', 'bridecake'], 'id': 1175, 'def': 'a rich cake with two or more tiers and covered with frosting and decorations; served at a wedding reception', 'name': 'wedding_cake'}, {'frequency': 'c', 'synset': 'wedding_ring.n.01', 'synonyms': ['wedding_ring', 'wedding_band'], 'id': 1176, 'def': 'a ring given to the bride and/or groom at the wedding', 'name': 'wedding_ring'}, {'frequency': 'f', 'synset': 'wet_suit.n.01', 'synonyms': ['wet_suit'], 'id': 1177, 'def': 'a close-fitting garment made of a permeable material; worn in cold water to retain body heat', 'name': 'wet_suit'}, {'frequency': 'f', 'synset': 'wheel.n.01', 'synonyms': ['wheel'], 'id': 1178, 'def': 'a circular frame with spokes (or a solid disc) that can rotate on a shaft or axle', 'name': 'wheel'}, {'frequency': 'c', 'synset': 'wheelchair.n.01', 'synonyms': ['wheelchair'], 'id': 1179, 'def': 'a movable chair mounted on large wheels', 'name': 'wheelchair'}, {'frequency': 'c', 'synset': 'whipped_cream.n.01', 'synonyms': ['whipped_cream'], 'id': 1180, 'def': 'cream that has been beaten until light and fluffy', 'name': 'whipped_cream'}, {'frequency': 'c', 'synset': 'whistle.n.03', 'synonyms': ['whistle'], 'id': 1181, 'def': 'a small wind instrument that produces a whistling sound by blowing into it', 'name': 'whistle'}, {'frequency': 'c', 'synset': 'wig.n.01', 'synonyms': ['wig'], 'id': 1182, 'def': 'hairpiece covering the head and made of real or synthetic hair', 'name': 'wig'}, {'frequency': 'c', 'synset': 'wind_chime.n.01', 'synonyms': ['wind_chime'], 'id': 1183, 'def': 'a decorative arrangement of pieces of metal or glass or pottery that hang together loosely so the wind can cause them to tinkle', 'name': 'wind_chime'}, {'frequency': 'c', 'synset': 'windmill.n.01', 'synonyms': ['windmill'], 'id': 1184, 'def': 'A mill or turbine that is powered by wind', 'name': 'windmill'}, {'frequency': 'c', 'synset': 'window_box.n.01', 'synonyms': ['window_box_(for_plants)'], 'id': 1185, 'def': 'a container for growing plants on a windowsill', 'name': 'window_box_(for_plants)'}, {'frequency': 'f', 'synset': 'windshield_wiper.n.01', 'synonyms': ['windshield_wiper', 'windscreen_wiper', 'wiper_(for_windshield/screen)'], 'id': 1186, 'def': 'a mechanical device that cleans the windshield', 'name': 'windshield_wiper'}, {'frequency': 'c', 'synset': 'windsock.n.01', 'synonyms': ['windsock', 'air_sock', 'air-sleeve', 'wind_sleeve', 'wind_cone'], 'id': 1187, 'def': 'a truncated cloth cone mounted on a mast/pole; shows wind direction', 'name': 'windsock'}, {'frequency': 'f', 'synset': 'wine_bottle.n.01', 'synonyms': ['wine_bottle'], 'id': 1188, 'def': 'a bottle for holding wine', 'name': 'wine_bottle'}, {'frequency': 'c', 'synset': 'wine_bucket.n.01', 'synonyms': ['wine_bucket', 'wine_cooler'], 'id': 1189, 'def': 'a bucket of ice used to chill a bottle of wine', 'name': 'wine_bucket'}, {'frequency': 'f', 'synset': 'wineglass.n.01', 'synonyms': ['wineglass'], 'id': 1190, 'def': 'a glass that has a stem and in which wine is served', 'name': 'wineglass'}, {'frequency': 'f', 'synset': 'winker.n.02', 'synonyms': ['blinder_(for_horses)'], 'id': 1191, 'def': 'blinds that prevent a horse from seeing something on either side', 'name': 'blinder_(for_horses)'}, {'frequency': 'c', 'synset': 'wok.n.01', 'synonyms': ['wok'], 'id': 1192, 'def': 'pan with a convex bottom; used for frying in Chinese cooking', 'name': 'wok'}, {'frequency': 'r', 'synset': 'wolf.n.01', 'synonyms': ['wolf'], 'id': 1193, 'def': 'a wild carnivorous mammal of the dog family, living and hunting in packs', 'name': 'wolf'}, {'frequency': 'c', 'synset': 'wooden_spoon.n.02', 'synonyms': ['wooden_spoon'], 'id': 1194, 'def': 'a spoon made of wood', 'name': 'wooden_spoon'}, {'frequency': 'c', 'synset': 'wreath.n.01', 'synonyms': ['wreath'], 'id': 1195, 'def': 'an arrangement of flowers, leaves, or stems fastened in a ring', 'name': 'wreath'}, {'frequency': 'c', 'synset': 'wrench.n.03', 'synonyms': ['wrench', 'spanner'], 'id': 1196, 'def': 'a hand tool that is used to hold or twist a nut or bolt', 'name': 'wrench'}, {'frequency': 'f', 'synset': 'wristband.n.01', 'synonyms': ['wristband'], 'id': 1197, 'def': 'band consisting of a part of a sleeve that covers the wrist', 'name': 'wristband'}, {'frequency': 'f', 'synset': 'wristlet.n.01', 'synonyms': ['wristlet', 'wrist_band'], 'id': 1198, 'def': 'a band or bracelet worn around the wrist', 'name': 'wristlet'}, {'frequency': 'c', 'synset': 'yacht.n.01', 'synonyms': ['yacht'], 'id': 1199, 'def': 'an expensive vessel propelled by sail or power and used for cruising or racing', 'name': 'yacht'}, {'frequency': 'c', 'synset': 'yogurt.n.01', 'synonyms': ['yogurt', 'yoghurt', 'yoghourt'], 'id': 1200, 'def': 'a custard-like food made from curdled milk', 'name': 'yogurt'}, {'frequency': 'c', 'synset': 'yoke.n.07', 'synonyms': ['yoke_(animal_equipment)'], 'id': 1201, 'def': 'gear joining two animals at the neck; NOT egg yolk', 'name': 'yoke_(animal_equipment)'}, {'frequency': 'f', 'synset': 'zebra.n.01', 'synonyms': ['zebra'], 'id': 1202, 'def': 'any of several fleet black-and-white striped African equines', 'name': 'zebra'}, {'frequency': 'c', 'synset': 'zucchini.n.02', 'synonyms': ['zucchini', 'courgette'], 'id': 1203, 'def': 'small cucumber-shaped vegetable marrow; typically dark green', 'name': 'zucchini'}] # noqa +# fmt: on diff --git a/custom_detectron2/data/datasets/lvis_v1_category_image_count.py b/custom_detectron2/data/datasets/lvis_v1_category_image_count.py new file mode 100644 index 0000000000000000000000000000000000000000..31bf0cfcd5096ab87835db86a28671d474514c40 --- /dev/null +++ b/custom_detectron2/data/datasets/lvis_v1_category_image_count.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# Autogen with +# with open("lvis_v1_train.json", "r") as f: +# a = json.load(f) +# c = a["categories"] +# for x in c: +# del x["name"] +# del x["instance_count"] +# del x["def"] +# del x["synonyms"] +# del x["frequency"] +# del x["synset"] +# LVIS_CATEGORY_IMAGE_COUNT = repr(c) + " # noqa" +# with open("/tmp/lvis_category_image_count.py", "wt") as f: +# f.write(f"LVIS_CATEGORY_IMAGE_COUNT = {LVIS_CATEGORY_IMAGE_COUNT}") +# Then paste the contents of that file below + +# fmt: off +LVIS_CATEGORY_IMAGE_COUNT = [{'id': 1, 'image_count': 64}, {'id': 2, 'image_count': 364}, {'id': 3, 'image_count': 1911}, {'id': 4, 'image_count': 149}, {'id': 5, 'image_count': 29}, {'id': 6, 'image_count': 26}, {'id': 7, 'image_count': 59}, {'id': 8, 'image_count': 22}, {'id': 9, 'image_count': 12}, {'id': 10, 'image_count': 28}, {'id': 11, 'image_count': 505}, {'id': 12, 'image_count': 1207}, {'id': 13, 'image_count': 4}, {'id': 14, 'image_count': 10}, {'id': 15, 'image_count': 500}, {'id': 16, 'image_count': 33}, {'id': 17, 'image_count': 3}, {'id': 18, 'image_count': 44}, {'id': 19, 'image_count': 561}, {'id': 20, 'image_count': 8}, {'id': 21, 'image_count': 9}, {'id': 22, 'image_count': 33}, {'id': 23, 'image_count': 1883}, {'id': 24, 'image_count': 98}, {'id': 25, 'image_count': 70}, {'id': 26, 'image_count': 46}, {'id': 27, 'image_count': 117}, {'id': 28, 'image_count': 41}, {'id': 29, 'image_count': 1395}, {'id': 30, 'image_count': 7}, {'id': 31, 'image_count': 1}, {'id': 32, 'image_count': 314}, {'id': 33, 'image_count': 31}, {'id': 34, 'image_count': 1905}, {'id': 35, 'image_count': 1859}, {'id': 36, 'image_count': 1623}, {'id': 37, 'image_count': 47}, {'id': 38, 'image_count': 3}, {'id': 39, 'image_count': 3}, {'id': 40, 'image_count': 1}, {'id': 41, 'image_count': 305}, {'id': 42, 'image_count': 6}, {'id': 43, 'image_count': 210}, {'id': 44, 'image_count': 36}, {'id': 45, 'image_count': 1787}, {'id': 46, 'image_count': 17}, {'id': 47, 'image_count': 51}, {'id': 48, 'image_count': 138}, {'id': 49, 'image_count': 3}, {'id': 50, 'image_count': 1470}, {'id': 51, 'image_count': 3}, {'id': 52, 'image_count': 2}, {'id': 53, 'image_count': 186}, {'id': 54, 'image_count': 76}, {'id': 55, 'image_count': 26}, {'id': 56, 'image_count': 303}, {'id': 57, 'image_count': 738}, {'id': 58, 'image_count': 1799}, {'id': 59, 'image_count': 1934}, {'id': 60, 'image_count': 1609}, {'id': 61, 'image_count': 1622}, {'id': 62, 'image_count': 41}, {'id': 63, 'image_count': 4}, {'id': 64, 'image_count': 11}, {'id': 65, 'image_count': 270}, {'id': 66, 'image_count': 349}, {'id': 67, 'image_count': 42}, {'id': 68, 'image_count': 823}, {'id': 69, 'image_count': 6}, {'id': 70, 'image_count': 48}, {'id': 71, 'image_count': 3}, {'id': 72, 'image_count': 42}, {'id': 73, 'image_count': 24}, {'id': 74, 'image_count': 16}, {'id': 75, 'image_count': 605}, {'id': 76, 'image_count': 646}, {'id': 77, 'image_count': 1765}, {'id': 78, 'image_count': 2}, {'id': 79, 'image_count': 125}, {'id': 80, 'image_count': 1420}, {'id': 81, 'image_count': 140}, {'id': 82, 'image_count': 4}, {'id': 83, 'image_count': 322}, {'id': 84, 'image_count': 60}, {'id': 85, 'image_count': 2}, {'id': 86, 'image_count': 231}, {'id': 87, 'image_count': 333}, {'id': 88, 'image_count': 1941}, {'id': 89, 'image_count': 367}, {'id': 90, 'image_count': 1922}, {'id': 91, 'image_count': 18}, {'id': 92, 'image_count': 81}, {'id': 93, 'image_count': 1}, {'id': 94, 'image_count': 1852}, {'id': 95, 'image_count': 430}, {'id': 96, 'image_count': 247}, {'id': 97, 'image_count': 94}, {'id': 98, 'image_count': 21}, {'id': 99, 'image_count': 1821}, {'id': 100, 'image_count': 16}, {'id': 101, 'image_count': 12}, {'id': 102, 'image_count': 25}, {'id': 103, 'image_count': 41}, {'id': 104, 'image_count': 244}, {'id': 105, 'image_count': 7}, {'id': 106, 'image_count': 1}, {'id': 107, 'image_count': 40}, {'id': 108, 'image_count': 40}, {'id': 109, 'image_count': 104}, {'id': 110, 'image_count': 1671}, {'id': 111, 'image_count': 49}, {'id': 112, 'image_count': 243}, {'id': 113, 'image_count': 2}, {'id': 114, 'image_count': 242}, {'id': 115, 'image_count': 271}, {'id': 116, 'image_count': 104}, {'id': 117, 'image_count': 8}, {'id': 118, 'image_count': 1758}, {'id': 119, 'image_count': 1}, {'id': 120, 'image_count': 48}, {'id': 121, 'image_count': 14}, {'id': 122, 'image_count': 40}, {'id': 123, 'image_count': 1}, {'id': 124, 'image_count': 37}, {'id': 125, 'image_count': 1510}, {'id': 126, 'image_count': 6}, {'id': 127, 'image_count': 1903}, {'id': 128, 'image_count': 70}, {'id': 129, 'image_count': 86}, {'id': 130, 'image_count': 7}, {'id': 131, 'image_count': 5}, {'id': 132, 'image_count': 1406}, {'id': 133, 'image_count': 1901}, {'id': 134, 'image_count': 15}, {'id': 135, 'image_count': 28}, {'id': 136, 'image_count': 6}, {'id': 137, 'image_count': 494}, {'id': 138, 'image_count': 234}, {'id': 139, 'image_count': 1922}, {'id': 140, 'image_count': 1}, {'id': 141, 'image_count': 35}, {'id': 142, 'image_count': 5}, {'id': 143, 'image_count': 1828}, {'id': 144, 'image_count': 8}, {'id': 145, 'image_count': 63}, {'id': 146, 'image_count': 1668}, {'id': 147, 'image_count': 4}, {'id': 148, 'image_count': 95}, {'id': 149, 'image_count': 17}, {'id': 150, 'image_count': 1567}, {'id': 151, 'image_count': 2}, {'id': 152, 'image_count': 103}, {'id': 153, 'image_count': 50}, {'id': 154, 'image_count': 1309}, {'id': 155, 'image_count': 6}, {'id': 156, 'image_count': 92}, {'id': 157, 'image_count': 19}, {'id': 158, 'image_count': 37}, {'id': 159, 'image_count': 4}, {'id': 160, 'image_count': 709}, {'id': 161, 'image_count': 9}, {'id': 162, 'image_count': 82}, {'id': 163, 'image_count': 15}, {'id': 164, 'image_count': 3}, {'id': 165, 'image_count': 61}, {'id': 166, 'image_count': 51}, {'id': 167, 'image_count': 5}, {'id': 168, 'image_count': 13}, {'id': 169, 'image_count': 642}, {'id': 170, 'image_count': 24}, {'id': 171, 'image_count': 255}, {'id': 172, 'image_count': 9}, {'id': 173, 'image_count': 1808}, {'id': 174, 'image_count': 31}, {'id': 175, 'image_count': 158}, {'id': 176, 'image_count': 80}, {'id': 177, 'image_count': 1884}, {'id': 178, 'image_count': 158}, {'id': 179, 'image_count': 2}, {'id': 180, 'image_count': 12}, {'id': 181, 'image_count': 1659}, {'id': 182, 'image_count': 7}, {'id': 183, 'image_count': 834}, {'id': 184, 'image_count': 57}, {'id': 185, 'image_count': 174}, {'id': 186, 'image_count': 95}, {'id': 187, 'image_count': 27}, {'id': 188, 'image_count': 22}, {'id': 189, 'image_count': 1391}, {'id': 190, 'image_count': 90}, {'id': 191, 'image_count': 40}, {'id': 192, 'image_count': 445}, {'id': 193, 'image_count': 21}, {'id': 194, 'image_count': 1132}, {'id': 195, 'image_count': 177}, {'id': 196, 'image_count': 4}, {'id': 197, 'image_count': 17}, {'id': 198, 'image_count': 84}, {'id': 199, 'image_count': 55}, {'id': 200, 'image_count': 30}, {'id': 201, 'image_count': 25}, {'id': 202, 'image_count': 2}, {'id': 203, 'image_count': 125}, {'id': 204, 'image_count': 1135}, {'id': 205, 'image_count': 19}, {'id': 206, 'image_count': 72}, {'id': 207, 'image_count': 1926}, {'id': 208, 'image_count': 159}, {'id': 209, 'image_count': 7}, {'id': 210, 'image_count': 1}, {'id': 211, 'image_count': 13}, {'id': 212, 'image_count': 35}, {'id': 213, 'image_count': 18}, {'id': 214, 'image_count': 8}, {'id': 215, 'image_count': 6}, {'id': 216, 'image_count': 35}, {'id': 217, 'image_count': 1222}, {'id': 218, 'image_count': 103}, {'id': 219, 'image_count': 28}, {'id': 220, 'image_count': 63}, {'id': 221, 'image_count': 28}, {'id': 222, 'image_count': 5}, {'id': 223, 'image_count': 7}, {'id': 224, 'image_count': 14}, {'id': 225, 'image_count': 1918}, {'id': 226, 'image_count': 133}, {'id': 227, 'image_count': 16}, {'id': 228, 'image_count': 27}, {'id': 229, 'image_count': 110}, {'id': 230, 'image_count': 1895}, {'id': 231, 'image_count': 4}, {'id': 232, 'image_count': 1927}, {'id': 233, 'image_count': 8}, {'id': 234, 'image_count': 1}, {'id': 235, 'image_count': 263}, {'id': 236, 'image_count': 10}, {'id': 237, 'image_count': 2}, {'id': 238, 'image_count': 3}, {'id': 239, 'image_count': 87}, {'id': 240, 'image_count': 9}, {'id': 241, 'image_count': 71}, {'id': 242, 'image_count': 13}, {'id': 243, 'image_count': 18}, {'id': 244, 'image_count': 2}, {'id': 245, 'image_count': 5}, {'id': 246, 'image_count': 45}, {'id': 247, 'image_count': 1}, {'id': 248, 'image_count': 23}, {'id': 249, 'image_count': 32}, {'id': 250, 'image_count': 4}, {'id': 251, 'image_count': 1}, {'id': 252, 'image_count': 858}, {'id': 253, 'image_count': 661}, {'id': 254, 'image_count': 168}, {'id': 255, 'image_count': 210}, {'id': 256, 'image_count': 65}, {'id': 257, 'image_count': 4}, {'id': 258, 'image_count': 2}, {'id': 259, 'image_count': 159}, {'id': 260, 'image_count': 31}, {'id': 261, 'image_count': 811}, {'id': 262, 'image_count': 1}, {'id': 263, 'image_count': 42}, {'id': 264, 'image_count': 27}, {'id': 265, 'image_count': 2}, {'id': 266, 'image_count': 5}, {'id': 267, 'image_count': 95}, {'id': 268, 'image_count': 32}, {'id': 269, 'image_count': 1}, {'id': 270, 'image_count': 1}, {'id': 271, 'image_count': 1844}, {'id': 272, 'image_count': 897}, {'id': 273, 'image_count': 31}, {'id': 274, 'image_count': 23}, {'id': 275, 'image_count': 1}, {'id': 276, 'image_count': 202}, {'id': 277, 'image_count': 746}, {'id': 278, 'image_count': 44}, {'id': 279, 'image_count': 14}, {'id': 280, 'image_count': 26}, {'id': 281, 'image_count': 1}, {'id': 282, 'image_count': 2}, {'id': 283, 'image_count': 25}, {'id': 284, 'image_count': 238}, {'id': 285, 'image_count': 592}, {'id': 286, 'image_count': 26}, {'id': 287, 'image_count': 5}, {'id': 288, 'image_count': 42}, {'id': 289, 'image_count': 13}, {'id': 290, 'image_count': 46}, {'id': 291, 'image_count': 1}, {'id': 292, 'image_count': 8}, {'id': 293, 'image_count': 34}, {'id': 294, 'image_count': 5}, {'id': 295, 'image_count': 1}, {'id': 296, 'image_count': 1871}, {'id': 297, 'image_count': 717}, {'id': 298, 'image_count': 1010}, {'id': 299, 'image_count': 679}, {'id': 300, 'image_count': 3}, {'id': 301, 'image_count': 4}, {'id': 302, 'image_count': 1}, {'id': 303, 'image_count': 166}, {'id': 304, 'image_count': 2}, {'id': 305, 'image_count': 266}, {'id': 306, 'image_count': 101}, {'id': 307, 'image_count': 6}, {'id': 308, 'image_count': 14}, {'id': 309, 'image_count': 133}, {'id': 310, 'image_count': 2}, {'id': 311, 'image_count': 38}, {'id': 312, 'image_count': 95}, {'id': 313, 'image_count': 1}, {'id': 314, 'image_count': 12}, {'id': 315, 'image_count': 49}, {'id': 316, 'image_count': 5}, {'id': 317, 'image_count': 5}, {'id': 318, 'image_count': 16}, {'id': 319, 'image_count': 216}, {'id': 320, 'image_count': 12}, {'id': 321, 'image_count': 1}, {'id': 322, 'image_count': 54}, {'id': 323, 'image_count': 5}, {'id': 324, 'image_count': 245}, {'id': 325, 'image_count': 12}, {'id': 326, 'image_count': 7}, {'id': 327, 'image_count': 35}, {'id': 328, 'image_count': 36}, {'id': 329, 'image_count': 32}, {'id': 330, 'image_count': 1027}, {'id': 331, 'image_count': 10}, {'id': 332, 'image_count': 12}, {'id': 333, 'image_count': 1}, {'id': 334, 'image_count': 67}, {'id': 335, 'image_count': 71}, {'id': 336, 'image_count': 30}, {'id': 337, 'image_count': 48}, {'id': 338, 'image_count': 249}, {'id': 339, 'image_count': 13}, {'id': 340, 'image_count': 29}, {'id': 341, 'image_count': 14}, {'id': 342, 'image_count': 236}, {'id': 343, 'image_count': 15}, {'id': 344, 'image_count': 1521}, {'id': 345, 'image_count': 25}, {'id': 346, 'image_count': 249}, {'id': 347, 'image_count': 139}, {'id': 348, 'image_count': 2}, {'id': 349, 'image_count': 2}, {'id': 350, 'image_count': 1890}, {'id': 351, 'image_count': 1240}, {'id': 352, 'image_count': 1}, {'id': 353, 'image_count': 9}, {'id': 354, 'image_count': 1}, {'id': 355, 'image_count': 3}, {'id': 356, 'image_count': 11}, {'id': 357, 'image_count': 4}, {'id': 358, 'image_count': 236}, {'id': 359, 'image_count': 44}, {'id': 360, 'image_count': 19}, {'id': 361, 'image_count': 1100}, {'id': 362, 'image_count': 7}, {'id': 363, 'image_count': 69}, {'id': 364, 'image_count': 2}, {'id': 365, 'image_count': 8}, {'id': 366, 'image_count': 5}, {'id': 367, 'image_count': 227}, {'id': 368, 'image_count': 6}, {'id': 369, 'image_count': 106}, {'id': 370, 'image_count': 81}, {'id': 371, 'image_count': 17}, {'id': 372, 'image_count': 134}, {'id': 373, 'image_count': 312}, {'id': 374, 'image_count': 8}, {'id': 375, 'image_count': 271}, {'id': 376, 'image_count': 2}, {'id': 377, 'image_count': 103}, {'id': 378, 'image_count': 1938}, {'id': 379, 'image_count': 574}, {'id': 380, 'image_count': 120}, {'id': 381, 'image_count': 2}, {'id': 382, 'image_count': 2}, {'id': 383, 'image_count': 13}, {'id': 384, 'image_count': 29}, {'id': 385, 'image_count': 1710}, {'id': 386, 'image_count': 66}, {'id': 387, 'image_count': 1008}, {'id': 388, 'image_count': 1}, {'id': 389, 'image_count': 3}, {'id': 390, 'image_count': 1942}, {'id': 391, 'image_count': 19}, {'id': 392, 'image_count': 1488}, {'id': 393, 'image_count': 46}, {'id': 394, 'image_count': 106}, {'id': 395, 'image_count': 115}, {'id': 396, 'image_count': 19}, {'id': 397, 'image_count': 2}, {'id': 398, 'image_count': 1}, {'id': 399, 'image_count': 28}, {'id': 400, 'image_count': 9}, {'id': 401, 'image_count': 192}, {'id': 402, 'image_count': 12}, {'id': 403, 'image_count': 21}, {'id': 404, 'image_count': 247}, {'id': 405, 'image_count': 6}, {'id': 406, 'image_count': 64}, {'id': 407, 'image_count': 7}, {'id': 408, 'image_count': 40}, {'id': 409, 'image_count': 542}, {'id': 410, 'image_count': 2}, {'id': 411, 'image_count': 1898}, {'id': 412, 'image_count': 36}, {'id': 413, 'image_count': 4}, {'id': 414, 'image_count': 1}, {'id': 415, 'image_count': 191}, {'id': 416, 'image_count': 6}, {'id': 417, 'image_count': 41}, {'id': 418, 'image_count': 39}, {'id': 419, 'image_count': 46}, {'id': 420, 'image_count': 1}, {'id': 421, 'image_count': 1451}, {'id': 422, 'image_count': 1878}, {'id': 423, 'image_count': 11}, {'id': 424, 'image_count': 82}, {'id': 425, 'image_count': 18}, {'id': 426, 'image_count': 1}, {'id': 427, 'image_count': 7}, {'id': 428, 'image_count': 3}, {'id': 429, 'image_count': 575}, {'id': 430, 'image_count': 1907}, {'id': 431, 'image_count': 8}, {'id': 432, 'image_count': 4}, {'id': 433, 'image_count': 32}, {'id': 434, 'image_count': 11}, {'id': 435, 'image_count': 4}, {'id': 436, 'image_count': 54}, {'id': 437, 'image_count': 202}, {'id': 438, 'image_count': 32}, {'id': 439, 'image_count': 3}, {'id': 440, 'image_count': 130}, {'id': 441, 'image_count': 119}, {'id': 442, 'image_count': 141}, {'id': 443, 'image_count': 29}, {'id': 444, 'image_count': 525}, {'id': 445, 'image_count': 1323}, {'id': 446, 'image_count': 2}, {'id': 447, 'image_count': 113}, {'id': 448, 'image_count': 16}, {'id': 449, 'image_count': 7}, {'id': 450, 'image_count': 35}, {'id': 451, 'image_count': 1908}, {'id': 452, 'image_count': 353}, {'id': 453, 'image_count': 18}, {'id': 454, 'image_count': 14}, {'id': 455, 'image_count': 77}, {'id': 456, 'image_count': 8}, {'id': 457, 'image_count': 37}, {'id': 458, 'image_count': 1}, {'id': 459, 'image_count': 346}, {'id': 460, 'image_count': 19}, {'id': 461, 'image_count': 1779}, {'id': 462, 'image_count': 23}, {'id': 463, 'image_count': 25}, {'id': 464, 'image_count': 67}, {'id': 465, 'image_count': 19}, {'id': 466, 'image_count': 28}, {'id': 467, 'image_count': 4}, {'id': 468, 'image_count': 27}, {'id': 469, 'image_count': 1861}, {'id': 470, 'image_count': 11}, {'id': 471, 'image_count': 13}, {'id': 472, 'image_count': 13}, {'id': 473, 'image_count': 32}, {'id': 474, 'image_count': 1767}, {'id': 475, 'image_count': 42}, {'id': 476, 'image_count': 17}, {'id': 477, 'image_count': 128}, {'id': 478, 'image_count': 1}, {'id': 479, 'image_count': 9}, {'id': 480, 'image_count': 10}, {'id': 481, 'image_count': 4}, {'id': 482, 'image_count': 9}, {'id': 483, 'image_count': 18}, {'id': 484, 'image_count': 41}, {'id': 485, 'image_count': 28}, {'id': 486, 'image_count': 3}, {'id': 487, 'image_count': 65}, {'id': 488, 'image_count': 9}, {'id': 489, 'image_count': 23}, {'id': 490, 'image_count': 24}, {'id': 491, 'image_count': 1}, {'id': 492, 'image_count': 2}, {'id': 493, 'image_count': 59}, {'id': 494, 'image_count': 48}, {'id': 495, 'image_count': 17}, {'id': 496, 'image_count': 1877}, {'id': 497, 'image_count': 18}, {'id': 498, 'image_count': 1920}, {'id': 499, 'image_count': 50}, {'id': 500, 'image_count': 1890}, {'id': 501, 'image_count': 99}, {'id': 502, 'image_count': 1530}, {'id': 503, 'image_count': 3}, {'id': 504, 'image_count': 11}, {'id': 505, 'image_count': 19}, {'id': 506, 'image_count': 3}, {'id': 507, 'image_count': 63}, {'id': 508, 'image_count': 5}, {'id': 509, 'image_count': 6}, {'id': 510, 'image_count': 233}, {'id': 511, 'image_count': 54}, {'id': 512, 'image_count': 36}, {'id': 513, 'image_count': 10}, {'id': 514, 'image_count': 124}, {'id': 515, 'image_count': 101}, {'id': 516, 'image_count': 3}, {'id': 517, 'image_count': 363}, {'id': 518, 'image_count': 3}, {'id': 519, 'image_count': 30}, {'id': 520, 'image_count': 18}, {'id': 521, 'image_count': 199}, {'id': 522, 'image_count': 97}, {'id': 523, 'image_count': 32}, {'id': 524, 'image_count': 121}, {'id': 525, 'image_count': 16}, {'id': 526, 'image_count': 12}, {'id': 527, 'image_count': 2}, {'id': 528, 'image_count': 214}, {'id': 529, 'image_count': 48}, {'id': 530, 'image_count': 26}, {'id': 531, 'image_count': 13}, {'id': 532, 'image_count': 4}, {'id': 533, 'image_count': 11}, {'id': 534, 'image_count': 123}, {'id': 535, 'image_count': 7}, {'id': 536, 'image_count': 200}, {'id': 537, 'image_count': 91}, {'id': 538, 'image_count': 9}, {'id': 539, 'image_count': 72}, {'id': 540, 'image_count': 1886}, {'id': 541, 'image_count': 4}, {'id': 542, 'image_count': 1}, {'id': 543, 'image_count': 1}, {'id': 544, 'image_count': 1932}, {'id': 545, 'image_count': 4}, {'id': 546, 'image_count': 56}, {'id': 547, 'image_count': 854}, {'id': 548, 'image_count': 755}, {'id': 549, 'image_count': 1843}, {'id': 550, 'image_count': 96}, {'id': 551, 'image_count': 7}, {'id': 552, 'image_count': 74}, {'id': 553, 'image_count': 66}, {'id': 554, 'image_count': 57}, {'id': 555, 'image_count': 44}, {'id': 556, 'image_count': 1905}, {'id': 557, 'image_count': 4}, {'id': 558, 'image_count': 90}, {'id': 559, 'image_count': 1635}, {'id': 560, 'image_count': 8}, {'id': 561, 'image_count': 5}, {'id': 562, 'image_count': 50}, {'id': 563, 'image_count': 545}, {'id': 564, 'image_count': 20}, {'id': 565, 'image_count': 193}, {'id': 566, 'image_count': 285}, {'id': 567, 'image_count': 3}, {'id': 568, 'image_count': 1}, {'id': 569, 'image_count': 1904}, {'id': 570, 'image_count': 294}, {'id': 571, 'image_count': 3}, {'id': 572, 'image_count': 5}, {'id': 573, 'image_count': 24}, {'id': 574, 'image_count': 2}, {'id': 575, 'image_count': 2}, {'id': 576, 'image_count': 16}, {'id': 577, 'image_count': 8}, {'id': 578, 'image_count': 154}, {'id': 579, 'image_count': 66}, {'id': 580, 'image_count': 1}, {'id': 581, 'image_count': 24}, {'id': 582, 'image_count': 1}, {'id': 583, 'image_count': 4}, {'id': 584, 'image_count': 75}, {'id': 585, 'image_count': 6}, {'id': 586, 'image_count': 126}, {'id': 587, 'image_count': 24}, {'id': 588, 'image_count': 22}, {'id': 589, 'image_count': 1872}, {'id': 590, 'image_count': 16}, {'id': 591, 'image_count': 423}, {'id': 592, 'image_count': 1927}, {'id': 593, 'image_count': 38}, {'id': 594, 'image_count': 3}, {'id': 595, 'image_count': 1945}, {'id': 596, 'image_count': 35}, {'id': 597, 'image_count': 1}, {'id': 598, 'image_count': 13}, {'id': 599, 'image_count': 9}, {'id': 600, 'image_count': 14}, {'id': 601, 'image_count': 37}, {'id': 602, 'image_count': 3}, {'id': 603, 'image_count': 4}, {'id': 604, 'image_count': 100}, {'id': 605, 'image_count': 195}, {'id': 606, 'image_count': 1}, {'id': 607, 'image_count': 12}, {'id': 608, 'image_count': 24}, {'id': 609, 'image_count': 489}, {'id': 610, 'image_count': 10}, {'id': 611, 'image_count': 1689}, {'id': 612, 'image_count': 42}, {'id': 613, 'image_count': 81}, {'id': 614, 'image_count': 894}, {'id': 615, 'image_count': 1868}, {'id': 616, 'image_count': 7}, {'id': 617, 'image_count': 1567}, {'id': 618, 'image_count': 10}, {'id': 619, 'image_count': 8}, {'id': 620, 'image_count': 7}, {'id': 621, 'image_count': 629}, {'id': 622, 'image_count': 89}, {'id': 623, 'image_count': 15}, {'id': 624, 'image_count': 134}, {'id': 625, 'image_count': 4}, {'id': 626, 'image_count': 1802}, {'id': 627, 'image_count': 595}, {'id': 628, 'image_count': 1210}, {'id': 629, 'image_count': 48}, {'id': 630, 'image_count': 418}, {'id': 631, 'image_count': 1846}, {'id': 632, 'image_count': 5}, {'id': 633, 'image_count': 221}, {'id': 634, 'image_count': 10}, {'id': 635, 'image_count': 7}, {'id': 636, 'image_count': 76}, {'id': 637, 'image_count': 22}, {'id': 638, 'image_count': 10}, {'id': 639, 'image_count': 341}, {'id': 640, 'image_count': 1}, {'id': 641, 'image_count': 705}, {'id': 642, 'image_count': 1900}, {'id': 643, 'image_count': 188}, {'id': 644, 'image_count': 227}, {'id': 645, 'image_count': 861}, {'id': 646, 'image_count': 6}, {'id': 647, 'image_count': 115}, {'id': 648, 'image_count': 5}, {'id': 649, 'image_count': 43}, {'id': 650, 'image_count': 14}, {'id': 651, 'image_count': 6}, {'id': 652, 'image_count': 15}, {'id': 653, 'image_count': 1167}, {'id': 654, 'image_count': 15}, {'id': 655, 'image_count': 994}, {'id': 656, 'image_count': 28}, {'id': 657, 'image_count': 2}, {'id': 658, 'image_count': 338}, {'id': 659, 'image_count': 334}, {'id': 660, 'image_count': 15}, {'id': 661, 'image_count': 102}, {'id': 662, 'image_count': 1}, {'id': 663, 'image_count': 8}, {'id': 664, 'image_count': 1}, {'id': 665, 'image_count': 1}, {'id': 666, 'image_count': 28}, {'id': 667, 'image_count': 91}, {'id': 668, 'image_count': 260}, {'id': 669, 'image_count': 131}, {'id': 670, 'image_count': 128}, {'id': 671, 'image_count': 3}, {'id': 672, 'image_count': 10}, {'id': 673, 'image_count': 39}, {'id': 674, 'image_count': 2}, {'id': 675, 'image_count': 925}, {'id': 676, 'image_count': 354}, {'id': 677, 'image_count': 31}, {'id': 678, 'image_count': 10}, {'id': 679, 'image_count': 215}, {'id': 680, 'image_count': 71}, {'id': 681, 'image_count': 43}, {'id': 682, 'image_count': 28}, {'id': 683, 'image_count': 34}, {'id': 684, 'image_count': 16}, {'id': 685, 'image_count': 273}, {'id': 686, 'image_count': 2}, {'id': 687, 'image_count': 999}, {'id': 688, 'image_count': 4}, {'id': 689, 'image_count': 107}, {'id': 690, 'image_count': 2}, {'id': 691, 'image_count': 1}, {'id': 692, 'image_count': 454}, {'id': 693, 'image_count': 9}, {'id': 694, 'image_count': 1901}, {'id': 695, 'image_count': 61}, {'id': 696, 'image_count': 91}, {'id': 697, 'image_count': 46}, {'id': 698, 'image_count': 1402}, {'id': 699, 'image_count': 74}, {'id': 700, 'image_count': 421}, {'id': 701, 'image_count': 226}, {'id': 702, 'image_count': 10}, {'id': 703, 'image_count': 1720}, {'id': 704, 'image_count': 261}, {'id': 705, 'image_count': 1337}, {'id': 706, 'image_count': 293}, {'id': 707, 'image_count': 62}, {'id': 708, 'image_count': 814}, {'id': 709, 'image_count': 407}, {'id': 710, 'image_count': 6}, {'id': 711, 'image_count': 16}, {'id': 712, 'image_count': 7}, {'id': 713, 'image_count': 1791}, {'id': 714, 'image_count': 2}, {'id': 715, 'image_count': 1915}, {'id': 716, 'image_count': 1940}, {'id': 717, 'image_count': 13}, {'id': 718, 'image_count': 16}, {'id': 719, 'image_count': 448}, {'id': 720, 'image_count': 12}, {'id': 721, 'image_count': 18}, {'id': 722, 'image_count': 4}, {'id': 723, 'image_count': 71}, {'id': 724, 'image_count': 189}, {'id': 725, 'image_count': 74}, {'id': 726, 'image_count': 103}, {'id': 727, 'image_count': 3}, {'id': 728, 'image_count': 110}, {'id': 729, 'image_count': 5}, {'id': 730, 'image_count': 9}, {'id': 731, 'image_count': 15}, {'id': 732, 'image_count': 25}, {'id': 733, 'image_count': 7}, {'id': 734, 'image_count': 647}, {'id': 735, 'image_count': 824}, {'id': 736, 'image_count': 100}, {'id': 737, 'image_count': 47}, {'id': 738, 'image_count': 121}, {'id': 739, 'image_count': 731}, {'id': 740, 'image_count': 73}, {'id': 741, 'image_count': 49}, {'id': 742, 'image_count': 23}, {'id': 743, 'image_count': 4}, {'id': 744, 'image_count': 62}, {'id': 745, 'image_count': 118}, {'id': 746, 'image_count': 99}, {'id': 747, 'image_count': 40}, {'id': 748, 'image_count': 1036}, {'id': 749, 'image_count': 105}, {'id': 750, 'image_count': 21}, {'id': 751, 'image_count': 229}, {'id': 752, 'image_count': 7}, {'id': 753, 'image_count': 72}, {'id': 754, 'image_count': 9}, {'id': 755, 'image_count': 10}, {'id': 756, 'image_count': 328}, {'id': 757, 'image_count': 468}, {'id': 758, 'image_count': 1}, {'id': 759, 'image_count': 2}, {'id': 760, 'image_count': 24}, {'id': 761, 'image_count': 11}, {'id': 762, 'image_count': 72}, {'id': 763, 'image_count': 17}, {'id': 764, 'image_count': 10}, {'id': 765, 'image_count': 17}, {'id': 766, 'image_count': 489}, {'id': 767, 'image_count': 47}, {'id': 768, 'image_count': 93}, {'id': 769, 'image_count': 1}, {'id': 770, 'image_count': 12}, {'id': 771, 'image_count': 228}, {'id': 772, 'image_count': 5}, {'id': 773, 'image_count': 76}, {'id': 774, 'image_count': 71}, {'id': 775, 'image_count': 30}, {'id': 776, 'image_count': 109}, {'id': 777, 'image_count': 14}, {'id': 778, 'image_count': 1}, {'id': 779, 'image_count': 8}, {'id': 780, 'image_count': 26}, {'id': 781, 'image_count': 339}, {'id': 782, 'image_count': 153}, {'id': 783, 'image_count': 2}, {'id': 784, 'image_count': 3}, {'id': 785, 'image_count': 8}, {'id': 786, 'image_count': 47}, {'id': 787, 'image_count': 8}, {'id': 788, 'image_count': 6}, {'id': 789, 'image_count': 116}, {'id': 790, 'image_count': 69}, {'id': 791, 'image_count': 13}, {'id': 792, 'image_count': 6}, {'id': 793, 'image_count': 1928}, {'id': 794, 'image_count': 79}, {'id': 795, 'image_count': 14}, {'id': 796, 'image_count': 7}, {'id': 797, 'image_count': 20}, {'id': 798, 'image_count': 114}, {'id': 799, 'image_count': 221}, {'id': 800, 'image_count': 502}, {'id': 801, 'image_count': 62}, {'id': 802, 'image_count': 87}, {'id': 803, 'image_count': 4}, {'id': 804, 'image_count': 1912}, {'id': 805, 'image_count': 7}, {'id': 806, 'image_count': 186}, {'id': 807, 'image_count': 18}, {'id': 808, 'image_count': 4}, {'id': 809, 'image_count': 3}, {'id': 810, 'image_count': 7}, {'id': 811, 'image_count': 1413}, {'id': 812, 'image_count': 7}, {'id': 813, 'image_count': 12}, {'id': 814, 'image_count': 248}, {'id': 815, 'image_count': 4}, {'id': 816, 'image_count': 1881}, {'id': 817, 'image_count': 529}, {'id': 818, 'image_count': 1932}, {'id': 819, 'image_count': 50}, {'id': 820, 'image_count': 3}, {'id': 821, 'image_count': 28}, {'id': 822, 'image_count': 10}, {'id': 823, 'image_count': 5}, {'id': 824, 'image_count': 5}, {'id': 825, 'image_count': 18}, {'id': 826, 'image_count': 14}, {'id': 827, 'image_count': 1890}, {'id': 828, 'image_count': 660}, {'id': 829, 'image_count': 8}, {'id': 830, 'image_count': 25}, {'id': 831, 'image_count': 10}, {'id': 832, 'image_count': 218}, {'id': 833, 'image_count': 36}, {'id': 834, 'image_count': 16}, {'id': 835, 'image_count': 808}, {'id': 836, 'image_count': 479}, {'id': 837, 'image_count': 1404}, {'id': 838, 'image_count': 307}, {'id': 839, 'image_count': 57}, {'id': 840, 'image_count': 28}, {'id': 841, 'image_count': 80}, {'id': 842, 'image_count': 11}, {'id': 843, 'image_count': 92}, {'id': 844, 'image_count': 20}, {'id': 845, 'image_count': 194}, {'id': 846, 'image_count': 23}, {'id': 847, 'image_count': 52}, {'id': 848, 'image_count': 673}, {'id': 849, 'image_count': 2}, {'id': 850, 'image_count': 2}, {'id': 851, 'image_count': 1}, {'id': 852, 'image_count': 2}, {'id': 853, 'image_count': 8}, {'id': 854, 'image_count': 80}, {'id': 855, 'image_count': 3}, {'id': 856, 'image_count': 3}, {'id': 857, 'image_count': 15}, {'id': 858, 'image_count': 2}, {'id': 859, 'image_count': 10}, {'id': 860, 'image_count': 386}, {'id': 861, 'image_count': 65}, {'id': 862, 'image_count': 3}, {'id': 863, 'image_count': 35}, {'id': 864, 'image_count': 5}, {'id': 865, 'image_count': 180}, {'id': 866, 'image_count': 99}, {'id': 867, 'image_count': 49}, {'id': 868, 'image_count': 28}, {'id': 869, 'image_count': 1}, {'id': 870, 'image_count': 52}, {'id': 871, 'image_count': 36}, {'id': 872, 'image_count': 70}, {'id': 873, 'image_count': 6}, {'id': 874, 'image_count': 29}, {'id': 875, 'image_count': 24}, {'id': 876, 'image_count': 1115}, {'id': 877, 'image_count': 61}, {'id': 878, 'image_count': 18}, {'id': 879, 'image_count': 18}, {'id': 880, 'image_count': 665}, {'id': 881, 'image_count': 1096}, {'id': 882, 'image_count': 29}, {'id': 883, 'image_count': 8}, {'id': 884, 'image_count': 14}, {'id': 885, 'image_count': 1622}, {'id': 886, 'image_count': 2}, {'id': 887, 'image_count': 3}, {'id': 888, 'image_count': 32}, {'id': 889, 'image_count': 55}, {'id': 890, 'image_count': 1}, {'id': 891, 'image_count': 10}, {'id': 892, 'image_count': 10}, {'id': 893, 'image_count': 47}, {'id': 894, 'image_count': 3}, {'id': 895, 'image_count': 29}, {'id': 896, 'image_count': 342}, {'id': 897, 'image_count': 25}, {'id': 898, 'image_count': 1469}, {'id': 899, 'image_count': 521}, {'id': 900, 'image_count': 347}, {'id': 901, 'image_count': 35}, {'id': 902, 'image_count': 7}, {'id': 903, 'image_count': 207}, {'id': 904, 'image_count': 108}, {'id': 905, 'image_count': 2}, {'id': 906, 'image_count': 34}, {'id': 907, 'image_count': 12}, {'id': 908, 'image_count': 10}, {'id': 909, 'image_count': 13}, {'id': 910, 'image_count': 361}, {'id': 911, 'image_count': 1023}, {'id': 912, 'image_count': 782}, {'id': 913, 'image_count': 2}, {'id': 914, 'image_count': 5}, {'id': 915, 'image_count': 247}, {'id': 916, 'image_count': 221}, {'id': 917, 'image_count': 4}, {'id': 918, 'image_count': 8}, {'id': 919, 'image_count': 158}, {'id': 920, 'image_count': 3}, {'id': 921, 'image_count': 752}, {'id': 922, 'image_count': 64}, {'id': 923, 'image_count': 707}, {'id': 924, 'image_count': 143}, {'id': 925, 'image_count': 1}, {'id': 926, 'image_count': 49}, {'id': 927, 'image_count': 126}, {'id': 928, 'image_count': 76}, {'id': 929, 'image_count': 11}, {'id': 930, 'image_count': 11}, {'id': 931, 'image_count': 4}, {'id': 932, 'image_count': 39}, {'id': 933, 'image_count': 11}, {'id': 934, 'image_count': 13}, {'id': 935, 'image_count': 91}, {'id': 936, 'image_count': 14}, {'id': 937, 'image_count': 5}, {'id': 938, 'image_count': 3}, {'id': 939, 'image_count': 10}, {'id': 940, 'image_count': 18}, {'id': 941, 'image_count': 9}, {'id': 942, 'image_count': 6}, {'id': 943, 'image_count': 951}, {'id': 944, 'image_count': 2}, {'id': 945, 'image_count': 1}, {'id': 946, 'image_count': 19}, {'id': 947, 'image_count': 1942}, {'id': 948, 'image_count': 1916}, {'id': 949, 'image_count': 139}, {'id': 950, 'image_count': 43}, {'id': 951, 'image_count': 1969}, {'id': 952, 'image_count': 5}, {'id': 953, 'image_count': 134}, {'id': 954, 'image_count': 74}, {'id': 955, 'image_count': 381}, {'id': 956, 'image_count': 1}, {'id': 957, 'image_count': 381}, {'id': 958, 'image_count': 6}, {'id': 959, 'image_count': 1826}, {'id': 960, 'image_count': 28}, {'id': 961, 'image_count': 1635}, {'id': 962, 'image_count': 1967}, {'id': 963, 'image_count': 16}, {'id': 964, 'image_count': 1926}, {'id': 965, 'image_count': 1789}, {'id': 966, 'image_count': 401}, {'id': 967, 'image_count': 1968}, {'id': 968, 'image_count': 1167}, {'id': 969, 'image_count': 1}, {'id': 970, 'image_count': 56}, {'id': 971, 'image_count': 17}, {'id': 972, 'image_count': 1}, {'id': 973, 'image_count': 58}, {'id': 974, 'image_count': 9}, {'id': 975, 'image_count': 8}, {'id': 976, 'image_count': 1124}, {'id': 977, 'image_count': 31}, {'id': 978, 'image_count': 16}, {'id': 979, 'image_count': 491}, {'id': 980, 'image_count': 432}, {'id': 981, 'image_count': 1945}, {'id': 982, 'image_count': 1899}, {'id': 983, 'image_count': 5}, {'id': 984, 'image_count': 28}, {'id': 985, 'image_count': 7}, {'id': 986, 'image_count': 146}, {'id': 987, 'image_count': 1}, {'id': 988, 'image_count': 25}, {'id': 989, 'image_count': 22}, {'id': 990, 'image_count': 1}, {'id': 991, 'image_count': 10}, {'id': 992, 'image_count': 9}, {'id': 993, 'image_count': 308}, {'id': 994, 'image_count': 4}, {'id': 995, 'image_count': 1969}, {'id': 996, 'image_count': 45}, {'id': 997, 'image_count': 12}, {'id': 998, 'image_count': 1}, {'id': 999, 'image_count': 85}, {'id': 1000, 'image_count': 1127}, {'id': 1001, 'image_count': 11}, {'id': 1002, 'image_count': 60}, {'id': 1003, 'image_count': 1}, {'id': 1004, 'image_count': 16}, {'id': 1005, 'image_count': 1}, {'id': 1006, 'image_count': 65}, {'id': 1007, 'image_count': 13}, {'id': 1008, 'image_count': 655}, {'id': 1009, 'image_count': 51}, {'id': 1010, 'image_count': 1}, {'id': 1011, 'image_count': 673}, {'id': 1012, 'image_count': 5}, {'id': 1013, 'image_count': 36}, {'id': 1014, 'image_count': 54}, {'id': 1015, 'image_count': 5}, {'id': 1016, 'image_count': 8}, {'id': 1017, 'image_count': 305}, {'id': 1018, 'image_count': 297}, {'id': 1019, 'image_count': 1053}, {'id': 1020, 'image_count': 223}, {'id': 1021, 'image_count': 1037}, {'id': 1022, 'image_count': 63}, {'id': 1023, 'image_count': 1881}, {'id': 1024, 'image_count': 507}, {'id': 1025, 'image_count': 333}, {'id': 1026, 'image_count': 1911}, {'id': 1027, 'image_count': 1765}, {'id': 1028, 'image_count': 1}, {'id': 1029, 'image_count': 5}, {'id': 1030, 'image_count': 1}, {'id': 1031, 'image_count': 9}, {'id': 1032, 'image_count': 2}, {'id': 1033, 'image_count': 151}, {'id': 1034, 'image_count': 82}, {'id': 1035, 'image_count': 1931}, {'id': 1036, 'image_count': 41}, {'id': 1037, 'image_count': 1895}, {'id': 1038, 'image_count': 24}, {'id': 1039, 'image_count': 22}, {'id': 1040, 'image_count': 35}, {'id': 1041, 'image_count': 69}, {'id': 1042, 'image_count': 962}, {'id': 1043, 'image_count': 588}, {'id': 1044, 'image_count': 21}, {'id': 1045, 'image_count': 825}, {'id': 1046, 'image_count': 52}, {'id': 1047, 'image_count': 5}, {'id': 1048, 'image_count': 5}, {'id': 1049, 'image_count': 5}, {'id': 1050, 'image_count': 1860}, {'id': 1051, 'image_count': 56}, {'id': 1052, 'image_count': 1582}, {'id': 1053, 'image_count': 7}, {'id': 1054, 'image_count': 2}, {'id': 1055, 'image_count': 1562}, {'id': 1056, 'image_count': 1885}, {'id': 1057, 'image_count': 1}, {'id': 1058, 'image_count': 5}, {'id': 1059, 'image_count': 137}, {'id': 1060, 'image_count': 1094}, {'id': 1061, 'image_count': 134}, {'id': 1062, 'image_count': 29}, {'id': 1063, 'image_count': 22}, {'id': 1064, 'image_count': 522}, {'id': 1065, 'image_count': 50}, {'id': 1066, 'image_count': 68}, {'id': 1067, 'image_count': 16}, {'id': 1068, 'image_count': 40}, {'id': 1069, 'image_count': 35}, {'id': 1070, 'image_count': 135}, {'id': 1071, 'image_count': 1413}, {'id': 1072, 'image_count': 772}, {'id': 1073, 'image_count': 50}, {'id': 1074, 'image_count': 1015}, {'id': 1075, 'image_count': 1}, {'id': 1076, 'image_count': 65}, {'id': 1077, 'image_count': 1900}, {'id': 1078, 'image_count': 1302}, {'id': 1079, 'image_count': 1977}, {'id': 1080, 'image_count': 2}, {'id': 1081, 'image_count': 29}, {'id': 1082, 'image_count': 36}, {'id': 1083, 'image_count': 138}, {'id': 1084, 'image_count': 4}, {'id': 1085, 'image_count': 67}, {'id': 1086, 'image_count': 26}, {'id': 1087, 'image_count': 25}, {'id': 1088, 'image_count': 33}, {'id': 1089, 'image_count': 37}, {'id': 1090, 'image_count': 50}, {'id': 1091, 'image_count': 270}, {'id': 1092, 'image_count': 12}, {'id': 1093, 'image_count': 316}, {'id': 1094, 'image_count': 41}, {'id': 1095, 'image_count': 224}, {'id': 1096, 'image_count': 105}, {'id': 1097, 'image_count': 1925}, {'id': 1098, 'image_count': 1021}, {'id': 1099, 'image_count': 1213}, {'id': 1100, 'image_count': 172}, {'id': 1101, 'image_count': 28}, {'id': 1102, 'image_count': 745}, {'id': 1103, 'image_count': 187}, {'id': 1104, 'image_count': 147}, {'id': 1105, 'image_count': 136}, {'id': 1106, 'image_count': 34}, {'id': 1107, 'image_count': 41}, {'id': 1108, 'image_count': 636}, {'id': 1109, 'image_count': 570}, {'id': 1110, 'image_count': 1149}, {'id': 1111, 'image_count': 61}, {'id': 1112, 'image_count': 1890}, {'id': 1113, 'image_count': 18}, {'id': 1114, 'image_count': 143}, {'id': 1115, 'image_count': 1517}, {'id': 1116, 'image_count': 7}, {'id': 1117, 'image_count': 943}, {'id': 1118, 'image_count': 6}, {'id': 1119, 'image_count': 1}, {'id': 1120, 'image_count': 11}, {'id': 1121, 'image_count': 101}, {'id': 1122, 'image_count': 1909}, {'id': 1123, 'image_count': 800}, {'id': 1124, 'image_count': 1}, {'id': 1125, 'image_count': 44}, {'id': 1126, 'image_count': 3}, {'id': 1127, 'image_count': 44}, {'id': 1128, 'image_count': 31}, {'id': 1129, 'image_count': 7}, {'id': 1130, 'image_count': 20}, {'id': 1131, 'image_count': 11}, {'id': 1132, 'image_count': 13}, {'id': 1133, 'image_count': 1924}, {'id': 1134, 'image_count': 113}, {'id': 1135, 'image_count': 2}, {'id': 1136, 'image_count': 139}, {'id': 1137, 'image_count': 12}, {'id': 1138, 'image_count': 37}, {'id': 1139, 'image_count': 1866}, {'id': 1140, 'image_count': 47}, {'id': 1141, 'image_count': 1468}, {'id': 1142, 'image_count': 729}, {'id': 1143, 'image_count': 24}, {'id': 1144, 'image_count': 1}, {'id': 1145, 'image_count': 10}, {'id': 1146, 'image_count': 3}, {'id': 1147, 'image_count': 14}, {'id': 1148, 'image_count': 4}, {'id': 1149, 'image_count': 29}, {'id': 1150, 'image_count': 4}, {'id': 1151, 'image_count': 70}, {'id': 1152, 'image_count': 46}, {'id': 1153, 'image_count': 14}, {'id': 1154, 'image_count': 48}, {'id': 1155, 'image_count': 1855}, {'id': 1156, 'image_count': 113}, {'id': 1157, 'image_count': 1}, {'id': 1158, 'image_count': 1}, {'id': 1159, 'image_count': 10}, {'id': 1160, 'image_count': 54}, {'id': 1161, 'image_count': 1923}, {'id': 1162, 'image_count': 630}, {'id': 1163, 'image_count': 31}, {'id': 1164, 'image_count': 69}, {'id': 1165, 'image_count': 7}, {'id': 1166, 'image_count': 11}, {'id': 1167, 'image_count': 1}, {'id': 1168, 'image_count': 30}, {'id': 1169, 'image_count': 50}, {'id': 1170, 'image_count': 45}, {'id': 1171, 'image_count': 28}, {'id': 1172, 'image_count': 114}, {'id': 1173, 'image_count': 193}, {'id': 1174, 'image_count': 21}, {'id': 1175, 'image_count': 91}, {'id': 1176, 'image_count': 31}, {'id': 1177, 'image_count': 1469}, {'id': 1178, 'image_count': 1924}, {'id': 1179, 'image_count': 87}, {'id': 1180, 'image_count': 77}, {'id': 1181, 'image_count': 11}, {'id': 1182, 'image_count': 47}, {'id': 1183, 'image_count': 21}, {'id': 1184, 'image_count': 47}, {'id': 1185, 'image_count': 70}, {'id': 1186, 'image_count': 1838}, {'id': 1187, 'image_count': 19}, {'id': 1188, 'image_count': 531}, {'id': 1189, 'image_count': 11}, {'id': 1190, 'image_count': 941}, {'id': 1191, 'image_count': 113}, {'id': 1192, 'image_count': 26}, {'id': 1193, 'image_count': 5}, {'id': 1194, 'image_count': 56}, {'id': 1195, 'image_count': 73}, {'id': 1196, 'image_count': 32}, {'id': 1197, 'image_count': 128}, {'id': 1198, 'image_count': 623}, {'id': 1199, 'image_count': 12}, {'id': 1200, 'image_count': 52}, {'id': 1201, 'image_count': 11}, {'id': 1202, 'image_count': 1674}, {'id': 1203, 'image_count': 81}] # noqa +# fmt: on diff --git a/custom_detectron2/data/datasets/pascal_voc.py b/custom_detectron2/data/datasets/pascal_voc.py new file mode 100644 index 0000000000000000000000000000000000000000..d67ce3ee5a6e323bd0859f8d6ce5c84a6c3cc21b --- /dev/null +++ b/custom_detectron2/data/datasets/pascal_voc.py @@ -0,0 +1,82 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import numpy as np +import os +import xml.etree.ElementTree as ET +from typing import List, Tuple, Union + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog +from custom_detectron2.structures import BoxMode +from custom_detectron2.utils.file_io import PathManager + +__all__ = ["load_voc_instances", "register_pascal_voc"] + + +# fmt: off +CLASS_NAMES = ( + "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", + "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", + "pottedplant", "sheep", "sofa", "train", "tvmonitor" +) +# fmt: on + + +def load_voc_instances(dirname: str, split: str, class_names: Union[List[str], Tuple[str, ...]]): + """ + Load Pascal VOC detection annotations to Detectron2 format. + + Args: + dirname: Contain "Annotations", "ImageSets", "JPEGImages" + split (str): one of "train", "test", "val", "trainval" + class_names: list or tuple of class names + """ + with PathManager.open(os.path.join(dirname, "ImageSets", "Main", split + ".txt")) as f: + fileids = np.loadtxt(f, dtype=np.str) + + # Needs to read many small annotation files. Makes sense at local + annotation_dirname = PathManager.get_local_path(os.path.join(dirname, "Annotations/")) + dicts = [] + for fileid in fileids: + anno_file = os.path.join(annotation_dirname, fileid + ".xml") + jpeg_file = os.path.join(dirname, "JPEGImages", fileid + ".jpg") + + with PathManager.open(anno_file) as f: + tree = ET.parse(f) + + r = { + "file_name": jpeg_file, + "image_id": fileid, + "height": int(tree.findall("./size/height")[0].text), + "width": int(tree.findall("./size/width")[0].text), + } + instances = [] + + for obj in tree.findall("object"): + cls = obj.find("name").text + # We include "difficult" samples in training. + # Based on limited experiments, they don't hurt accuracy. + # difficult = int(obj.find("difficult").text) + # if difficult == 1: + # continue + bbox = obj.find("bndbox") + bbox = [float(bbox.find(x).text) for x in ["xmin", "ymin", "xmax", "ymax"]] + # Original annotations are integers in the range [1, W or H] + # Assuming they mean 1-based pixel indices (inclusive), + # a box with annotation (xmin=1, xmax=W) covers the whole image. + # In coordinate space this is represented by (xmin=0, xmax=W) + bbox[0] -= 1.0 + bbox[1] -= 1.0 + instances.append( + {"category_id": class_names.index(cls), "bbox": bbox, "bbox_mode": BoxMode.XYXY_ABS} + ) + r["annotations"] = instances + dicts.append(r) + return dicts + + +def register_pascal_voc(name, dirname, split, year, class_names=CLASS_NAMES): + DatasetCatalog.register(name, lambda: load_voc_instances(dirname, split, class_names)) + MetadataCatalog.get(name).set( + thing_classes=list(class_names), dirname=dirname, year=year, split=split + ) diff --git a/custom_detectron2/data/datasets/register_coco.py b/custom_detectron2/data/datasets/register_coco.py new file mode 100644 index 0000000000000000000000000000000000000000..e564438d5bf016bcdbb65b4bbdc215d79f579f8a --- /dev/null +++ b/custom_detectron2/data/datasets/register_coco.py @@ -0,0 +1,3 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .coco import register_coco_instances # noqa +from .coco_panoptic import register_coco_panoptic_separated # noqa diff --git a/custom_detectron2/data/detection_utils.py b/custom_detectron2/data/detection_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..425f3e2d501b1f68bec23f8d662a6aff83ddcebc --- /dev/null +++ b/custom_detectron2/data/detection_utils.py @@ -0,0 +1,659 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +Common data processing utilities that are used in a +typical object detection data pipeline. +""" +import logging +import numpy as np +from typing import List, Union +import custom_pycocotools.mask as mask_util +import torch +from PIL import Image + +from custom_detectron2.structures import ( + BitMasks, + Boxes, + BoxMode, + Instances, + Keypoints, + PolygonMasks, + RotatedBoxes, + polygons_to_bitmask, +) +from custom_detectron2.utils.file_io import PathManager + +from . import transforms as T +from .catalog import MetadataCatalog + +__all__ = [ + "SizeMismatchError", + "convert_image_to_rgb", + "check_image_size", + "transform_proposals", + "transform_instance_annotations", + "annotations_to_instances", + "annotations_to_instances_rotated", + "build_augmentation", + "build_transform_gen", + "create_keypoint_hflip_indices", + "filter_empty_instances", + "read_image", +] + + +class SizeMismatchError(ValueError): + """ + When loaded image has difference width/height compared with annotation. + """ + + +# https://en.wikipedia.org/wiki/YUV#SDTV_with_BT.601 +_M_RGB2YUV = [[0.299, 0.587, 0.114], [-0.14713, -0.28886, 0.436], [0.615, -0.51499, -0.10001]] +_M_YUV2RGB = [[1.0, 0.0, 1.13983], [1.0, -0.39465, -0.58060], [1.0, 2.03211, 0.0]] + +# https://www.exiv2.org/tags.html +_EXIF_ORIENT = 274 # exif 'Orientation' tag + + +def convert_PIL_to_numpy(image, format): + """ + Convert PIL image to numpy array of target format. + + Args: + image (PIL.Image): a PIL image + format (str): the format of output image + + Returns: + (np.ndarray): also see `read_image` + """ + if format is not None: + # PIL only supports RGB, so convert to RGB and flip channels over below + conversion_format = format + if format in ["BGR", "YUV-BT.601"]: + conversion_format = "RGB" + image = image.convert(conversion_format) + image = np.asarray(image) + # PIL squeezes out the channel dimension for "L", so make it HWC + if format == "L": + image = np.expand_dims(image, -1) + + # handle formats not supported by PIL + elif format == "BGR": + # flip channels if needed + image = image[:, :, ::-1] + elif format == "YUV-BT.601": + image = image / 255.0 + image = np.dot(image, np.array(_M_RGB2YUV).T) + + return image + + +def convert_image_to_rgb(image, format): + """ + Convert an image from given format to RGB. + + Args: + image (np.ndarray or Tensor): an HWC image + format (str): the format of input image, also see `read_image` + + Returns: + (np.ndarray): (H,W,3) RGB image in 0-255 range, can be either float or uint8 + """ + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + if format == "BGR": + image = image[:, :, [2, 1, 0]] + elif format == "YUV-BT.601": + image = np.dot(image, np.array(_M_YUV2RGB).T) + image = image * 255.0 + else: + if format == "L": + image = image[:, :, 0] + image = image.astype(np.uint8) + image = np.asarray(Image.fromarray(image, mode=format).convert("RGB")) + return image + + +def _apply_exif_orientation(image): + """ + Applies the exif orientation correctly. + + This code exists per the bug: + https://github.com/python-pillow/Pillow/issues/3973 + with the function `ImageOps.exif_transpose`. The Pillow source raises errors with + various methods, especially `tobytes` + + Function based on: + https://github.com/wkentaro/labelme/blob/v4.5.4/labelme/utils/image.py#L59 + https://github.com/python-pillow/Pillow/blob/7.1.2/src/PIL/ImageOps.py#L527 + + Args: + image (PIL.Image): a PIL image + + Returns: + (PIL.Image): the PIL image with exif orientation applied, if applicable + """ + if not hasattr(image, "getexif"): + return image + + try: + exif = image.getexif() + except Exception: # https://github.com/facebookresearch/detectron2/issues/1885 + exif = None + + if exif is None: + return image + + orientation = exif.get(_EXIF_ORIENT) + + method = { + 2: Image.FLIP_LEFT_RIGHT, + 3: Image.ROTATE_180, + 4: Image.FLIP_TOP_BOTTOM, + 5: Image.TRANSPOSE, + 6: Image.ROTATE_270, + 7: Image.TRANSVERSE, + 8: Image.ROTATE_90, + }.get(orientation) + + if method is not None: + return image.transpose(method) + return image + + +def read_image(file_name, format=None): + """ + Read an image into the given format. + Will apply rotation and flipping if the image has such exif information. + + Args: + file_name (str): image file path + format (str): one of the supported image modes in PIL, or "BGR" or "YUV-BT.601". + + Returns: + image (np.ndarray): + an HWC image in the given format, which is 0-255, uint8 for + supported image modes in PIL or "BGR"; float (0-1 for Y) for YUV-BT.601. + """ + with PathManager.open(file_name, "rb") as f: + image = Image.open(f) + + # work around this bug: https://github.com/python-pillow/Pillow/issues/3973 + image = _apply_exif_orientation(image) + return convert_PIL_to_numpy(image, format) + + +def check_image_size(dataset_dict, image): + """ + Raise an error if the image does not match the size specified in the dict. + """ + if "width" in dataset_dict or "height" in dataset_dict: + image_wh = (image.shape[1], image.shape[0]) + expected_wh = (dataset_dict["width"], dataset_dict["height"]) + if not image_wh == expected_wh: + raise SizeMismatchError( + "Mismatched image shape{}, got {}, expect {}.".format( + " for image " + dataset_dict["file_name"] + if "file_name" in dataset_dict + else "", + image_wh, + expected_wh, + ) + + " Please check the width/height in your annotation." + ) + + # To ensure bbox always remap to original image size + if "width" not in dataset_dict: + dataset_dict["width"] = image.shape[1] + if "height" not in dataset_dict: + dataset_dict["height"] = image.shape[0] + + +def transform_proposals(dataset_dict, image_shape, transforms, *, proposal_topk, min_box_size=0): + """ + Apply transformations to the proposals in dataset_dict, if any. + + Args: + dataset_dict (dict): a dict read from the dataset, possibly + contains fields "proposal_boxes", "proposal_objectness_logits", "proposal_bbox_mode" + image_shape (tuple): height, width + transforms (TransformList): + proposal_topk (int): only keep top-K scoring proposals + min_box_size (int): proposals with either side smaller than this + threshold are removed + + The input dict is modified in-place, with abovementioned keys removed. A new + key "proposals" will be added. Its value is an `Instances` + object which contains the transformed proposals in its field + "proposal_boxes" and "objectness_logits". + """ + if "proposal_boxes" in dataset_dict: + # Transform proposal boxes + boxes = transforms.apply_box( + BoxMode.convert( + dataset_dict.pop("proposal_boxes"), + dataset_dict.pop("proposal_bbox_mode"), + BoxMode.XYXY_ABS, + ) + ) + boxes = Boxes(boxes) + objectness_logits = torch.as_tensor( + dataset_dict.pop("proposal_objectness_logits").astype("float32") + ) + + boxes.clip(image_shape) + keep = boxes.nonempty(threshold=min_box_size) + boxes = boxes[keep] + objectness_logits = objectness_logits[keep] + + proposals = Instances(image_shape) + proposals.proposal_boxes = boxes[:proposal_topk] + proposals.objectness_logits = objectness_logits[:proposal_topk] + dataset_dict["proposals"] = proposals + + +def get_bbox(annotation): + """ + Get bbox from data + Args: + annotation (dict): dict of instance annotations for a single instance. + Returns: + bbox (ndarray): x1, y1, x2, y2 coordinates + """ + # bbox is 1d (per-instance bounding box) + bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) + return bbox + + +def transform_instance_annotations( + annotation, transforms, image_size, *, keypoint_hflip_indices=None +): + """ + Apply transforms to box, segmentation and keypoints annotations of a single instance. + + It will use `transforms.apply_box` for the box, and + `transforms.apply_coords` for segmentation polygons & keypoints. + If you need anything more specially designed for each data structure, + you'll need to implement your own version of this function or the transforms. + + Args: + annotation (dict): dict of instance annotations for a single instance. + It will be modified in-place. + transforms (TransformList or list[Transform]): + image_size (tuple): the height, width of the transformed image + keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. + + Returns: + dict: + the same input dict with fields "bbox", "segmentation", "keypoints" + transformed according to `transforms`. + The "bbox_mode" field will be set to XYXY_ABS. + """ + if isinstance(transforms, (tuple, list)): + transforms = T.TransformList(transforms) + # bbox is 1d (per-instance bounding box) + bbox = BoxMode.convert(annotation["bbox"], annotation["bbox_mode"], BoxMode.XYXY_ABS) + # clip transformed bbox to image size + bbox = transforms.apply_box(np.array([bbox]))[0].clip(min=0) + annotation["bbox"] = np.minimum(bbox, list(image_size + image_size)[::-1]) + annotation["bbox_mode"] = BoxMode.XYXY_ABS + + if "segmentation" in annotation: + # each instance contains 1 or more polygons + segm = annotation["segmentation"] + if isinstance(segm, list): + # polygons + polygons = [np.asarray(p).reshape(-1, 2) for p in segm] + annotation["segmentation"] = [ + p.reshape(-1) for p in transforms.apply_polygons(polygons) + ] + elif isinstance(segm, dict): + # RLE + mask = mask_util.decode(segm) + mask = transforms.apply_segmentation(mask) + assert tuple(mask.shape[:2]) == image_size + annotation["segmentation"] = mask + else: + raise ValueError( + "Cannot transform segmentation of type '{}'!" + "Supported types are: polygons as list[list[float] or ndarray]," + " COCO-style RLE as a dict.".format(type(segm)) + ) + + if "keypoints" in annotation: + keypoints = transform_keypoint_annotations( + annotation["keypoints"], transforms, image_size, keypoint_hflip_indices + ) + annotation["keypoints"] = keypoints + + return annotation + + +def transform_keypoint_annotations(keypoints, transforms, image_size, keypoint_hflip_indices=None): + """ + Transform keypoint annotations of an image. + If a keypoint is transformed out of image boundary, it will be marked "unlabeled" (visibility=0) + + Args: + keypoints (list[float]): Nx3 float in Detectron2's Dataset format. + Each point is represented by (x, y, visibility). + transforms (TransformList): + image_size (tuple): the height, width of the transformed image + keypoint_hflip_indices (ndarray[int]): see `create_keypoint_hflip_indices`. + When `transforms` includes horizontal flip, will use the index + mapping to flip keypoints. + """ + # (N*3,) -> (N, 3) + keypoints = np.asarray(keypoints, dtype="float64").reshape(-1, 3) + keypoints_xy = transforms.apply_coords(keypoints[:, :2]) + + # Set all out-of-boundary points to "unlabeled" + inside = (keypoints_xy >= np.array([0, 0])) & (keypoints_xy <= np.array(image_size[::-1])) + inside = inside.all(axis=1) + keypoints[:, :2] = keypoints_xy + keypoints[:, 2][~inside] = 0 + + # This assumes that HorizFlipTransform is the only one that does flip + do_hflip = sum(isinstance(t, T.HFlipTransform) for t in transforms.transforms) % 2 == 1 + + # Alternative way: check if probe points was horizontally flipped. + # probe = np.asarray([[0.0, 0.0], [image_width, 0.0]]) + # probe_aug = transforms.apply_coords(probe.copy()) + # do_hflip = np.sign(probe[1][0] - probe[0][0]) != np.sign(probe_aug[1][0] - probe_aug[0][0]) # noqa + + # If flipped, swap each keypoint with its opposite-handed equivalent + if do_hflip: + if keypoint_hflip_indices is None: + raise ValueError("Cannot flip keypoints without providing flip indices!") + if len(keypoints) != len(keypoint_hflip_indices): + raise ValueError( + "Keypoint data has {} points, but metadata " + "contains {} points!".format(len(keypoints), len(keypoint_hflip_indices)) + ) + keypoints = keypoints[np.asarray(keypoint_hflip_indices, dtype=np.int32), :] + + # Maintain COCO convention that if visibility == 0 (unlabeled), then x, y = 0 + keypoints[keypoints[:, 2] == 0] = 0 + return keypoints + + +def annotations_to_instances(annos, image_size, mask_format="polygon"): + """ + Create an :class:`Instances` object used by the models, + from instance annotations in the dataset dict. + + Args: + annos (list[dict]): a list of instance annotations in one image, each + element for one instance. + image_size (tuple): height, width + + Returns: + Instances: + It will contain fields "gt_boxes", "gt_classes", + "gt_masks", "gt_keypoints", if they can be obtained from `annos`. + This is the format that builtin models expect. + """ + boxes = ( + np.stack( + [BoxMode.convert(obj["bbox"], obj["bbox_mode"], BoxMode.XYXY_ABS) for obj in annos] + ) + if len(annos) + else np.zeros((0, 4)) + ) + target = Instances(image_size) + target.gt_boxes = Boxes(boxes) + + classes = [int(obj["category_id"]) for obj in annos] + classes = torch.tensor(classes, dtype=torch.int64) + target.gt_classes = classes + + if len(annos) and "segmentation" in annos[0]: + segms = [obj["segmentation"] for obj in annos] + if mask_format == "polygon": + try: + masks = PolygonMasks(segms) + except ValueError as e: + raise ValueError( + "Failed to use mask_format=='polygon' from the given annotations!" + ) from e + else: + assert mask_format == "bitmask", mask_format + masks = [] + for segm in segms: + if isinstance(segm, list): + # polygon + masks.append(polygons_to_bitmask(segm, *image_size)) + elif isinstance(segm, dict): + # COCO RLE + masks.append(mask_util.decode(segm)) + elif isinstance(segm, np.ndarray): + assert segm.ndim == 2, "Expect segmentation of 2 dimensions, got {}.".format( + segm.ndim + ) + # mask array + masks.append(segm) + else: + raise ValueError( + "Cannot convert segmentation of type '{}' to BitMasks!" + "Supported types are: polygons as list[list[float] or ndarray]," + " COCO-style RLE as a dict, or a binary segmentation mask " + " in a 2D numpy array of shape HxW.".format(type(segm)) + ) + # torch.from_numpy does not support array with negative stride. + masks = BitMasks( + torch.stack([torch.from_numpy(np.ascontiguousarray(x)) for x in masks]) + ) + target.gt_masks = masks + + if len(annos) and "keypoints" in annos[0]: + kpts = [obj.get("keypoints", []) for obj in annos] + target.gt_keypoints = Keypoints(kpts) + + return target + + +def annotations_to_instances_rotated(annos, image_size): + """ + Create an :class:`Instances` object used by the models, + from instance annotations in the dataset dict. + Compared to `annotations_to_instances`, this function is for rotated boxes only + + Args: + annos (list[dict]): a list of instance annotations in one image, each + element for one instance. + image_size (tuple): height, width + + Returns: + Instances: + Containing fields "gt_boxes", "gt_classes", + if they can be obtained from `annos`. + This is the format that builtin models expect. + """ + boxes = [obj["bbox"] for obj in annos] + target = Instances(image_size) + boxes = target.gt_boxes = RotatedBoxes(boxes) + boxes.clip(image_size) + + classes = [obj["category_id"] for obj in annos] + classes = torch.tensor(classes, dtype=torch.int64) + target.gt_classes = classes + + return target + + +def filter_empty_instances( + instances, by_box=True, by_mask=True, box_threshold=1e-5, return_mask=False +): + """ + Filter out empty instances in an `Instances` object. + + Args: + instances (Instances): + by_box (bool): whether to filter out instances with empty boxes + by_mask (bool): whether to filter out instances with empty masks + box_threshold (float): minimum width and height to be considered non-empty + return_mask (bool): whether to return boolean mask of filtered instances + + Returns: + Instances: the filtered instances. + tensor[bool], optional: boolean mask of filtered instances + """ + assert by_box or by_mask + r = [] + if by_box: + r.append(instances.gt_boxes.nonempty(threshold=box_threshold)) + if instances.has("gt_masks") and by_mask: + r.append(instances.gt_masks.nonempty()) + + # TODO: can also filter visible keypoints + + if not r: + return instances + m = r[0] + for x in r[1:]: + m = m & x + if return_mask: + return instances[m], m + return instances[m] + + +def create_keypoint_hflip_indices(dataset_names: Union[str, List[str]]) -> List[int]: + """ + Args: + dataset_names: list of dataset names + + Returns: + list[int]: a list of size=#keypoints, storing the + horizontally-flipped keypoint indices. + """ + if isinstance(dataset_names, str): + dataset_names = [dataset_names] + + check_metadata_consistency("keypoint_names", dataset_names) + check_metadata_consistency("keypoint_flip_map", dataset_names) + + meta = MetadataCatalog.get(dataset_names[0]) + names = meta.keypoint_names + # TODO flip -> hflip + flip_map = dict(meta.keypoint_flip_map) + flip_map.update({v: k for k, v in flip_map.items()}) + flipped_names = [i if i not in flip_map else flip_map[i] for i in names] + flip_indices = [names.index(i) for i in flipped_names] + return flip_indices + + +def get_fed_loss_cls_weights(dataset_names: Union[str, List[str]], freq_weight_power=1.0): + """ + Get frequency weight for each class sorted by class id. + We now calcualte freqency weight using image_count to the power freq_weight_power. + + Args: + dataset_names: list of dataset names + freq_weight_power: power value + """ + if isinstance(dataset_names, str): + dataset_names = [dataset_names] + + check_metadata_consistency("class_image_count", dataset_names) + + meta = MetadataCatalog.get(dataset_names[0]) + class_freq_meta = meta.class_image_count + class_freq = torch.tensor( + [c["image_count"] for c in sorted(class_freq_meta, key=lambda x: x["id"])] + ) + class_freq_weight = class_freq.float() ** freq_weight_power + return class_freq_weight + + +def gen_crop_transform_with_instance(crop_size, image_size, instance): + """ + Generate a CropTransform so that the cropping region contains + the center of the given instance. + + Args: + crop_size (tuple): h, w in pixels + image_size (tuple): h, w + instance (dict): an annotation dict of one instance, in Detectron2's + dataset format. + """ + crop_size = np.asarray(crop_size, dtype=np.int32) + bbox = BoxMode.convert(instance["bbox"], instance["bbox_mode"], BoxMode.XYXY_ABS) + center_yx = (bbox[1] + bbox[3]) * 0.5, (bbox[0] + bbox[2]) * 0.5 + assert ( + image_size[0] >= center_yx[0] and image_size[1] >= center_yx[1] + ), "The annotation bounding box is outside of the image!" + assert ( + image_size[0] >= crop_size[0] and image_size[1] >= crop_size[1] + ), "Crop size is larger than image size!" + + min_yx = np.maximum(np.floor(center_yx).astype(np.int32) - crop_size, 0) + max_yx = np.maximum(np.asarray(image_size, dtype=np.int32) - crop_size, 0) + max_yx = np.minimum(max_yx, np.ceil(center_yx).astype(np.int32)) + + y0 = np.random.randint(min_yx[0], max_yx[0] + 1) + x0 = np.random.randint(min_yx[1], max_yx[1] + 1) + return T.CropTransform(x0, y0, crop_size[1], crop_size[0]) + + +def check_metadata_consistency(key, dataset_names): + """ + Check that the datasets have consistent metadata. + + Args: + key (str): a metadata key + dataset_names (list[str]): a list of dataset names + + Raises: + AttributeError: if the key does not exist in the metadata + ValueError: if the given datasets do not have the same metadata values defined by key + """ + if len(dataset_names) == 0: + return + logger = logging.getLogger(__name__) + entries_per_dataset = [getattr(MetadataCatalog.get(d), key) for d in dataset_names] + for idx, entry in enumerate(entries_per_dataset): + if entry != entries_per_dataset[0]: + logger.error( + "Metadata '{}' for dataset '{}' is '{}'".format(key, dataset_names[idx], str(entry)) + ) + logger.error( + "Metadata '{}' for dataset '{}' is '{}'".format( + key, dataset_names[0], str(entries_per_dataset[0]) + ) + ) + raise ValueError("Datasets have different metadata '{}'!".format(key)) + + +def build_augmentation(cfg, is_train): + """ + Create a list of default :class:`Augmentation` from config. + Now it includes resizing and flipping. + + Returns: + list[Augmentation] + """ + if is_train: + min_size = cfg.INPUT.MIN_SIZE_TRAIN + max_size = cfg.INPUT.MAX_SIZE_TRAIN + sample_style = cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING + else: + min_size = cfg.INPUT.MIN_SIZE_TEST + max_size = cfg.INPUT.MAX_SIZE_TEST + sample_style = "choice" + augmentation = [T.ResizeShortestEdge(min_size, max_size, sample_style)] + if is_train and cfg.INPUT.RANDOM_FLIP != "none": + augmentation.append( + T.RandomFlip( + horizontal=cfg.INPUT.RANDOM_FLIP == "horizontal", + vertical=cfg.INPUT.RANDOM_FLIP == "vertical", + ) + ) + return augmentation + + +build_transform_gen = build_augmentation +""" +Alias for backward-compatibility. +""" diff --git a/custom_detectron2/data/samplers/__init__.py b/custom_detectron2/data/samplers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..85c9f1a9df8a4038fbd4246239b699402e382309 --- /dev/null +++ b/custom_detectron2/data/samplers/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .distributed_sampler import ( + InferenceSampler, + RandomSubsetTrainingSampler, + RepeatFactorTrainingSampler, + TrainingSampler, +) + +from .grouped_batch_sampler import GroupedBatchSampler + +__all__ = [ + "GroupedBatchSampler", + "TrainingSampler", + "RandomSubsetTrainingSampler", + "InferenceSampler", + "RepeatFactorTrainingSampler", +] diff --git a/custom_detectron2/data/samplers/distributed_sampler.py b/custom_detectron2/data/samplers/distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..eedc4c3ba88886c614a26dbc38fa1860bb503505 --- /dev/null +++ b/custom_detectron2/data/samplers/distributed_sampler.py @@ -0,0 +1,278 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import math +from collections import defaultdict +from typing import Optional +import torch +from torch.utils.data.sampler import Sampler + +from custom_detectron2.utils import comm + +logger = logging.getLogger(__name__) + + +class TrainingSampler(Sampler): + """ + In training, we only care about the "infinite stream" of training data. + So this sampler produces an infinite stream of indices and + all workers cooperate to correctly shuffle the indices and sample different indices. + + The samplers in each worker effectively produces `indices[worker_id::num_workers]` + where `indices` is an infinite stream of indices consisting of + `shuffle(range(size)) + shuffle(range(size)) + ...` (if shuffle is True) + or `range(size) + range(size) + ...` (if shuffle is False) + + Note that this sampler does not shard based on pytorch DataLoader worker id. + A sampler passed to pytorch DataLoader is used only with map-style dataset + and will not be executed inside workers. + But if this sampler is used in a way that it gets execute inside a dataloader + worker, then extra work needs to be done to shard its outputs based on worker id. + This is required so that workers don't produce identical data. + :class:`ToIterableDataset` implements this logic. + This note is true for all samplers in detectron2. + """ + + def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + if not isinstance(size, int): + raise TypeError(f"TrainingSampler(size=) expects an int. Got type {type(size)}.") + if size <= 0: + raise ValueError(f"TrainingSampler(size=) expects a positive int. Got {size}.") + self._size = size + self._shuffle = shuffle + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + if self._shuffle: + yield from torch.randperm(self._size, generator=g).tolist() + else: + yield from torch.arange(self._size).tolist() + + +class RandomSubsetTrainingSampler(TrainingSampler): + """ + Similar to TrainingSampler, but only sample a random subset of indices. + This is useful when you want to estimate the accuracy vs data-number curves by + training the model with different subset_ratio. + """ + + def __init__( + self, + size: int, + subset_ratio: float, + shuffle: bool = True, + seed_shuffle: Optional[int] = None, + seed_subset: Optional[int] = None, + ): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + subset_ratio (float): the ratio of subset data to sample from the underlying dataset + shuffle (bool): whether to shuffle the indices or not + seed_shuffle (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + seed_subset (int): the seed to randomize the subset to be sampled. + Must be the same across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + super().__init__(size=size, shuffle=shuffle, seed=seed_shuffle) + + assert 0.0 < subset_ratio <= 1.0 + self._size_subset = int(size * subset_ratio) + assert self._size_subset > 0 + if seed_subset is None: + seed_subset = comm.shared_random_seed() + self._seed_subset = int(seed_subset) + + # randomly generate the subset indexes to be sampled from + g = torch.Generator() + g.manual_seed(self._seed_subset) + indexes_randperm = torch.randperm(self._size, generator=g) + self._indexes_subset = indexes_randperm[: self._size_subset] + + logger.info("Using RandomSubsetTrainingSampler......") + logger.info(f"Randomly sample {self._size_subset} data from the original {self._size} data") + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) # self._seed equals seed_shuffle from __init__() + while True: + if self._shuffle: + # generate a random permutation to shuffle self._indexes_subset + randperm = torch.randperm(self._size_subset, generator=g) + yield from self._indexes_subset[randperm].tolist() + else: + yield from self._indexes_subset.tolist() + + +class RepeatFactorTrainingSampler(Sampler): + """ + Similar to TrainingSampler, but a sample may appear more times than others based + on its "repeat factor". This is suitable for training on class imbalanced datasets like LVIS. + """ + + def __init__(self, repeat_factors, *, shuffle=True, seed=None): + """ + Args: + repeat_factors (Tensor): a float vector, the repeat factor for each indice. When it's + full of ones, it is equivalent to ``TrainingSampler(len(repeat_factors), ...)``. + shuffle (bool): whether to shuffle the indices or not + seed (int): the initial seed of the shuffle. Must be the same + across all workers. If None, will use a random seed shared + among workers (require synchronization among all workers). + """ + self._shuffle = shuffle + if seed is None: + seed = comm.shared_random_seed() + self._seed = int(seed) + + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + + # Split into whole number (_int_part) and fractional (_frac_part) parts. + self._int_part = torch.trunc(repeat_factors) + self._frac_part = repeat_factors - self._int_part + + @staticmethod + def repeat_factors_from_category_frequency(dataset_dicts, repeat_thresh): + """ + Compute (fractional) per-image repeat factors based on category frequency. + The repeat factor for an image is a function of the frequency of the rarest + category labeled in that image. The "frequency of category c" in [0, 1] is defined + as the fraction of images in the training set (without repeats) in which category c + appears. + See :paper:`lvis` (>= v2) Appendix B.2. + + Args: + dataset_dicts (list[dict]): annotations in Detectron2 dataset format. + repeat_thresh (float): frequency threshold below which data is repeated. + If the frequency is half of `repeat_thresh`, the image will be + repeated twice. + + Returns: + torch.Tensor: + the i-th element is the repeat factor for the dataset image at index i. + """ + # 1. For each category c, compute the fraction of images that contain it: f(c) + category_freq = defaultdict(int) + for dataset_dict in dataset_dicts: # For each image (without repeats) + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + for cat_id in cat_ids: + category_freq[cat_id] += 1 + num_images = len(dataset_dicts) + for k, v in category_freq.items(): + category_freq[k] = v / num_images + + # 2. For each category c, compute the category-level repeat factor: + # r(c) = max(1, sqrt(t / f(c))) + category_rep = { + cat_id: max(1.0, math.sqrt(repeat_thresh / cat_freq)) + for cat_id, cat_freq in category_freq.items() + } + + # 3. For each image I, compute the image-level repeat factor: + # r(I) = max_{c in I} r(c) + rep_factors = [] + for dataset_dict in dataset_dicts: + cat_ids = {ann["category_id"] for ann in dataset_dict["annotations"]} + rep_factor = max({category_rep[cat_id] for cat_id in cat_ids}, default=1.0) + rep_factors.append(rep_factor) + + return torch.tensor(rep_factors, dtype=torch.float32) + + def _get_epoch_indices(self, generator): + """ + Create a list of dataset indices (with repeats) to use for one epoch. + + Args: + generator (torch.Generator): pseudo random number generator used for + stochastic rounding. + + Returns: + torch.Tensor: list of dataset indices to use in one epoch. Each index + is repeated based on its calculated repeat factor. + """ + # Since repeat factors are fractional, we use stochastic rounding so + # that the target repeat factor is achieved in expectation over the + # course of training + rands = torch.rand(len(self._frac_part), generator=generator) + rep_factors = self._int_part + (rands < self._frac_part).float() + # Construct a list of indices in which we repeat images as specified + indices = [] + for dataset_index, rep_factor in enumerate(rep_factors): + indices.extend([dataset_index] * int(rep_factor.item())) + return torch.tensor(indices, dtype=torch.int64) + + def __iter__(self): + start = self._rank + yield from itertools.islice(self._infinite_indices(), start, None, self._world_size) + + def _infinite_indices(self): + g = torch.Generator() + g.manual_seed(self._seed) + while True: + # Sample indices with repeats determined by stochastic rounding; each + # "epoch" may have a slightly different size due to the rounding. + indices = self._get_epoch_indices(g) + if self._shuffle: + randperm = torch.randperm(len(indices), generator=g) + yield from indices[randperm].tolist() + else: + yield from indices.tolist() + + +class InferenceSampler(Sampler): + """ + Produce indices for inference across all workers. + Inference needs to run on the __exact__ set of samples, + therefore when the total number of samples is not divisible by the number of workers, + this sampler produces different number of samples on different workers. + """ + + def __init__(self, size: int): + """ + Args: + size (int): the total number of data of the underlying dataset to sample from + """ + self._size = size + assert size > 0 + self._rank = comm.get_rank() + self._world_size = comm.get_world_size() + self._local_indices = self._get_local_indices(size, self._world_size, self._rank) + + @staticmethod + def _get_local_indices(total_size, world_size, rank): + shard_size = total_size // world_size + left = total_size % world_size + shard_sizes = [shard_size + int(r < left) for r in range(world_size)] + + begin = sum(shard_sizes[:rank]) + end = min(sum(shard_sizes[: rank + 1]), total_size) + return range(begin, end) + + def __iter__(self): + yield from self._local_indices + + def __len__(self): + return len(self._local_indices) diff --git a/custom_detectron2/data/samplers/grouped_batch_sampler.py b/custom_detectron2/data/samplers/grouped_batch_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..5b247730aacd04dd0c752664acde3257c4eddd71 --- /dev/null +++ b/custom_detectron2/data/samplers/grouped_batch_sampler.py @@ -0,0 +1,47 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from torch.utils.data.sampler import BatchSampler, Sampler + + +class GroupedBatchSampler(BatchSampler): + """ + Wraps another sampler to yield a mini-batch of indices. + It enforces that the batch only contain elements from the same group. + It also tries to provide mini-batches which follows an ordering which is + as close as possible to the ordering from the original sampler. + """ + + def __init__(self, sampler, group_ids, batch_size): + """ + Args: + sampler (Sampler): Base sampler. + group_ids (list[int]): If the sampler produces indices in range [0, N), + `group_ids` must be a list of `N` ints which contains the group id of each sample. + The group ids must be a set of integers in the range [0, num_groups). + batch_size (int): Size of mini-batch. + """ + if not isinstance(sampler, Sampler): + raise ValueError( + "sampler should be an instance of " + "torch.utils.data.Sampler, but got sampler={}".format(sampler) + ) + self.sampler = sampler + self.group_ids = np.asarray(group_ids) + assert self.group_ids.ndim == 1 + self.batch_size = batch_size + groups = np.unique(self.group_ids).tolist() + + # buffer the indices of each group until batch size is reached + self.buffer_per_group = {k: [] for k in groups} + + def __iter__(self): + for idx in self.sampler: + group_id = self.group_ids[idx] + group_buffer = self.buffer_per_group[group_id] + group_buffer.append(idx) + if len(group_buffer) == self.batch_size: + yield group_buffer[:] # yield a copy of the list + del group_buffer[:] + + def __len__(self): + raise NotImplementedError("len() of GroupedBatchSampler is not well-defined.") diff --git a/custom_detectron2/data/transforms/__init__.py b/custom_detectron2/data/transforms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d3721fc22e5aa3b4100458235f99c45498646d --- /dev/null +++ b/custom_detectron2/data/transforms/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from fvcore.transforms.transform import Transform, TransformList # order them first +from fvcore.transforms.transform import * +from .transform import * +from .augmentation import * +from .augmentation_impl import * + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +from custom_detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/custom_detectron2/data/transforms/augmentation.py b/custom_detectron2/data/transforms/augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..63dd41aef658c9b51c7246880399405a029c5580 --- /dev/null +++ b/custom_detectron2/data/transforms/augmentation.py @@ -0,0 +1,380 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import inspect +import numpy as np +import pprint +from typing import Any, List, Optional, Tuple, Union +from fvcore.transforms.transform import Transform, TransformList + +""" +See "Data Augmentation" tutorial for an overview of the system: +https://detectron2.readthedocs.io/tutorials/augmentation.html +""" + + +__all__ = [ + "Augmentation", + "AugmentationList", + "AugInput", + "TransformGen", + "apply_transform_gens", + "StandardAugInput", + "apply_augmentations", +] + + +def _check_img_dtype(img): + assert isinstance(img, np.ndarray), "[Augmentation] Needs an numpy array, but got a {}!".format( + type(img) + ) + assert not isinstance(img.dtype, np.integer) or ( + img.dtype == np.uint8 + ), "[Augmentation] Got image of type {}, use uint8 or floating points instead!".format( + img.dtype + ) + assert img.ndim in [2, 3], img.ndim + + +def _get_aug_input_args(aug, aug_input) -> List[Any]: + """ + Get the arguments to be passed to ``aug.get_transform`` from the input ``aug_input``. + """ + if aug.input_args is None: + # Decide what attributes are needed automatically + prms = list(inspect.signature(aug.get_transform).parameters.items()) + # The default behavior is: if there is one parameter, then its "image" + # (work automatically for majority of use cases, and also avoid BC breaking), + # Otherwise, use the argument names. + if len(prms) == 1: + names = ("image",) + else: + names = [] + for name, prm in prms: + if prm.kind in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ): + raise TypeError( + f""" \ +The default implementation of `{type(aug)}.__call__` does not allow \ +`{type(aug)}.get_transform` to use variable-length arguments (*args, **kwargs)! \ +If arguments are unknown, reimplement `__call__` instead. \ +""" + ) + names.append(name) + aug.input_args = tuple(names) + + args = [] + for f in aug.input_args: + try: + args.append(getattr(aug_input, f)) + except AttributeError as e: + raise AttributeError( + f"{type(aug)}.get_transform needs input attribute '{f}', " + f"but it is not an attribute of {type(aug_input)}!" + ) from e + return args + + +class Augmentation: + """ + Augmentation defines (often random) policies/strategies to generate :class:`Transform` + from data. It is often used for pre-processing of input data. + + A "policy" that generates a :class:`Transform` may, in the most general case, + need arbitrary information from input data in order to determine what transforms + to apply. Therefore, each :class:`Augmentation` instance defines the arguments + needed by its :meth:`get_transform` method. When called with the positional arguments, + the :meth:`get_transform` method executes the policy. + + Note that :class:`Augmentation` defines the policies to create a :class:`Transform`, + but not how to execute the actual transform operations to those data. + Its :meth:`__call__` method will use :meth:`AugInput.transform` to execute the transform. + + The returned `Transform` object is meant to describe deterministic transformation, which means + it can be re-applied on associated data, e.g. the geometry of an image and its segmentation + masks need to be transformed together. + (If such re-application is not needed, then determinism is not a crucial requirement.) + """ + + input_args: Optional[Tuple[str]] = None + """ + Stores the attribute names needed by :meth:`get_transform`, e.g. ``("image", "sem_seg")``. + By default, it is just a tuple of argument names in :meth:`self.get_transform`, which often only + contain "image". As long as the argument name convention is followed, there is no need for + users to touch this attribute. + """ + + def _init(self, params=None): + if params: + for k, v in params.items(): + if k != "self" and not k.startswith("_"): + setattr(self, k, v) + + def get_transform(self, *args) -> Transform: + """ + Execute the policy based on input data, and decide what transform to apply to inputs. + + Args: + args: Any fixed-length positional arguments. By default, the name of the arguments + should exist in the :class:`AugInput` to be used. + + Returns: + Transform: Returns the deterministic transform to apply to the input. + + Examples: + :: + class MyAug: + # if a policy needs to know both image and semantic segmentation + def get_transform(image, sem_seg) -> T.Transform: + pass + tfm: Transform = MyAug().get_transform(image, sem_seg) + new_image = tfm.apply_image(image) + + Notes: + Users can freely use arbitrary new argument names in custom + :meth:`get_transform` method, as long as they are available in the + input data. In detectron2 we use the following convention: + + * image: (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or + floating point in range [0, 1] or [0, 255]. + * boxes: (N,4) ndarray of float32. It represents the instance bounding boxes + of N instances. Each is in XYXY format in unit of absolute coordinates. + * sem_seg: (H,W) ndarray of type uint8. Each element is an integer label of pixel. + + We do not specify convention for other types and do not include builtin + :class:`Augmentation` that uses other types in detectron2. + """ + raise NotImplementedError + + def __call__(self, aug_input) -> Transform: + """ + Augment the given `aug_input` **in-place**, and return the transform that's used. + + This method will be called to apply the augmentation. In most augmentation, it + is enough to use the default implementation, which calls :meth:`get_transform` + using the inputs. But a subclass can overwrite it to have more complicated logic. + + Args: + aug_input (AugInput): an object that has attributes needed by this augmentation + (defined by ``self.get_transform``). Its ``transform`` method will be called + to in-place transform it. + + Returns: + Transform: the transform that is applied on the input. + """ + args = _get_aug_input_args(self, aug_input) + tfm = self.get_transform(*args) + assert isinstance(tfm, (Transform, TransformList)), ( + f"{type(self)}.get_transform must return an instance of Transform! " + f"Got {type(tfm)} instead." + ) + aug_input.transform(tfm) + return tfm + + def _rand_range(self, low=1.0, high=None, size=None): + """ + Uniform float random number between low and high. + """ + if high is None: + low, high = 0, low + if size is None: + size = [] + return np.random.uniform(low, high, size) + + def __repr__(self): + """ + Produce something like: + "MyAugmentation(field1={self.field1}, field2={self.field2})" + """ + try: + sig = inspect.signature(self.__init__) + classname = type(self).__name__ + argstr = [] + for name, param in sig.parameters.items(): + assert ( + param.kind != param.VAR_POSITIONAL and param.kind != param.VAR_KEYWORD + ), "The default __repr__ doesn't support *args or **kwargs" + assert hasattr(self, name), ( + "Attribute {} not found! " + "Default __repr__ only works if attributes match the constructor.".format(name) + ) + attr = getattr(self, name) + default = param.default + if default is attr: + continue + attr_str = pprint.pformat(attr) + if "\n" in attr_str: + # don't show it if pformat decides to use >1 lines + attr_str = "..." + argstr.append("{}={}".format(name, attr_str)) + return "{}({})".format(classname, ", ".join(argstr)) + except AssertionError: + return super().__repr__() + + __str__ = __repr__ + + +class _TransformToAug(Augmentation): + def __init__(self, tfm: Transform): + self.tfm = tfm + + def get_transform(self, *args): + return self.tfm + + def __repr__(self): + return repr(self.tfm) + + __str__ = __repr__ + + +def _transform_to_aug(tfm_or_aug): + """ + Wrap Transform into Augmentation. + Private, used internally to implement augmentations. + """ + assert isinstance(tfm_or_aug, (Transform, Augmentation)), tfm_or_aug + if isinstance(tfm_or_aug, Augmentation): + return tfm_or_aug + else: + return _TransformToAug(tfm_or_aug) + + +class AugmentationList(Augmentation): + """ + Apply a sequence of augmentations. + + It has ``__call__`` method to apply the augmentations. + + Note that :meth:`get_transform` method is impossible (will throw error if called) + for :class:`AugmentationList`, because in order to apply a sequence of augmentations, + the kth augmentation must be applied first, to provide inputs needed by the (k+1)th + augmentation. + """ + + def __init__(self, augs): + """ + Args: + augs (list[Augmentation or Transform]): + """ + super().__init__() + self.augs = [_transform_to_aug(x) for x in augs] + + def __call__(self, aug_input) -> TransformList: + tfms = [] + for x in self.augs: + tfm = x(aug_input) + tfms.append(tfm) + return TransformList(tfms) + + def __repr__(self): + msgs = [str(x) for x in self.augs] + return "AugmentationList[{}]".format(", ".join(msgs)) + + __str__ = __repr__ + + +class AugInput: + """ + Input that can be used with :meth:`Augmentation.__call__`. + This is a standard implementation for the majority of use cases. + This class provides the standard attributes **"image", "boxes", "sem_seg"** + defined in :meth:`__init__` and they may be needed by different augmentations. + Most augmentation policies do not need attributes beyond these three. + + After applying augmentations to these attributes (using :meth:`AugInput.transform`), + the returned transforms can then be used to transform other data structures that users have. + + Examples: + :: + input = AugInput(image, boxes=boxes) + tfms = augmentation(input) + transformed_image = input.image + transformed_boxes = input.boxes + transformed_other_data = tfms.apply_other(other_data) + + An extended project that works with new data types may implement augmentation policies + that need other inputs. An algorithm may need to transform inputs in a way different + from the standard approach defined in this class. In those rare situations, users can + implement a class similar to this class, that satify the following condition: + + * The input must provide access to these data in the form of attribute access + (``getattr``). For example, if an :class:`Augmentation` to be applied needs "image" + and "sem_seg" arguments, its input must have the attribute "image" and "sem_seg". + * The input must have a ``transform(tfm: Transform) -> None`` method which + in-place transforms all its attributes. + """ + + # TODO maybe should support more builtin data types here + def __init__( + self, + image: np.ndarray, + *, + boxes: Optional[np.ndarray] = None, + sem_seg: Optional[np.ndarray] = None, + ): + """ + Args: + image (ndarray): (H,W) or (H,W,C) ndarray of type uint8 in range [0, 255], or + floating point in range [0, 1] or [0, 255]. The meaning of C is up + to users. + boxes (ndarray or None): Nx4 float32 boxes in XYXY_ABS mode + sem_seg (ndarray or None): HxW uint8 semantic segmentation mask. Each element + is an integer label of pixel. + """ + _check_img_dtype(image) + self.image = image + self.boxes = boxes + self.sem_seg = sem_seg + + def transform(self, tfm: Transform) -> None: + """ + In-place transform all attributes of this class. + + By "in-place", it means after calling this method, accessing an attribute such + as ``self.image`` will return transformed data. + """ + self.image = tfm.apply_image(self.image) + if self.boxes is not None: + self.boxes = tfm.apply_box(self.boxes) + if self.sem_seg is not None: + self.sem_seg = tfm.apply_segmentation(self.sem_seg) + + def apply_augmentations( + self, augmentations: List[Union[Augmentation, Transform]] + ) -> TransformList: + """ + Equivalent of ``AugmentationList(augmentations)(self)`` + """ + return AugmentationList(augmentations)(self) + + +def apply_augmentations(augmentations: List[Union[Transform, Augmentation]], inputs): + """ + Use ``T.AugmentationList(augmentations)(inputs)`` instead. + """ + if isinstance(inputs, np.ndarray): + # handle the common case of image-only Augmentation, also for backward compatibility + image_only = True + inputs = AugInput(inputs) + else: + image_only = False + tfms = inputs.apply_augmentations(augmentations) + return inputs.image if image_only else inputs, tfms + + +apply_transform_gens = apply_augmentations +""" +Alias for backward-compatibility. +""" + +TransformGen = Augmentation +""" +Alias for Augmentation, since it is something that generates :class:`Transform`s +""" + +StandardAugInput = AugInput +""" +Alias for compatibility. It's not worth the complexity to have two classes. +""" diff --git a/custom_detectron2/data/transforms/augmentation_impl.py b/custom_detectron2/data/transforms/augmentation_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..228b16318e8d4aff3c759ddd7367a1e92e6cbc07 --- /dev/null +++ b/custom_detectron2/data/transforms/augmentation_impl.py @@ -0,0 +1,736 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Implement many useful :class:`Augmentation`. +""" +import numpy as np +import sys +from numpy import random +from typing import Tuple +import torch +from fvcore.transforms.transform import ( + BlendTransform, + CropTransform, + HFlipTransform, + NoOpTransform, + PadTransform, + Transform, + TransformList, + VFlipTransform, +) +from PIL import Image + +from custom_detectron2.structures import Boxes, pairwise_iou + +from .augmentation import Augmentation, _transform_to_aug +from .transform import ExtentTransform, ResizeTransform, RotationTransform + +__all__ = [ + "FixedSizeCrop", + "RandomApply", + "RandomBrightness", + "RandomContrast", + "RandomCrop", + "RandomExtent", + "RandomFlip", + "RandomSaturation", + "RandomLighting", + "RandomRotation", + "Resize", + "ResizeScale", + "ResizeShortestEdge", + "RandomCrop_CategoryAreaConstraint", + "RandomResize", + "MinIoURandomCrop", +] + + +class RandomApply(Augmentation): + """ + Randomly apply an augmentation with a given probability. + """ + + def __init__(self, tfm_or_aug, prob=0.5): + """ + Args: + tfm_or_aug (Transform, Augmentation): the transform or augmentation + to be applied. It can either be a `Transform` or `Augmentation` + instance. + prob (float): probability between 0.0 and 1.0 that + the wrapper transformation is applied + """ + super().__init__() + self.aug = _transform_to_aug(tfm_or_aug) + assert 0.0 <= prob <= 1.0, f"Probablity must be between 0.0 and 1.0 (given: {prob})" + self.prob = prob + + def get_transform(self, *args): + do = self._rand_range() < self.prob + if do: + return self.aug.get_transform(*args) + else: + return NoOpTransform() + + def __call__(self, aug_input): + do = self._rand_range() < self.prob + if do: + return self.aug(aug_input) + else: + return NoOpTransform() + + +class RandomFlip(Augmentation): + """ + Flip the image horizontally or vertically with the given probability. + """ + + def __init__(self, prob=0.5, *, horizontal=True, vertical=False): + """ + Args: + prob (float): probability of flip. + horizontal (boolean): whether to apply horizontal flipping + vertical (boolean): whether to apply vertical flipping + """ + super().__init__() + + if horizontal and vertical: + raise ValueError("Cannot do both horiz and vert. Please use two Flip instead.") + if not horizontal and not vertical: + raise ValueError("At least one of horiz or vert has to be True!") + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + do = self._rand_range() < self.prob + if do: + if self.horizontal: + return HFlipTransform(w) + elif self.vertical: + return VFlipTransform(h) + else: + return NoOpTransform() + + +class Resize(Augmentation): + """Resize image to a fixed target size""" + + def __init__(self, shape, interp=Image.BILINEAR): + """ + Args: + shape: (h, w) tuple or a int + interp: PIL interpolation method + """ + if isinstance(shape, int): + shape = (shape, shape) + shape = tuple(shape) + self._init(locals()) + + def get_transform(self, image): + return ResizeTransform( + image.shape[0], image.shape[1], self.shape[0], self.shape[1], self.interp + ) + + +class ResizeShortestEdge(Augmentation): + """ + Resize the image while keeping the aspect ratio unchanged. + It attempts to scale the shorter edge to the given `short_edge_length`, + as long as the longer edge does not exceed `max_size`. + If `max_size` is reached, then downscale so that the longer edge does not exceed max_size. + """ + + @torch.jit.unused + def __init__( + self, short_edge_length, max_size=sys.maxsize, sample_style="range", interp=Image.BILINEAR + ): + """ + Args: + short_edge_length (list[int]): If ``sample_style=="range"``, + a [min, max] interval from which to sample the shortest edge length. + If ``sample_style=="choice"``, a list of shortest edge lengths to sample from. + max_size (int): maximum allowed longest edge length. + sample_style (str): either "range" or "choice". + """ + super().__init__() + assert sample_style in ["range", "choice"], sample_style + + self.is_range = sample_style == "range" + if isinstance(short_edge_length, int): + short_edge_length = (short_edge_length, short_edge_length) + if self.is_range: + assert len(short_edge_length) == 2, ( + "short_edge_length must be two values using 'range' sample style." + f" Got {short_edge_length}!" + ) + self._init(locals()) + + @torch.jit.unused + def get_transform(self, image): + h, w = image.shape[:2] + if self.is_range: + size = np.random.randint(self.short_edge_length[0], self.short_edge_length[1] + 1) + else: + size = np.random.choice(self.short_edge_length) + if size == 0: + return NoOpTransform() + + newh, neww = ResizeShortestEdge.get_output_shape(h, w, size, self.max_size) + return ResizeTransform(h, w, newh, neww, self.interp) + + @staticmethod + def get_output_shape( + oldh: int, oldw: int, short_edge_length: int, max_size: int + ) -> Tuple[int, int]: + """ + Compute the output size given input size and target short edge length. + """ + h, w = oldh, oldw + size = short_edge_length * 1.0 + scale = size / min(h, w) + if h < w: + newh, neww = size, scale * w + else: + newh, neww = scale * h, size + if max(newh, neww) > max_size: + scale = max_size * 1.0 / max(newh, neww) + newh = newh * scale + neww = neww * scale + neww = int(neww + 0.5) + newh = int(newh + 0.5) + return (newh, neww) + + +class ResizeScale(Augmentation): + """ + Takes target size as input and randomly scales the given target size between `min_scale` + and `max_scale`. It then scales the input image such that it fits inside the scaled target + box, keeping the aspect ratio constant. + This implements the resize part of the Google's 'resize_and_crop' data augmentation: + https://github.com/tensorflow/tpu/blob/master/models/official/detection/utils/input_utils.py#L127 + """ + + def __init__( + self, + min_scale: float, + max_scale: float, + target_height: int, + target_width: int, + interp: int = Image.BILINEAR, + ): + """ + Args: + min_scale: minimum image scale range. + max_scale: maximum image scale range. + target_height: target image height. + target_width: target image width. + interp: image interpolation method. + """ + super().__init__() + self._init(locals()) + + def _get_resize(self, image: np.ndarray, scale: float) -> Transform: + input_size = image.shape[:2] + + # Compute new target size given a scale. + target_size = (self.target_height, self.target_width) + target_scale_size = np.multiply(target_size, scale) + + # Compute actual rescaling applied to input image and output size. + output_scale = np.minimum( + target_scale_size[0] / input_size[0], target_scale_size[1] / input_size[1] + ) + output_size = np.round(np.multiply(input_size, output_scale)).astype(int) + + return ResizeTransform( + input_size[0], input_size[1], output_size[0], output_size[1], self.interp + ) + + def get_transform(self, image: np.ndarray) -> Transform: + random_scale = np.random.uniform(self.min_scale, self.max_scale) + return self._get_resize(image, random_scale) + + +class RandomRotation(Augmentation): + """ + This method returns a copy of this image, rotated the given + number of degrees counter clockwise around the given center. + """ + + def __init__(self, angle, expand=True, center=None, sample_style="range", interp=None): + """ + Args: + angle (list[float]): If ``sample_style=="range"``, + a [min, max] interval from which to sample the angle (in degrees). + If ``sample_style=="choice"``, a list of angles to sample from + expand (bool): choose if the image should be resized to fit the whole + rotated image (default), or simply cropped + center (list[[float, float]]): If ``sample_style=="range"``, + a [[minx, miny], [maxx, maxy]] relative interval from which to sample the center, + [0, 0] being the top left of the image and [1, 1] the bottom right. + If ``sample_style=="choice"``, a list of centers to sample from + Default: None, which means that the center of rotation is the center of the image + center has no effect if expand=True because it only affects shifting + """ + super().__init__() + assert sample_style in ["range", "choice"], sample_style + self.is_range = sample_style == "range" + if isinstance(angle, (float, int)): + angle = (angle, angle) + if center is not None and isinstance(center[0], (float, int)): + center = (center, center) + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + center = None + if self.is_range: + angle = np.random.uniform(self.angle[0], self.angle[1]) + if self.center is not None: + center = ( + np.random.uniform(self.center[0][0], self.center[1][0]), + np.random.uniform(self.center[0][1], self.center[1][1]), + ) + else: + angle = np.random.choice(self.angle) + if self.center is not None: + center = np.random.choice(self.center) + + if center is not None: + center = (w * center[0], h * center[1]) # Convert to absolute coordinates + + if angle % 360 == 0: + return NoOpTransform() + + return RotationTransform(h, w, angle, expand=self.expand, center=center, interp=self.interp) + + +class FixedSizeCrop(Augmentation): + """ + If `crop_size` is smaller than the input image size, then it uses a random crop of + the crop size. If `crop_size` is larger than the input image size, then it pads + the right and the bottom of the image to the crop size if `pad` is True, otherwise + it returns the smaller image. + """ + + def __init__( + self, + crop_size: Tuple[int], + pad: bool = True, + pad_value: float = 128.0, + seg_pad_value: int = 255, + ): + """ + Args: + crop_size: target image (height, width). + pad: if True, will pad images smaller than `crop_size` up to `crop_size` + pad_value: the padding value to the image. + seg_pad_value: the padding value to the segmentation mask. + """ + super().__init__() + self._init(locals()) + + def _get_crop(self, image: np.ndarray) -> Transform: + # Compute the image scale and scaled size. + input_size = image.shape[:2] + output_size = self.crop_size + + # Add random crop if the image is scaled up. + max_offset = np.subtract(input_size, output_size) + max_offset = np.maximum(max_offset, 0) + offset = np.multiply(max_offset, np.random.uniform(0.0, 1.0)) + offset = np.round(offset).astype(int) + return CropTransform( + offset[1], offset[0], output_size[1], output_size[0], input_size[1], input_size[0] + ) + + def _get_pad(self, image: np.ndarray) -> Transform: + # Compute the image scale and scaled size. + input_size = image.shape[:2] + output_size = self.crop_size + + # Add padding if the image is scaled down. + pad_size = np.subtract(output_size, input_size) + pad_size = np.maximum(pad_size, 0) + original_size = np.minimum(input_size, output_size) + return PadTransform( + 0, + 0, + pad_size[1], + pad_size[0], + original_size[1], + original_size[0], + self.pad_value, + self.seg_pad_value, + ) + + def get_transform(self, image: np.ndarray) -> TransformList: + transforms = [self._get_crop(image)] + if self.pad: + transforms.append(self._get_pad(image)) + return TransformList(transforms) + + +class RandomCrop(Augmentation): + """ + Randomly crop a rectangle region out of an image. + """ + + def __init__(self, crop_type: str, crop_size): + """ + Args: + crop_type (str): one of "relative_range", "relative", "absolute", "absolute_range". + crop_size (tuple[float, float]): two floats, explained below. + + - "relative": crop a (H * crop_size[0], W * crop_size[1]) region from an input image of + size (H, W). crop size should be in (0, 1] + - "relative_range": uniformly sample two values from [crop_size[0], 1] + and [crop_size[1]], 1], and use them as in "relative" crop type. + - "absolute" crop a (crop_size[0], crop_size[1]) region from input image. + crop_size must be smaller than the input image size. + - "absolute_range", for an input of size (H, W), uniformly sample H_crop in + [crop_size[0], min(H, crop_size[1])] and W_crop in [crop_size[0], min(W, crop_size[1])]. + Then crop a region (H_crop, W_crop). + """ + # TODO style of relative_range and absolute_range are not consistent: + # one takes (h, w) but another takes (min, max) + super().__init__() + assert crop_type in ["relative_range", "relative", "absolute", "absolute_range"] + self._init(locals()) + + def get_transform(self, image): + h, w = image.shape[:2] + croph, cropw = self.get_crop_size((h, w)) + assert h >= croph and w >= cropw, "Shape computation in {} has bugs.".format(self) + h0 = np.random.randint(h - croph + 1) + w0 = np.random.randint(w - cropw + 1) + return CropTransform(w0, h0, cropw, croph) + + def get_crop_size(self, image_size): + """ + Args: + image_size (tuple): height, width + + Returns: + crop_size (tuple): height, width in absolute pixels + """ + h, w = image_size + if self.crop_type == "relative": + ch, cw = self.crop_size + return int(h * ch + 0.5), int(w * cw + 0.5) + elif self.crop_type == "relative_range": + crop_size = np.asarray(self.crop_size, dtype=np.float32) + ch, cw = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * ch + 0.5), int(w * cw + 0.5) + elif self.crop_type == "absolute": + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == "absolute_range": + assert self.crop_size[0] <= self.crop_size[1] + ch = np.random.randint(min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1) + cw = np.random.randint(min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1) + return ch, cw + else: + raise NotImplementedError("Unknown crop type {}".format(self.crop_type)) + + +class RandomCrop_CategoryAreaConstraint(Augmentation): + """ + Similar to :class:`RandomCrop`, but find a cropping window such that no single category + occupies a ratio of more than `single_category_max_area` in semantic segmentation ground + truth, which can cause unstability in training. The function attempts to find such a valid + cropping window for at most 10 times. + """ + + def __init__( + self, + crop_type: str, + crop_size, + single_category_max_area: float = 1.0, + ignored_category: int = None, + ): + """ + Args: + crop_type, crop_size: same as in :class:`RandomCrop` + single_category_max_area: the maximum allowed area ratio of a + category. Set to 1.0 to disable + ignored_category: allow this category in the semantic segmentation + ground truth to exceed the area ratio. Usually set to the category + that's ignored in training. + """ + self.crop_aug = RandomCrop(crop_type, crop_size) + self._init(locals()) + + def get_transform(self, image, sem_seg): + if self.single_category_max_area >= 1.0: + return self.crop_aug.get_transform(image) + else: + h, w = sem_seg.shape + for _ in range(10): + crop_size = self.crop_aug.get_crop_size((h, w)) + y0 = np.random.randint(h - crop_size[0] + 1) + x0 = np.random.randint(w - crop_size[1] + 1) + sem_seg_temp = sem_seg[y0 : y0 + crop_size[0], x0 : x0 + crop_size[1]] + labels, cnt = np.unique(sem_seg_temp, return_counts=True) + if self.ignored_category is not None: + cnt = cnt[labels != self.ignored_category] + if len(cnt) > 1 and np.max(cnt) < np.sum(cnt) * self.single_category_max_area: + break + crop_tfm = CropTransform(x0, y0, crop_size[1], crop_size[0]) + return crop_tfm + + +class RandomExtent(Augmentation): + """ + Outputs an image by cropping a random "subrect" of the source image. + + The subrect can be parameterized to include pixels outside the source image, + in which case they will be set to zeros (i.e. black). The size of the output + image will vary with the size of the random subrect. + """ + + def __init__(self, scale_range, shift_range): + """ + Args: + output_size (h, w): Dimensions of output image + scale_range (l, h): Range of input-to-output size scaling factor + shift_range (x, y): Range of shifts of the cropped subrect. The rect + is shifted by [w / 2 * Uniform(-x, x), h / 2 * Uniform(-y, y)], + where (w, h) is the (width, height) of the input image. Set each + component to zero to crop at the image's center. + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + img_h, img_w = image.shape[:2] + + # Initialize src_rect to fit the input image. + src_rect = np.array([-0.5 * img_w, -0.5 * img_h, 0.5 * img_w, 0.5 * img_h]) + + # Apply a random scaling to the src_rect. + src_rect *= np.random.uniform(self.scale_range[0], self.scale_range[1]) + + # Apply a random shift to the coordinates origin. + src_rect[0::2] += self.shift_range[0] * img_w * (np.random.rand() - 0.5) + src_rect[1::2] += self.shift_range[1] * img_h * (np.random.rand() - 0.5) + + # Map src_rect coordinates into image coordinates (center at corner). + src_rect[0::2] += 0.5 * img_w + src_rect[1::2] += 0.5 * img_h + + return ExtentTransform( + src_rect=(src_rect[0], src_rect[1], src_rect[2], src_rect[3]), + output_size=(int(src_rect[3] - src_rect[1]), int(src_rect[2] - src_rect[0])), + ) + + +class RandomContrast(Augmentation): + """ + Randomly transforms image contrast. + + Contrast intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce contrast + - intensity = 1 will preserve the input image + - intensity > 1 will increase contrast + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation + intensity_max (float): Maximum augmentation + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + w = np.random.uniform(self.intensity_min, self.intensity_max) + return BlendTransform(src_image=image.mean(), src_weight=1 - w, dst_weight=w) + + +class RandomBrightness(Augmentation): + """ + Randomly transforms image brightness. + + Brightness intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce brightness + - intensity = 1 will preserve the input image + - intensity > 1 will increase brightness + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation + intensity_max (float): Maximum augmentation + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + w = np.random.uniform(self.intensity_min, self.intensity_max) + return BlendTransform(src_image=0, src_weight=1 - w, dst_weight=w) + + +class RandomSaturation(Augmentation): + """ + Randomly transforms saturation of an RGB image. + Input images are assumed to have 'RGB' channel order. + + Saturation intensity is uniformly sampled in (intensity_min, intensity_max). + - intensity < 1 will reduce saturation (make the image more grayscale) + - intensity = 1 will preserve the input image + - intensity > 1 will increase saturation + + See: https://pillow.readthedocs.io/en/3.0.x/reference/ImageEnhance.html + """ + + def __init__(self, intensity_min, intensity_max): + """ + Args: + intensity_min (float): Minimum augmentation (1 preserves input). + intensity_max (float): Maximum augmentation (1 preserves input). + """ + super().__init__() + self._init(locals()) + + def get_transform(self, image): + assert image.shape[-1] == 3, "RandomSaturation only works on RGB images" + w = np.random.uniform(self.intensity_min, self.intensity_max) + grayscale = image.dot([0.299, 0.587, 0.114])[:, :, np.newaxis] + return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w) + + +class RandomLighting(Augmentation): + """ + The "lighting" augmentation described in AlexNet, using fixed PCA over ImageNet. + Input images are assumed to have 'RGB' channel order. + + The degree of color jittering is randomly sampled via a normal distribution, + with standard deviation given by the scale parameter. + """ + + def __init__(self, scale): + """ + Args: + scale (float): Standard deviation of principal component weighting. + """ + super().__init__() + self._init(locals()) + self.eigen_vecs = np.array( + [[-0.5675, 0.7192, 0.4009], [-0.5808, -0.0045, -0.8140], [-0.5836, -0.6948, 0.4203]] + ) + self.eigen_vals = np.array([0.2175, 0.0188, 0.0045]) + + def get_transform(self, image): + assert image.shape[-1] == 3, "RandomLighting only works on RGB images" + weights = np.random.normal(scale=self.scale, size=3) + return BlendTransform( + src_image=self.eigen_vecs.dot(weights * self.eigen_vals), src_weight=1.0, dst_weight=1.0 + ) + + +class RandomResize(Augmentation): + """Randomly resize image to a target size in shape_list""" + + def __init__(self, shape_list, interp=Image.BILINEAR): + """ + Args: + shape_list: a list of shapes in (h, w) + interp: PIL interpolation method + """ + self.shape_list = shape_list + self._init(locals()) + + def get_transform(self, image): + shape_idx = np.random.randint(low=0, high=len(self.shape_list)) + h, w = self.shape_list[shape_idx] + return ResizeTransform(image.shape[0], image.shape[1], h, w, self.interp) + + +class MinIoURandomCrop(Augmentation): + """Random crop the image & bboxes, the cropped patches have minimum IoU + requirement with original image & bboxes, the IoU threshold is randomly + selected from min_ious. + + Args: + min_ious (tuple): minimum IoU threshold for all intersections with + bounding boxes + min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w, + where a >= min_crop_size) + mode_trials: number of trials for sampling min_ious threshold + crop_trials: number of trials for sampling crop_size after cropping + """ + + def __init__( + self, + min_ious=(0.1, 0.3, 0.5, 0.7, 0.9), + min_crop_size=0.3, + mode_trials=1000, + crop_trials=50, + ): + self.min_ious = min_ious + self.sample_mode = (1, *min_ious, 0) + self.min_crop_size = min_crop_size + self.mode_trials = mode_trials + self.crop_trials = crop_trials + + def get_transform(self, image, boxes): + """Call function to crop images and bounding boxes with minimum IoU + constraint. + + Args: + boxes: ground truth boxes in (x1, y1, x2, y2) format + """ + if boxes is None: + return NoOpTransform() + h, w, c = image.shape + for _ in range(self.mode_trials): + mode = random.choice(self.sample_mode) + self.mode = mode + if mode == 1: + return NoOpTransform() + + min_iou = mode + for _ in range(self.crop_trials): + new_w = random.uniform(self.min_crop_size * w, w) + new_h = random.uniform(self.min_crop_size * h, h) + + # h / w in [0.5, 2] + if new_h / new_w < 0.5 or new_h / new_w > 2: + continue + + left = random.uniform(w - new_w) + top = random.uniform(h - new_h) + + patch = np.array((int(left), int(top), int(left + new_w), int(top + new_h))) + # Line or point crop is not allowed + if patch[2] == patch[0] or patch[3] == patch[1]: + continue + overlaps = pairwise_iou( + Boxes(patch.reshape(-1, 4)), Boxes(boxes.reshape(-1, 4)) + ).reshape(-1) + if len(overlaps) > 0 and overlaps.min() < min_iou: + continue + + # center of boxes should inside the crop img + # only adjust boxes and instance masks when the gt is not empty + if len(overlaps) > 0: + # adjust boxes + def is_center_of_bboxes_in_patch(boxes, patch): + center = (boxes[:, :2] + boxes[:, 2:]) / 2 + mask = ( + (center[:, 0] > patch[0]) + * (center[:, 1] > patch[1]) + * (center[:, 0] < patch[2]) + * (center[:, 1] < patch[3]) + ) + return mask + + mask = is_center_of_bboxes_in_patch(boxes, patch) + if not mask.any(): + continue + return CropTransform(int(left), int(top), int(new_w), int(new_h)) diff --git a/custom_detectron2/data/transforms/transform.py b/custom_detectron2/data/transforms/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..46769a2569ffc6223a95990f8db5973757e7d23f --- /dev/null +++ b/custom_detectron2/data/transforms/transform.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +See "Data Augmentation" tutorial for an overview of the system: +https://detectron2.readthedocs.io/tutorials/augmentation.html +""" + +import numpy as np +import torch +import torch.nn.functional as F +from fvcore.transforms.transform import ( + CropTransform, + HFlipTransform, + NoOpTransform, + Transform, + TransformList, +) +from PIL import Image + +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + pass + +__all__ = [ + "ExtentTransform", + "ResizeTransform", + "RotationTransform", + "ColorTransform", + "PILColorTransform", +] + + +class ExtentTransform(Transform): + """ + Extracts a subregion from the source image and scales it to the output size. + + The fill color is used to map pixels from the source rect that fall outside + the source image. + + See: https://pillow.readthedocs.io/en/latest/PIL.html#PIL.ImageTransform.ExtentTransform + """ + + def __init__(self, src_rect, output_size, interp=Image.BILINEAR, fill=0): + """ + Args: + src_rect (x0, y0, x1, y1): src coordinates + output_size (h, w): dst image size + interp: PIL interpolation methods + fill: Fill color used when src_rect extends outside image + """ + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + h, w = self.output_size + if len(img.shape) > 2 and img.shape[2] == 1: + pil_image = Image.fromarray(img[:, :, 0], mode="L") + else: + pil_image = Image.fromarray(img) + pil_image = pil_image.transform( + size=(w, h), + method=Image.EXTENT, + data=self.src_rect, + resample=interp if interp else self.interp, + fill=self.fill, + ) + ret = np.asarray(pil_image) + if len(img.shape) > 2 and img.shape[2] == 1: + ret = np.expand_dims(ret, -1) + return ret + + def apply_coords(self, coords): + # Transform image center from source coordinates into output coordinates + # and then map the new origin to the corner of the output image. + h, w = self.output_size + x0, y0, x1, y1 = self.src_rect + new_coords = coords.astype(np.float32) + new_coords[:, 0] -= 0.5 * (x0 + x1) + new_coords[:, 1] -= 0.5 * (y0 + y1) + new_coords[:, 0] *= w / (x1 - x0) + new_coords[:, 1] *= h / (y1 - y0) + new_coords[:, 0] += 0.5 * w + new_coords[:, 1] += 0.5 * h + return new_coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + +class ResizeTransform(Transform): + """ + Resize the image to a target size. + """ + + def __init__(self, h, w, new_h, new_w, interp=None): + """ + Args: + h, w (int): original image size + new_h, new_w (int): new image size + interp: PIL interpolation methods, defaults to bilinear. + """ + # TODO decide on PIL vs opencv + super().__init__() + if interp is None: + interp = Image.BILINEAR + self._set_attributes(locals()) + + def apply_image(self, img, interp=None): + assert img.shape[:2] == (self.h, self.w) + assert len(img.shape) <= 4 + interp_method = interp if interp is not None else self.interp + + if img.dtype == np.uint8: + if len(img.shape) > 2 and img.shape[2] == 1: + pil_image = Image.fromarray(img[:, :, 0], mode="L") + else: + pil_image = Image.fromarray(img) + pil_image = pil_image.resize((self.new_w, self.new_h), interp_method) + ret = np.asarray(pil_image) + if len(img.shape) > 2 and img.shape[2] == 1: + ret = np.expand_dims(ret, -1) + else: + # PIL only supports uint8 + if any(x < 0 for x in img.strides): + img = np.ascontiguousarray(img) + img = torch.from_numpy(img) + shape = list(img.shape) + shape_4d = shape[:2] + [1] * (4 - len(shape)) + shape[2:] + img = img.view(shape_4d).permute(2, 3, 0, 1) # hw(c) -> nchw + _PIL_RESIZE_TO_INTERPOLATE_MODE = { + Image.NEAREST: "nearest", + Image.BILINEAR: "bilinear", + Image.BICUBIC: "bicubic", + } + mode = _PIL_RESIZE_TO_INTERPOLATE_MODE[interp_method] + align_corners = None if mode == "nearest" else False + img = F.interpolate( + img, (self.new_h, self.new_w), mode=mode, align_corners=align_corners + ) + shape[:2] = (self.new_h, self.new_w) + ret = img.permute(2, 3, 0, 1).view(shape).numpy() # nchw -> hw(c) + + return ret + + def apply_coords(self, coords): + coords[:, 0] = coords[:, 0] * (self.new_w * 1.0 / self.w) + coords[:, 1] = coords[:, 1] * (self.new_h * 1.0 / self.h) + return coords + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=Image.NEAREST) + return segmentation + + def inverse(self): + return ResizeTransform(self.new_h, self.new_w, self.h, self.w, self.interp) + + +class RotationTransform(Transform): + """ + This method returns a copy of this image, rotated the given + number of degrees counter clockwise around its center. + """ + + def __init__(self, h, w, angle, expand=True, center=None, interp=None): + """ + Args: + h, w (int): original image size + angle (float): degrees for rotation + expand (bool): choose if the image should be resized to fit the whole + rotated image (default), or simply cropped + center (tuple (width, height)): coordinates of the rotation center + if left to None, the center will be fit to the center of each image + center has no effect if expand=True because it only affects shifting + interp: cv2 interpolation method, default cv2.INTER_LINEAR + """ + super().__init__() + image_center = np.array((w / 2, h / 2)) + if center is None: + center = image_center + if interp is None: + interp = cv2.INTER_LINEAR + abs_cos, abs_sin = (abs(np.cos(np.deg2rad(angle))), abs(np.sin(np.deg2rad(angle)))) + if expand: + # find the new width and height bounds + bound_w, bound_h = np.rint( + [h * abs_sin + w * abs_cos, h * abs_cos + w * abs_sin] + ).astype(int) + else: + bound_w, bound_h = w, h + + self._set_attributes(locals()) + self.rm_coords = self.create_rotation_matrix() + # Needed because of this problem https://github.com/opencv/opencv/issues/11784 + self.rm_image = self.create_rotation_matrix(offset=-0.5) + + def apply_image(self, img, interp=None): + """ + img should be a numpy array, formatted as Height * Width * Nchannels + """ + if len(img) == 0 or self.angle % 360 == 0: + return img + assert img.shape[:2] == (self.h, self.w) + interp = interp if interp is not None else self.interp + return cv2.warpAffine(img, self.rm_image, (self.bound_w, self.bound_h), flags=interp) + + def apply_coords(self, coords): + """ + coords should be a N * 2 array-like, containing N couples of (x, y) points + """ + coords = np.asarray(coords, dtype=float) + if len(coords) == 0 or self.angle % 360 == 0: + return coords + return cv2.transform(coords[:, np.newaxis, :], self.rm_coords)[:, 0, :] + + def apply_segmentation(self, segmentation): + segmentation = self.apply_image(segmentation, interp=cv2.INTER_NEAREST) + return segmentation + + def create_rotation_matrix(self, offset=0): + center = (self.center[0] + offset, self.center[1] + offset) + rm = cv2.getRotationMatrix2D(tuple(center), self.angle, 1) + if self.expand: + # Find the coordinates of the center of rotation in the new image + # The only point for which we know the future coordinates is the center of the image + rot_im_center = cv2.transform(self.image_center[None, None, :] + offset, rm)[0, 0, :] + new_center = np.array([self.bound_w / 2, self.bound_h / 2]) + offset - rot_im_center + # shift the rotation center to the new coordinates + rm[:, 2] += new_center + return rm + + def inverse(self): + """ + The inverse is to rotate it back with expand, and crop to get the original shape. + """ + if not self.expand: # Not possible to inverse if a part of the image is lost + raise NotImplementedError() + rotation = RotationTransform( + self.bound_h, self.bound_w, -self.angle, True, None, self.interp + ) + crop = CropTransform( + (rotation.bound_w - self.w) // 2, (rotation.bound_h - self.h) // 2, self.w, self.h + ) + return TransformList([rotation, crop]) + + +class ColorTransform(Transform): + """ + Generic wrapper for any photometric transforms. + These transformations should only affect the color space and + not the coordinate space of the image (e.g. annotation + coordinates such as bounding boxes should not be changed) + """ + + def __init__(self, op): + """ + Args: + op (Callable): operation to be applied to the image, + which takes in an ndarray and returns an ndarray. + """ + if not callable(op): + raise ValueError("op parameter should be callable") + super().__init__() + self._set_attributes(locals()) + + def apply_image(self, img): + return self.op(img) + + def apply_coords(self, coords): + return coords + + def inverse(self): + return NoOpTransform() + + def apply_segmentation(self, segmentation): + return segmentation + + +class PILColorTransform(ColorTransform): + """ + Generic wrapper for PIL Photometric image transforms, + which affect the color space and not the coordinate + space of the image + """ + + def __init__(self, op): + """ + Args: + op (Callable): operation to be applied to the image, + which takes in a PIL Image and returns a transformed + PIL Image. + For reference on possible operations see: + - https://pillow.readthedocs.io/en/stable/ + """ + if not callable(op): + raise ValueError("op parameter should be callable") + super().__init__(op) + + def apply_image(self, img): + img = Image.fromarray(img) + return np.asarray(super().apply_image(img)) + + +def HFlip_rotated_box(transform, rotated_boxes): + """ + Apply the horizontal flip transform on rotated boxes. + + Args: + rotated_boxes (ndarray): Nx5 floating point array of + (x_center, y_center, width, height, angle_degrees) format + in absolute coordinates. + """ + # Transform x_center + rotated_boxes[:, 0] = transform.width - rotated_boxes[:, 0] + # Transform angle + rotated_boxes[:, 4] = -rotated_boxes[:, 4] + return rotated_boxes + + +def Resize_rotated_box(transform, rotated_boxes): + """ + Apply the resizing transform on rotated boxes. For details of how these (approximation) + formulas are derived, please refer to :meth:`RotatedBoxes.scale`. + + Args: + rotated_boxes (ndarray): Nx5 floating point array of + (x_center, y_center, width, height, angle_degrees) format + in absolute coordinates. + """ + scale_factor_x = transform.new_w * 1.0 / transform.w + scale_factor_y = transform.new_h * 1.0 / transform.h + rotated_boxes[:, 0] *= scale_factor_x + rotated_boxes[:, 1] *= scale_factor_y + theta = rotated_boxes[:, 4] * np.pi / 180.0 + c = np.cos(theta) + s = np.sin(theta) + rotated_boxes[:, 2] *= np.sqrt(np.square(scale_factor_x * c) + np.square(scale_factor_y * s)) + rotated_boxes[:, 3] *= np.sqrt(np.square(scale_factor_x * s) + np.square(scale_factor_y * c)) + rotated_boxes[:, 4] = np.arctan2(scale_factor_x * s, scale_factor_y * c) * 180 / np.pi + + return rotated_boxes + + +HFlipTransform.register_type("rotated_box", HFlip_rotated_box) +ResizeTransform.register_type("rotated_box", Resize_rotated_box) + +# not necessary any more with latest fvcore +NoOpTransform.register_type("rotated_box", lambda t, x: x) diff --git a/custom_detectron2/engine/__init__.py b/custom_detectron2/engine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08a61572b4c7d09c8d400e903a96cbf5b2cc4763 --- /dev/null +++ b/custom_detectron2/engine/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from .launch import * +from .train_loop import * + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +# prefer to let hooks and defaults live in separate namespaces (therefore not in __all__) +# but still make them available here +from .hooks import * +from .defaults import * diff --git a/custom_detectron2/engine/defaults.py b/custom_detectron2/engine/defaults.py new file mode 100644 index 0000000000000000000000000000000000000000..574f87cd7a11d11dfb91b9db62c795ba6403aaa9 --- /dev/null +++ b/custom_detectron2/engine/defaults.py @@ -0,0 +1,715 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +This file contains components with some default boilerplate logic user may need +in training / testing. They will not work for everyone, but many users may find them useful. + +The behavior of functions/classes in this file is subject to change, +since they are meant to represent the "common default behavior" people need in their projects. +""" + +import argparse +import logging +import os +import sys +import weakref +from collections import OrderedDict +from typing import Optional +import torch +from fvcore.nn.precise_bn import get_bn_modules +from omegaconf import OmegaConf +from torch.nn.parallel import DistributedDataParallel + +import custom_detectron2.data.transforms as T +from custom_detectron2.checkpoint import DetectionCheckpointer +from custom_detectron2.config import CfgNode, LazyConfig +from custom_detectron2.data import ( + MetadataCatalog, + build_detection_test_loader, + build_detection_train_loader, +) +from custom_detectron2.evaluation import ( + DatasetEvaluator, + inference_on_dataset, + print_csv_format, + verify_results, +) +from custom_detectron2.modeling import build_model +from custom_detectron2.solver import build_lr_scheduler, build_optimizer +from custom_detectron2.utils import comm +from custom_detectron2.utils.collect_env import collect_env_info +from custom_detectron2.utils.env import seed_all_rng +from custom_detectron2.utils.events import CommonMetricPrinter, JSONWriter, TensorboardXWriter +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.logger import setup_logger + +from . import hooks +from .train_loop import AMPTrainer, SimpleTrainer, TrainerBase + +__all__ = [ + "create_ddp_model", + "default_argument_parser", + "default_setup", + "default_writers", + "DefaultPredictor", + "DefaultTrainer", +] + + +def create_ddp_model(model, *, fp16_compression=False, **kwargs): + """ + Create a DistributedDataParallel model if there are >1 processes. + + Args: + model: a torch.nn.Module + fp16_compression: add fp16 compression hooks to the ddp object. + See more at https://pytorch.org/docs/stable/ddp_comm_hooks.html#torch.distributed.algorithms.ddp_comm_hooks.default_hooks.fp16_compress_hook + kwargs: other arguments of :module:`torch.nn.parallel.DistributedDataParallel`. + """ # noqa + if comm.get_world_size() == 1: + return model + if "device_ids" not in kwargs: + kwargs["device_ids"] = [comm.get_local_rank()] + ddp = DistributedDataParallel(model, **kwargs) + if fp16_compression: + from torch.distributed.algorithms.ddp_comm_hooks import default as comm_hooks + + ddp.register_comm_hook(state=None, hook=comm_hooks.fp16_compress_hook) + return ddp + + +def default_argument_parser(epilog=None): + """ + Create a parser with some common arguments used by detectron2 users. + + Args: + epilog (str): epilog passed to ArgumentParser describing the usage. + + Returns: + argparse.ArgumentParser: + """ + parser = argparse.ArgumentParser( + epilog=epilog + or f""" +Examples: + +Run on single machine: + $ {sys.argv[0]} --num-gpus 8 --config-file cfg.yaml + +Change some config options: + $ {sys.argv[0]} --config-file cfg.yaml MODEL.WEIGHTS /path/to/weight.pth SOLVER.BASE_LR 0.001 + +Run on multiple machines: + (machine0)$ {sys.argv[0]} --machine-rank 0 --num-machines 2 --dist-url [--other-flags] + (machine1)$ {sys.argv[0]} --machine-rank 1 --num-machines 2 --dist-url [--other-flags] +""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument("--config-file", default="", metavar="FILE", help="path to config file") + parser.add_argument( + "--resume", + action="store_true", + help="Whether to attempt to resume from the checkpoint directory. " + "See documentation of `DefaultTrainer.resume_or_load()` for what it means.", + ) + parser.add_argument("--eval-only", action="store_true", help="perform evaluation only") + parser.add_argument("--num-gpus", type=int, default=1, help="number of gpus *per machine*") + parser.add_argument("--num-machines", type=int, default=1, help="total number of machines") + parser.add_argument( + "--machine-rank", type=int, default=0, help="the rank of this machine (unique per machine)" + ) + + # PyTorch still may leave orphan processes in multi-gpu training. + # Therefore we use a deterministic way to obtain port, + # so that users are aware of orphan processes by seeing the port occupied. + port = 2**15 + 2**14 + hash(os.getuid() if sys.platform != "win32" else 1) % 2**14 + parser.add_argument( + "--dist-url", + default="tcp://127.0.0.1:{}".format(port), + help="initialization URL for pytorch distributed backend. See " + "https://pytorch.org/docs/stable/distributed.html for details.", + ) + parser.add_argument( + "opts", + help=""" +Modify config options at the end of the command. For Yacs configs, use +space-separated "PATH.KEY VALUE" pairs. +For python-based LazyConfig, use "path.key=value". + """.strip(), + default=None, + nargs=argparse.REMAINDER, + ) + return parser + + +def _try_get_key(cfg, *keys, default=None): + """ + Try select keys from cfg until the first key that exists. Otherwise return default. + """ + if isinstance(cfg, CfgNode): + cfg = OmegaConf.create(cfg.dump()) + for k in keys: + none = object() + p = OmegaConf.select(cfg, k, default=none) + if p is not none: + return p + return default + + +def _highlight(code, filename): + try: + import pygments + except ImportError: + return code + + from pygments.lexers import Python3Lexer, YamlLexer + from pygments.formatters import Terminal256Formatter + + lexer = Python3Lexer() if filename.endswith(".py") else YamlLexer() + code = pygments.highlight(code, lexer, Terminal256Formatter(style="monokai")) + return code + + +def default_setup(cfg, args): + """ + Perform some basic common setups at the beginning of a job, including: + + 1. Set up the detectron2 logger + 2. Log basic information about environment, cmdline arguments, and config + 3. Backup the config to the output directory + + Args: + cfg (CfgNode or omegaconf.DictConfig): the full config to be used + args (argparse.NameSpace): the command line arguments to be logged + """ + output_dir = _try_get_key(cfg, "OUTPUT_DIR", "output_dir", "train.output_dir") + if comm.is_main_process() and output_dir: + PathManager.mkdirs(output_dir) + + rank = comm.get_rank() + setup_logger(output_dir, distributed_rank=rank, name="fvcore") + logger = setup_logger(output_dir, distributed_rank=rank) + + logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) + logger.info("Environment info:\n" + collect_env_info()) + + logger.info("Command line arguments: " + str(args)) + if hasattr(args, "config_file") and args.config_file != "": + logger.info( + "Contents of args.config_file={}:\n{}".format( + args.config_file, + _highlight(PathManager.open(args.config_file, "r").read(), args.config_file), + ) + ) + + if comm.is_main_process() and output_dir: + # Note: some of our scripts may expect the existence of + # config.yaml in output directory + path = os.path.join(output_dir, "config.yaml") + if isinstance(cfg, CfgNode): + logger.info("Running with full config:\n{}".format(_highlight(cfg.dump(), ".yaml"))) + with PathManager.open(path, "w") as f: + f.write(cfg.dump()) + else: + LazyConfig.save(cfg, path) + logger.info("Full config saved to {}".format(path)) + + # make sure each worker has a different, yet deterministic seed if specified + seed = _try_get_key(cfg, "SEED", "train.seed", default=-1) + seed_all_rng(None if seed < 0 else seed + rank) + + # cudnn benchmark has large overhead. It shouldn't be used considering the small size of + # typical validation set. + if not (hasattr(args, "eval_only") and args.eval_only): + torch.backends.cudnn.benchmark = _try_get_key( + cfg, "CUDNN_BENCHMARK", "train.cudnn_benchmark", default=False + ) + + +def default_writers(output_dir: str, max_iter: Optional[int] = None): + """ + Build a list of :class:`EventWriter` to be used. + It now consists of a :class:`CommonMetricPrinter`, + :class:`TensorboardXWriter` and :class:`JSONWriter`. + + Args: + output_dir: directory to store JSON metrics and tensorboard events + max_iter: the total number of iterations + + Returns: + list[EventWriter]: a list of :class:`EventWriter` objects. + """ + PathManager.mkdirs(output_dir) + return [ + # It may not always print what you want to see, since it prints "common" metrics only. + CommonMetricPrinter(max_iter), + JSONWriter(os.path.join(output_dir, "metrics.json")), + TensorboardXWriter(output_dir), + ] + + +class DefaultPredictor: + """ + Create a simple end-to-end predictor with the given config that runs on + single device for a single input image. + + Compared to using the model directly, this class does the following additions: + + 1. Load checkpoint from `cfg.MODEL.WEIGHTS`. + 2. Always take BGR image as the input and apply conversion defined by `cfg.INPUT.FORMAT`. + 3. Apply resizing defined by `cfg.INPUT.{MIN,MAX}_SIZE_TEST`. + 4. Take one input image and produce a single output, instead of a batch. + + This is meant for simple demo purposes, so it does the above steps automatically. + This is not meant for benchmarks or running complicated inference logic. + If you'd like to do anything more complicated, please refer to its source code as + examples to build and use the model manually. + + Attributes: + metadata (Metadata): the metadata of the underlying dataset, obtained from + cfg.DATASETS.TEST. + + Examples: + :: + pred = DefaultPredictor(cfg) + inputs = cv2.imread("input.jpg") + outputs = pred(inputs) + """ + + def __init__(self, cfg): + self.cfg = cfg.clone() # cfg can be modified by model + self.model = build_model(self.cfg) + self.model.eval() + if len(cfg.DATASETS.TEST): + self.metadata = MetadataCatalog.get(cfg.DATASETS.TEST[0]) + + checkpointer = DetectionCheckpointer(self.model) + checkpointer.load(cfg.MODEL.WEIGHTS) + + self.aug = T.ResizeShortestEdge( + [cfg.INPUT.MIN_SIZE_TEST, cfg.INPUT.MIN_SIZE_TEST], cfg.INPUT.MAX_SIZE_TEST + ) + + self.input_format = cfg.INPUT.FORMAT + assert self.input_format in ["RGB", "BGR"], self.input_format + + def __call__(self, original_image): + """ + Args: + original_image (np.ndarray): an image of shape (H, W, C) (in BGR order). + + Returns: + predictions (dict): + the output of the model for one image only. + See :doc:`/tutorials/models` for details about the format. + """ + with torch.no_grad(): # https://github.com/sphinx-doc/sphinx/issues/4258 + # Apply pre-processing to image. + if self.input_format == "RGB": + # whether the model expects BGR inputs or RGB + original_image = original_image[:, :, ::-1] + height, width = original_image.shape[:2] + image = self.aug.get_transform(original_image).apply_image(original_image) + image = torch.as_tensor(image.astype("float32").transpose(2, 0, 1)) + + inputs = {"image": image, "height": height, "width": width} + predictions = self.model([inputs])[0] + return predictions + + +class DefaultTrainer(TrainerBase): + """ + A trainer with default training logic. It does the following: + + 1. Create a :class:`SimpleTrainer` using model, optimizer, dataloader + defined by the given config. Create a LR scheduler defined by the config. + 2. Load the last checkpoint or `cfg.MODEL.WEIGHTS`, if exists, when + `resume_or_load` is called. + 3. Register a few common hooks defined by the config. + + It is created to simplify the **standard model training workflow** and reduce code boilerplate + for users who only need the standard training workflow, with standard features. + It means this class makes *many assumptions* about your training logic that + may easily become invalid in a new research. In fact, any assumptions beyond those made in the + :class:`SimpleTrainer` are too much for research. + + The code of this class has been annotated about restrictive assumptions it makes. + When they do not work for you, you're encouraged to: + + 1. Overwrite methods of this class, OR: + 2. Use :class:`SimpleTrainer`, which only does minimal SGD training and + nothing else. You can then add your own hooks if needed. OR: + 3. Write your own training loop similar to `tools/plain_train_net.py`. + + See the :doc:`/tutorials/training` tutorials for more details. + + Note that the behavior of this class, like other functions/classes in + this file, is not stable, since it is meant to represent the "common default behavior". + It is only guaranteed to work well with the standard models and training workflow in detectron2. + To obtain more stable behavior, write your own training logic with other public APIs. + + Examples: + :: + trainer = DefaultTrainer(cfg) + trainer.resume_or_load() # load last checkpoint or MODEL.WEIGHTS + trainer.train() + + Attributes: + scheduler: + checkpointer (DetectionCheckpointer): + cfg (CfgNode): + """ + + def __init__(self, cfg): + """ + Args: + cfg (CfgNode): + """ + super().__init__() + logger = logging.getLogger("detectron2") + if not logger.isEnabledFor(logging.INFO): # setup_logger is not called for d2 + setup_logger() + cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size()) + + # Assume these objects must be constructed in this order. + model = self.build_model(cfg) + optimizer = self.build_optimizer(cfg, model) + data_loader = self.build_train_loader(cfg) + + model = create_ddp_model(model, broadcast_buffers=False) + self._trainer = (AMPTrainer if cfg.SOLVER.AMP.ENABLED else SimpleTrainer)( + model, data_loader, optimizer + ) + + self.scheduler = self.build_lr_scheduler(cfg, optimizer) + self.checkpointer = DetectionCheckpointer( + # Assume you want to save checkpoints together with logs/statistics + model, + cfg.OUTPUT_DIR, + trainer=weakref.proxy(self), + ) + self.start_iter = 0 + self.max_iter = cfg.SOLVER.MAX_ITER + self.cfg = cfg + + self.register_hooks(self.build_hooks()) + + def resume_or_load(self, resume=True): + """ + If `resume==True` and `cfg.OUTPUT_DIR` contains the last checkpoint (defined by + a `last_checkpoint` file), resume from the file. Resuming means loading all + available states (eg. optimizer and scheduler) and update iteration counter + from the checkpoint. ``cfg.MODEL.WEIGHTS`` will not be used. + + Otherwise, this is considered as an independent training. The method will load model + weights from the file `cfg.MODEL.WEIGHTS` (but will not load other states) and start + from iteration 0. + + Args: + resume (bool): whether to do resume or not + """ + self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, resume=resume) + if resume and self.checkpointer.has_checkpoint(): + # The checkpoint stores the training iteration that just finished, thus we start + # at the next iteration + self.start_iter = self.iter + 1 + + def build_hooks(self): + """ + Build a list of default hooks, including timing, evaluation, + checkpointing, lr scheduling, precise BN, writing events. + + Returns: + list[HookBase]: + """ + cfg = self.cfg.clone() + cfg.defrost() + cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN + + ret = [ + hooks.IterationTimer(), + hooks.LRScheduler(), + hooks.PreciseBN( + # Run at the same freq as (but before) evaluation. + cfg.TEST.EVAL_PERIOD, + self.model, + # Build a new data loader to not affect training + self.build_train_loader(cfg), + cfg.TEST.PRECISE_BN.NUM_ITER, + ) + if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model) + else None, + ] + + # Do PreciseBN before checkpointer, because it updates the model and need to + # be saved by checkpointer. + # This is not always the best: if checkpointing has a different frequency, + # some checkpoints may have more precise statistics than others. + if comm.is_main_process(): + ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD)) + + def test_and_save_results(): + self._last_eval_results = self.test(self.cfg, self.model) + return self._last_eval_results + + # Do evaluation after checkpointer, because then if it fails, + # we can use the saved checkpoint to debug. + ret.append(hooks.EvalHook(cfg.TEST.EVAL_PERIOD, test_and_save_results)) + + if comm.is_main_process(): + # Here the default print/log frequency of each writer is used. + # run writers in the end, so that evaluation metrics are written + ret.append(hooks.PeriodicWriter(self.build_writers(), period=20)) + return ret + + def build_writers(self): + """ + Build a list of writers to be used using :func:`default_writers()`. + If you'd like a different list of writers, you can overwrite it in + your trainer. + + Returns: + list[EventWriter]: a list of :class:`EventWriter` objects. + """ + return default_writers(self.cfg.OUTPUT_DIR, self.max_iter) + + def train(self): + """ + Run training. + + Returns: + OrderedDict of results, if evaluation is enabled. Otherwise None. + """ + super().train(self.start_iter, self.max_iter) + if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process(): + assert hasattr( + self, "_last_eval_results" + ), "No evaluation results obtained during training!" + verify_results(self.cfg, self._last_eval_results) + return self._last_eval_results + + def run_step(self): + self._trainer.iter = self.iter + self._trainer.run_step() + + def state_dict(self): + ret = super().state_dict() + ret["_trainer"] = self._trainer.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self._trainer.load_state_dict(state_dict["_trainer"]) + + @classmethod + def build_model(cls, cfg): + """ + Returns: + torch.nn.Module: + + It now calls :func:`detectron2.modeling.build_model`. + Overwrite it if you'd like a different model. + """ + model = build_model(cfg) + logger = logging.getLogger(__name__) + logger.info("Model:\n{}".format(model)) + return model + + @classmethod + def build_optimizer(cls, cfg, model): + """ + Returns: + torch.optim.Optimizer: + + It now calls :func:`detectron2.solver.build_optimizer`. + Overwrite it if you'd like a different optimizer. + """ + return build_optimizer(cfg, model) + + @classmethod + def build_lr_scheduler(cls, cfg, optimizer): + """ + It now calls :func:`detectron2.solver.build_lr_scheduler`. + Overwrite it if you'd like a different scheduler. + """ + return build_lr_scheduler(cfg, optimizer) + + @classmethod + def build_train_loader(cls, cfg): + """ + Returns: + iterable + + It now calls :func:`detectron2.data.build_detection_train_loader`. + Overwrite it if you'd like a different data loader. + """ + return build_detection_train_loader(cfg) + + @classmethod + def build_test_loader(cls, cfg, dataset_name): + """ + Returns: + iterable + + It now calls :func:`detectron2.data.build_detection_test_loader`. + Overwrite it if you'd like a different data loader. + """ + return build_detection_test_loader(cfg, dataset_name) + + @classmethod + def build_evaluator(cls, cfg, dataset_name): + """ + Returns: + DatasetEvaluator or None + + It is not implemented by default. + """ + raise NotImplementedError( + """ +If you want DefaultTrainer to automatically run evaluation, +please implement `build_evaluator()` in subclasses (see train_net.py for example). +Alternatively, you can call evaluation functions yourself (see Colab balloon tutorial for example). +""" + ) + + @classmethod + def test(cls, cfg, model, evaluators=None): + """ + Evaluate the given model. The given model is expected to already contain + weights to evaluate. + + Args: + cfg (CfgNode): + model (nn.Module): + evaluators (list[DatasetEvaluator] or None): if None, will call + :meth:`build_evaluator`. Otherwise, must have the same length as + ``cfg.DATASETS.TEST``. + + Returns: + dict: a dict of result metrics + """ + logger = logging.getLogger(__name__) + if isinstance(evaluators, DatasetEvaluator): + evaluators = [evaluators] + if evaluators is not None: + assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( + len(cfg.DATASETS.TEST), len(evaluators) + ) + + results = OrderedDict() + for idx, dataset_name in enumerate(cfg.DATASETS.TEST): + data_loader = cls.build_test_loader(cfg, dataset_name) + # When evaluators are passed in as arguments, + # implicitly assume that evaluators can be created before data_loader. + if evaluators is not None: + evaluator = evaluators[idx] + else: + try: + evaluator = cls.build_evaluator(cfg, dataset_name) + except NotImplementedError: + logger.warn( + "No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " + "or implement its `build_evaluator` method." + ) + results[dataset_name] = {} + continue + results_i = inference_on_dataset(model, data_loader, evaluator) + results[dataset_name] = results_i + if comm.is_main_process(): + assert isinstance( + results_i, dict + ), "Evaluator must return a dict on the main process. Got {} instead.".format( + results_i + ) + logger.info("Evaluation results for {} in csv format:".format(dataset_name)) + print_csv_format(results_i) + + if len(results) == 1: + results = list(results.values())[0] + return results + + @staticmethod + def auto_scale_workers(cfg, num_workers: int): + """ + When the config is defined for certain number of workers (according to + ``cfg.SOLVER.REFERENCE_WORLD_SIZE``) that's different from the number of + workers currently in use, returns a new cfg where the total batch size + is scaled so that the per-GPU batch size stays the same as the + original ``IMS_PER_BATCH // REFERENCE_WORLD_SIZE``. + + Other config options are also scaled accordingly: + * training steps and warmup steps are scaled inverse proportionally. + * learning rate are scaled proportionally, following :paper:`ImageNet in 1h`. + + For example, with the original config like the following: + + .. code-block:: yaml + + IMS_PER_BATCH: 16 + BASE_LR: 0.1 + REFERENCE_WORLD_SIZE: 8 + MAX_ITER: 5000 + STEPS: (4000,) + CHECKPOINT_PERIOD: 1000 + + When this config is used on 16 GPUs instead of the reference number 8, + calling this method will return a new config with: + + .. code-block:: yaml + + IMS_PER_BATCH: 32 + BASE_LR: 0.2 + REFERENCE_WORLD_SIZE: 16 + MAX_ITER: 2500 + STEPS: (2000,) + CHECKPOINT_PERIOD: 500 + + Note that both the original config and this new config can be trained on 16 GPUs. + It's up to user whether to enable this feature (by setting ``REFERENCE_WORLD_SIZE``). + + Returns: + CfgNode: a new config. Same as original if ``cfg.SOLVER.REFERENCE_WORLD_SIZE==0``. + """ + old_world_size = cfg.SOLVER.REFERENCE_WORLD_SIZE + if old_world_size == 0 or old_world_size == num_workers: + return cfg + cfg = cfg.clone() + frozen = cfg.is_frozen() + cfg.defrost() + + assert ( + cfg.SOLVER.IMS_PER_BATCH % old_world_size == 0 + ), "Invalid REFERENCE_WORLD_SIZE in config!" + scale = num_workers / old_world_size + bs = cfg.SOLVER.IMS_PER_BATCH = int(round(cfg.SOLVER.IMS_PER_BATCH * scale)) + lr = cfg.SOLVER.BASE_LR = cfg.SOLVER.BASE_LR * scale + max_iter = cfg.SOLVER.MAX_ITER = int(round(cfg.SOLVER.MAX_ITER / scale)) + warmup_iter = cfg.SOLVER.WARMUP_ITERS = int(round(cfg.SOLVER.WARMUP_ITERS / scale)) + cfg.SOLVER.STEPS = tuple(int(round(s / scale)) for s in cfg.SOLVER.STEPS) + cfg.TEST.EVAL_PERIOD = int(round(cfg.TEST.EVAL_PERIOD / scale)) + cfg.SOLVER.CHECKPOINT_PERIOD = int(round(cfg.SOLVER.CHECKPOINT_PERIOD / scale)) + cfg.SOLVER.REFERENCE_WORLD_SIZE = num_workers # maintain invariant + logger = logging.getLogger(__name__) + logger.info( + f"Auto-scaling the config to batch_size={bs}, learning_rate={lr}, " + f"max_iter={max_iter}, warmup={warmup_iter}." + ) + + if frozen: + cfg.freeze() + return cfg + + +# Access basic attributes from the underlying trainer +for _attr in ["model", "data_loader", "optimizer"]: + setattr( + DefaultTrainer, + _attr, + property( + # getter + lambda self, x=_attr: getattr(self._trainer, x), + # setter + lambda self, value, x=_attr: setattr(self._trainer, x, value), + ), + ) diff --git a/custom_detectron2/engine/hooks.py b/custom_detectron2/engine/hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..a8773d612a05bb028e6db98b60030fcd4fe04981 --- /dev/null +++ b/custom_detectron2/engine/hooks.py @@ -0,0 +1,690 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import datetime +import itertools +import logging +import math +import operator +import os +import tempfile +import time +import warnings +from collections import Counter +import torch +from fvcore.common.checkpoint import Checkpointer +from fvcore.common.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer +from fvcore.common.param_scheduler import ParamScheduler +from fvcore.common.timer import Timer +from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats + +import custom_detectron2.utils.comm as comm +from custom_detectron2.evaluation.testing import flatten_results_dict +from custom_detectron2.solver import LRMultiplier +from custom_detectron2.solver import LRScheduler as _LRScheduler +from custom_detectron2.utils.events import EventStorage, EventWriter +from custom_detectron2.utils.file_io import PathManager + +from .train_loop import HookBase + +__all__ = [ + "CallbackHook", + "IterationTimer", + "PeriodicWriter", + "PeriodicCheckpointer", + "BestCheckpointer", + "LRScheduler", + "AutogradProfiler", + "EvalHook", + "PreciseBN", + "TorchProfiler", + "TorchMemoryStats", +] + + +""" +Implement some common hooks. +""" + + +class CallbackHook(HookBase): + """ + Create a hook using callback functions provided by the user. + """ + + def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): + """ + Each argument is a function that takes one argument: the trainer. + """ + self._before_train = before_train + self._before_step = before_step + self._after_step = after_step + self._after_train = after_train + + def before_train(self): + if self._before_train: + self._before_train(self.trainer) + + def after_train(self): + if self._after_train: + self._after_train(self.trainer) + # The functions may be closures that hold reference to the trainer + # Therefore, delete them to avoid circular reference. + del self._before_train, self._after_train + del self._before_step, self._after_step + + def before_step(self): + if self._before_step: + self._before_step(self.trainer) + + def after_step(self): + if self._after_step: + self._after_step(self.trainer) + + +class IterationTimer(HookBase): + """ + Track the time spent for each iteration (each run_step call in the trainer). + Print a summary in the end of training. + + This hook uses the time between the call to its :meth:`before_step` + and :meth:`after_step` methods. + Under the convention that :meth:`before_step` of all hooks should only + take negligible amount of time, the :class:`IterationTimer` hook should be + placed at the beginning of the list of hooks to obtain accurate timing. + """ + + def __init__(self, warmup_iter=3): + """ + Args: + warmup_iter (int): the number of iterations at the beginning to exclude + from timing. + """ + self._warmup_iter = warmup_iter + self._step_timer = Timer() + self._start_time = time.perf_counter() + self._total_timer = Timer() + + def before_train(self): + self._start_time = time.perf_counter() + self._total_timer.reset() + self._total_timer.pause() + + def after_train(self): + logger = logging.getLogger(__name__) + total_time = time.perf_counter() - self._start_time + total_time_minus_hooks = self._total_timer.seconds() + hook_time = total_time - total_time_minus_hooks + + num_iter = self.trainer.storage.iter + 1 - self.trainer.start_iter - self._warmup_iter + + if num_iter > 0 and total_time_minus_hooks > 0: + # Speed is meaningful only after warmup + # NOTE this format is parsed by grep in some scripts + logger.info( + "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( + num_iter, + str(datetime.timedelta(seconds=int(total_time_minus_hooks))), + total_time_minus_hooks / num_iter, + ) + ) + + logger.info( + "Total training time: {} ({} on hooks)".format( + str(datetime.timedelta(seconds=int(total_time))), + str(datetime.timedelta(seconds=int(hook_time))), + ) + ) + + def before_step(self): + self._step_timer.reset() + self._total_timer.resume() + + def after_step(self): + # +1 because we're in after_step, the current step is done + # but not yet counted + iter_done = self.trainer.storage.iter - self.trainer.start_iter + 1 + if iter_done >= self._warmup_iter: + sec = self._step_timer.seconds() + self.trainer.storage.put_scalars(time=sec) + else: + self._start_time = time.perf_counter() + self._total_timer.reset() + + self._total_timer.pause() + + +class PeriodicWriter(HookBase): + """ + Write events to EventStorage (by calling ``writer.write()``) periodically. + + It is executed every ``period`` iterations and after the last iteration. + Note that ``period`` does not affect how data is smoothed by each writer. + """ + + def __init__(self, writers, period=20): + """ + Args: + writers (list[EventWriter]): a list of EventWriter objects + period (int): + """ + self._writers = writers + for w in writers: + assert isinstance(w, EventWriter), w + self._period = period + + def after_step(self): + if (self.trainer.iter + 1) % self._period == 0 or ( + self.trainer.iter == self.trainer.max_iter - 1 + ): + for writer in self._writers: + writer.write() + + def after_train(self): + for writer in self._writers: + # If any new data is found (e.g. produced by other after_train), + # write them before closing + writer.write() + writer.close() + + +class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): + """ + Same as :class:`detectron2.checkpoint.PeriodicCheckpointer`, but as a hook. + + Note that when used as a hook, + it is unable to save additional data other than what's defined + by the given `checkpointer`. + + It is executed every ``period`` iterations and after the last iteration. + """ + + def before_train(self): + self.max_iter = self.trainer.max_iter + + def after_step(self): + # No way to use **kwargs + self.step(self.trainer.iter) + + +class BestCheckpointer(HookBase): + """ + Checkpoints best weights based off given metric. + + This hook should be used in conjunction to and executed after the hook + that produces the metric, e.g. `EvalHook`. + """ + + def __init__( + self, + eval_period: int, + checkpointer: Checkpointer, + val_metric: str, + mode: str = "max", + file_prefix: str = "model_best", + ) -> None: + """ + Args: + eval_period (int): the period `EvalHook` is set to run. + checkpointer: the checkpointer object used to save checkpoints. + val_metric (str): validation metric to track for best checkpoint, e.g. "bbox/AP50" + mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be + maximized or minimized, e.g. for "bbox/AP50" it should be "max" + file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" + """ + self._logger = logging.getLogger(__name__) + self._period = eval_period + self._val_metric = val_metric + assert mode in [ + "max", + "min", + ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' + if mode == "max": + self._compare = operator.gt + else: + self._compare = operator.lt + self._checkpointer = checkpointer + self._file_prefix = file_prefix + self.best_metric = None + self.best_iter = None + + def _update_best(self, val, iteration): + if math.isnan(val) or math.isinf(val): + return False + self.best_metric = val + self.best_iter = iteration + return True + + def _best_checking(self): + metric_tuple = self.trainer.storage.latest().get(self._val_metric) + if metric_tuple is None: + self._logger.warning( + f"Given val metric {self._val_metric} does not seem to be computed/stored." + "Will not be checkpointing based on it." + ) + return + else: + latest_metric, metric_iter = metric_tuple + + if self.best_metric is None: + if self._update_best(latest_metric, metric_iter): + additional_state = {"iteration": metric_iter} + self._checkpointer.save(f"{self._file_prefix}", **additional_state) + self._logger.info( + f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" + ) + elif self._compare(latest_metric, self.best_metric): + additional_state = {"iteration": metric_iter} + self._checkpointer.save(f"{self._file_prefix}", **additional_state) + self._logger.info( + f"Saved best model as latest eval score for {self._val_metric} is " + f"{latest_metric:0.5f}, better than last best score " + f"{self.best_metric:0.5f} @ iteration {self.best_iter}." + ) + self._update_best(latest_metric, metric_iter) + else: + self._logger.info( + f"Not saving as latest eval score for {self._val_metric} is {latest_metric:0.5f}, " + f"not better than best score {self.best_metric:0.5f} @ iteration {self.best_iter}." + ) + + def after_step(self): + # same conditions as `EvalHook` + next_iter = self.trainer.iter + 1 + if ( + self._period > 0 + and next_iter % self._period == 0 + and next_iter != self.trainer.max_iter + ): + self._best_checking() + + def after_train(self): + # same conditions as `EvalHook` + if self.trainer.iter + 1 >= self.trainer.max_iter: + self._best_checking() + + +class LRScheduler(HookBase): + """ + A hook which executes a torch builtin LR scheduler and summarizes the LR. + It is executed after every iteration. + """ + + def __init__(self, optimizer=None, scheduler=None): + """ + Args: + optimizer (torch.optim.Optimizer): + scheduler (torch.optim.LRScheduler or fvcore.common.param_scheduler.ParamScheduler): + if a :class:`ParamScheduler` object, it defines the multiplier over the base LR + in the optimizer. + + If any argument is not given, will try to obtain it from the trainer. + """ + self._optimizer = optimizer + self._scheduler = scheduler + + def before_train(self): + self._optimizer = self._optimizer or self.trainer.optimizer + if isinstance(self.scheduler, ParamScheduler): + self._scheduler = LRMultiplier( + self._optimizer, + self.scheduler, + self.trainer.max_iter, + last_iter=self.trainer.iter - 1, + ) + self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer) + + @staticmethod + def get_best_param_group_id(optimizer): + # NOTE: some heuristics on what LR to summarize + # summarize the param group with most parameters + largest_group = max(len(g["params"]) for g in optimizer.param_groups) + + if largest_group == 1: + # If all groups have one parameter, + # then find the most common initial LR, and use it for summary + lr_count = Counter([g["lr"] for g in optimizer.param_groups]) + lr = lr_count.most_common()[0][0] + for i, g in enumerate(optimizer.param_groups): + if g["lr"] == lr: + return i + else: + for i, g in enumerate(optimizer.param_groups): + if len(g["params"]) == largest_group: + return i + + def after_step(self): + lr = self._optimizer.param_groups[self._best_param_group_id]["lr"] + self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) + self.scheduler.step() + + @property + def scheduler(self): + return self._scheduler or self.trainer.scheduler + + def state_dict(self): + if isinstance(self.scheduler, _LRScheduler): + return self.scheduler.state_dict() + return {} + + def load_state_dict(self, state_dict): + if isinstance(self.scheduler, _LRScheduler): + logger = logging.getLogger(__name__) + logger.info("Loading scheduler from state_dict ...") + self.scheduler.load_state_dict(state_dict) + + +class TorchProfiler(HookBase): + """ + A hook which runs `torch.profiler.profile`. + + Examples: + :: + hooks.TorchProfiler( + lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR + ) + + The above example will run the profiler for iteration 10~20 and dump + results to ``OUTPUT_DIR``. We did not profile the first few iterations + because they are typically slower than the rest. + The result files can be loaded in the ``chrome://tracing`` page in chrome browser, + and the tensorboard visualizations can be visualized using + ``tensorboard --logdir OUTPUT_DIR/log`` + """ + + def __init__(self, enable_predicate, output_dir, *, activities=None, save_tensorboard=True): + """ + Args: + enable_predicate (callable[trainer -> bool]): a function which takes a trainer, + and returns whether to enable the profiler. + It will be called once every step, and can be used to select which steps to profile. + output_dir (str): the output directory to dump tracing files. + activities (iterable): same as in `torch.profiler.profile`. + save_tensorboard (bool): whether to save tensorboard visualizations at (output_dir)/log/ + """ + self._enable_predicate = enable_predicate + self._activities = activities + self._output_dir = output_dir + self._save_tensorboard = save_tensorboard + + def before_step(self): + if self._enable_predicate(self.trainer): + if self._save_tensorboard: + on_trace_ready = torch.profiler.tensorboard_trace_handler( + os.path.join( + self._output_dir, + "log", + "profiler-tensorboard-iter{}".format(self.trainer.iter), + ), + f"worker{comm.get_rank()}", + ) + else: + on_trace_ready = None + self._profiler = torch.profiler.profile( + activities=self._activities, + on_trace_ready=on_trace_ready, + record_shapes=True, + profile_memory=True, + with_stack=True, + with_flops=True, + ) + self._profiler.__enter__() + else: + self._profiler = None + + def after_step(self): + if self._profiler is None: + return + self._profiler.__exit__(None, None, None) + if not self._save_tensorboard: + PathManager.mkdirs(self._output_dir) + out_file = os.path.join( + self._output_dir, "profiler-trace-iter{}.json".format(self.trainer.iter) + ) + if "://" not in out_file: + self._profiler.export_chrome_trace(out_file) + else: + # Support non-posix filesystems + with tempfile.TemporaryDirectory(prefix="detectron2_profiler") as d: + tmp_file = os.path.join(d, "tmp.json") + self._profiler.export_chrome_trace(tmp_file) + with open(tmp_file) as f: + content = f.read() + with PathManager.open(out_file, "w") as f: + f.write(content) + + +class AutogradProfiler(TorchProfiler): + """ + A hook which runs `torch.autograd.profiler.profile`. + + Examples: + :: + hooks.AutogradProfiler( + lambda trainer: 10 < trainer.iter < 20, self.cfg.OUTPUT_DIR + ) + + The above example will run the profiler for iteration 10~20 and dump + results to ``OUTPUT_DIR``. We did not profile the first few iterations + because they are typically slower than the rest. + The result files can be loaded in the ``chrome://tracing`` page in chrome browser. + + Note: + When used together with NCCL on older version of GPUs, + autograd profiler may cause deadlock because it unnecessarily allocates + memory on every device it sees. The memory management calls, if + interleaved with NCCL calls, lead to deadlock on GPUs that do not + support ``cudaLaunchCooperativeKernelMultiDevice``. + """ + + def __init__(self, enable_predicate, output_dir, *, use_cuda=True): + """ + Args: + enable_predicate (callable[trainer -> bool]): a function which takes a trainer, + and returns whether to enable the profiler. + It will be called once every step, and can be used to select which steps to profile. + output_dir (str): the output directory to dump tracing files. + use_cuda (bool): same as in `torch.autograd.profiler.profile`. + """ + warnings.warn("AutogradProfiler has been deprecated in favor of TorchProfiler.") + self._enable_predicate = enable_predicate + self._use_cuda = use_cuda + self._output_dir = output_dir + + def before_step(self): + if self._enable_predicate(self.trainer): + self._profiler = torch.autograd.profiler.profile(use_cuda=self._use_cuda) + self._profiler.__enter__() + else: + self._profiler = None + + +class EvalHook(HookBase): + """ + Run an evaluation function periodically, and at the end of training. + + It is executed every ``eval_period`` iterations and after the last iteration. + """ + + def __init__(self, eval_period, eval_function, eval_after_train=True): + """ + Args: + eval_period (int): the period to run `eval_function`. Set to 0 to + not evaluate periodically (but still evaluate after the last iteration + if `eval_after_train` is True). + eval_function (callable): a function which takes no arguments, and + returns a nested dict of evaluation metrics. + eval_after_train (bool): whether to evaluate after the last iteration + + Note: + This hook must be enabled in all or none workers. + If you would like only certain workers to perform evaluation, + give other workers a no-op function (`eval_function=lambda: None`). + """ + self._period = eval_period + self._func = eval_function + self._eval_after_train = eval_after_train + + def _do_eval(self): + results = self._func() + + if results: + assert isinstance( + results, dict + ), "Eval function must return a dict. Got {} instead.".format(results) + + flattened_results = flatten_results_dict(results) + for k, v in flattened_results.items(): + try: + v = float(v) + except Exception as e: + raise ValueError( + "[EvalHook] eval_function should return a nested dict of float. " + "Got '{}: {}' instead.".format(k, v) + ) from e + self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) + + # Evaluation may take different time among workers. + # A barrier make them start the next iteration together. + comm.synchronize() + + def after_step(self): + next_iter = self.trainer.iter + 1 + if self._period > 0 and next_iter % self._period == 0: + # do the last eval in after_train + if next_iter != self.trainer.max_iter: + self._do_eval() + + def after_train(self): + # This condition is to prevent the eval from running after a failed training + if self._eval_after_train and self.trainer.iter + 1 >= self.trainer.max_iter: + self._do_eval() + # func is likely a closure that holds reference to the trainer + # therefore we clean it to avoid circular reference in the end + del self._func + + +class PreciseBN(HookBase): + """ + The standard implementation of BatchNorm uses EMA in inference, which is + sometimes suboptimal. + This class computes the true average of statistics rather than the moving average, + and put true averages to every BN layer in the given model. + + It is executed every ``period`` iterations and after the last iteration. + """ + + def __init__(self, period, model, data_loader, num_iter): + """ + Args: + period (int): the period this hook is run, or 0 to not run during training. + The hook will always run in the end of training. + model (nn.Module): a module whose all BN layers in training mode will be + updated by precise BN. + Note that user is responsible for ensuring the BN layers to be + updated are in training mode when this hook is triggered. + data_loader (iterable): it will produce data to be run by `model(data)`. + num_iter (int): number of iterations used to compute the precise + statistics. + """ + self._logger = logging.getLogger(__name__) + if len(get_bn_modules(model)) == 0: + self._logger.info( + "PreciseBN is disabled because model does not contain BN layers in training mode." + ) + self._disabled = True + return + + self._model = model + self._data_loader = data_loader + self._num_iter = num_iter + self._period = period + self._disabled = False + + self._data_iter = None + + def after_step(self): + next_iter = self.trainer.iter + 1 + is_final = next_iter == self.trainer.max_iter + if is_final or (self._period > 0 and next_iter % self._period == 0): + self.update_stats() + + def update_stats(self): + """ + Update the model with precise statistics. Users can manually call this method. + """ + if self._disabled: + return + + if self._data_iter is None: + self._data_iter = iter(self._data_loader) + + def data_loader(): + for num_iter in itertools.count(1): + if num_iter % 100 == 0: + self._logger.info( + "Running precise-BN ... {}/{} iterations.".format(num_iter, self._num_iter) + ) + # This way we can reuse the same iterator + yield next(self._data_iter) + + with EventStorage(): # capture events in a new storage to discard them + self._logger.info( + "Running precise-BN for {} iterations... ".format(self._num_iter) + + "Note that this could produce different statistics every time." + ) + update_bn_stats(self._model, data_loader(), self._num_iter) + + +class TorchMemoryStats(HookBase): + """ + Writes pytorch's cuda memory statistics periodically. + """ + + def __init__(self, period=20, max_runs=10): + """ + Args: + period (int): Output stats each 'period' iterations + max_runs (int): Stop the logging after 'max_runs' + """ + + self._logger = logging.getLogger(__name__) + self._period = period + self._max_runs = max_runs + self._runs = 0 + + def after_step(self): + if self._runs > self._max_runs: + return + + if (self.trainer.iter + 1) % self._period == 0 or ( + self.trainer.iter == self.trainer.max_iter - 1 + ): + if torch.cuda.is_available(): + max_reserved_mb = torch.cuda.max_memory_reserved() / 1024.0 / 1024.0 + reserved_mb = torch.cuda.memory_reserved() / 1024.0 / 1024.0 + max_allocated_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + allocated_mb = torch.cuda.memory_allocated() / 1024.0 / 1024.0 + + self._logger.info( + ( + " iter: {} " + " max_reserved_mem: {:.0f}MB " + " reserved_mem: {:.0f}MB " + " max_allocated_mem: {:.0f}MB " + " allocated_mem: {:.0f}MB " + ).format( + self.trainer.iter, + max_reserved_mb, + reserved_mb, + max_allocated_mb, + allocated_mb, + ) + ) + + self._runs += 1 + if self._runs == self._max_runs: + mem_summary = torch.cuda.memory_summary() + self._logger.info("\n" + mem_summary) + + torch.cuda.reset_peak_memory_stats() diff --git a/custom_detectron2/engine/launch.py b/custom_detectron2/engine/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..91ce1305187778143b0f2f6e487bfadd2700fffd --- /dev/null +++ b/custom_detectron2/engine/launch.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from datetime import timedelta +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from custom_detectron2.utils import comm + +__all__ = ["DEFAULT_TIMEOUT", "launch"] + +DEFAULT_TIMEOUT = timedelta(minutes=30) + + +def _find_free_port(): + import socket + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # Binding to port 0 will cause the OS to find an available port for us + sock.bind(("", 0)) + port = sock.getsockname()[1] + sock.close() + # NOTE: there is still a chance the port could be taken by other processes. + return port + + +def launch( + main_func, + # Should be num_processes_per_machine, but kept for compatibility. + num_gpus_per_machine, + num_machines=1, + machine_rank=0, + dist_url=None, + args=(), + timeout=DEFAULT_TIMEOUT, +): + """ + Launch multi-process or distributed training. + This function must be called on all machines involved in the training. + It will spawn child processes (defined by ``num_gpus_per_machine``) on each machine. + + Args: + main_func: a function that will be called by `main_func(*args)` + num_gpus_per_machine (int): number of processes per machine. When + using GPUs, this should be the number of GPUs. + num_machines (int): the total number of machines + machine_rank (int): the rank of this machine + dist_url (str): url to connect to for distributed jobs, including protocol + e.g. "tcp://127.0.0.1:8686". + Can be set to "auto" to automatically select a free port on localhost + timeout (timedelta): timeout of the distributed workers + args (tuple): arguments passed to main_func + """ + world_size = num_machines * num_gpus_per_machine + if world_size > 1: + # https://github.com/pytorch/pytorch/pull/14391 + # TODO prctl in spawned processes + + if dist_url == "auto": + assert num_machines == 1, "dist_url=auto not supported in multi-machine jobs." + port = _find_free_port() + dist_url = f"tcp://127.0.0.1:{port}" + if num_machines > 1 and dist_url.startswith("file://"): + logger = logging.getLogger(__name__) + logger.warning( + "file:// is not a reliable init_method in multi-machine jobs. Prefer tcp://" + ) + + mp.start_processes( + _distributed_worker, + nprocs=num_gpus_per_machine, + args=( + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + args, + timeout, + ), + daemon=False, + ) + else: + main_func(*args) + + +def _distributed_worker( + local_rank, + main_func, + world_size, + num_gpus_per_machine, + machine_rank, + dist_url, + args, + timeout=DEFAULT_TIMEOUT, +): + has_gpu = torch.cuda.is_available() + if has_gpu: + assert num_gpus_per_machine <= torch.cuda.device_count() + global_rank = machine_rank * num_gpus_per_machine + local_rank + try: + dist.init_process_group( + backend="NCCL" if has_gpu else "GLOO", + init_method=dist_url, + world_size=world_size, + rank=global_rank, + timeout=timeout, + ) + except Exception as e: + logger = logging.getLogger(__name__) + logger.error("Process group URL: {}".format(dist_url)) + raise e + + # Setup the local process group. + comm.create_local_process_group(num_gpus_per_machine) + if has_gpu: + torch.cuda.set_device(local_rank) + + # synchronize is needed here to prevent a possible timeout after calling init_process_group + # See: https://github.com/facebookresearch/maskrcnn-benchmark/issues/172 + comm.synchronize() + + main_func(*args) diff --git a/custom_detectron2/engine/train_loop.py b/custom_detectron2/engine/train_loop.py new file mode 100644 index 0000000000000000000000000000000000000000..066055a99be04d87bde48efdcdbef162bfb27792 --- /dev/null +++ b/custom_detectron2/engine/train_loop.py @@ -0,0 +1,469 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +import time +import weakref +from typing import List, Mapping, Optional +import torch +from torch.nn.parallel import DataParallel, DistributedDataParallel + +import custom_detectron2.utils.comm as comm +from custom_detectron2.utils.events import EventStorage, get_event_storage +from custom_detectron2.utils.logger import _log_api_usage + +__all__ = ["HookBase", "TrainerBase", "SimpleTrainer", "AMPTrainer"] + + +class HookBase: + """ + Base class for hooks that can be registered with :class:`TrainerBase`. + + Each hook can implement 4 methods. The way they are called is demonstrated + in the following snippet: + :: + hook.before_train() + for iter in range(start_iter, max_iter): + hook.before_step() + trainer.run_step() + hook.after_step() + iter += 1 + hook.after_train() + + Notes: + 1. In the hook method, users can access ``self.trainer`` to access more + properties about the context (e.g., model, current iteration, or config + if using :class:`DefaultTrainer`). + + 2. A hook that does something in :meth:`before_step` can often be + implemented equivalently in :meth:`after_step`. + If the hook takes non-trivial time, it is strongly recommended to + implement the hook in :meth:`after_step` instead of :meth:`before_step`. + The convention is that :meth:`before_step` should only take negligible time. + + Following this convention will allow hooks that do care about the difference + between :meth:`before_step` and :meth:`after_step` (e.g., timer) to + function properly. + + """ + + trainer: "TrainerBase" = None + """ + A weak reference to the trainer object. Set by the trainer when the hook is registered. + """ + + def before_train(self): + """ + Called before the first iteration. + """ + pass + + def after_train(self): + """ + Called after the last iteration. + """ + pass + + def before_step(self): + """ + Called before each iteration. + """ + pass + + def after_backward(self): + """ + Called after the backward pass of each iteration. + """ + pass + + def after_step(self): + """ + Called after each iteration. + """ + pass + + def state_dict(self): + """ + Hooks are stateless by default, but can be made checkpointable by + implementing `state_dict` and `load_state_dict`. + """ + return {} + + +class TrainerBase: + """ + Base class for iterative trainer with hooks. + + The only assumption we made here is: the training runs in a loop. + A subclass can implement what the loop is. + We made no assumptions about the existence of dataloader, optimizer, model, etc. + + Attributes: + iter(int): the current iteration. + + start_iter(int): The iteration to start with. + By convention the minimum possible value is 0. + + max_iter(int): The iteration to end training. + + storage(EventStorage): An EventStorage that's opened during the course of training. + """ + + def __init__(self) -> None: + self._hooks: List[HookBase] = [] + self.iter: int = 0 + self.start_iter: int = 0 + self.max_iter: int + self.storage: EventStorage + _log_api_usage("trainer." + self.__class__.__name__) + + def register_hooks(self, hooks: List[Optional[HookBase]]) -> None: + """ + Register hooks to the trainer. The hooks are executed in the order + they are registered. + + Args: + hooks (list[Optional[HookBase]]): list of hooks + """ + hooks = [h for h in hooks if h is not None] + for h in hooks: + assert isinstance(h, HookBase) + # To avoid circular reference, hooks and trainer cannot own each other. + # This normally does not matter, but will cause memory leak if the + # involved objects contain __del__: + # See http://engineering.hearsaysocial.com/2013/06/16/circular-references-in-python/ + h.trainer = weakref.proxy(self) + self._hooks.extend(hooks) + + def train(self, start_iter: int, max_iter: int): + """ + Args: + start_iter, max_iter (int): See docs above + """ + logger = logging.getLogger(__name__) + logger.info("Starting training from iteration {}".format(start_iter)) + + self.iter = self.start_iter = start_iter + self.max_iter = max_iter + + with EventStorage(start_iter) as self.storage: + try: + self.before_train() + for self.iter in range(start_iter, max_iter): + self.before_step() + self.run_step() + self.after_step() + # self.iter == max_iter can be used by `after_train` to + # tell whether the training successfully finished or failed + # due to exceptions. + self.iter += 1 + except Exception: + logger.exception("Exception during training:") + raise + finally: + self.after_train() + + def before_train(self): + for h in self._hooks: + h.before_train() + + def after_train(self): + self.storage.iter = self.iter + for h in self._hooks: + h.after_train() + + def before_step(self): + # Maintain the invariant that storage.iter == trainer.iter + # for the entire execution of each step + self.storage.iter = self.iter + + for h in self._hooks: + h.before_step() + + def after_backward(self): + for h in self._hooks: + h.after_backward() + + def after_step(self): + for h in self._hooks: + h.after_step() + + def run_step(self): + raise NotImplementedError + + def state_dict(self): + ret = {"iteration": self.iter} + hooks_state = {} + for h in self._hooks: + sd = h.state_dict() + if sd: + name = type(h).__qualname__ + if name in hooks_state: + # TODO handle repetitive stateful hooks + continue + hooks_state[name] = sd + if hooks_state: + ret["hooks"] = hooks_state + return ret + + def load_state_dict(self, state_dict): + logger = logging.getLogger(__name__) + self.iter = state_dict["iteration"] + for key, value in state_dict.get("hooks", {}).items(): + for h in self._hooks: + try: + name = type(h).__qualname__ + except AttributeError: + continue + if name == key: + h.load_state_dict(value) + break + else: + logger.warning(f"Cannot find the hook '{key}', its state_dict is ignored.") + + +class SimpleTrainer(TrainerBase): + """ + A simple trainer for the most common type of task: + single-cost single-optimizer single-data-source iterative optimization, + optionally using data-parallelism. + It assumes that every step, you: + + 1. Compute the loss with a data from the data_loader. + 2. Compute the gradients with the above loss. + 3. Update the model with the optimizer. + + All other tasks during training (checkpointing, logging, evaluation, LR schedule) + are maintained by hooks, which can be registered by :meth:`TrainerBase.register_hooks`. + + If you want to do anything fancier than this, + either subclass TrainerBase and implement your own `run_step`, + or write your own training loop. + """ + + def __init__(self, model, data_loader, optimizer, gather_metric_period=1): + """ + Args: + model: a torch Module. Takes a data from data_loader and returns a + dict of losses. + data_loader: an iterable. Contains data to be used to call model. + optimizer: a torch optimizer. + gather_metric_period: an int. Every gather_metric_period iterations + the metrics are gathered from all the ranks to rank 0 and logged. + """ + super().__init__() + + """ + We set the model to training mode in the trainer. + However it's valid to train a model that's in eval mode. + If you want your model (or a submodule of it) to behave + like evaluation during training, you can overwrite its train() method. + """ + model.train() + + self.model = model + self.data_loader = data_loader + # to access the data loader iterator, call `self._data_loader_iter` + self._data_loader_iter_obj = None + self.optimizer = optimizer + self.gather_metric_period = gather_metric_period + + def run_step(self): + """ + Implement the standard training logic described above. + """ + assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" + start = time.perf_counter() + """ + If you want to do something with the data, you can wrap the dataloader. + """ + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + """ + If you want to do something with the losses, you can wrap the model. + """ + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + """ + If you need to accumulate gradients or do something similar, you can + wrap the optimizer with your custom `zero_grad()` method. + """ + self.optimizer.zero_grad() + losses.backward() + + self.after_backward() + + self._write_metrics(loss_dict, data_time) + + """ + If you need gradient clipping/scaling or other processing, you can + wrap the optimizer with your custom `step()` method. But it is + suboptimal as explained in https://arxiv.org/abs/2006.15704 Sec 3.2.4 + """ + self.optimizer.step() + + @property + def _data_loader_iter(self): + # only create the data loader iterator when it is used + if self._data_loader_iter_obj is None: + self._data_loader_iter_obj = iter(self.data_loader) + return self._data_loader_iter_obj + + def reset_data_loader(self, data_loader_builder): + """ + Delete and replace the current data loader with a new one, which will be created + by calling `data_loader_builder` (without argument). + """ + del self.data_loader + data_loader = data_loader_builder() + self.data_loader = data_loader + self._data_loader_iter_obj = None + + def _write_metrics( + self, + loss_dict: Mapping[str, torch.Tensor], + data_time: float, + prefix: str = "", + ) -> None: + if (self.iter + 1) % self.gather_metric_period == 0: + SimpleTrainer.write_metrics(loss_dict, data_time, prefix) + + @staticmethod + def write_metrics( + loss_dict: Mapping[str, torch.Tensor], + data_time: float, + prefix: str = "", + ) -> None: + """ + Args: + loss_dict (dict): dict of scalar losses + data_time (float): time taken by the dataloader iteration + prefix (str): prefix for logging keys + """ + metrics_dict = {k: v.detach().cpu().item() for k, v in loss_dict.items()} + metrics_dict["data_time"] = data_time + + # Gather metrics among all workers for logging + # This assumes we do DDP-style training, which is currently the only + # supported method in detectron2. + all_metrics_dict = comm.gather(metrics_dict) + + if comm.is_main_process(): + storage = get_event_storage() + + # data_time among workers can have high variance. The actual latency + # caused by data_time is the maximum among workers. + data_time = np.max([x.pop("data_time") for x in all_metrics_dict]) + storage.put_scalar("data_time", data_time) + + # average the rest metrics + metrics_dict = { + k: np.mean([x[k] for x in all_metrics_dict]) for k in all_metrics_dict[0].keys() + } + total_losses_reduced = sum(metrics_dict.values()) + if not np.isfinite(total_losses_reduced): + raise FloatingPointError( + f"Loss became infinite or NaN at iteration={storage.iter}!\n" + f"loss_dict = {metrics_dict}" + ) + + storage.put_scalar("{}total_loss".format(prefix), total_losses_reduced) + if len(metrics_dict) > 1: + storage.put_scalars(**metrics_dict) + + def state_dict(self): + ret = super().state_dict() + ret["optimizer"] = self.optimizer.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.optimizer.load_state_dict(state_dict["optimizer"]) + + +class AMPTrainer(SimpleTrainer): + """ + Like :class:`SimpleTrainer`, but uses PyTorch's native automatic mixed precision + in the training loop. + """ + + def __init__( + self, + model, + data_loader, + optimizer, + gather_metric_period=1, + grad_scaler=None, + precision: torch.dtype = torch.float16, + log_grad_scaler: bool = False, + ): + """ + Args: + model, data_loader, optimizer, gather_metric_period: same as in :class:`SimpleTrainer`. + grad_scaler: torch GradScaler to automatically scale gradients. + precision: torch.dtype as the target precision to cast to in computations + """ + unsupported = "AMPTrainer does not support single-process multi-device training!" + if isinstance(model, DistributedDataParallel): + assert not (model.device_ids and len(model.device_ids) > 1), unsupported + assert not isinstance(model, DataParallel), unsupported + + super().__init__(model, data_loader, optimizer, gather_metric_period) + + if grad_scaler is None: + from torch.cuda.amp import GradScaler + + grad_scaler = GradScaler() + self.grad_scaler = grad_scaler + self.precision = precision + self.log_grad_scaler = log_grad_scaler + + def run_step(self): + """ + Implement the AMP training logic. + """ + assert self.model.training, "[AMPTrainer] model was changed to eval mode!" + assert torch.cuda.is_available(), "[AMPTrainer] CUDA is required for AMP training!" + from torch.cuda.amp import autocast + + start = time.perf_counter() + data = next(self._data_loader_iter) + data_time = time.perf_counter() - start + + with autocast(dtype=self.precision): + loss_dict = self.model(data) + if isinstance(loss_dict, torch.Tensor): + losses = loss_dict + loss_dict = {"total_loss": loss_dict} + else: + losses = sum(loss_dict.values()) + + self.optimizer.zero_grad() + self.grad_scaler.scale(losses).backward() + + if self.log_grad_scaler: + storage = get_event_storage() + storage.put_scalar("[metric]grad_scaler", self.grad_scaler.get_scale()) + + self.after_backward() + + self._write_metrics(loss_dict, data_time) + + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + + def state_dict(self): + ret = super().state_dict() + ret["grad_scaler"] = self.grad_scaler.state_dict() + return ret + + def load_state_dict(self, state_dict): + super().load_state_dict(state_dict) + self.grad_scaler.load_state_dict(state_dict["grad_scaler"]) diff --git a/custom_detectron2/evaluation/__init__.py b/custom_detectron2/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d96609e8f2261a6800fe85fcf3e1eaeaa44455c6 --- /dev/null +++ b/custom_detectron2/evaluation/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .cityscapes_evaluation import CityscapesInstanceEvaluator, CityscapesSemSegEvaluator +from .coco_evaluation import COCOEvaluator +from .rotated_coco_evaluation import RotatedCOCOEvaluator +from .evaluator import DatasetEvaluator, DatasetEvaluators, inference_context, inference_on_dataset +from .lvis_evaluation import LVISEvaluator +from .panoptic_evaluation import COCOPanopticEvaluator +from .pascal_voc_evaluation import PascalVOCDetectionEvaluator +from .sem_seg_evaluation import SemSegEvaluator +from .testing import print_csv_format, verify_results + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/evaluation/cityscapes_evaluation.py b/custom_detectron2/evaluation/cityscapes_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..881aed078e2e1322dbd48e2006785888652441e4 --- /dev/null +++ b/custom_detectron2/evaluation/cityscapes_evaluation.py @@ -0,0 +1,197 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import glob +import logging +import numpy as np +import os +import tempfile +from collections import OrderedDict +import torch +from PIL import Image + +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.utils import comm +from custom_detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + + +class CityscapesEvaluator(DatasetEvaluator): + """ + Base class for evaluation using cityscapes API. + """ + + def __init__(self, dataset_name): + """ + Args: + dataset_name (str): the name of the dataset. + It must have the following metadata associated with it: + "thing_classes", "gt_dir". + """ + self._metadata = MetadataCatalog.get(dataset_name) + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + def reset(self): + self._working_dir = tempfile.TemporaryDirectory(prefix="cityscapes_eval_") + self._temp_dir = self._working_dir.name + # All workers will write to the same results directory + # TODO this does not work in distributed training + assert ( + comm.get_local_size() == comm.get_world_size() + ), "CityscapesEvaluator currently do not work with multiple machines." + self._temp_dir = comm.all_gather(self._temp_dir)[0] + if self._temp_dir != self._working_dir.name: + self._working_dir.cleanup() + self._logger.info( + "Writing cityscapes results to temporary directory {} ...".format(self._temp_dir) + ) + + +class CityscapesInstanceEvaluator(CityscapesEvaluator): + """ + Evaluate instance segmentation results on cityscapes dataset using cityscapes API. + + Note: + * It does not work in multi-machine distributed training. + * It contains a synchronization, therefore has to be used on all ranks. + * Only the main process runs evaluation. + """ + + def process(self, inputs, outputs): + from cityscapesscripts.helpers.labels import name2label + + for input, output in zip(inputs, outputs): + file_name = input["file_name"] + basename = os.path.splitext(os.path.basename(file_name))[0] + pred_txt = os.path.join(self._temp_dir, basename + "_pred.txt") + + if "instances" in output: + output = output["instances"].to(self._cpu_device) + num_instances = len(output) + with open(pred_txt, "w") as fout: + for i in range(num_instances): + pred_class = output.pred_classes[i] + classes = self._metadata.thing_classes[pred_class] + class_id = name2label[classes].id + score = output.scores[i] + mask = output.pred_masks[i].numpy().astype("uint8") + png_filename = os.path.join( + self._temp_dir, basename + "_{}_{}.png".format(i, classes) + ) + + Image.fromarray(mask * 255).save(png_filename) + fout.write( + "{} {} {}\n".format(os.path.basename(png_filename), class_id, score) + ) + else: + # Cityscapes requires a prediction file for every ground truth image. + with open(pred_txt, "w") as fout: + pass + + def evaluate(self): + """ + Returns: + dict: has a key "segm", whose value is a dict of "AP" and "AP50". + """ + comm.synchronize() + if comm.get_rank() > 0: + return + import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as cityscapes_eval + + self._logger.info("Evaluating results under {} ...".format(self._temp_dir)) + + # set some global states in cityscapes evaluation API, before evaluating + cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir) + cityscapes_eval.args.predictionWalk = None + cityscapes_eval.args.JSONOutput = False + cityscapes_eval.args.colorized = False + cityscapes_eval.args.gtInstancesFile = os.path.join(self._temp_dir, "gtInstances.json") + + # These lines are adopted from + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa + gt_dir = PathManager.get_local_path(self._metadata.gt_dir) + groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_instanceIds.png")) + assert len( + groundTruthImgList + ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format( + cityscapes_eval.args.groundTruthSearch + ) + predictionImgList = [] + for gt in groundTruthImgList: + predictionImgList.append(cityscapes_eval.getPrediction(gt, cityscapes_eval.args)) + results = cityscapes_eval.evaluateImgLists( + predictionImgList, groundTruthImgList, cityscapes_eval.args + )["averages"] + + ret = OrderedDict() + ret["segm"] = {"AP": results["allAp"] * 100, "AP50": results["allAp50%"] * 100} + self._working_dir.cleanup() + return ret + + +class CityscapesSemSegEvaluator(CityscapesEvaluator): + """ + Evaluate semantic segmentation results on cityscapes dataset using cityscapes API. + + Note: + * It does not work in multi-machine distributed training. + * It contains a synchronization, therefore has to be used on all ranks. + * Only the main process runs evaluation. + """ + + def process(self, inputs, outputs): + from cityscapesscripts.helpers.labels import trainId2label + + for input, output in zip(inputs, outputs): + file_name = input["file_name"] + basename = os.path.splitext(os.path.basename(file_name))[0] + pred_filename = os.path.join(self._temp_dir, basename + "_pred.png") + + output = output["sem_seg"].argmax(dim=0).to(self._cpu_device).numpy() + pred = 255 * np.ones(output.shape, dtype=np.uint8) + for train_id, label in trainId2label.items(): + if label.ignoreInEval: + continue + pred[output == train_id] = label.id + Image.fromarray(pred).save(pred_filename) + + def evaluate(self): + comm.synchronize() + if comm.get_rank() > 0: + return + # Load the Cityscapes eval script *after* setting the required env var, + # since the script reads CITYSCAPES_DATASET into global variables at load time. + import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as cityscapes_eval + + self._logger.info("Evaluating results under {} ...".format(self._temp_dir)) + + # set some global states in cityscapes evaluation API, before evaluating + cityscapes_eval.args.predictionPath = os.path.abspath(self._temp_dir) + cityscapes_eval.args.predictionWalk = None + cityscapes_eval.args.JSONOutput = False + cityscapes_eval.args.colorized = False + + # These lines are adopted from + # https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalPixelLevelSemanticLabeling.py # noqa + gt_dir = PathManager.get_local_path(self._metadata.gt_dir) + groundTruthImgList = glob.glob(os.path.join(gt_dir, "*", "*_gtFine_labelIds.png")) + assert len( + groundTruthImgList + ), "Cannot find any ground truth images to use for evaluation. Searched for: {}".format( + cityscapes_eval.args.groundTruthSearch + ) + predictionImgList = [] + for gt in groundTruthImgList: + predictionImgList.append(cityscapes_eval.getPrediction(cityscapes_eval.args, gt)) + results = cityscapes_eval.evaluateImgLists( + predictionImgList, groundTruthImgList, cityscapes_eval.args + ) + ret = OrderedDict() + ret["sem_seg"] = { + "IoU": 100.0 * results["averageScoreClasses"], + "iIoU": 100.0 * results["averageScoreInstClasses"], + "IoU_sup": 100.0 * results["averageScoreCategories"], + "iIoU_sup": 100.0 * results["averageScoreInstCategories"], + } + self._working_dir.cleanup() + return ret diff --git a/custom_detectron2/evaluation/coco_evaluation.py b/custom_detectron2/evaluation/coco_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..9d651a9fe2eb90d0c6a682ab1a832debedacaf12 --- /dev/null +++ b/custom_detectron2/evaluation/coco_evaluation.py @@ -0,0 +1,722 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import copy +import io +import itertools +import json +import logging +import numpy as np +import os +import pickle +from collections import OrderedDict +import custom_pycocotools.mask as mask_util +import torch +from custom_pycocotools.coco import COCO +from custom_pycocotools.cocoeval import COCOeval +from tabulate import tabulate + +import custom_detectron2.utils.comm as comm +from custom_detectron2.config import CfgNode +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.data.datasets.coco import convert_to_coco_json +from custom_detectron2.structures import Boxes, BoxMode, pairwise_iou +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.logger import create_small_table + +from .evaluator import DatasetEvaluator + +try: + from custom_detectron2.evaluation.fast_eval_api import COCOeval_opt +except ImportError: + COCOeval_opt = COCOeval + + +class COCOEvaluator(DatasetEvaluator): + """ + Evaluate AR for object proposals, AP for instance detection/segmentation, AP + for keypoint detection outputs using COCO's metrics. + See http://cocodataset.org/#detection-eval and + http://cocodataset.org/#keypoints-eval to understand its metrics. + The metrics range from 0 to 100 (instead of 0 to 1), where a -1 or NaN means + the metric cannot be computed (e.g. due to no predictions made). + + In addition to COCO, this evaluator is able to support any bounding box detection, + instance segmentation, or keypoint detection dataset. + """ + + def __init__( + self, + dataset_name, + tasks=None, + distributed=True, + output_dir=None, + *, + max_dets_per_image=None, + use_fast_impl=True, + kpt_oks_sigmas=(), + allow_cached_coco=True, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + It must have either the following corresponding metadata: + + "json_file": the path to the COCO format annotation + + Or it must be in detectron2's standard dataset format + so it can be converted to COCO format automatically. + tasks (tuple[str]): tasks that can be evaluated under the given + configuration. A task is one of "bbox", "segm", "keypoints". + By default, will infer this automatically from predictions. + distributed (True): if True, will collect results from all ranks and run evaluation + in the main process. + Otherwise, will only evaluate the results in the current process. + output_dir (str): optional, an output directory to dump all + results predicted on the dataset. The dump contains two files: + + 1. "instances_predictions.pth" a file that can be loaded with `torch.load` and + contains all the results in the format they are produced by the model. + 2. "coco_instances_results.json" a json file in COCO's result format. + max_dets_per_image (int): limit on the maximum number of detections per image. + By default in COCO, this limit is to 100, but this can be customized + to be greater, as is needed in evaluation metrics AP fixed and AP pool + (see https://arxiv.org/pdf/2102.01066.pdf) + This doesn't affect keypoint evaluation. + use_fast_impl (bool): use a fast but **unofficial** implementation to compute AP. + Although the results should be very close to the official implementation in COCO + API, it is still recommended to compute results with the official API for use in + papers. The faster implementation also uses more RAM. + kpt_oks_sigmas (list[float]): The sigmas used to calculate keypoint OKS. + See http://cocodataset.org/#keypoints-eval + When empty, it will use the defaults in COCO. + Otherwise it should be the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS. + allow_cached_coco (bool): Whether to use cached coco json from previous validation + runs. You should set this to False if you need to use different validation data. + Defaults to True. + """ + self._logger = logging.getLogger(__name__) + self._distributed = distributed + self._output_dir = output_dir + + if use_fast_impl and (COCOeval_opt is COCOeval): + self._logger.info("Fast COCO eval is not built. Falling back to official COCO eval.") + use_fast_impl = False + self._use_fast_impl = use_fast_impl + + # COCOeval requires the limit on the number of detections per image (maxDets) to be a list + # with at least 3 elements. The default maxDets in COCOeval is [1, 10, 100], in which the + # 3rd element (100) is used as the limit on the number of detections per image when + # evaluating AP. COCOEvaluator expects an integer for max_dets_per_image, so for COCOeval, + # we reformat max_dets_per_image into [1, 10, max_dets_per_image], based on the defaults. + if max_dets_per_image is None: + max_dets_per_image = [1, 10, 100] + else: + max_dets_per_image = [1, 10, max_dets_per_image] + self._max_dets_per_image = max_dets_per_image + + if tasks is not None and isinstance(tasks, CfgNode): + kpt_oks_sigmas = ( + tasks.TEST.KEYPOINT_OKS_SIGMAS if not kpt_oks_sigmas else kpt_oks_sigmas + ) + self._logger.warn( + "COCO Evaluator instantiated using config, this is deprecated behavior." + " Please pass in explicit arguments instead." + ) + self._tasks = None # Infering it from predictions should be better + else: + self._tasks = tasks + + self._cpu_device = torch.device("cpu") + + self._metadata = MetadataCatalog.get(dataset_name) + if not hasattr(self._metadata, "json_file"): + if output_dir is None: + raise ValueError( + "output_dir must be provided to COCOEvaluator " + "for datasets not in COCO format." + ) + self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...") + + cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json") + self._metadata.json_file = cache_path + convert_to_coco_json(dataset_name, cache_path, allow_cached=allow_cached_coco) + + json_file = PathManager.get_local_path(self._metadata.json_file) + with contextlib.redirect_stdout(io.StringIO()): + self._coco_api = COCO(json_file) + + # Test set json files do not contain annotations (evaluation must be + # performed using the COCO evaluation server). + self._do_evaluation = "annotations" in self._coco_api.dataset + if self._do_evaluation: + self._kpt_oks_sigmas = kpt_oks_sigmas + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + if len(prediction) > 1: + self._predictions.append(prediction) + + def evaluate(self, img_ids=None): + """ + Args: + img_ids: a list of image IDs to evaluate on. Default to None for the whole dataset + """ + if self._distributed: + comm.synchronize() + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) + + if not comm.is_main_process(): + return {} + else: + predictions = self._predictions + + if len(predictions) == 0: + self._logger.warning("[COCOEvaluator] Did not receive valid predictions.") + return {} + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "instances_predictions.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(predictions, f) + + self._results = OrderedDict() + if "proposals" in predictions[0]: + self._eval_box_proposals(predictions) + if "instances" in predictions[0]: + self._eval_predictions(predictions, img_ids=img_ids) + # Copy so the caller can do whatever with results + return copy.deepcopy(self._results) + + def _tasks_from_predictions(self, predictions): + """ + Get COCO API "tasks" (i.e. iou_type) from COCO-format predictions. + """ + tasks = {"bbox"} + for pred in predictions: + if "segmentation" in pred: + tasks.add("segm") + if "keypoints" in pred: + tasks.add("keypoints") + return sorted(tasks) + + def _eval_predictions(self, predictions, img_ids=None): + """ + Evaluate predictions. Fill self._results with the metrics of the tasks. + """ + self._logger.info("Preparing results for COCO format ...") + coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) + tasks = self._tasks or self._tasks_from_predictions(coco_results) + + # unmap the category ids for COCO + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + dataset_id_to_contiguous_id = self._metadata.thing_dataset_id_to_contiguous_id + all_contiguous_ids = list(dataset_id_to_contiguous_id.values()) + num_classes = len(all_contiguous_ids) + assert min(all_contiguous_ids) == 0 and max(all_contiguous_ids) == num_classes - 1 + + reverse_id_mapping = {v: k for k, v in dataset_id_to_contiguous_id.items()} + for result in coco_results: + category_id = result["category_id"] + assert category_id < num_classes, ( + f"A prediction has class={category_id}, " + f"but the dataset only has {num_classes} classes and " + f"predicted class id should be in [0, {num_classes - 1}]." + ) + result["category_id"] = reverse_id_mapping[category_id] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "coco_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(coco_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info( + "Evaluating predictions with {} COCO API...".format( + "unofficial" if self._use_fast_impl else "official" + ) + ) + for task in sorted(tasks): + assert task in {"bbox", "segm", "keypoints"}, f"Got unknown task: {task}!" + coco_eval = ( + _evaluate_predictions_on_coco( + self._coco_api, + coco_results, + task, + kpt_oks_sigmas=self._kpt_oks_sigmas, + cocoeval_fn=COCOeval_opt if self._use_fast_impl else COCOeval, + img_ids=img_ids, + max_dets_per_image=self._max_dets_per_image, + ) + if len(coco_results) > 0 + else None # cocoapi does not handle empty results very well + ) + + res = self._derive_coco_results( + coco_eval, task, class_names=self._metadata.get("thing_classes") + ) + self._results[task] = res + + def _eval_box_proposals(self, predictions): + """ + Evaluate the box proposals in predictions. + Fill self._results with the metrics for "box_proposals" task. + """ + if self._output_dir: + # Saving generated box proposals to file. + # Predicted box_proposals are in XYXY_ABS mode. + bbox_mode = BoxMode.XYXY_ABS.value + ids, boxes, objectness_logits = [], [], [] + for prediction in predictions: + ids.append(prediction["image_id"]) + boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy()) + objectness_logits.append(prediction["proposals"].objectness_logits.numpy()) + + proposal_data = { + "boxes": boxes, + "objectness_logits": objectness_logits, + "ids": ids, + "bbox_mode": bbox_mode, + } + with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f: + pickle.dump(proposal_data, f) + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating bbox proposals ...") + res = {} + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = _evaluate_box_proposals(predictions, self._coco_api, area=area, limit=limit) + key = "AR{}@{:d}".format(suffix, limit) + res[key] = float(stats["ar"].item() * 100) + self._logger.info("Proposal metrics: \n" + create_small_table(res)) + self._results["box_proposals"] = res + + def _derive_coco_results(self, coco_eval, iou_type, class_names=None): + """ + Derive the desired score numbers from summarized COCOeval. + + Args: + coco_eval (None or COCOEval): None represents no predictions from model. + iou_type (str): + class_names (None or list[str]): if provided, will use it to predict + per-category AP. + + Returns: + a dict of {metric name: score} + """ + + metrics = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl"], + "keypoints": ["AP", "AP50", "AP75", "APm", "APl"], + }[iou_type] + + if coco_eval is None: + self._logger.warn("No predictions from the model!") + return {metric: float("nan") for metric in metrics} + + # the standard metrics + results = { + metric: float(coco_eval.stats[idx] * 100 if coco_eval.stats[idx] >= 0 else "nan") + for idx, metric in enumerate(metrics) + } + self._logger.info( + "Evaluation results for {}: \n".format(iou_type) + create_small_table(results) + ) + if not np.isfinite(sum(results.values())): + self._logger.info("Some metrics cannot be computed and is shown as NaN.") + + if class_names is None or len(class_names) <= 1: + return results + # Compute per-category AP + # from https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L222-L252 # noqa + precisions = coco_eval.eval["precision"] + # precision has dims (iou, recall, cls, area range, max dets) + assert len(class_names) == precisions.shape[2] + + results_per_category = [] + for idx, name in enumerate(class_names): + # area range index 0: all area ranges + # max dets index -1: typically 100 per image + precision = precisions[:, :, idx, 0, -1] + precision = precision[precision > -1] + ap = np.mean(precision) if precision.size else float("nan") + results_per_category.append(("{}".format(name), float(ap * 100))) + + # tabulate it + N_COLS = min(6, len(results_per_category) * 2) + results_flatten = list(itertools.chain(*results_per_category)) + results_2d = itertools.zip_longest(*[results_flatten[i::N_COLS] for i in range(N_COLS)]) + table = tabulate( + results_2d, + tablefmt="pipe", + floatfmt=".3f", + headers=["category", "AP"] * (N_COLS // 2), + numalign="left", + ) + self._logger.info("Per-category {} AP: \n".format(iou_type) + table) + + results.update({"AP-" + name: ap for name, ap in results_per_category}) + return results + + +def instances_to_coco_json(instances, img_id): + """ + Dump an "Instances" object to a COCO-format json that's used for evaluation. + + Args: + instances (Instances): + img_id (int): the image id + + Returns: + list[dict]: list of json annotations in COCO format. + """ + num_instance = len(instances) + if num_instance == 0: + return [] + + boxes = instances.pred_boxes.tensor.numpy() + boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + boxes = boxes.tolist() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + + has_mask = instances.has("pred_masks") + if has_mask: + # use RLE to encode the masks, because they are too large and takes memory + # since this evaluator stores outputs of the entire dataset + rles = [ + mask_util.encode(np.array(mask[:, :, None], order="F", dtype="uint8"))[0] + for mask in instances.pred_masks + ] + for rle in rles: + # "counts" is an array encoded by mask_util as a byte-stream. Python3's + # json writer which always produces strings cannot serialize a bytestream + # unless you decode it. Thankfully, utf-8 works out (which is also what + # the custom_pycocotools/_mask.pyx does). + rle["counts"] = rle["counts"].decode("utf-8") + + has_keypoints = instances.has("pred_keypoints") + if has_keypoints: + keypoints = instances.pred_keypoints + + results = [] + for k in range(num_instance): + result = { + "image_id": img_id, + "category_id": classes[k], + "bbox": boxes[k], + "score": scores[k], + } + if has_mask: + result["segmentation"] = rles[k] + if has_keypoints: + # In COCO annotations, + # keypoints coordinates are pixel indices. + # However our predictions are floating point coordinates. + # Therefore we subtract 0.5 to be consistent with the annotation format. + # This is the inverse of data loading logic in `datasets/coco.py`. + keypoints[k][:, :2] -= 0.5 + result["keypoints"] = keypoints[k].flatten().tolist() + results.append(result) + return results + + +# inspired from Detectron: +# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa +def _evaluate_box_proposals(dataset_predictions, coco_api, thresholds=None, area="all", limit=None): + """ + Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official COCO API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0**2, 1e5**2], # all + [0**2, 32**2], # small + [32**2, 96**2], # medium + [96**2, 1e5**2], # large + [96**2, 128**2], # 96-128 + [128**2, 256**2], # 128-256 + [256**2, 512**2], # 256-512 + [512**2, 1e5**2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for prediction_dict in dataset_predictions: + predictions = prediction_dict["proposals"] + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + inds = predictions.objectness_logits.sort(descending=True)[1] + predictions = predictions[inds] + + ann_ids = coco_api.getAnnIds(imgIds=prediction_dict["image_id"]) + anno = coco_api.loadAnns(ann_ids) + gt_boxes = [ + BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) + for obj in anno + if obj["iscrowd"] == 0 + ] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = Boxes(gt_boxes) + gt_areas = torch.as_tensor([obj["area"] for obj in anno if obj["iscrowd"] == 0]) + + if len(gt_boxes) == 0 or len(predictions) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if limit is not None and len(predictions) > limit: + predictions = predictions[:limit] + + overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(predictions), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = ( + torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32) + ) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def _evaluate_predictions_on_coco( + coco_gt, + coco_results, + iou_type, + kpt_oks_sigmas=None, + cocoeval_fn=COCOeval_opt, + img_ids=None, + max_dets_per_image=None, +): + """ + Evaluate the coco results using COCOEval API. + """ + assert len(coco_results) > 0 + + if iou_type == "segm": + coco_results = copy.deepcopy(coco_results) + # When evaluating mask AP, if the results contain bbox, cocoapi will + # use the box area as the area of the instance, instead of the mask area. + # This leads to a different definition of small/medium/large. + # We remove the bbox field to let mask AP use mask area. + for c in coco_results: + c.pop("bbox", None) + + coco_dt = coco_gt.loadRes(coco_results) + coco_eval = cocoeval_fn(coco_gt, coco_dt, iou_type) + # For COCO, the default max_dets_per_image is [1, 10, 100]. + if max_dets_per_image is None: + max_dets_per_image = [1, 10, 100] # Default from COCOEval + else: + assert ( + len(max_dets_per_image) >= 3 + ), "COCOeval requires maxDets (and max_dets_per_image) to have length at least 3" + # In the case that user supplies a custom input for max_dets_per_image, + # apply COCOevalMaxDets to evaluate AP with the custom input. + if max_dets_per_image[2] != 100: + coco_eval = COCOevalMaxDets(coco_gt, coco_dt, iou_type) + if iou_type != "keypoints": + coco_eval.params.maxDets = max_dets_per_image + + if img_ids is not None: + coco_eval.params.imgIds = img_ids + + if iou_type == "keypoints": + # Use the COCO default keypoint OKS sigmas unless overrides are specified + if kpt_oks_sigmas: + assert hasattr(coco_eval.params, "kpt_oks_sigmas"), "custom_pycocotools is too old!" + coco_eval.params.kpt_oks_sigmas = np.array(kpt_oks_sigmas) + # COCOAPI requires every detection and every gt to have keypoints, so + # we just take the first entry from both + num_keypoints_dt = len(coco_results[0]["keypoints"]) // 3 + num_keypoints_gt = len(next(iter(coco_gt.anns.values()))["keypoints"]) // 3 + num_keypoints_oks = len(coco_eval.params.kpt_oks_sigmas) + assert num_keypoints_oks == num_keypoints_dt == num_keypoints_gt, ( + f"[COCOEvaluator] Prediction contain {num_keypoints_dt} keypoints. " + f"Ground truth contains {num_keypoints_gt} keypoints. " + f"The length of cfg.TEST.KEYPOINT_OKS_SIGMAS is {num_keypoints_oks}. " + "They have to agree with each other. For meaning of OKS, please refer to " + "http://cocodataset.org/#keypoints-eval." + ) + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + return coco_eval + + +class COCOevalMaxDets(COCOeval): + """ + Modified version of COCOeval for evaluating AP with a custom + maxDets (by default for COCO, maxDets is 100) + """ + + def summarize(self): + """ + Compute and display summary metrics for evaluation results given + a custom value for max_dets_per_image + """ + + def _summarize(ap=1, iouThr=None, areaRng="all", maxDets=100): + p = self.params + iStr = " {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}" + titleStr = "Average Precision" if ap == 1 else "Average Recall" + typeStr = "(AP)" if ap == 1 else "(AR)" + iouStr = ( + "{:0.2f}:{:0.2f}".format(p.iouThrs[0], p.iouThrs[-1]) + if iouThr is None + else "{:0.2f}".format(iouThr) + ) + + aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] + mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] + if ap == 1: + # dimension of precision: [TxRxKxAxM] + s = self.eval["precision"] + # IoU + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, :, aind, mind] + else: + # dimension of recall: [TxKxAxM] + s = self.eval["recall"] + if iouThr is not None: + t = np.where(iouThr == p.iouThrs)[0] + s = s[t] + s = s[:, :, aind, mind] + if len(s[s > -1]) == 0: + mean_s = -1 + else: + mean_s = np.mean(s[s > -1]) + print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) + return mean_s + + def _summarizeDets(): + stats = np.zeros((12,)) + # Evaluate AP using the custom limit on maximum detections per image + stats[0] = _summarize(1, maxDets=self.params.maxDets[2]) + stats[1] = _summarize(1, iouThr=0.5, maxDets=self.params.maxDets[2]) + stats[2] = _summarize(1, iouThr=0.75, maxDets=self.params.maxDets[2]) + stats[3] = _summarize(1, areaRng="small", maxDets=self.params.maxDets[2]) + stats[4] = _summarize(1, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[5] = _summarize(1, areaRng="large", maxDets=self.params.maxDets[2]) + stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) + stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) + stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) + stats[9] = _summarize(0, areaRng="small", maxDets=self.params.maxDets[2]) + stats[10] = _summarize(0, areaRng="medium", maxDets=self.params.maxDets[2]) + stats[11] = _summarize(0, areaRng="large", maxDets=self.params.maxDets[2]) + return stats + + def _summarizeKps(): + stats = np.zeros((10,)) + stats[0] = _summarize(1, maxDets=20) + stats[1] = _summarize(1, maxDets=20, iouThr=0.5) + stats[2] = _summarize(1, maxDets=20, iouThr=0.75) + stats[3] = _summarize(1, maxDets=20, areaRng="medium") + stats[4] = _summarize(1, maxDets=20, areaRng="large") + stats[5] = _summarize(0, maxDets=20) + stats[6] = _summarize(0, maxDets=20, iouThr=0.5) + stats[7] = _summarize(0, maxDets=20, iouThr=0.75) + stats[8] = _summarize(0, maxDets=20, areaRng="medium") + stats[9] = _summarize(0, maxDets=20, areaRng="large") + return stats + + if not self.eval: + raise Exception("Please run accumulate() first") + iouType = self.params.iouType + if iouType == "segm" or iouType == "bbox": + summarize = _summarizeDets + elif iouType == "keypoints": + summarize = _summarizeKps + self.stats = summarize() + + def __str__(self): + self.summarize() diff --git a/custom_detectron2/evaluation/evaluator.py b/custom_detectron2/evaluation/evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..6465b00de7c0c3e6e5ca1de05e7284e7c85bcd80 --- /dev/null +++ b/custom_detectron2/evaluation/evaluator.py @@ -0,0 +1,224 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import datetime +import logging +import time +from collections import OrderedDict, abc +from contextlib import ExitStack, contextmanager +from typing import List, Union +import torch +from torch import nn + +from custom_detectron2.utils.comm import get_world_size, is_main_process +from custom_detectron2.utils.logger import log_every_n_seconds + + +class DatasetEvaluator: + """ + Base class for a dataset evaluator. + + The function :func:`inference_on_dataset` runs the model over + all samples in the dataset, and have a DatasetEvaluator to process the inputs/outputs. + + This class will accumulate information of the inputs/outputs (by :meth:`process`), + and produce evaluation results in the end (by :meth:`evaluate`). + """ + + def reset(self): + """ + Preparation for a new round of evaluation. + Should be called before starting a round of evaluation. + """ + pass + + def process(self, inputs, outputs): + """ + Process the pair of inputs and outputs. + If they contain batches, the pairs can be consumed one-by-one using `zip`: + + .. code-block:: python + + for input_, output in zip(inputs, outputs): + # do evaluation on single input/output pair + ... + + Args: + inputs (list): the inputs that's used to call the model. + outputs (list): the return value of `model(inputs)` + """ + pass + + def evaluate(self): + """ + Evaluate/summarize the performance, after processing all input/output pairs. + + Returns: + dict: + A new evaluator class can return a dict of arbitrary format + as long as the user can process the results. + In our train_net.py, we expect the following format: + + * key: the name of the task (e.g., bbox) + * value: a dict of {metric name: score}, e.g.: {"AP50": 80} + """ + pass + + +class DatasetEvaluators(DatasetEvaluator): + """ + Wrapper class to combine multiple :class:`DatasetEvaluator` instances. + + This class dispatches every evaluation call to + all of its :class:`DatasetEvaluator`. + """ + + def __init__(self, evaluators): + """ + Args: + evaluators (list): the evaluators to combine. + """ + super().__init__() + self._evaluators = evaluators + + def reset(self): + for evaluator in self._evaluators: + evaluator.reset() + + def process(self, inputs, outputs): + for evaluator in self._evaluators: + evaluator.process(inputs, outputs) + + def evaluate(self): + results = OrderedDict() + for evaluator in self._evaluators: + result = evaluator.evaluate() + if is_main_process() and result is not None: + for k, v in result.items(): + assert ( + k not in results + ), "Different evaluators produce results with the same key {}".format(k) + results[k] = v + return results + + +def inference_on_dataset( + model, data_loader, evaluator: Union[DatasetEvaluator, List[DatasetEvaluator], None] +): + """ + Run model on the data_loader and evaluate the metrics with evaluator. + Also benchmark the inference speed of `model.__call__` accurately. + The model will be used in eval mode. + + Args: + model (callable): a callable which takes an object from + `data_loader` and returns some outputs. + + If it's an nn.Module, it will be temporarily set to `eval` mode. + If you wish to evaluate a model in `training` mode instead, you can + wrap the given model and override its behavior of `.eval()` and `.train()`. + data_loader: an iterable object with a length. + The elements it generates will be the inputs to the model. + evaluator: the evaluator(s) to run. Use `None` if you only want to benchmark, + but don't want to do any evaluation. + + Returns: + The return value of `evaluator.evaluate()` + """ + num_devices = get_world_size() + logger = logging.getLogger(__name__) + logger.info("Start inference on {} batches".format(len(data_loader))) + + total = len(data_loader) # inference data loader must have a fixed length + if evaluator is None: + # create a no-op evaluator + evaluator = DatasetEvaluators([]) + if isinstance(evaluator, abc.MutableSequence): + evaluator = DatasetEvaluators(evaluator) + evaluator.reset() + + num_warmup = min(5, total - 1) + start_time = time.perf_counter() + total_data_time = 0 + total_compute_time = 0 + total_eval_time = 0 + with ExitStack() as stack: + if isinstance(model, nn.Module): + stack.enter_context(inference_context(model)) + stack.enter_context(torch.no_grad()) + + start_data_time = time.perf_counter() + for idx, inputs in enumerate(data_loader): + total_data_time += time.perf_counter() - start_data_time + if idx == num_warmup: + start_time = time.perf_counter() + total_data_time = 0 + total_compute_time = 0 + total_eval_time = 0 + + start_compute_time = time.perf_counter() + outputs = model(inputs) + if torch.cuda.is_available(): + torch.cuda.synchronize() + total_compute_time += time.perf_counter() - start_compute_time + + start_eval_time = time.perf_counter() + evaluator.process(inputs, outputs) + total_eval_time += time.perf_counter() - start_eval_time + + iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) + data_seconds_per_iter = total_data_time / iters_after_start + compute_seconds_per_iter = total_compute_time / iters_after_start + eval_seconds_per_iter = total_eval_time / iters_after_start + total_seconds_per_iter = (time.perf_counter() - start_time) / iters_after_start + if idx >= num_warmup * 2 or compute_seconds_per_iter > 5: + eta = datetime.timedelta(seconds=int(total_seconds_per_iter * (total - idx - 1))) + log_every_n_seconds( + logging.INFO, + ( + f"Inference done {idx + 1}/{total}. " + f"Dataloading: {data_seconds_per_iter:.4f} s/iter. " + f"Inference: {compute_seconds_per_iter:.4f} s/iter. " + f"Eval: {eval_seconds_per_iter:.4f} s/iter. " + f"Total: {total_seconds_per_iter:.4f} s/iter. " + f"ETA={eta}" + ), + n=5, + ) + start_data_time = time.perf_counter() + + # Measure the time only for this worker (before the synchronization barrier) + total_time = time.perf_counter() - start_time + total_time_str = str(datetime.timedelta(seconds=total_time)) + # NOTE this format is parsed by grep + logger.info( + "Total inference time: {} ({:.6f} s / iter per device, on {} devices)".format( + total_time_str, total_time / (total - num_warmup), num_devices + ) + ) + total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) + logger.info( + "Total inference pure compute time: {} ({:.6f} s / iter per device, on {} devices)".format( + total_compute_time_str, total_compute_time / (total - num_warmup), num_devices + ) + ) + + results = evaluator.evaluate() + # An evaluator may return None when not in main process. + # Replace it by an empty dict instead to make it easier for downstream code to handle + if results is None: + results = {} + return results + + +@contextmanager +def inference_context(model): + """ + A context where the model is temporarily changed to eval mode, + and restored to previous mode afterwards. + + Args: + model: a torch Module + """ + training_mode = model.training + model.eval() + yield + model.train(training_mode) diff --git a/custom_detectron2/evaluation/fast_eval_api.py b/custom_detectron2/evaluation/fast_eval_api.py new file mode 100644 index 0000000000000000000000000000000000000000..659aabb71b057e7db3b11110e21819516dd39ee1 --- /dev/null +++ b/custom_detectron2/evaluation/fast_eval_api.py @@ -0,0 +1,121 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import numpy as np +import time +from custom_pycocotools.cocoeval import COCOeval + +from custom_detectron2 import _C + +logger = logging.getLogger(__name__) + + +class COCOeval_opt(COCOeval): + """ + This is a slightly modified version of the original COCO API, where the functions evaluateImg() + and accumulate() are implemented in C++ to speedup evaluation + """ + + def evaluate(self): + """ + Run per image evaluation on given images and store results in self.evalImgs_cpp, a + datastructure that isn't readable from Python but is used by a c++ implementation of + accumulate(). Unlike the original COCO PythonAPI, we don't populate the datastructure + self.evalImgs because this datastructure is a computational bottleneck. + :return: None + """ + tic = time.time() + + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = "segm" if p.useSegm == 1 else "bbox" + logger.info("Evaluate annotation type *{}*".format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() # bottleneck + + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == "segm" or p.iouType == "bbox": + computeIoU = self.computeIoU + elif p.iouType == "keypoints": + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) for imgId in p.imgIds for catId in catIds + } # bottleneck + + maxDet = p.maxDets[-1] + + # <<<< Beginning of code differences with original COCO API + def convert_instances_to_cpp(instances, is_det=False): + # Convert annotations for a list of instances in an image to a format that's fast + # to access in C++ + instances_cpp = [] + for instance in instances: + instance_cpp = _C.InstanceAnnotation( + int(instance["id"]), + instance["score"] if is_det else instance.get("score", 0.0), + instance["area"], + bool(instance.get("iscrowd", 0)), + bool(instance.get("ignore", 0)), + ) + instances_cpp.append(instance_cpp) + return instances_cpp + + # Convert GT annotations, detections, and IOUs to a format that's fast to access in C++ + ground_truth_instances = [ + [convert_instances_to_cpp(self._gts[imgId, catId]) for catId in p.catIds] + for imgId in p.imgIds + ] + detected_instances = [ + [convert_instances_to_cpp(self._dts[imgId, catId], is_det=True) for catId in p.catIds] + for imgId in p.imgIds + ] + ious = [[self.ious[imgId, catId] for catId in catIds] for imgId in p.imgIds] + + if not p.useCats: + # For each image, flatten per-category lists into a single list + ground_truth_instances = [[[o for c in i for o in c]] for i in ground_truth_instances] + detected_instances = [[[o for c in i for o in c]] for i in detected_instances] + + # Call C++ implementation of self.evaluateImgs() + self._evalImgs_cpp = _C.COCOevalEvaluateImages( + p.areaRng, maxDet, p.iouThrs, ious, ground_truth_instances, detected_instances + ) + self._evalImgs = None + + self._paramsEval = copy.deepcopy(self.params) + toc = time.time() + logger.info("COCOeval_opt.evaluate() finished in {:0.2f} seconds.".format(toc - tic)) + # >>>> End of code differences with original COCO API + + def accumulate(self): + """ + Accumulate per image evaluation results and store the result in self.eval. Does not + support changing parameter settings from those used by self.evaluate() + """ + logger.info("Accumulating evaluation results...") + tic = time.time() + assert hasattr( + self, "_evalImgs_cpp" + ), "evaluate() must be called before accmulate() is called." + + self.eval = _C.COCOevalAccumulate(self._paramsEval, self._evalImgs_cpp) + + # recall is num_iou_thresholds X num_categories X num_area_ranges X num_max_detections + self.eval["recall"] = np.array(self.eval["recall"]).reshape( + self.eval["counts"][:1] + self.eval["counts"][2:] + ) + + # precision and scores are num_iou_thresholds X num_recall_thresholds X num_categories X + # num_area_ranges X num_max_detections + self.eval["precision"] = np.array(self.eval["precision"]).reshape(self.eval["counts"]) + self.eval["scores"] = np.array(self.eval["scores"]).reshape(self.eval["counts"]) + toc = time.time() + logger.info("COCOeval_opt.accumulate() finished in {:0.2f} seconds.".format(toc - tic)) diff --git a/custom_detectron2/evaluation/lvis_evaluation.py b/custom_detectron2/evaluation/lvis_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7769b48c3feb96cbcb7ee56794c1a3a35c3540 --- /dev/null +++ b/custom_detectron2/evaluation/lvis_evaluation.py @@ -0,0 +1,380 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import json +import logging +import os +import pickle +from collections import OrderedDict +import torch + +import custom_detectron2.utils.comm as comm +from custom_detectron2.config import CfgNode +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.structures import Boxes, BoxMode, pairwise_iou +from custom_detectron2.utils.file_io import PathManager +from custom_detectron2.utils.logger import create_small_table + +from .coco_evaluation import instances_to_coco_json +from .evaluator import DatasetEvaluator + + +class LVISEvaluator(DatasetEvaluator): + """ + Evaluate object proposal and instance detection/segmentation outputs using + LVIS's metrics and evaluation API. + """ + + def __init__( + self, + dataset_name, + tasks=None, + distributed=True, + output_dir=None, + *, + max_dets_per_image=None, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + It must have the following corresponding metadata: + "json_file": the path to the LVIS format annotation + tasks (tuple[str]): tasks that can be evaluated under the given + configuration. A task is one of "bbox", "segm". + By default, will infer this automatically from predictions. + distributed (True): if True, will collect results from all ranks for evaluation. + Otherwise, will evaluate the results in the current process. + output_dir (str): optional, an output directory to dump results. + max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP + This limit, by default of the LVIS dataset, is 300. + """ + from lvis import LVIS + + self._logger = logging.getLogger(__name__) + + if tasks is not None and isinstance(tasks, CfgNode): + self._logger.warn( + "COCO Evaluator instantiated using config, this is deprecated behavior." + " Please pass in explicit arguments instead." + ) + self._tasks = None # Infering it from predictions should be better + else: + self._tasks = tasks + + self._distributed = distributed + self._output_dir = output_dir + self._max_dets_per_image = max_dets_per_image + + self._cpu_device = torch.device("cpu") + + self._metadata = MetadataCatalog.get(dataset_name) + json_file = PathManager.get_local_path(self._metadata.json_file) + self._lvis_api = LVIS(json_file) + # Test set json files do not contain annotations (evaluation must be + # performed using the LVIS evaluation server). + self._do_evaluation = len(self._lvis_api.get_ann_ids()) > 0 + + def reset(self): + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a LVIS model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a LVIS model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + prediction["instances"] = instances_to_coco_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + self._predictions.append(prediction) + + def evaluate(self): + if self._distributed: + comm.synchronize() + predictions = comm.gather(self._predictions, dst=0) + predictions = list(itertools.chain(*predictions)) + + if not comm.is_main_process(): + return + else: + predictions = self._predictions + + if len(predictions) == 0: + self._logger.warning("[LVISEvaluator] Did not receive valid predictions.") + return {} + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "instances_predictions.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(predictions, f) + + self._results = OrderedDict() + if "proposals" in predictions[0]: + self._eval_box_proposals(predictions) + if "instances" in predictions[0]: + self._eval_predictions(predictions) + # Copy so the caller can do whatever with results + return copy.deepcopy(self._results) + + def _tasks_from_predictions(self, predictions): + for pred in predictions: + if "segmentation" in pred: + return ("bbox", "segm") + return ("bbox",) + + def _eval_predictions(self, predictions): + """ + Evaluate predictions. Fill self._results with the metrics of the tasks. + + Args: + predictions (list[dict]): list of outputs from the model + """ + self._logger.info("Preparing results in the LVIS format ...") + lvis_results = list(itertools.chain(*[x["instances"] for x in predictions])) + tasks = self._tasks or self._tasks_from_predictions(lvis_results) + + # LVIS evaluator can be used to evaluate results for COCO dataset categories. + # In this case `_metadata` variable will have a field with COCO-specific category mapping. + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + for result in lvis_results: + result["category_id"] = reverse_id_mapping[result["category_id"]] + else: + # unmap the category ids for LVIS (from 0-indexed to 1-indexed) + for result in lvis_results: + result["category_id"] += 1 + + if self._output_dir: + file_path = os.path.join(self._output_dir, "lvis_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(lvis_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating predictions ...") + for task in sorted(tasks): + res = _evaluate_predictions_on_lvis( + self._lvis_api, + lvis_results, + task, + max_dets_per_image=self._max_dets_per_image, + class_names=self._metadata.get("thing_classes"), + ) + self._results[task] = res + + def _eval_box_proposals(self, predictions): + """ + Evaluate the box proposals in predictions. + Fill self._results with the metrics for "box_proposals" task. + """ + if self._output_dir: + # Saving generated box proposals to file. + # Predicted box_proposals are in XYXY_ABS mode. + bbox_mode = BoxMode.XYXY_ABS.value + ids, boxes, objectness_logits = [], [], [] + for prediction in predictions: + ids.append(prediction["image_id"]) + boxes.append(prediction["proposals"].proposal_boxes.tensor.numpy()) + objectness_logits.append(prediction["proposals"].objectness_logits.numpy()) + + proposal_data = { + "boxes": boxes, + "objectness_logits": objectness_logits, + "ids": ids, + "bbox_mode": bbox_mode, + } + with PathManager.open(os.path.join(self._output_dir, "box_proposals.pkl"), "wb") as f: + pickle.dump(proposal_data, f) + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating bbox proposals ...") + res = {} + areas = {"all": "", "small": "s", "medium": "m", "large": "l"} + for limit in [100, 1000]: + for area, suffix in areas.items(): + stats = _evaluate_box_proposals(predictions, self._lvis_api, area=area, limit=limit) + key = "AR{}@{:d}".format(suffix, limit) + res[key] = float(stats["ar"].item() * 100) + self._logger.info("Proposal metrics: \n" + create_small_table(res)) + self._results["box_proposals"] = res + + +# inspired from Detectron: +# https://github.com/facebookresearch/Detectron/blob/a6a835f5b8208c45d0dce217ce9bbda915f44df7/detectron/datasets/json_dataset_evaluator.py#L255 # noqa +def _evaluate_box_proposals(dataset_predictions, lvis_api, thresholds=None, area="all", limit=None): + """ + Evaluate detection proposal recall metrics. This function is a much + faster alternative to the official LVIS API recall evaluation code. However, + it produces slightly different results. + """ + # Record max overlap value for each gt box + # Return vector of overlap values + areas = { + "all": 0, + "small": 1, + "medium": 2, + "large": 3, + "96-128": 4, + "128-256": 5, + "256-512": 6, + "512-inf": 7, + } + area_ranges = [ + [0**2, 1e5**2], # all + [0**2, 32**2], # small + [32**2, 96**2], # medium + [96**2, 1e5**2], # large + [96**2, 128**2], # 96-128 + [128**2, 256**2], # 128-256 + [256**2, 512**2], # 256-512 + [512**2, 1e5**2], + ] # 512-inf + assert area in areas, "Unknown area range: {}".format(area) + area_range = area_ranges[areas[area]] + gt_overlaps = [] + num_pos = 0 + + for prediction_dict in dataset_predictions: + predictions = prediction_dict["proposals"] + + # sort predictions in descending order + # TODO maybe remove this and make it explicit in the documentation + inds = predictions.objectness_logits.sort(descending=True)[1] + predictions = predictions[inds] + + ann_ids = lvis_api.get_ann_ids(img_ids=[prediction_dict["image_id"]]) + anno = lvis_api.load_anns(ann_ids) + gt_boxes = [ + BoxMode.convert(obj["bbox"], BoxMode.XYWH_ABS, BoxMode.XYXY_ABS) for obj in anno + ] + gt_boxes = torch.as_tensor(gt_boxes).reshape(-1, 4) # guard against no boxes + gt_boxes = Boxes(gt_boxes) + gt_areas = torch.as_tensor([obj["area"] for obj in anno]) + + if len(gt_boxes) == 0 or len(predictions) == 0: + continue + + valid_gt_inds = (gt_areas >= area_range[0]) & (gt_areas <= area_range[1]) + gt_boxes = gt_boxes[valid_gt_inds] + + num_pos += len(gt_boxes) + + if len(gt_boxes) == 0: + continue + + if limit is not None and len(predictions) > limit: + predictions = predictions[:limit] + + overlaps = pairwise_iou(predictions.proposal_boxes, gt_boxes) + + _gt_overlaps = torch.zeros(len(gt_boxes)) + for j in range(min(len(predictions), len(gt_boxes))): + # find which proposal box maximally covers each gt box + # and get the iou amount of coverage for each gt box + max_overlaps, argmax_overlaps = overlaps.max(dim=0) + + # find which gt box is 'best' covered (i.e. 'best' = most iou) + gt_ovr, gt_ind = max_overlaps.max(dim=0) + assert gt_ovr >= 0 + # find the proposal box that covers the best covered gt box + box_ind = argmax_overlaps[gt_ind] + # record the iou coverage of this gt box + _gt_overlaps[j] = overlaps[box_ind, gt_ind] + assert _gt_overlaps[j] == gt_ovr + # mark the proposal box and the gt box as used + overlaps[box_ind, :] = -1 + overlaps[:, gt_ind] = -1 + + # append recorded iou coverage level + gt_overlaps.append(_gt_overlaps) + gt_overlaps = ( + torch.cat(gt_overlaps, dim=0) if len(gt_overlaps) else torch.zeros(0, dtype=torch.float32) + ) + gt_overlaps, _ = torch.sort(gt_overlaps) + + if thresholds is None: + step = 0.05 + thresholds = torch.arange(0.5, 0.95 + 1e-5, step, dtype=torch.float32) + recalls = torch.zeros_like(thresholds) + # compute recall for each iou threshold + for i, t in enumerate(thresholds): + recalls[i] = (gt_overlaps >= t).float().sum() / float(num_pos) + # ar = 2 * np.trapz(recalls, thresholds) + ar = recalls.mean() + return { + "ar": ar, + "recalls": recalls, + "thresholds": thresholds, + "gt_overlaps": gt_overlaps, + "num_pos": num_pos, + } + + +def _evaluate_predictions_on_lvis( + lvis_gt, lvis_results, iou_type, max_dets_per_image=None, class_names=None +): + """ + Args: + iou_type (str): + max_dets_per_image (None or int): limit on maximum detections per image in evaluating AP + This limit, by default of the LVIS dataset, is 300. + class_names (None or list[str]): if provided, will use it to predict + per-category AP. + + Returns: + a dict of {metric name: score} + """ + metrics = { + "bbox": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], + "segm": ["AP", "AP50", "AP75", "APs", "APm", "APl", "APr", "APc", "APf"], + }[iou_type] + + logger = logging.getLogger(__name__) + + if len(lvis_results) == 0: # TODO: check if needed + logger.warn("No predictions from the model!") + return {metric: float("nan") for metric in metrics} + + if iou_type == "segm": + lvis_results = copy.deepcopy(lvis_results) + # When evaluating mask AP, if the results contain bbox, LVIS API will + # use the box area as the area of the instance, instead of the mask area. + # This leads to a different definition of small/medium/large. + # We remove the bbox field to let mask AP use mask area. + for c in lvis_results: + c.pop("bbox", None) + + if max_dets_per_image is None: + max_dets_per_image = 300 # Default for LVIS dataset + + from lvis import LVISEval, LVISResults + + logger.info(f"Evaluating with max detections per image = {max_dets_per_image}") + lvis_results = LVISResults(lvis_gt, lvis_results, max_dets=max_dets_per_image) + lvis_eval = LVISEval(lvis_gt, lvis_results, iou_type) + lvis_eval.run() + lvis_eval.print_results() + + # Pull the standard metrics from the LVIS results + results = lvis_eval.get_results() + results = {metric: float(results[metric] * 100) for metric in metrics} + logger.info("Evaluation results for {}: \n".format(iou_type) + create_small_table(results)) + return results diff --git a/custom_detectron2/evaluation/panoptic_evaluation.py b/custom_detectron2/evaluation/panoptic_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..833f3bd2f483360eb5cb7f6a5b6b02a2db0d01d6 --- /dev/null +++ b/custom_detectron2/evaluation/panoptic_evaluation.py @@ -0,0 +1,199 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import contextlib +import io +import itertools +import json +import logging +import numpy as np +import os +import tempfile +from collections import OrderedDict +from typing import Optional +from PIL import Image +from tabulate import tabulate + +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.utils import comm +from custom_detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + +logger = logging.getLogger(__name__) + + +class COCOPanopticEvaluator(DatasetEvaluator): + """ + Evaluate Panoptic Quality metrics on COCO using PanopticAPI. + It saves panoptic segmentation prediction in `output_dir` + + It contains a synchronize call and has to be called from all workers. + """ + + def __init__(self, dataset_name: str, output_dir: Optional[str] = None): + """ + Args: + dataset_name: name of the dataset + output_dir: output directory to save results for evaluation. + """ + self._metadata = MetadataCatalog.get(dataset_name) + self._thing_contiguous_id_to_dataset_id = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + self._stuff_contiguous_id_to_dataset_id = { + v: k for k, v in self._metadata.stuff_dataset_id_to_contiguous_id.items() + } + + self._output_dir = output_dir + if self._output_dir is not None: + PathManager.mkdirs(self._output_dir) + + def reset(self): + self._predictions = [] + + def _convert_category_id(self, segment_info): + isthing = segment_info.pop("isthing", None) + if isthing is None: + # the model produces panoptic category id directly. No more conversion needed + return segment_info + if isthing is True: + segment_info["category_id"] = self._thing_contiguous_id_to_dataset_id[ + segment_info["category_id"] + ] + else: + segment_info["category_id"] = self._stuff_contiguous_id_to_dataset_id[ + segment_info["category_id"] + ] + return segment_info + + def process(self, inputs, outputs): + from panopticapi.utils import id2rgb + + for input, output in zip(inputs, outputs): + panoptic_img, segments_info = output["panoptic_seg"] + panoptic_img = panoptic_img.cpu().numpy() + if segments_info is None: + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label, and add 1 to panoptic_img since the official + # evaluation script uses 0 for VOID label. + label_divisor = self._metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_img): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = ( + pred_class in self._metadata.thing_dataset_id_to_contiguous_id.values() + ) + segments_info.append( + { + "id": int(panoptic_label) + 1, + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + # Official evaluation script uses 0 for VOID label. + panoptic_img += 1 + + file_name = os.path.basename(input["file_name"]) + file_name_png = os.path.splitext(file_name)[0] + ".png" + with io.BytesIO() as out: + Image.fromarray(id2rgb(panoptic_img)).save(out, format="PNG") + segments_info = [self._convert_category_id(x) for x in segments_info] + self._predictions.append( + { + "image_id": input["image_id"], + "file_name": file_name_png, + "png_string": out.getvalue(), + "segments_info": segments_info, + } + ) + + def evaluate(self): + comm.synchronize() + + self._predictions = comm.gather(self._predictions) + self._predictions = list(itertools.chain(*self._predictions)) + if not comm.is_main_process(): + return + + # PanopticApi requires local files + gt_json = PathManager.get_local_path(self._metadata.panoptic_json) + gt_folder = PathManager.get_local_path(self._metadata.panoptic_root) + + with tempfile.TemporaryDirectory(prefix="panoptic_eval") as pred_dir: + logger.info("Writing all panoptic predictions to {} ...".format(pred_dir)) + for p in self._predictions: + with open(os.path.join(pred_dir, p["file_name"]), "wb") as f: + f.write(p.pop("png_string")) + + with open(gt_json, "r") as f: + json_data = json.load(f) + json_data["annotations"] = self._predictions + + output_dir = self._output_dir or pred_dir + predictions_json = os.path.join(output_dir, "predictions.json") + with PathManager.open(predictions_json, "w") as f: + f.write(json.dumps(json_data)) + + from panopticapi.evaluation import pq_compute + + with contextlib.redirect_stdout(io.StringIO()): + pq_res = pq_compute( + gt_json, + PathManager.get_local_path(predictions_json), + gt_folder=gt_folder, + pred_folder=pred_dir, + ) + + res = {} + res["PQ"] = 100 * pq_res["All"]["pq"] + res["SQ"] = 100 * pq_res["All"]["sq"] + res["RQ"] = 100 * pq_res["All"]["rq"] + res["PQ_th"] = 100 * pq_res["Things"]["pq"] + res["SQ_th"] = 100 * pq_res["Things"]["sq"] + res["RQ_th"] = 100 * pq_res["Things"]["rq"] + res["PQ_st"] = 100 * pq_res["Stuff"]["pq"] + res["SQ_st"] = 100 * pq_res["Stuff"]["sq"] + res["RQ_st"] = 100 * pq_res["Stuff"]["rq"] + + results = OrderedDict({"panoptic_seg": res}) + _print_panoptic_results(pq_res) + + return results + + +def _print_panoptic_results(pq_res): + headers = ["", "PQ", "SQ", "RQ", "#categories"] + data = [] + for name in ["All", "Things", "Stuff"]: + row = [name] + [pq_res[name][k] * 100 for k in ["pq", "sq", "rq"]] + [pq_res[name]["n"]] + data.append(row) + table = tabulate( + data, headers=headers, tablefmt="pipe", floatfmt=".3f", stralign="center", numalign="center" + ) + logger.info("Panoptic Evaluation Results:\n" + table) + + +if __name__ == "__main__": + from custom_detectron2.utils.logger import setup_logger + + logger = setup_logger() + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--gt-json") + parser.add_argument("--gt-dir") + parser.add_argument("--pred-json") + parser.add_argument("--pred-dir") + args = parser.parse_args() + + from panopticapi.evaluation import pq_compute + + with contextlib.redirect_stdout(io.StringIO()): + pq_res = pq_compute( + args.gt_json, args.pred_json, gt_folder=args.gt_dir, pred_folder=args.pred_dir + ) + _print_panoptic_results(pq_res) diff --git a/custom_detectron2/evaluation/pascal_voc_evaluation.py b/custom_detectron2/evaluation/pascal_voc_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..7de637cf20a7389ebc60b16c491fbcbd3ab87305 --- /dev/null +++ b/custom_detectron2/evaluation/pascal_voc_evaluation.py @@ -0,0 +1,300 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +import os +import tempfile +import xml.etree.ElementTree as ET +from collections import OrderedDict, defaultdict +from functools import lru_cache +import torch + +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.utils import comm +from custom_detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + + +class PascalVOCDetectionEvaluator(DatasetEvaluator): + """ + Evaluate Pascal VOC style AP for Pascal VOC dataset. + It contains a synchronization, therefore has to be called from all ranks. + + Note that the concept of AP can be implemented in different ways and may not + produce identical results. This class mimics the implementation of the official + Pascal VOC Matlab API, and should produce similar but not identical results to the + official API. + """ + + def __init__(self, dataset_name): + """ + Args: + dataset_name (str): name of the dataset, e.g., "voc_2007_test" + """ + self._dataset_name = dataset_name + meta = MetadataCatalog.get(dataset_name) + + # Too many tiny files, download all to local for speed. + annotation_dir_local = PathManager.get_local_path( + os.path.join(meta.dirname, "Annotations/") + ) + self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml") + self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt") + self._class_names = meta.thing_classes + assert meta.year in [2007, 2012], meta.year + self._is_2007 = meta.year == 2007 + self._cpu_device = torch.device("cpu") + self._logger = logging.getLogger(__name__) + + def reset(self): + self._predictions = defaultdict(list) # class name -> list of prediction strings + + def process(self, inputs, outputs): + for input, output in zip(inputs, outputs): + image_id = input["image_id"] + instances = output["instances"].to(self._cpu_device) + boxes = instances.pred_boxes.tensor.numpy() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + for box, score, cls in zip(boxes, scores, classes): + xmin, ymin, xmax, ymax = box + # The inverse of data loading logic in `datasets/pascal_voc.py` + xmin += 1 + ymin += 1 + self._predictions[cls].append( + f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}" + ) + + def evaluate(self): + """ + Returns: + dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75". + """ + all_predictions = comm.gather(self._predictions, dst=0) + if not comm.is_main_process(): + return + predictions = defaultdict(list) + for predictions_per_rank in all_predictions: + for clsid, lines in predictions_per_rank.items(): + predictions[clsid].extend(lines) + del all_predictions + + self._logger.info( + "Evaluating {} using {} metric. " + "Note that results do not use the official Matlab API.".format( + self._dataset_name, 2007 if self._is_2007 else 2012 + ) + ) + + with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname: + res_file_template = os.path.join(dirname, "{}.txt") + + aps = defaultdict(list) # iou -> ap per class + for cls_id, cls_name in enumerate(self._class_names): + lines = predictions.get(cls_id, [""]) + + with open(res_file_template.format(cls_name), "w") as f: + f.write("\n".join(lines)) + + for thresh in range(50, 100, 5): + rec, prec, ap = voc_eval( + res_file_template, + self._anno_file_template, + self._image_set_path, + cls_name, + ovthresh=thresh / 100.0, + use_07_metric=self._is_2007, + ) + aps[thresh].append(ap * 100) + + ret = OrderedDict() + mAP = {iou: np.mean(x) for iou, x in aps.items()} + ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75]} + return ret + + +############################################################################## +# +# Below code is modified from +# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py +# -------------------------------------------------------- +# Fast/er R-CNN +# Licensed under The MIT License [see LICENSE for details] +# Written by Bharath Hariharan +# -------------------------------------------------------- + +"""Python implementation of the PASCAL VOC devkit's AP evaluation code.""" + + +@lru_cache(maxsize=None) +def parse_rec(filename): + """Parse a PASCAL VOC xml file.""" + with PathManager.open(filename) as f: + tree = ET.parse(f) + objects = [] + for obj in tree.findall("object"): + obj_struct = {} + obj_struct["name"] = obj.find("name").text + obj_struct["pose"] = obj.find("pose").text + obj_struct["truncated"] = int(obj.find("truncated").text) + obj_struct["difficult"] = int(obj.find("difficult").text) + bbox = obj.find("bndbox") + obj_struct["bbox"] = [ + int(bbox.find("xmin").text), + int(bbox.find("ymin").text), + int(bbox.find("xmax").text), + int(bbox.find("ymax").text), + ] + objects.append(obj_struct) + + return objects + + +def voc_ap(rec, prec, use_07_metric=False): + """Compute VOC AP given precision and recall. If use_07_metric is true, uses + the VOC 07 11-point method (default:False). + """ + if use_07_metric: + # 11 point metric + ap = 0.0 + for t in np.arange(0.0, 1.1, 0.1): + if np.sum(rec >= t) == 0: + p = 0 + else: + p = np.max(prec[rec >= t]) + ap = ap + p / 11.0 + else: + # correct AP calculation + # first append sentinel values at the end + mrec = np.concatenate(([0.0], rec, [1.0])) + mpre = np.concatenate(([0.0], prec, [0.0])) + + # compute the precision envelope + for i in range(mpre.size - 1, 0, -1): + mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) + + # to calculate area under PR curve, look for points + # where X axis (recall) changes value + i = np.where(mrec[1:] != mrec[:-1])[0] + + # and sum (\Delta recall) * prec + ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) + return ap + + +def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False): + """rec, prec, ap = voc_eval(detpath, + annopath, + imagesetfile, + classname, + [ovthresh], + [use_07_metric]) + + Top level function that does the PASCAL VOC evaluation. + + detpath: Path to detections + detpath.format(classname) should produce the detection results file. + annopath: Path to annotations + annopath.format(imagename) should be the xml annotations file. + imagesetfile: Text file containing the list of images, one image per line. + classname: Category name (duh) + [ovthresh]: Overlap threshold (default = 0.5) + [use_07_metric]: Whether to use VOC07's 11 point AP computation + (default False) + """ + # assumes detections are in detpath.format(classname) + # assumes annotations are in annopath.format(imagename) + # assumes imagesetfile is a text file with each line an image name + + # first load gt + # read list of images + with PathManager.open(imagesetfile, "r") as f: + lines = f.readlines() + imagenames = [x.strip() for x in lines] + + # load annots + recs = {} + for imagename in imagenames: + recs[imagename] = parse_rec(annopath.format(imagename)) + + # extract gt objects for this class + class_recs = {} + npos = 0 + for imagename in imagenames: + R = [obj for obj in recs[imagename] if obj["name"] == classname] + bbox = np.array([x["bbox"] for x in R]) + difficult = np.array([x["difficult"] for x in R]).astype(bool) + # difficult = np.array([False for x in R]).astype(bool) # treat all "difficult" as GT + det = [False] * len(R) + npos = npos + sum(~difficult) + class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det} + + # read dets + detfile = detpath.format(classname) + with open(detfile, "r") as f: + lines = f.readlines() + + splitlines = [x.strip().split(" ") for x in lines] + image_ids = [x[0] for x in splitlines] + confidence = np.array([float(x[1]) for x in splitlines]) + BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4) + + # sort by confidence + sorted_ind = np.argsort(-confidence) + BB = BB[sorted_ind, :] + image_ids = [image_ids[x] for x in sorted_ind] + + # go down dets and mark TPs and FPs + nd = len(image_ids) + tp = np.zeros(nd) + fp = np.zeros(nd) + for d in range(nd): + R = class_recs[image_ids[d]] + bb = BB[d, :].astype(float) + ovmax = -np.inf + BBGT = R["bbox"].astype(float) + + if BBGT.size > 0: + # compute overlaps + # intersection + ixmin = np.maximum(BBGT[:, 0], bb[0]) + iymin = np.maximum(BBGT[:, 1], bb[1]) + ixmax = np.minimum(BBGT[:, 2], bb[2]) + iymax = np.minimum(BBGT[:, 3], bb[3]) + iw = np.maximum(ixmax - ixmin + 1.0, 0.0) + ih = np.maximum(iymax - iymin + 1.0, 0.0) + inters = iw * ih + + # union + uni = ( + (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) + + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0) + - inters + ) + + overlaps = inters / uni + ovmax = np.max(overlaps) + jmax = np.argmax(overlaps) + + if ovmax > ovthresh: + if not R["difficult"][jmax]: + if not R["det"][jmax]: + tp[d] = 1.0 + R["det"][jmax] = 1 + else: + fp[d] = 1.0 + else: + fp[d] = 1.0 + + # compute precision recall + fp = np.cumsum(fp) + tp = np.cumsum(tp) + rec = tp / float(npos) + # avoid divide by zero in case the first detection matches a difficult + # ground truth + prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps) + ap = voc_ap(rec, prec, use_07_metric) + + return rec, prec, ap diff --git a/custom_detectron2/evaluation/rotated_coco_evaluation.py b/custom_detectron2/evaluation/rotated_coco_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..d6d8b1a1cbb8243331f29f6a877f07373a7a6d97 --- /dev/null +++ b/custom_detectron2/evaluation/rotated_coco_evaluation.py @@ -0,0 +1,207 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import json +import numpy as np +import os +import torch +from custom_pycocotools.cocoeval import COCOeval, maskUtils + +from custom_detectron2.structures import BoxMode, RotatedBoxes, pairwise_iou_rotated +from custom_detectron2.utils.file_io import PathManager + +from .coco_evaluation import COCOEvaluator + + +class RotatedCOCOeval(COCOeval): + @staticmethod + def is_rotated(box_list): + if type(box_list) == np.ndarray: + return box_list.shape[1] == 5 + elif type(box_list) == list: + if box_list == []: # cannot decide the box_dim + return False + return np.all( + np.array( + [ + (len(obj) == 5) and ((type(obj) == list) or (type(obj) == np.ndarray)) + for obj in box_list + ] + ) + ) + return False + + @staticmethod + def boxlist_to_tensor(boxlist, output_box_dim): + if type(boxlist) == np.ndarray: + box_tensor = torch.from_numpy(boxlist) + elif type(boxlist) == list: + if boxlist == []: + return torch.zeros((0, output_box_dim), dtype=torch.float32) + else: + box_tensor = torch.FloatTensor(boxlist) + else: + raise Exception("Unrecognized boxlist type") + + input_box_dim = box_tensor.shape[1] + if input_box_dim != output_box_dim: + if input_box_dim == 4 and output_box_dim == 5: + box_tensor = BoxMode.convert(box_tensor, BoxMode.XYWH_ABS, BoxMode.XYWHA_ABS) + else: + raise Exception( + "Unable to convert from {}-dim box to {}-dim box".format( + input_box_dim, output_box_dim + ) + ) + return box_tensor + + def compute_iou_dt_gt(self, dt, gt, is_crowd): + if self.is_rotated(dt) or self.is_rotated(gt): + # TODO: take is_crowd into consideration + assert all(c == 0 for c in is_crowd) + dt = RotatedBoxes(self.boxlist_to_tensor(dt, output_box_dim=5)) + gt = RotatedBoxes(self.boxlist_to_tensor(gt, output_box_dim=5)) + return pairwise_iou_rotated(dt, gt) + else: + # This is the same as the classical COCO evaluation + return maskUtils.iou(dt, gt, is_crowd) + + def computeIoU(self, imgId, catId): + p = self.params + if p.useCats: + gt = self._gts[imgId, catId] + dt = self._dts[imgId, catId] + else: + gt = [_ for cId in p.catIds for _ in self._gts[imgId, cId]] + dt = [_ for cId in p.catIds for _ in self._dts[imgId, cId]] + if len(gt) == 0 and len(dt) == 0: + return [] + inds = np.argsort([-d["score"] for d in dt], kind="mergesort") + dt = [dt[i] for i in inds] + if len(dt) > p.maxDets[-1]: + dt = dt[0 : p.maxDets[-1]] + + assert p.iouType == "bbox", "unsupported iouType for iou computation" + + g = [g["bbox"] for g in gt] + d = [d["bbox"] for d in dt] + + # compute iou between each dt and gt region + iscrowd = [int(o["iscrowd"]) for o in gt] + + # Note: this function is copied from cocoeval.py in cocoapi + # and the major difference is here. + ious = self.compute_iou_dt_gt(d, g, iscrowd) + return ious + + +class RotatedCOCOEvaluator(COCOEvaluator): + """ + Evaluate object proposal/instance detection outputs using COCO-like metrics and APIs, + with rotated boxes support. + Note: this uses IOU only and does not consider angle differences. + """ + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a COCO model (e.g., GeneralizedRCNN). + It is a list of dict. Each dict corresponds to an image and + contains keys like "height", "width", "file_name", "image_id". + outputs: the outputs of a COCO model. It is a list of dicts with key + "instances" that contains :class:`Instances`. + """ + for input, output in zip(inputs, outputs): + prediction = {"image_id": input["image_id"]} + + if "instances" in output: + instances = output["instances"].to(self._cpu_device) + + prediction["instances"] = self.instances_to_json(instances, input["image_id"]) + if "proposals" in output: + prediction["proposals"] = output["proposals"].to(self._cpu_device) + self._predictions.append(prediction) + + def instances_to_json(self, instances, img_id): + num_instance = len(instances) + if num_instance == 0: + return [] + + boxes = instances.pred_boxes.tensor.numpy() + if boxes.shape[1] == 4: + boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS) + boxes = boxes.tolist() + scores = instances.scores.tolist() + classes = instances.pred_classes.tolist() + + results = [] + for k in range(num_instance): + result = { + "image_id": img_id, + "category_id": classes[k], + "bbox": boxes[k], + "score": scores[k], + } + + results.append(result) + return results + + def _eval_predictions(self, predictions, img_ids=None): # img_ids: unused + """ + Evaluate predictions on the given tasks. + Fill self._results with the metrics of the tasks. + """ + self._logger.info("Preparing results for COCO format ...") + coco_results = list(itertools.chain(*[x["instances"] for x in predictions])) + + # unmap the category ids for COCO + if hasattr(self._metadata, "thing_dataset_id_to_contiguous_id"): + reverse_id_mapping = { + v: k for k, v in self._metadata.thing_dataset_id_to_contiguous_id.items() + } + for result in coco_results: + result["category_id"] = reverse_id_mapping[result["category_id"]] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "coco_instances_results.json") + self._logger.info("Saving results to {}".format(file_path)) + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(coco_results)) + f.flush() + + if not self._do_evaluation: + self._logger.info("Annotations are not available for evaluation.") + return + + self._logger.info("Evaluating predictions ...") + + assert self._tasks is None or set(self._tasks) == { + "bbox" + }, "[RotatedCOCOEvaluator] Only bbox evaluation is supported" + coco_eval = ( + self._evaluate_predictions_on_coco(self._coco_api, coco_results) + if len(coco_results) > 0 + else None # cocoapi does not handle empty results very well + ) + + task = "bbox" + res = self._derive_coco_results( + coco_eval, task, class_names=self._metadata.get("thing_classes") + ) + self._results[task] = res + + def _evaluate_predictions_on_coco(self, coco_gt, coco_results): + """ + Evaluate the coco results using COCOEval API. + """ + assert len(coco_results) > 0 + + coco_dt = coco_gt.loadRes(coco_results) + + # Only bbox is supported for now + coco_eval = RotatedCOCOeval(coco_gt, coco_dt, iouType="bbox") + + coco_eval.evaluate() + coco_eval.accumulate() + coco_eval.summarize() + + return coco_eval diff --git a/custom_detectron2/evaluation/sem_seg_evaluation.py b/custom_detectron2/evaluation/sem_seg_evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..b82c23ddee9a916d0ba42b1536465e8a359c7932 --- /dev/null +++ b/custom_detectron2/evaluation/sem_seg_evaluation.py @@ -0,0 +1,265 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import json +import logging +import numpy as np +import os +from collections import OrderedDict +from typing import Optional, Union +import custom_pycocotools.mask as mask_util +import torch +from PIL import Image + +from custom_detectron2.data import DatasetCatalog, MetadataCatalog +from custom_detectron2.utils.comm import all_gather, is_main_process, synchronize +from custom_detectron2.utils.file_io import PathManager + +from .evaluator import DatasetEvaluator + +_CV2_IMPORTED = True +try: + import cv2 # noqa +except ImportError: + # OpenCV is an optional dependency at the moment + _CV2_IMPORTED = False + + +def load_image_into_numpy_array( + filename: str, + copy: bool = False, + dtype: Optional[Union[np.dtype, str]] = None, +) -> np.ndarray: + with PathManager.open(filename, "rb") as f: + array = np.array(Image.open(f), copy=copy, dtype=dtype) + return array + + +class SemSegEvaluator(DatasetEvaluator): + """ + Evaluate semantic segmentation metrics. + """ + + def __init__( + self, + dataset_name, + distributed=True, + output_dir=None, + *, + sem_seg_loading_fn=load_image_into_numpy_array, + num_classes=None, + ignore_label=None, + ): + """ + Args: + dataset_name (str): name of the dataset to be evaluated. + distributed (bool): if True, will collect results from all ranks for evaluation. + Otherwise, will evaluate the results in the current process. + output_dir (str): an output directory to dump results. + sem_seg_loading_fn: function to read sem seg file and load into numpy array. + Default provided, but projects can customize. + num_classes, ignore_label: deprecated argument + """ + self._logger = logging.getLogger(__name__) + if num_classes is not None: + self._logger.warn( + "SemSegEvaluator(num_classes) is deprecated! It should be obtained from metadata." + ) + if ignore_label is not None: + self._logger.warn( + "SemSegEvaluator(ignore_label) is deprecated! It should be obtained from metadata." + ) + self._dataset_name = dataset_name + self._distributed = distributed + self._output_dir = output_dir + + self._cpu_device = torch.device("cpu") + + self.input_file_to_gt_file = { + dataset_record["file_name"]: dataset_record["sem_seg_file_name"] + for dataset_record in DatasetCatalog.get(dataset_name) + } + + meta = MetadataCatalog.get(dataset_name) + # Dict that maps contiguous training ids to COCO category ids + try: + c2d = meta.stuff_dataset_id_to_contiguous_id + self._contiguous_id_to_dataset_id = {v: k for k, v in c2d.items()} + except AttributeError: + self._contiguous_id_to_dataset_id = None + self._class_names = meta.stuff_classes + self.sem_seg_loading_fn = sem_seg_loading_fn + self._num_classes = len(meta.stuff_classes) + if num_classes is not None: + assert self._num_classes == num_classes, f"{self._num_classes} != {num_classes}" + self._ignore_label = ignore_label if ignore_label is not None else meta.ignore_label + + # This is because cv2.erode did not work for int datatype. Only works for uint8. + self._compute_boundary_iou = True + if not _CV2_IMPORTED: + self._compute_boundary_iou = False + self._logger.warn( + """Boundary IoU calculation requires OpenCV. B-IoU metrics are + not going to be computed because OpenCV is not available to import.""" + ) + if self._num_classes >= np.iinfo(np.uint8).max: + self._compute_boundary_iou = False + self._logger.warn( + f"""SemSegEvaluator(num_classes) is more than supported value for Boundary IoU calculation! + B-IoU metrics are not going to be computed. Max allowed value (exclusive) + for num_classes for calculating Boundary IoU is {np.iinfo(np.uint8).max}. + The number of classes of dataset {self._dataset_name} is {self._num_classes}""" + ) + + def reset(self): + self._conf_matrix = np.zeros((self._num_classes + 1, self._num_classes + 1), dtype=np.int64) + self._b_conf_matrix = np.zeros( + (self._num_classes + 1, self._num_classes + 1), dtype=np.int64 + ) + self._predictions = [] + + def process(self, inputs, outputs): + """ + Args: + inputs: the inputs to a model. + It is a list of dicts. Each dict corresponds to an image and + contains keys like "height", "width", "file_name". + outputs: the outputs of a model. It is either list of semantic segmentation predictions + (Tensor [H, W]) or list of dicts with key "sem_seg" that contains semantic + segmentation prediction in the same format. + """ + for input, output in zip(inputs, outputs): + output = output["sem_seg"].argmax(dim=0).to(self._cpu_device) + pred = np.array(output, dtype=np.int) + gt_filename = self.input_file_to_gt_file[input["file_name"]] + gt = self.sem_seg_loading_fn(gt_filename, dtype=np.int) + + gt[gt == self._ignore_label] = self._num_classes + + self._conf_matrix += np.bincount( + (self._num_classes + 1) * pred.reshape(-1) + gt.reshape(-1), + minlength=self._conf_matrix.size, + ).reshape(self._conf_matrix.shape) + + if self._compute_boundary_iou: + b_gt = self._mask_to_boundary(gt.astype(np.uint8)) + b_pred = self._mask_to_boundary(pred.astype(np.uint8)) + + self._b_conf_matrix += np.bincount( + (self._num_classes + 1) * b_pred.reshape(-1) + b_gt.reshape(-1), + minlength=self._conf_matrix.size, + ).reshape(self._conf_matrix.shape) + + self._predictions.extend(self.encode_json_sem_seg(pred, input["file_name"])) + + def evaluate(self): + """ + Evaluates standard semantic segmentation metrics (http://cocodataset.org/#stuff-eval): + + * Mean intersection-over-union averaged across classes (mIoU) + * Frequency Weighted IoU (fwIoU) + * Mean pixel accuracy averaged across classes (mACC) + * Pixel Accuracy (pACC) + """ + if self._distributed: + synchronize() + conf_matrix_list = all_gather(self._conf_matrix) + b_conf_matrix_list = all_gather(self._b_conf_matrix) + self._predictions = all_gather(self._predictions) + self._predictions = list(itertools.chain(*self._predictions)) + if not is_main_process(): + return + + self._conf_matrix = np.zeros_like(self._conf_matrix) + for conf_matrix in conf_matrix_list: + self._conf_matrix += conf_matrix + + self._b_conf_matrix = np.zeros_like(self._b_conf_matrix) + for b_conf_matrix in b_conf_matrix_list: + self._b_conf_matrix += b_conf_matrix + + if self._output_dir: + PathManager.mkdirs(self._output_dir) + file_path = os.path.join(self._output_dir, "sem_seg_predictions.json") + with PathManager.open(file_path, "w") as f: + f.write(json.dumps(self._predictions)) + + acc = np.full(self._num_classes, np.nan, dtype=np.float) + iou = np.full(self._num_classes, np.nan, dtype=np.float) + tp = self._conf_matrix.diagonal()[:-1].astype(np.float) + pos_gt = np.sum(self._conf_matrix[:-1, :-1], axis=0).astype(np.float) + class_weights = pos_gt / np.sum(pos_gt) + pos_pred = np.sum(self._conf_matrix[:-1, :-1], axis=1).astype(np.float) + acc_valid = pos_gt > 0 + acc[acc_valid] = tp[acc_valid] / pos_gt[acc_valid] + union = pos_gt + pos_pred - tp + iou_valid = np.logical_and(acc_valid, union > 0) + iou[iou_valid] = tp[iou_valid] / union[iou_valid] + macc = np.sum(acc[acc_valid]) / np.sum(acc_valid) + miou = np.sum(iou[iou_valid]) / np.sum(iou_valid) + fiou = np.sum(iou[iou_valid] * class_weights[iou_valid]) + pacc = np.sum(tp) / np.sum(pos_gt) + + if self._compute_boundary_iou: + b_iou = np.full(self._num_classes, np.nan, dtype=np.float) + b_tp = self._b_conf_matrix.diagonal()[:-1].astype(np.float) + b_pos_gt = np.sum(self._b_conf_matrix[:-1, :-1], axis=0).astype(np.float) + b_pos_pred = np.sum(self._b_conf_matrix[:-1, :-1], axis=1).astype(np.float) + b_union = b_pos_gt + b_pos_pred - b_tp + b_iou_valid = b_union > 0 + b_iou[b_iou_valid] = b_tp[b_iou_valid] / b_union[b_iou_valid] + + res = {} + res["mIoU"] = 100 * miou + res["fwIoU"] = 100 * fiou + for i, name in enumerate(self._class_names): + res[f"IoU-{name}"] = 100 * iou[i] + if self._compute_boundary_iou: + res[f"BoundaryIoU-{name}"] = 100 * b_iou[i] + res[f"min(IoU, B-Iou)-{name}"] = 100 * min(iou[i], b_iou[i]) + res["mACC"] = 100 * macc + res["pACC"] = 100 * pacc + for i, name in enumerate(self._class_names): + res[f"ACC-{name}"] = 100 * acc[i] + + if self._output_dir: + file_path = os.path.join(self._output_dir, "sem_seg_evaluation.pth") + with PathManager.open(file_path, "wb") as f: + torch.save(res, f) + results = OrderedDict({"sem_seg": res}) + self._logger.info(results) + return results + + def encode_json_sem_seg(self, sem_seg, input_file_name): + """ + Convert semantic segmentation to COCO stuff format with segments encoded as RLEs. + See http://cocodataset.org/#format-results + """ + json_list = [] + for label in np.unique(sem_seg): + if self._contiguous_id_to_dataset_id is not None: + assert ( + label in self._contiguous_id_to_dataset_id + ), "Label {} is not in the metadata info for {}".format(label, self._dataset_name) + dataset_id = self._contiguous_id_to_dataset_id[label] + else: + dataset_id = int(label) + mask = (sem_seg == label).astype(np.uint8) + mask_rle = mask_util.encode(np.array(mask[:, :, None], order="F"))[0] + mask_rle["counts"] = mask_rle["counts"].decode("utf-8") + json_list.append( + {"file_name": input_file_name, "category_id": dataset_id, "segmentation": mask_rle} + ) + return json_list + + def _mask_to_boundary(self, mask: np.ndarray, dilation_ratio=0.02): + assert mask.ndim == 2, "mask_to_boundary expects a 2-dimensional image" + h, w = mask.shape + diag_len = np.sqrt(h**2 + w**2) + dilation = max(1, int(round(dilation_ratio * diag_len))) + kernel = np.ones((3, 3), dtype=np.uint8) + + padded_mask = cv2.copyMakeBorder(mask, 1, 1, 1, 1, cv2.BORDER_CONSTANT, value=0) + eroded_mask_with_padding = cv2.erode(padded_mask, kernel, iterations=dilation) + eroded_mask = eroded_mask_with_padding[1:-1, 1:-1] + boundary = mask - eroded_mask + return boundary diff --git a/custom_detectron2/evaluation/testing.py b/custom_detectron2/evaluation/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..9e5ae625bb0593fc20739dd3ea549157e4df4f3d --- /dev/null +++ b/custom_detectron2/evaluation/testing.py @@ -0,0 +1,85 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +import pprint +import sys +from collections.abc import Mapping + + +def print_csv_format(results): + """ + Print main metrics in a format similar to Detectron, + so that they are easy to copypaste into a spreadsheet. + + Args: + results (OrderedDict[dict]): task_name -> {metric -> score} + unordered dict can also be printed, but in arbitrary order + """ + assert isinstance(results, Mapping) or not len(results), results + logger = logging.getLogger(__name__) + for task, res in results.items(): + if isinstance(res, Mapping): + # Don't print "AP-category" metrics since they are usually not tracked. + important_res = [(k, v) for k, v in res.items() if "-" not in k] + logger.info("copypaste: Task: {}".format(task)) + logger.info("copypaste: " + ",".join([k[0] for k in important_res])) + logger.info("copypaste: " + ",".join(["{0:.4f}".format(k[1]) for k in important_res])) + else: + logger.info(f"copypaste: {task}={res}") + + +def verify_results(cfg, results): + """ + Args: + results (OrderedDict[dict]): task_name -> {metric -> score} + + Returns: + bool: whether the verification succeeds or not + """ + expected_results = cfg.TEST.EXPECTED_RESULTS + if not len(expected_results): + return True + + ok = True + for task, metric, expected, tolerance in expected_results: + actual = results[task].get(metric, None) + if actual is None: + ok = False + continue + if not np.isfinite(actual): + ok = False + continue + diff = abs(actual - expected) + if diff > tolerance: + ok = False + + logger = logging.getLogger(__name__) + if not ok: + logger.error("Result verification failed!") + logger.error("Expected Results: " + str(expected_results)) + logger.error("Actual Results: " + pprint.pformat(results)) + + sys.exit(1) + else: + logger.info("Results verification passed.") + return ok + + +def flatten_results_dict(results): + """ + Expand a hierarchical dict of scalars into a flat dict of scalars. + If results[k1][k2][k3] = v, the returned dict will have the entry + {"k1/k2/k3": v}. + + Args: + results (dict): + """ + r = {} + for k, v in results.items(): + if isinstance(v, Mapping): + v = flatten_results_dict(v) + for kk, vv in v.items(): + r[k + "/" + kk] = vv + else: + r[k] = v + return r diff --git a/custom_detectron2/export/README.md b/custom_detectron2/export/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c86ff62516f4e8e4b1a6c1f33f11192933cf3861 --- /dev/null +++ b/custom_detectron2/export/README.md @@ -0,0 +1,15 @@ + +This directory contains code to prepare a detectron2 model for deployment. +Currently it supports exporting a detectron2 model to TorchScript, ONNX, or (deprecated) Caffe2 format. + +Please see [documentation](https://detectron2.readthedocs.io/tutorials/deployment.html) for its usage. + + +### Acknowledgements + +Thanks to Mobile Vision team at Facebook for developing the Caffe2 conversion tools. + +Thanks to Computing Platform Department - PAI team at Alibaba Group (@bddpqq, @chenbohua3) who +help export Detectron2 models to TorchScript. + +Thanks to ONNX Converter team at Microsoft who help export Detectron2 models to ONNX. diff --git a/custom_detectron2/export/__init__.py b/custom_detectron2/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5a58758f64aae6071fa688be4400622ce6036efa --- /dev/null +++ b/custom_detectron2/export/__init__.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- + +import warnings + +from .flatten import TracingAdapter +from .torchscript import dump_torchscript_IR, scripting_with_instances + +try: + from caffe2.proto import caffe2_pb2 as _tmp + from caffe2.python import core + + # caffe2 is optional +except ImportError: + pass +else: + from .api import * + + +# TODO: Update ONNX Opset version and run tests when a newer PyTorch is supported +STABLE_ONNX_OPSET_VERSION = 11 + + +def add_export_config(cfg): + warnings.warn( + "add_export_config has been deprecated and behaves as no-op function.", DeprecationWarning + ) + return cfg + + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/export/api.py b/custom_detectron2/export/api.py new file mode 100644 index 0000000000000000000000000000000000000000..4f9d7dac1d086902f082910059d539083bd85c0d --- /dev/null +++ b/custom_detectron2/export/api.py @@ -0,0 +1,230 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import logging +import os +import torch +from caffe2.proto import caffe2_pb2 +from torch import nn + +from custom_detectron2.config import CfgNode +from custom_detectron2.utils.file_io import PathManager + +from .caffe2_inference import ProtobufDetectionModel +from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format +from .shared import get_pb_arg_vali, get_pb_arg_vals, save_graph + +__all__ = [ + "Caffe2Model", + "Caffe2Tracer", +] + + +class Caffe2Tracer: + """ + Make a detectron2 model traceable with Caffe2 operators. + This class creates a traceable version of a detectron2 model which: + + 1. Rewrite parts of the model using ops in Caffe2. Note that some ops do + not have GPU implementation in Caffe2. + 2. Remove post-processing and only produce raw layer outputs + + After making a traceable model, the class provide methods to export such a + model to different deployment formats. + Exported graph produced by this class take two input tensors: + + 1. (1, C, H, W) float "data" which is an image (usually in [0, 255]). + (H, W) often has to be padded to multiple of 32 (depend on the model + architecture). + 2. 1x3 float "im_info", each row of which is (height, width, 1.0). + Height and width are true image shapes before padding. + + The class currently only supports models using builtin meta architectures. + Batch inference is not supported, and contributions are welcome. + """ + + def __init__(self, cfg: CfgNode, model: nn.Module, inputs): + """ + Args: + cfg (CfgNode): a detectron2 config used to construct caffe2-compatible model. + model (nn.Module): An original pytorch model. Must be among a few official models + in detectron2 that can be converted to become caffe2-compatible automatically. + Weights have to be already loaded to this model. + inputs: sample inputs that the given model takes for inference. + Will be used to trace the model. For most models, random inputs with + no detected objects will not work as they lead to wrong traces. + """ + assert isinstance(cfg, CfgNode), cfg + assert isinstance(model, torch.nn.Module), type(model) + + # TODO make it support custom models, by passing in c2 model directly + C2MetaArch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[cfg.MODEL.META_ARCHITECTURE] + self.traceable_model = C2MetaArch(cfg, copy.deepcopy(model)) + self.inputs = inputs + self.traceable_inputs = self.traceable_model.get_caffe2_inputs(inputs) + + def export_caffe2(self): + """ + Export the model to Caffe2's protobuf format. + The returned object can be saved with its :meth:`.save_protobuf()` method. + The result can be loaded and executed using Caffe2 runtime. + + Returns: + :class:`Caffe2Model` + """ + from .caffe2_export import export_caffe2_detection_model + + predict_net, init_net = export_caffe2_detection_model( + self.traceable_model, self.traceable_inputs + ) + return Caffe2Model(predict_net, init_net) + + def export_onnx(self): + """ + Export the model to ONNX format. + Note that the exported model contains custom ops only available in caffe2, therefore it + cannot be directly executed by other runtime (such as onnxruntime or TensorRT). + Post-processing or transformation passes may be applied on the model to accommodate + different runtimes, but we currently do not provide support for them. + + Returns: + onnx.ModelProto: an onnx model. + """ + from .caffe2_export import export_onnx_model as export_onnx_model_impl + + return export_onnx_model_impl(self.traceable_model, (self.traceable_inputs,)) + + def export_torchscript(self): + """ + Export the model to a ``torch.jit.TracedModule`` by tracing. + The returned object can be saved to a file by ``.save()``. + + Returns: + torch.jit.TracedModule: a torch TracedModule + """ + logger = logging.getLogger(__name__) + logger.info("Tracing the model with torch.jit.trace ...") + with torch.no_grad(): + return torch.jit.trace(self.traceable_model, (self.traceable_inputs,)) + + +class Caffe2Model(nn.Module): + """ + A wrapper around the traced model in Caffe2's protobuf format. + The exported graph has different inputs/outputs from the original Pytorch + model, as explained in :class:`Caffe2Tracer`. This class wraps around the + exported graph to simulate the same interface as the original Pytorch model. + It also provides functions to save/load models in Caffe2's format.' + + Examples: + :: + c2_model = Caffe2Tracer(cfg, torch_model, inputs).export_caffe2() + inputs = [{"image": img_tensor_CHW}] + outputs = c2_model(inputs) + orig_outputs = torch_model(inputs) + """ + + def __init__(self, predict_net, init_net): + super().__init__() + self.eval() # always in eval mode + self._predict_net = predict_net + self._init_net = init_net + self._predictor = None + + __init__.__HIDE_SPHINX_DOC__ = True + + @property + def predict_net(self): + """ + caffe2.core.Net: the underlying caffe2 predict net + """ + return self._predict_net + + @property + def init_net(self): + """ + caffe2.core.Net: the underlying caffe2 init net + """ + return self._init_net + + def save_protobuf(self, output_dir): + """ + Save the model as caffe2's protobuf format. + It saves the following files: + + * "model.pb": definition of the graph. Can be visualized with + tools like `netron `_. + * "model_init.pb": model parameters + * "model.pbtxt": human-readable definition of the graph. Not + needed for deployment. + + Args: + output_dir (str): the output directory to save protobuf files. + """ + logger = logging.getLogger(__name__) + logger.info("Saving model to {} ...".format(output_dir)) + if not PathManager.exists(output_dir): + PathManager.mkdirs(output_dir) + + with PathManager.open(os.path.join(output_dir, "model.pb"), "wb") as f: + f.write(self._predict_net.SerializeToString()) + with PathManager.open(os.path.join(output_dir, "model.pbtxt"), "w") as f: + f.write(str(self._predict_net)) + with PathManager.open(os.path.join(output_dir, "model_init.pb"), "wb") as f: + f.write(self._init_net.SerializeToString()) + + def save_graph(self, output_file, inputs=None): + """ + Save the graph as SVG format. + + Args: + output_file (str): a SVG file + inputs: optional inputs given to the model. + If given, the inputs will be used to run the graph to record + shape of every tensor. The shape information will be + saved together with the graph. + """ + from .caffe2_export import run_and_save_graph + + if inputs is None: + save_graph(self._predict_net, output_file, op_only=False) + else: + size_divisibility = get_pb_arg_vali(self._predict_net, "size_divisibility", 0) + device = get_pb_arg_vals(self._predict_net, "device", b"cpu").decode("ascii") + inputs = convert_batched_inputs_to_c2_format(inputs, size_divisibility, device) + inputs = [x.cpu().numpy() for x in inputs] + run_and_save_graph(self._predict_net, self._init_net, inputs, output_file) + + @staticmethod + def load_protobuf(dir): + """ + Args: + dir (str): a directory used to save Caffe2Model with + :meth:`save_protobuf`. + The files "model.pb" and "model_init.pb" are needed. + + Returns: + Caffe2Model: the caffe2 model loaded from this directory. + """ + predict_net = caffe2_pb2.NetDef() + with PathManager.open(os.path.join(dir, "model.pb"), "rb") as f: + predict_net.ParseFromString(f.read()) + + init_net = caffe2_pb2.NetDef() + with PathManager.open(os.path.join(dir, "model_init.pb"), "rb") as f: + init_net.ParseFromString(f.read()) + + return Caffe2Model(predict_net, init_net) + + def __call__(self, inputs): + """ + An interface that wraps around a Caffe2 model and mimics detectron2's models' + input/output format. See details about the format at :doc:`/tutorials/models`. + This is used to compare the outputs of caffe2 model with its original torch model. + + Due to the extra conversion between Pytorch/Caffe2, this method is not meant for + benchmark. Because of the conversion, this method also has dependency + on detectron2 in order to convert to detectron2's output format. + """ + if self._predictor is None: + self._predictor = ProtobufDetectionModel(self._predict_net, self._init_net) + return self._predictor(inputs) diff --git a/custom_detectron2/export/c10.py b/custom_detectron2/export/c10.py new file mode 100644 index 0000000000000000000000000000000000000000..c657db1eecb3a80622bee34d8ea58b9c3ff913aa --- /dev/null +++ b/custom_detectron2/export/c10.py @@ -0,0 +1,557 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import math +from typing import Dict +import torch +import torch.nn.functional as F + +from custom_detectron2.layers import ShapeSpec, cat +from custom_detectron2.layers.roi_align_rotated import ROIAlignRotated +from custom_detectron2.modeling import poolers +from custom_detectron2.modeling.proposal_generator import rpn +from custom_detectron2.modeling.roi_heads.mask_head import mask_rcnn_inference +from custom_detectron2.structures import Boxes, ImageList, Instances, Keypoints, RotatedBoxes + +from .shared import alias, to_device + + +""" +This file contains caffe2-compatible implementation of several detectron2 components. +""" + + +class Caffe2Boxes(Boxes): + """ + Representing a list of detectron2.structures.Boxes from minibatch, each box + is represented by a 5d vector (batch index + 4 coordinates), or a 6d vector + (batch index + 5 coordinates) for RotatedBoxes. + """ + + def __init__(self, tensor): + assert isinstance(tensor, torch.Tensor) + assert tensor.dim() == 2 and tensor.size(-1) in [4, 5, 6], tensor.size() + # TODO: make tensor immutable when dim is Nx5 for Boxes, + # and Nx6 for RotatedBoxes? + self.tensor = tensor + + +# TODO clean up this class, maybe just extend Instances +class InstancesList(object): + """ + Tensor representation of a list of Instances object for a batch of images. + + When dealing with a batch of images with Caffe2 ops, a list of bboxes + (instances) are usually represented by single Tensor with size + (sigma(Ni), 5) or (sigma(Ni), 4) plus a batch split Tensor. This class is + for providing common functions to convert between these two representations. + """ + + def __init__(self, im_info, indices, extra_fields=None): + # [N, 3] -> (H, W, Scale) + self.im_info = im_info + # [N,] -> indice of batch to which the instance belongs + self.indices = indices + # [N, ...] + self.batch_extra_fields = extra_fields or {} + + self.image_size = self.im_info + + def get_fields(self): + """like `get_fields` in the Instances object, + but return each field in tensor representations""" + ret = {} + for k, v in self.batch_extra_fields.items(): + # if isinstance(v, torch.Tensor): + # tensor_rep = v + # elif isinstance(v, (Boxes, Keypoints)): + # tensor_rep = v.tensor + # else: + # raise ValueError("Can't find tensor representation for: {}".format()) + ret[k] = v + return ret + + def has(self, name): + return name in self.batch_extra_fields + + def set(self, name, value): + # len(tensor) is a bad practice that generates ONNX constants during tracing. + # Although not a problem for the `assert` statement below, torch ONNX exporter + # still raises a misleading warning as it does not this call comes from `assert` + if isinstance(value, Boxes): + data_len = value.tensor.shape[0] + elif isinstance(value, torch.Tensor): + data_len = value.shape[0] + else: + data_len = len(value) + if len(self.batch_extra_fields): + assert ( + len(self) == data_len + ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) + self.batch_extra_fields[name] = value + + def __getattr__(self, name): + if name not in self.batch_extra_fields: + raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) + return self.batch_extra_fields[name] + + def __len__(self): + return len(self.indices) + + def flatten(self): + ret = [] + for _, v in self.batch_extra_fields.items(): + if isinstance(v, (Boxes, Keypoints)): + ret.append(v.tensor) + else: + ret.append(v) + return ret + + @staticmethod + def to_d2_instances_list(instances_list): + """ + Convert InstancesList to List[Instances]. The input `instances_list` can + also be a List[Instances], in this case this method is a non-op. + """ + if not isinstance(instances_list, InstancesList): + assert all(isinstance(x, Instances) for x in instances_list) + return instances_list + + ret = [] + for i, info in enumerate(instances_list.im_info): + instances = Instances(torch.Size([int(info[0].item()), int(info[1].item())])) + + ids = instances_list.indices == i + for k, v in instances_list.batch_extra_fields.items(): + if isinstance(v, torch.Tensor): + instances.set(k, v[ids]) + continue + elif isinstance(v, Boxes): + instances.set(k, v[ids, -4:]) + continue + + target_type, tensor_source = v + assert isinstance(tensor_source, torch.Tensor) + assert tensor_source.shape[0] == instances_list.indices.shape[0] + tensor_source = tensor_source[ids] + + if issubclass(target_type, Boxes): + instances.set(k, Boxes(tensor_source[:, -4:])) + elif issubclass(target_type, Keypoints): + instances.set(k, Keypoints(tensor_source)) + elif issubclass(target_type, torch.Tensor): + instances.set(k, tensor_source) + else: + raise ValueError("Can't handle targe type: {}".format(target_type)) + + ret.append(instances) + return ret + + +class Caffe2Compatible(object): + """ + A model can inherit this class to indicate that it can be traced and deployed with caffe2. + """ + + def _get_tensor_mode(self): + return self._tensor_mode + + def _set_tensor_mode(self, v): + self._tensor_mode = v + + tensor_mode = property(_get_tensor_mode, _set_tensor_mode) + """ + If true, the model expects C2-style tensor only inputs/outputs format. + """ + + +class Caffe2RPN(Caffe2Compatible, rpn.RPN): + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super(Caffe2Compatible, cls).from_config(cfg, input_shape) + assert tuple(cfg.MODEL.RPN.BBOX_REG_WEIGHTS) == (1.0, 1.0, 1.0, 1.0) or tuple( + cfg.MODEL.RPN.BBOX_REG_WEIGHTS + ) == (1.0, 1.0, 1.0, 1.0, 1.0) + return ret + + def _generate_proposals( + self, images, objectness_logits_pred, anchor_deltas_pred, gt_instances=None + ): + assert isinstance(images, ImageList) + if self.tensor_mode: + im_info = images.image_sizes + else: + im_info = torch.tensor([[im_sz[0], im_sz[1], 1.0] for im_sz in images.image_sizes]).to( + images.tensor.device + ) + assert isinstance(im_info, torch.Tensor) + + rpn_rois_list = [] + rpn_roi_probs_list = [] + for scores, bbox_deltas, cell_anchors_tensor, feat_stride in zip( + objectness_logits_pred, + anchor_deltas_pred, + [b for (n, b) in self.anchor_generator.cell_anchors.named_buffers()], + self.anchor_generator.strides, + ): + scores = scores.detach() + bbox_deltas = bbox_deltas.detach() + + rpn_rois, rpn_roi_probs = torch.ops._caffe2.GenerateProposals( + scores, + bbox_deltas, + im_info, + cell_anchors_tensor, + spatial_scale=1.0 / feat_stride, + pre_nms_topN=self.pre_nms_topk[self.training], + post_nms_topN=self.post_nms_topk[self.training], + nms_thresh=self.nms_thresh, + min_size=self.min_box_size, + # correct_transform_coords=True, # deprecated argument + angle_bound_on=True, # Default + angle_bound_lo=-180, + angle_bound_hi=180, + clip_angle_thresh=1.0, # Default + legacy_plus_one=False, + ) + rpn_rois_list.append(rpn_rois) + rpn_roi_probs_list.append(rpn_roi_probs) + + # For FPN in D2, in RPN all proposals from different levels are concated + # together, ranked and picked by top post_nms_topk. Then in ROIPooler + # it calculates level_assignments and calls the RoIAlign from + # the corresponding level. + + if len(objectness_logits_pred) == 1: + rpn_rois = rpn_rois_list[0] + rpn_roi_probs = rpn_roi_probs_list[0] + else: + assert len(rpn_rois_list) == len(rpn_roi_probs_list) + rpn_post_nms_topN = self.post_nms_topk[self.training] + + device = rpn_rois_list[0].device + input_list = [to_device(x, "cpu") for x in (rpn_rois_list + rpn_roi_probs_list)] + + # TODO remove this after confirming rpn_max_level/rpn_min_level + # is not needed in CollectRpnProposals. + feature_strides = list(self.anchor_generator.strides) + rpn_min_level = int(math.log2(feature_strides[0])) + rpn_max_level = int(math.log2(feature_strides[-1])) + assert (rpn_max_level - rpn_min_level + 1) == len( + rpn_rois_list + ), "CollectRpnProposals requires continuous levels" + + rpn_rois = torch.ops._caffe2.CollectRpnProposals( + input_list, + # NOTE: in current implementation, rpn_max_level and rpn_min_level + # are not needed, only the subtraction of two matters and it + # can be infer from the number of inputs. Keep them now for + # consistency. + rpn_max_level=2 + len(rpn_rois_list) - 1, + rpn_min_level=2, + rpn_post_nms_topN=rpn_post_nms_topN, + ) + rpn_rois = to_device(rpn_rois, device) + rpn_roi_probs = [] + + proposals = self.c2_postprocess(im_info, rpn_rois, rpn_roi_probs, self.tensor_mode) + return proposals, {} + + def forward(self, images, features, gt_instances=None): + assert not self.training + features = [features[f] for f in self.in_features] + objectness_logits_pred, anchor_deltas_pred = self.rpn_head(features) + return self._generate_proposals( + images, + objectness_logits_pred, + anchor_deltas_pred, + gt_instances, + ) + + @staticmethod + def c2_postprocess(im_info, rpn_rois, rpn_roi_probs, tensor_mode): + proposals = InstancesList( + im_info=im_info, + indices=rpn_rois[:, 0], + extra_fields={ + "proposal_boxes": Caffe2Boxes(rpn_rois), + "objectness_logits": (torch.Tensor, rpn_roi_probs), + }, + ) + if not tensor_mode: + proposals = InstancesList.to_d2_instances_list(proposals) + else: + proposals = [proposals] + return proposals + + +class Caffe2ROIPooler(Caffe2Compatible, poolers.ROIPooler): + @staticmethod + def c2_preprocess(box_lists): + assert all(isinstance(x, Boxes) for x in box_lists) + if all(isinstance(x, Caffe2Boxes) for x in box_lists): + # input is pure-tensor based + assert len(box_lists) == 1 + pooler_fmt_boxes = box_lists[0].tensor + else: + pooler_fmt_boxes = poolers.convert_boxes_to_pooler_format(box_lists) + return pooler_fmt_boxes + + def forward(self, x, box_lists): + assert not self.training + + pooler_fmt_boxes = self.c2_preprocess(box_lists) + num_level_assignments = len(self.level_poolers) + + if num_level_assignments == 1: + if isinstance(self.level_poolers[0], ROIAlignRotated): + c2_roi_align = torch.ops._caffe2.RoIAlignRotated + aligned = True + else: + c2_roi_align = torch.ops._caffe2.RoIAlign + aligned = self.level_poolers[0].aligned + + x0 = x[0] + if x0.is_quantized: + x0 = x0.dequantize() + + out = c2_roi_align( + x0, + pooler_fmt_boxes, + order="NCHW", + spatial_scale=float(self.level_poolers[0].spatial_scale), + pooled_h=int(self.output_size[0]), + pooled_w=int(self.output_size[1]), + sampling_ratio=int(self.level_poolers[0].sampling_ratio), + aligned=aligned, + ) + return out + + device = pooler_fmt_boxes.device + assert ( + self.max_level - self.min_level + 1 == 4 + ), "Currently DistributeFpnProposals only support 4 levels" + fpn_outputs = torch.ops._caffe2.DistributeFpnProposals( + to_device(pooler_fmt_boxes, "cpu"), + roi_canonical_scale=self.canonical_box_size, + roi_canonical_level=self.canonical_level, + roi_max_level=self.max_level, + roi_min_level=self.min_level, + legacy_plus_one=False, + ) + fpn_outputs = [to_device(x, device) for x in fpn_outputs] + + rois_fpn_list = fpn_outputs[:-1] + rois_idx_restore_int32 = fpn_outputs[-1] + + roi_feat_fpn_list = [] + for roi_fpn, x_level, pooler in zip(rois_fpn_list, x, self.level_poolers): + if isinstance(pooler, ROIAlignRotated): + c2_roi_align = torch.ops._caffe2.RoIAlignRotated + aligned = True + else: + c2_roi_align = torch.ops._caffe2.RoIAlign + aligned = bool(pooler.aligned) + + if x_level.is_quantized: + x_level = x_level.dequantize() + + roi_feat_fpn = c2_roi_align( + x_level, + roi_fpn, + order="NCHW", + spatial_scale=float(pooler.spatial_scale), + pooled_h=int(self.output_size[0]), + pooled_w=int(self.output_size[1]), + sampling_ratio=int(pooler.sampling_ratio), + aligned=aligned, + ) + roi_feat_fpn_list.append(roi_feat_fpn) + + roi_feat_shuffled = cat(roi_feat_fpn_list, dim=0) + assert roi_feat_shuffled.numel() > 0 and rois_idx_restore_int32.numel() > 0, ( + "Caffe2 export requires tracing with a model checkpoint + input that can produce valid" + " detections. But no detections were obtained with the given checkpoint and input!" + ) + roi_feat = torch.ops._caffe2.BatchPermutation(roi_feat_shuffled, rois_idx_restore_int32) + return roi_feat + + +class Caffe2FastRCNNOutputsInference: + def __init__(self, tensor_mode): + self.tensor_mode = tensor_mode # whether the output is caffe2 tensor mode + + def __call__(self, box_predictor, predictions, proposals): + """equivalent to FastRCNNOutputLayers.inference""" + num_classes = box_predictor.num_classes + score_thresh = box_predictor.test_score_thresh + nms_thresh = box_predictor.test_nms_thresh + topk_per_image = box_predictor.test_topk_per_image + is_rotated = len(box_predictor.box2box_transform.weights) == 5 + + if is_rotated: + box_dim = 5 + assert box_predictor.box2box_transform.weights[4] == 1, ( + "The weights for Rotated BBoxTransform in C2 have only 4 dimensions," + + " thus enforcing the angle weight to be 1 for now" + ) + box2box_transform_weights = box_predictor.box2box_transform.weights[:4] + else: + box_dim = 4 + box2box_transform_weights = box_predictor.box2box_transform.weights + + class_logits, box_regression = predictions + if num_classes + 1 == class_logits.shape[1]: + class_prob = F.softmax(class_logits, -1) + else: + assert num_classes == class_logits.shape[1] + class_prob = F.sigmoid(class_logits) + # BoxWithNMSLimit will infer num_classes from the shape of the class_prob + # So append a zero column as placeholder for the background class + class_prob = torch.cat((class_prob, torch.zeros(class_prob.shape[0], 1)), dim=1) + + assert box_regression.shape[1] % box_dim == 0 + cls_agnostic_bbox_reg = box_regression.shape[1] // box_dim == 1 + + input_tensor_mode = proposals[0].proposal_boxes.tensor.shape[1] == box_dim + 1 + + proposal_boxes = proposals[0].proposal_boxes + if isinstance(proposal_boxes, Caffe2Boxes): + rois = Caffe2Boxes.cat([p.proposal_boxes for p in proposals]) + elif isinstance(proposal_boxes, RotatedBoxes): + rois = RotatedBoxes.cat([p.proposal_boxes for p in proposals]) + elif isinstance(proposal_boxes, Boxes): + rois = Boxes.cat([p.proposal_boxes for p in proposals]) + else: + raise NotImplementedError( + 'Expected proposals[0].proposal_boxes to be type "Boxes", ' + f"instead got {type(proposal_boxes)}" + ) + + device, dtype = rois.tensor.device, rois.tensor.dtype + if input_tensor_mode: + im_info = proposals[0].image_size + rois = rois.tensor + else: + im_info = torch.tensor( + [[sz[0], sz[1], 1.0] for sz in [x.image_size for x in proposals]] + ) + batch_ids = cat( + [ + torch.full((b, 1), i, dtype=dtype, device=device) + for i, b in enumerate(len(p) for p in proposals) + ], + dim=0, + ) + rois = torch.cat([batch_ids, rois.tensor], dim=1) + + roi_pred_bbox, roi_batch_splits = torch.ops._caffe2.BBoxTransform( + to_device(rois, "cpu"), + to_device(box_regression, "cpu"), + to_device(im_info, "cpu"), + weights=box2box_transform_weights, + apply_scale=True, + rotated=is_rotated, + angle_bound_on=True, + angle_bound_lo=-180, + angle_bound_hi=180, + clip_angle_thresh=1.0, + legacy_plus_one=False, + ) + roi_pred_bbox = to_device(roi_pred_bbox, device) + roi_batch_splits = to_device(roi_batch_splits, device) + + nms_outputs = torch.ops._caffe2.BoxWithNMSLimit( + to_device(class_prob, "cpu"), + to_device(roi_pred_bbox, "cpu"), + to_device(roi_batch_splits, "cpu"), + score_thresh=float(score_thresh), + nms=float(nms_thresh), + detections_per_im=int(topk_per_image), + soft_nms_enabled=False, + soft_nms_method="linear", + soft_nms_sigma=0.5, + soft_nms_min_score_thres=0.001, + rotated=is_rotated, + cls_agnostic_bbox_reg=cls_agnostic_bbox_reg, + input_boxes_include_bg_cls=False, + output_classes_include_bg_cls=False, + legacy_plus_one=False, + ) + roi_score_nms = to_device(nms_outputs[0], device) + roi_bbox_nms = to_device(nms_outputs[1], device) + roi_class_nms = to_device(nms_outputs[2], device) + roi_batch_splits_nms = to_device(nms_outputs[3], device) + roi_keeps_nms = to_device(nms_outputs[4], device) + roi_keeps_size_nms = to_device(nms_outputs[5], device) + if not self.tensor_mode: + roi_class_nms = roi_class_nms.to(torch.int64) + + roi_batch_ids = cat( + [ + torch.full((b, 1), i, dtype=dtype, device=device) + for i, b in enumerate(int(x.item()) for x in roi_batch_splits_nms) + ], + dim=0, + ) + + roi_class_nms = alias(roi_class_nms, "class_nms") + roi_score_nms = alias(roi_score_nms, "score_nms") + roi_bbox_nms = alias(roi_bbox_nms, "bbox_nms") + roi_batch_splits_nms = alias(roi_batch_splits_nms, "batch_splits_nms") + roi_keeps_nms = alias(roi_keeps_nms, "keeps_nms") + roi_keeps_size_nms = alias(roi_keeps_size_nms, "keeps_size_nms") + + results = InstancesList( + im_info=im_info, + indices=roi_batch_ids[:, 0], + extra_fields={ + "pred_boxes": Caffe2Boxes(roi_bbox_nms), + "scores": roi_score_nms, + "pred_classes": roi_class_nms, + }, + ) + + if not self.tensor_mode: + results = InstancesList.to_d2_instances_list(results) + batch_splits = roi_batch_splits_nms.int().tolist() + kept_indices = list(roi_keeps_nms.to(torch.int64).split(batch_splits)) + else: + results = [results] + kept_indices = [roi_keeps_nms] + + return results, kept_indices + + +class Caffe2MaskRCNNInference: + def __call__(self, pred_mask_logits, pred_instances): + """equivalent to mask_head.mask_rcnn_inference""" + if all(isinstance(x, InstancesList) for x in pred_instances): + assert len(pred_instances) == 1 + mask_probs_pred = pred_mask_logits.sigmoid() + mask_probs_pred = alias(mask_probs_pred, "mask_fcn_probs") + pred_instances[0].set("pred_masks", mask_probs_pred) + else: + mask_rcnn_inference(pred_mask_logits, pred_instances) + + +class Caffe2KeypointRCNNInference: + def __init__(self, use_heatmap_max_keypoint): + self.use_heatmap_max_keypoint = use_heatmap_max_keypoint + + def __call__(self, pred_keypoint_logits, pred_instances): + # just return the keypoint heatmap for now, + # there will be option to call HeatmapMaxKeypointOp + output = alias(pred_keypoint_logits, "kps_score") + if all(isinstance(x, InstancesList) for x in pred_instances): + assert len(pred_instances) == 1 + if self.use_heatmap_max_keypoint: + device = output.device + output = torch.ops._caffe2.HeatmapMaxKeypoint( + to_device(output, "cpu"), + pred_instances[0].pred_boxes.tensor, + should_output_softmax=True, # worth make it configerable? + ) + output = to_device(output, device) + output = alias(output, "keypoints_out") + pred_instances[0].set("pred_keypoints", output) + return pred_keypoint_logits diff --git a/custom_detectron2/export/caffe2_export.py b/custom_detectron2/export/caffe2_export.py new file mode 100644 index 0000000000000000000000000000000000000000..d609c27c7deb396352967dbcbc79b1e00f2a2de1 --- /dev/null +++ b/custom_detectron2/export/caffe2_export.py @@ -0,0 +1,203 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import copy +import io +import logging +import numpy as np +from typing import List +import onnx +import onnx.optimizer +import torch +from caffe2.proto import caffe2_pb2 +from caffe2.python import core +from caffe2.python.onnx.backend import Caffe2Backend +from tabulate import tabulate +from termcolor import colored +from torch.onnx import OperatorExportTypes + +from .shared import ( + ScopedWS, + construct_init_net_from_params, + fuse_alias_placeholder, + fuse_copy_between_cpu_and_gpu, + get_params_from_init_net, + group_norm_replace_aten_with_caffe2, + infer_device_type, + remove_dead_end_ops, + remove_reshape_for_fc, + save_graph, +) + +logger = logging.getLogger(__name__) + + +def export_onnx_model(model, inputs): + """ + Trace and export a model to onnx format. + + Args: + model (nn.Module): + inputs (tuple[args]): the model will be called by `model(*inputs)` + + Returns: + an onnx model + """ + assert isinstance(model, torch.nn.Module) + + # make sure all modules are in eval mode, onnx may change the training state + # of the module if the states are not consistent + def _check_eval(module): + assert not module.training + + model.apply(_check_eval) + + # Export the model to ONNX + with torch.no_grad(): + with io.BytesIO() as f: + torch.onnx.export( + model, + inputs, + f, + operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK, + # verbose=True, # NOTE: uncomment this for debugging + # export_params=True, + ) + onnx_model = onnx.load_from_string(f.getvalue()) + + return onnx_model + + +def _op_stats(net_def): + type_count = {} + for t in [op.type for op in net_def.op]: + type_count[t] = type_count.get(t, 0) + 1 + type_count_list = sorted(type_count.items(), key=lambda kv: kv[0]) # alphabet + type_count_list = sorted(type_count_list, key=lambda kv: -kv[1]) # count + return "\n".join("{:>4}x {}".format(count, name) for name, count in type_count_list) + + +def _assign_device_option( + predict_net: caffe2_pb2.NetDef, init_net: caffe2_pb2.NetDef, tensor_inputs: List[torch.Tensor] +): + """ + ONNX exported network doesn't have concept of device, assign necessary + device option for each op in order to make it runable on GPU runtime. + """ + + def _get_device_type(torch_tensor): + assert torch_tensor.device.type in ["cpu", "cuda"] + assert torch_tensor.device.index == 0 + return torch_tensor.device.type + + def _assign_op_device_option(net_proto, net_ssa, blob_device_types): + for op, ssa_i in zip(net_proto.op, net_ssa): + if op.type in ["CopyCPUToGPU", "CopyGPUToCPU"]: + op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0)) + else: + devices = [blob_device_types[b] for b in ssa_i[0] + ssa_i[1]] + assert all(d == devices[0] for d in devices) + if devices[0] == "cuda": + op.device_option.CopyFrom(core.DeviceOption(caffe2_pb2.CUDA, 0)) + + # update ops in predict_net + predict_net_input_device_types = { + (name, 0): _get_device_type(tensor) + for name, tensor in zip(predict_net.external_input, tensor_inputs) + } + predict_net_device_types = infer_device_type( + predict_net, known_status=predict_net_input_device_types, device_name_style="pytorch" + ) + predict_net_ssa, _ = core.get_ssa(predict_net) + _assign_op_device_option(predict_net, predict_net_ssa, predict_net_device_types) + + # update ops in init_net + init_net_ssa, versions = core.get_ssa(init_net) + init_net_output_device_types = { + (name, versions[name]): predict_net_device_types[(name, 0)] + for name in init_net.external_output + } + init_net_device_types = infer_device_type( + init_net, known_status=init_net_output_device_types, device_name_style="pytorch" + ) + _assign_op_device_option(init_net, init_net_ssa, init_net_device_types) + + +def export_caffe2_detection_model(model: torch.nn.Module, tensor_inputs: List[torch.Tensor]): + """ + Export a caffe2-compatible Detectron2 model to caffe2 format via ONNX. + + Arg: + model: a caffe2-compatible version of detectron2 model, defined in caffe2_modeling.py + tensor_inputs: a list of tensors that caffe2 model takes as input. + """ + model = copy.deepcopy(model) + assert isinstance(model, torch.nn.Module) + assert hasattr(model, "encode_additional_info") + + # Export via ONNX + logger.info( + "Exporting a {} model via ONNX ...".format(type(model).__name__) + + " Some warnings from ONNX are expected and are usually not to worry about." + ) + onnx_model = export_onnx_model(model, (tensor_inputs,)) + # Convert ONNX model to Caffe2 protobuf + init_net, predict_net = Caffe2Backend.onnx_graph_to_caffe2_net(onnx_model) + ops_table = [[op.type, op.input, op.output] for op in predict_net.op] + table = tabulate(ops_table, headers=["type", "input", "output"], tablefmt="pipe") + logger.info( + "ONNX export Done. Exported predict_net (before optimizations):\n" + colored(table, "cyan") + ) + + # Apply protobuf optimization + fuse_alias_placeholder(predict_net, init_net) + if any(t.device.type != "cpu" for t in tensor_inputs): + fuse_copy_between_cpu_and_gpu(predict_net) + remove_dead_end_ops(init_net) + _assign_device_option(predict_net, init_net, tensor_inputs) + params, device_options = get_params_from_init_net(init_net) + predict_net, params = remove_reshape_for_fc(predict_net, params) + init_net = construct_init_net_from_params(params, device_options) + group_norm_replace_aten_with_caffe2(predict_net) + + # Record necessary information for running the pb model in Detectron2 system. + model.encode_additional_info(predict_net, init_net) + + logger.info("Operators used in predict_net: \n{}".format(_op_stats(predict_net))) + logger.info("Operators used in init_net: \n{}".format(_op_stats(init_net))) + + return predict_net, init_net + + +def run_and_save_graph(predict_net, init_net, tensor_inputs, graph_save_path): + """ + Run the caffe2 model on given inputs, recording the shape and draw the graph. + + predict_net/init_net: caffe2 model. + tensor_inputs: a list of tensors that caffe2 model takes as input. + graph_save_path: path for saving graph of exported model. + """ + + logger.info("Saving graph of ONNX exported model to {} ...".format(graph_save_path)) + save_graph(predict_net, graph_save_path, op_only=False) + + # Run the exported Caffe2 net + logger.info("Running ONNX exported model ...") + with ScopedWS("__ws_tmp__", True) as ws: + ws.RunNetOnce(init_net) + initialized_blobs = set(ws.Blobs()) + uninitialized = [inp for inp in predict_net.external_input if inp not in initialized_blobs] + for name, blob in zip(uninitialized, tensor_inputs): + ws.FeedBlob(name, blob) + + try: + ws.RunNetOnce(predict_net) + except RuntimeError as e: + logger.warning("Encountered RuntimeError: \n{}".format(str(e))) + + ws_blobs = {b: ws.FetchBlob(b) for b in ws.Blobs()} + blob_sizes = {b: ws_blobs[b].shape for b in ws_blobs if isinstance(ws_blobs[b], np.ndarray)} + + logger.info("Saving graph with blob shapes to {} ...".format(graph_save_path)) + save_graph(predict_net, graph_save_path, op_only=False, blob_sizes=blob_sizes) + + return ws_blobs diff --git a/custom_detectron2/export/caffe2_inference.py b/custom_detectron2/export/caffe2_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..deb886c0417285ed1d5ad85eb941fa1ac757cdab --- /dev/null +++ b/custom_detectron2/export/caffe2_inference.py @@ -0,0 +1,161 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +import numpy as np +from itertools import count +import torch +from caffe2.proto import caffe2_pb2 +from caffe2.python import core + +from .caffe2_modeling import META_ARCH_CAFFE2_EXPORT_TYPE_MAP, convert_batched_inputs_to_c2_format +from .shared import ScopedWS, get_pb_arg_vali, get_pb_arg_vals, infer_device_type + +logger = logging.getLogger(__name__) + + +# ===== ref: mobile-vision predictor's 'Caffe2Wrapper' class ====== +class ProtobufModel(torch.nn.Module): + """ + Wrapper of a caffe2's protobuf model. + It works just like nn.Module, but running caffe2 under the hood. + Input/Output are tuple[tensor] that match the caffe2 net's external_input/output. + """ + + _ids = count(0) + + def __init__(self, predict_net, init_net): + logger.info(f"Initializing ProtobufModel for: {predict_net.name} ...") + super().__init__() + assert isinstance(predict_net, caffe2_pb2.NetDef) + assert isinstance(init_net, caffe2_pb2.NetDef) + # create unique temporary workspace for each instance + self.ws_name = "__tmp_ProtobufModel_{}__".format(next(self._ids)) + self.net = core.Net(predict_net) + + logger.info("Running init_net once to fill the parameters ...") + with ScopedWS(self.ws_name, is_reset=True, is_cleanup=False) as ws: + ws.RunNetOnce(init_net) + uninitialized_external_input = [] + for blob in self.net.Proto().external_input: + if blob not in ws.Blobs(): + uninitialized_external_input.append(blob) + ws.CreateBlob(blob) + ws.CreateNet(self.net) + + self._error_msgs = set() + self._input_blobs = uninitialized_external_input + + def _infer_output_devices(self, inputs): + """ + Returns: + list[str]: list of device for each external output + """ + + def _get_device_type(torch_tensor): + assert torch_tensor.device.type in ["cpu", "cuda"] + assert torch_tensor.device.index == 0 + return torch_tensor.device.type + + predict_net = self.net.Proto() + input_device_types = { + (name, 0): _get_device_type(tensor) for name, tensor in zip(self._input_blobs, inputs) + } + device_type_map = infer_device_type( + predict_net, known_status=input_device_types, device_name_style="pytorch" + ) + ssa, versions = core.get_ssa(predict_net) + versioned_outputs = [(name, versions[name]) for name in predict_net.external_output] + output_devices = [device_type_map[outp] for outp in versioned_outputs] + return output_devices + + def forward(self, inputs): + """ + Args: + inputs (tuple[torch.Tensor]) + + Returns: + tuple[torch.Tensor] + """ + assert len(inputs) == len(self._input_blobs), ( + f"Length of inputs ({len(inputs)}) " + f"doesn't match the required input blobs: {self._input_blobs}" + ) + + with ScopedWS(self.ws_name, is_reset=False, is_cleanup=False) as ws: + for b, tensor in zip(self._input_blobs, inputs): + ws.FeedBlob(b, tensor) + + try: + ws.RunNet(self.net.Proto().name) + except RuntimeError as e: + if not str(e) in self._error_msgs: + self._error_msgs.add(str(e)) + logger.warning("Encountered new RuntimeError: \n{}".format(str(e))) + logger.warning("Catch the error and use partial results.") + + c2_outputs = [ws.FetchBlob(b) for b in self.net.Proto().external_output] + # Remove outputs of current run, this is necessary in order to + # prevent fetching the result from previous run if the model fails + # in the middle. + for b in self.net.Proto().external_output: + # Needs to create uninitialized blob to make the net runable. + # This is "equivalent" to: ws.RemoveBlob(b) then ws.CreateBlob(b), + # but there'no such API. + ws.FeedBlob(b, f"{b}, a C++ native class of type nullptr (uninitialized).") + + # Cast output to torch.Tensor on the desired device + output_devices = ( + self._infer_output_devices(inputs) + if any(t.device.type != "cpu" for t in inputs) + else ["cpu" for _ in self.net.Proto().external_output] + ) + + outputs = [] + for name, c2_output, device in zip( + self.net.Proto().external_output, c2_outputs, output_devices + ): + if not isinstance(c2_output, np.ndarray): + raise RuntimeError( + "Invalid output for blob {}, received: {}".format(name, c2_output) + ) + outputs.append(torch.tensor(c2_output).to(device=device)) + return tuple(outputs) + + +class ProtobufDetectionModel(torch.nn.Module): + """ + A class works just like a pytorch meta arch in terms of inference, but running + caffe2 model under the hood. + """ + + def __init__(self, predict_net, init_net, *, convert_outputs=None): + """ + Args: + predict_net, init_net (core.Net): caffe2 nets + convert_outptus (callable): a function that converts caffe2 + outputs to the same format of the original pytorch model. + By default, use the one defined in the caffe2 meta_arch. + """ + super().__init__() + self.protobuf_model = ProtobufModel(predict_net, init_net) + self.size_divisibility = get_pb_arg_vali(predict_net, "size_divisibility", 0) + self.device = get_pb_arg_vals(predict_net, "device", b"cpu").decode("ascii") + + if convert_outputs is None: + meta_arch = get_pb_arg_vals(predict_net, "meta_architecture", b"GeneralizedRCNN") + meta_arch = META_ARCH_CAFFE2_EXPORT_TYPE_MAP[meta_arch.decode("ascii")] + self._convert_outputs = meta_arch.get_outputs_converter(predict_net, init_net) + else: + self._convert_outputs = convert_outputs + + def _convert_inputs(self, batched_inputs): + # currently all models convert inputs in the same way + return convert_batched_inputs_to_c2_format( + batched_inputs, self.size_divisibility, self.device + ) + + def forward(self, batched_inputs): + c2_inputs = self._convert_inputs(batched_inputs) + c2_results = self.protobuf_model(c2_inputs) + c2_results = dict(zip(self.protobuf_model.net.Proto().external_output, c2_results)) + return self._convert_outputs(batched_inputs, c2_inputs, c2_results) diff --git a/custom_detectron2/export/caffe2_modeling.py b/custom_detectron2/export/caffe2_modeling.py new file mode 100644 index 0000000000000000000000000000000000000000..050751370255a8986bf75d02da38536f9abe9065 --- /dev/null +++ b/custom_detectron2/export/caffe2_modeling.py @@ -0,0 +1,419 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import functools +import io +import struct +import types +import torch + +from custom_detectron2.modeling import meta_arch +from custom_detectron2.modeling.box_regression import Box2BoxTransform +from custom_detectron2.modeling.roi_heads import keypoint_head +from custom_detectron2.structures import Boxes, ImageList, Instances, RotatedBoxes + +from .c10 import Caffe2Compatible +from .caffe2_patch import ROIHeadsPatcher, patch_generalized_rcnn +from .shared import ( + alias, + check_set_pb_arg, + get_pb_arg_floats, + get_pb_arg_valf, + get_pb_arg_vali, + get_pb_arg_vals, + mock_torch_nn_functional_interpolate, +) + + +def assemble_rcnn_outputs_by_name(image_sizes, tensor_outputs, force_mask_on=False): + """ + A function to assemble caffe2 model's outputs (i.e. Dict[str, Tensor]) + to detectron2's format (i.e. list of Instances instance). + This only works when the model follows the Caffe2 detectron's naming convention. + + Args: + image_sizes (List[List[int, int]]): [H, W] of every image. + tensor_outputs (Dict[str, Tensor]): external_output to its tensor. + + force_mask_on (Bool): if true, the it make sure there'll be pred_masks even + if the mask is not found from tensor_outputs (usually due to model crash) + """ + + results = [Instances(image_size) for image_size in image_sizes] + + batch_splits = tensor_outputs.get("batch_splits", None) + if batch_splits: + raise NotImplementedError() + assert len(image_sizes) == 1 + result = results[0] + + bbox_nms = tensor_outputs["bbox_nms"] + score_nms = tensor_outputs["score_nms"] + class_nms = tensor_outputs["class_nms"] + # Detection will always success because Conv support 0-batch + assert bbox_nms is not None + assert score_nms is not None + assert class_nms is not None + if bbox_nms.shape[1] == 5: + result.pred_boxes = RotatedBoxes(bbox_nms) + else: + result.pred_boxes = Boxes(bbox_nms) + result.scores = score_nms + result.pred_classes = class_nms.to(torch.int64) + + mask_fcn_probs = tensor_outputs.get("mask_fcn_probs", None) + if mask_fcn_probs is not None: + # finish the mask pred + mask_probs_pred = mask_fcn_probs + num_masks = mask_probs_pred.shape[0] + class_pred = result.pred_classes + indices = torch.arange(num_masks, device=class_pred.device) + mask_probs_pred = mask_probs_pred[indices, class_pred][:, None] + result.pred_masks = mask_probs_pred + elif force_mask_on: + # NOTE: there's no way to know the height/width of mask here, it won't be + # used anyway when batch size is 0, so just set them to 0. + result.pred_masks = torch.zeros([0, 1, 0, 0], dtype=torch.uint8) + + keypoints_out = tensor_outputs.get("keypoints_out", None) + kps_score = tensor_outputs.get("kps_score", None) + if keypoints_out is not None: + # keypoints_out: [N, 4, #kypoints], where 4 is in order of (x, y, score, prob) + keypoints_tensor = keypoints_out + # NOTE: it's possible that prob is not calculated if "should_output_softmax" + # is set to False in HeatmapMaxKeypoint, so just using raw score, seems + # it doesn't affect mAP. TODO: check more carefully. + keypoint_xyp = keypoints_tensor.transpose(1, 2)[:, :, [0, 1, 2]] + result.pred_keypoints = keypoint_xyp + elif kps_score is not None: + # keypoint heatmap to sparse data structure + pred_keypoint_logits = kps_score + keypoint_head.keypoint_rcnn_inference(pred_keypoint_logits, [result]) + + return results + + +def _cast_to_f32(f64): + return struct.unpack("f", struct.pack("f", f64))[0] + + +def set_caffe2_compatible_tensor_mode(model, enable=True): + def _fn(m): + if isinstance(m, Caffe2Compatible): + m.tensor_mode = enable + + model.apply(_fn) + + +def convert_batched_inputs_to_c2_format(batched_inputs, size_divisibility, device): + """ + See get_caffe2_inputs() below. + """ + assert all(isinstance(x, dict) for x in batched_inputs) + assert all(x["image"].dim() == 3 for x in batched_inputs) + + images = [x["image"] for x in batched_inputs] + images = ImageList.from_tensors(images, size_divisibility) + + im_info = [] + for input_per_image, image_size in zip(batched_inputs, images.image_sizes): + target_height = input_per_image.get("height", image_size[0]) + target_width = input_per_image.get("width", image_size[1]) # noqa + # NOTE: The scale inside im_info is kept as convention and for providing + # post-processing information if further processing is needed. For + # current Caffe2 model definitions that don't include post-processing inside + # the model, this number is not used. + # NOTE: There can be a slight difference between width and height + # scales, using a single number can results in numerical difference + # compared with D2's post-processing. + scale = target_height / image_size[0] + im_info.append([image_size[0], image_size[1], scale]) + im_info = torch.Tensor(im_info) + + return images.tensor.to(device), im_info.to(device) + + +class Caffe2MetaArch(Caffe2Compatible, torch.nn.Module): + """ + Base class for caffe2-compatible implementation of a meta architecture. + The forward is traceable and its traced graph can be converted to caffe2 + graph through ONNX. + """ + + def __init__(self, cfg, torch_model): + """ + Args: + cfg (CfgNode): + torch_model (nn.Module): the detectron2 model (meta_arch) to be + converted. + """ + super().__init__() + self._wrapped_model = torch_model + self.eval() + set_caffe2_compatible_tensor_mode(self, True) + + def get_caffe2_inputs(self, batched_inputs): + """ + Convert pytorch-style structured inputs to caffe2-style inputs that + are tuples of tensors. + + Args: + batched_inputs (list[dict]): inputs to a detectron2 model + in its standard format. Each dict has "image" (CHW tensor), and optionally + "height" and "width". + + Returns: + tuple[Tensor]: + tuple of tensors that will be the inputs to the + :meth:`forward` method. For existing models, the first + is an NCHW tensor (padded and batched); the second is + a im_info Nx3 tensor, where the rows are + (height, width, unused legacy parameter) + """ + return convert_batched_inputs_to_c2_format( + batched_inputs, + self._wrapped_model.backbone.size_divisibility, + self._wrapped_model.device, + ) + + def encode_additional_info(self, predict_net, init_net): + """ + Save extra metadata that will be used by inference in the output protobuf. + """ + pass + + def forward(self, inputs): + """ + Run the forward in caffe2-style. It has to use caffe2-compatible ops + and the method will be used for tracing. + + Args: + inputs (tuple[Tensor]): inputs defined by :meth:`get_caffe2_input`. + They will be the inputs of the converted caffe2 graph. + + Returns: + tuple[Tensor]: output tensors. They will be the outputs of the + converted caffe2 graph. + """ + raise NotImplementedError + + def _caffe2_preprocess_image(self, inputs): + """ + Caffe2 implementation of preprocess_image, which is called inside each MetaArch's forward. + It normalizes the input images, and the final caffe2 graph assumes the + inputs have been batched already. + """ + data, im_info = inputs + data = alias(data, "data") + im_info = alias(im_info, "im_info") + mean, std = self._wrapped_model.pixel_mean, self._wrapped_model.pixel_std + normalized_data = (data - mean) / std + normalized_data = alias(normalized_data, "normalized_data") + + # Pack (data, im_info) into ImageList which is recognized by self.inference. + images = ImageList(tensor=normalized_data, image_sizes=im_info) + return images + + @staticmethod + def get_outputs_converter(predict_net, init_net): + """ + Creates a function that converts outputs of the caffe2 model to + detectron2's standard format. + The function uses information in `predict_net` and `init_net` that are + available at inferene time. Therefore the function logic can be used in inference. + + The returned function has the following signature: + + def convert(batched_inputs, c2_inputs, c2_results) -> detectron2_outputs + + Where + + * batched_inputs (list[dict]): the original input format of the meta arch + * c2_inputs (tuple[Tensor]): the caffe2 inputs. + * c2_results (dict[str, Tensor]): the caffe2 output format, + corresponding to the outputs of the :meth:`forward` function. + * detectron2_outputs: the original output format of the meta arch. + + This function can be used to compare the outputs of the original meta arch and + the converted caffe2 graph. + + Returns: + callable: a callable of the above signature. + """ + raise NotImplementedError + + +class Caffe2GeneralizedRCNN(Caffe2MetaArch): + def __init__(self, cfg, torch_model): + assert isinstance(torch_model, meta_arch.GeneralizedRCNN) + torch_model = patch_generalized_rcnn(torch_model) + super().__init__(cfg, torch_model) + + try: + use_heatmap_max_keypoint = cfg.EXPORT_CAFFE2.USE_HEATMAP_MAX_KEYPOINT + except AttributeError: + use_heatmap_max_keypoint = False + self.roi_heads_patcher = ROIHeadsPatcher( + self._wrapped_model.roi_heads, use_heatmap_max_keypoint + ) + + def encode_additional_info(self, predict_net, init_net): + size_divisibility = self._wrapped_model.backbone.size_divisibility + check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility) + check_set_pb_arg( + predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii") + ) + check_set_pb_arg(predict_net, "meta_architecture", "s", b"GeneralizedRCNN") + + @mock_torch_nn_functional_interpolate() + def forward(self, inputs): + if not self.tensor_mode: + return self._wrapped_model.inference(inputs) + images = self._caffe2_preprocess_image(inputs) + features = self._wrapped_model.backbone(images.tensor) + proposals, _ = self._wrapped_model.proposal_generator(images, features) + with self.roi_heads_patcher.mock_roi_heads(): + detector_results, _ = self._wrapped_model.roi_heads(images, features, proposals) + return tuple(detector_results[0].flatten()) + + @staticmethod + def get_outputs_converter(predict_net, init_net): + def f(batched_inputs, c2_inputs, c2_results): + _, im_info = c2_inputs + image_sizes = [[int(im[0]), int(im[1])] for im in im_info] + results = assemble_rcnn_outputs_by_name(image_sizes, c2_results) + return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes) + + return f + + +class Caffe2RetinaNet(Caffe2MetaArch): + def __init__(self, cfg, torch_model): + assert isinstance(torch_model, meta_arch.RetinaNet) + super().__init__(cfg, torch_model) + + @mock_torch_nn_functional_interpolate() + def forward(self, inputs): + assert self.tensor_mode + images = self._caffe2_preprocess_image(inputs) + + # explicitly return the images sizes to avoid removing "im_info" by ONNX + # since it's not used in the forward path + return_tensors = [images.image_sizes] + + features = self._wrapped_model.backbone(images.tensor) + features = [features[f] for f in self._wrapped_model.head_in_features] + for i, feature_i in enumerate(features): + features[i] = alias(feature_i, "feature_{}".format(i), is_backward=True) + return_tensors.append(features[i]) + + pred_logits, pred_anchor_deltas = self._wrapped_model.head(features) + for i, (box_cls_i, box_delta_i) in enumerate(zip(pred_logits, pred_anchor_deltas)): + return_tensors.append(alias(box_cls_i, "box_cls_{}".format(i))) + return_tensors.append(alias(box_delta_i, "box_delta_{}".format(i))) + + return tuple(return_tensors) + + def encode_additional_info(self, predict_net, init_net): + size_divisibility = self._wrapped_model.backbone.size_divisibility + check_set_pb_arg(predict_net, "size_divisibility", "i", size_divisibility) + check_set_pb_arg( + predict_net, "device", "s", str.encode(str(self._wrapped_model.device), "ascii") + ) + check_set_pb_arg(predict_net, "meta_architecture", "s", b"RetinaNet") + + # Inference parameters: + check_set_pb_arg( + predict_net, "score_threshold", "f", _cast_to_f32(self._wrapped_model.test_score_thresh) + ) + check_set_pb_arg( + predict_net, "topk_candidates", "i", self._wrapped_model.test_topk_candidates + ) + check_set_pb_arg( + predict_net, "nms_threshold", "f", _cast_to_f32(self._wrapped_model.test_nms_thresh) + ) + check_set_pb_arg( + predict_net, + "max_detections_per_image", + "i", + self._wrapped_model.max_detections_per_image, + ) + + check_set_pb_arg( + predict_net, + "bbox_reg_weights", + "floats", + [_cast_to_f32(w) for w in self._wrapped_model.box2box_transform.weights], + ) + self._encode_anchor_generator_cfg(predict_net) + + def _encode_anchor_generator_cfg(self, predict_net): + # serialize anchor_generator for future use + serialized_anchor_generator = io.BytesIO() + torch.save(self._wrapped_model.anchor_generator, serialized_anchor_generator) + # Ideally we can put anchor generating inside the model, then we don't + # need to store this information. + bytes = serialized_anchor_generator.getvalue() + check_set_pb_arg(predict_net, "serialized_anchor_generator", "s", bytes) + + @staticmethod + def get_outputs_converter(predict_net, init_net): + self = types.SimpleNamespace() + serialized_anchor_generator = io.BytesIO( + get_pb_arg_vals(predict_net, "serialized_anchor_generator", None) + ) + self.anchor_generator = torch.load(serialized_anchor_generator) + bbox_reg_weights = get_pb_arg_floats(predict_net, "bbox_reg_weights", None) + self.box2box_transform = Box2BoxTransform(weights=tuple(bbox_reg_weights)) + self.test_score_thresh = get_pb_arg_valf(predict_net, "score_threshold", None) + self.test_topk_candidates = get_pb_arg_vali(predict_net, "topk_candidates", None) + self.test_nms_thresh = get_pb_arg_valf(predict_net, "nms_threshold", None) + self.max_detections_per_image = get_pb_arg_vali( + predict_net, "max_detections_per_image", None + ) + + # hack to reuse inference code from RetinaNet + for meth in [ + "forward_inference", + "inference_single_image", + "_transpose_dense_predictions", + "_decode_multi_level_predictions", + "_decode_per_level_predictions", + ]: + setattr(self, meth, functools.partial(getattr(meta_arch.RetinaNet, meth), self)) + + def f(batched_inputs, c2_inputs, c2_results): + _, im_info = c2_inputs + image_sizes = [[int(im[0]), int(im[1])] for im in im_info] + dummy_images = ImageList( + torch.randn( + ( + len(im_info), + 3, + ) + + tuple(image_sizes[0]) + ), + image_sizes, + ) + + num_features = len([x for x in c2_results.keys() if x.startswith("box_cls_")]) + pred_logits = [c2_results["box_cls_{}".format(i)] for i in range(num_features)] + pred_anchor_deltas = [c2_results["box_delta_{}".format(i)] for i in range(num_features)] + + # For each feature level, feature should have the same batch size and + # spatial dimension as the box_cls and box_delta. + dummy_features = [x.clone()[:, 0:0, :, :] for x in pred_logits] + # self.num_classess can be inferred + self.num_classes = pred_logits[0].shape[1] // (pred_anchor_deltas[0].shape[1] // 4) + + results = self.forward_inference( + dummy_images, dummy_features, [pred_logits, pred_anchor_deltas] + ) + return meta_arch.GeneralizedRCNN._postprocess(results, batched_inputs, image_sizes) + + return f + + +META_ARCH_CAFFE2_EXPORT_TYPE_MAP = { + "GeneralizedRCNN": Caffe2GeneralizedRCNN, + "RetinaNet": Caffe2RetinaNet, +} diff --git a/custom_detectron2/export/caffe2_patch.py b/custom_detectron2/export/caffe2_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..40d50429be5666006a760a3add98981a3d9b78c4 --- /dev/null +++ b/custom_detectron2/export/caffe2_patch.py @@ -0,0 +1,152 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import contextlib +from unittest import mock +import torch + +from custom_detectron2.modeling import poolers +from custom_detectron2.modeling.proposal_generator import rpn +from custom_detectron2.modeling.roi_heads import keypoint_head, mask_head +from custom_detectron2.modeling.roi_heads.fast_rcnn import FastRCNNOutputLayers + +from .c10 import ( + Caffe2Compatible, + Caffe2FastRCNNOutputsInference, + Caffe2KeypointRCNNInference, + Caffe2MaskRCNNInference, + Caffe2ROIPooler, + Caffe2RPN, +) + + +class GenericMixin(object): + pass + + +class Caffe2CompatibleConverter(object): + """ + A GenericUpdater which implements the `create_from` interface, by modifying + module object and assign it with another class replaceCls. + """ + + def __init__(self, replaceCls): + self.replaceCls = replaceCls + + def create_from(self, module): + # update module's class to the new class + assert isinstance(module, torch.nn.Module) + if issubclass(self.replaceCls, GenericMixin): + # replaceCls should act as mixin, create a new class on-the-fly + new_class = type( + "{}MixedWith{}".format(self.replaceCls.__name__, module.__class__.__name__), + (self.replaceCls, module.__class__), + {}, # {"new_method": lambda self: ...}, + ) + module.__class__ = new_class + else: + # replaceCls is complete class, this allow arbitrary class swap + module.__class__ = self.replaceCls + + # initialize Caffe2Compatible + if isinstance(module, Caffe2Compatible): + module.tensor_mode = False + + return module + + +def patch(model, target, updater, *args, **kwargs): + """ + recursively (post-order) update all modules with the target type and its + subclasses, make a initialization/composition/inheritance/... via the + updater.create_from. + """ + for name, module in model.named_children(): + model._modules[name] = patch(module, target, updater, *args, **kwargs) + if isinstance(model, target): + return updater.create_from(model, *args, **kwargs) + return model + + +def patch_generalized_rcnn(model): + ccc = Caffe2CompatibleConverter + model = patch(model, rpn.RPN, ccc(Caffe2RPN)) + model = patch(model, poolers.ROIPooler, ccc(Caffe2ROIPooler)) + + return model + + +@contextlib.contextmanager +def mock_fastrcnn_outputs_inference( + tensor_mode, check=True, box_predictor_type=FastRCNNOutputLayers +): + with mock.patch.object( + box_predictor_type, + "inference", + autospec=True, + side_effect=Caffe2FastRCNNOutputsInference(tensor_mode), + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +@contextlib.contextmanager +def mock_mask_rcnn_inference(tensor_mode, patched_module, check=True): + with mock.patch( + "{}.mask_rcnn_inference".format(patched_module), side_effect=Caffe2MaskRCNNInference() + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +@contextlib.contextmanager +def mock_keypoint_rcnn_inference(tensor_mode, patched_module, use_heatmap_max_keypoint, check=True): + with mock.patch( + "{}.keypoint_rcnn_inference".format(patched_module), + side_effect=Caffe2KeypointRCNNInference(use_heatmap_max_keypoint), + ) as mocked_func: + yield + if check: + assert mocked_func.call_count > 0 + + +class ROIHeadsPatcher: + def __init__(self, heads, use_heatmap_max_keypoint): + self.heads = heads + self.use_heatmap_max_keypoint = use_heatmap_max_keypoint + + @contextlib.contextmanager + def mock_roi_heads(self, tensor_mode=True): + """ + Patching several inference functions inside ROIHeads and its subclasses + + Args: + tensor_mode (bool): whether the inputs/outputs are caffe2's tensor + format or not. Default to True. + """ + # NOTE: this requries the `keypoint_rcnn_inference` and `mask_rcnn_inference` + # are called inside the same file as BaseXxxHead due to using mock.patch. + kpt_heads_mod = keypoint_head.BaseKeypointRCNNHead.__module__ + mask_head_mod = mask_head.BaseMaskRCNNHead.__module__ + + mock_ctx_managers = [ + mock_fastrcnn_outputs_inference( + tensor_mode=tensor_mode, + check=True, + box_predictor_type=type(self.heads.box_predictor), + ) + ] + if getattr(self.heads, "keypoint_on", False): + mock_ctx_managers += [ + mock_keypoint_rcnn_inference( + tensor_mode, kpt_heads_mod, self.use_heatmap_max_keypoint + ) + ] + if getattr(self.heads, "mask_on", False): + mock_ctx_managers += [mock_mask_rcnn_inference(tensor_mode, mask_head_mod)] + + with contextlib.ExitStack() as stack: # python 3.3+ + for mgr in mock_ctx_managers: + stack.enter_context(mgr) + yield diff --git a/custom_detectron2/export/flatten.py b/custom_detectron2/export/flatten.py new file mode 100644 index 0000000000000000000000000000000000000000..36c757b82ff1b2a106725c14ae959f08f035b6ba --- /dev/null +++ b/custom_detectron2/export/flatten.py @@ -0,0 +1,330 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import collections +from dataclasses import dataclass +from typing import Callable, List, Optional, Tuple +import torch +from torch import nn + +from custom_detectron2.structures import Boxes, Instances, ROIMasks +from custom_detectron2.utils.registry import _convert_target_to_string, locate + +from .torchscript_patch import patch_builtin_len + + +@dataclass +class Schema: + """ + A Schema defines how to flatten a possibly hierarchical object into tuple of + primitive objects, so it can be used as inputs/outputs of PyTorch's tracing. + + PyTorch does not support tracing a function that produces rich output + structures (e.g. dict, Instances, Boxes). To trace such a function, we + flatten the rich object into tuple of tensors, and return this tuple of tensors + instead. Meanwhile, we also need to know how to "rebuild" the original object + from the flattened results, so we can evaluate the flattened results. + A Schema defines how to flatten an object, and while flattening it, it records + necessary schemas so that the object can be rebuilt using the flattened outputs. + + The flattened object and the schema object is returned by ``.flatten`` classmethod. + Then the original object can be rebuilt with the ``__call__`` method of schema. + + A Schema is a dataclass that can be serialized easily. + """ + + # inspired by FetchMapper in tensorflow/python/client/session.py + + @classmethod + def flatten(cls, obj): + raise NotImplementedError + + def __call__(self, values): + raise NotImplementedError + + @staticmethod + def _concat(values): + ret = () + sizes = [] + for v in values: + assert isinstance(v, tuple), "Flattened results must be a tuple" + ret = ret + v + sizes.append(len(v)) + return ret, sizes + + @staticmethod + def _split(values, sizes): + if len(sizes): + expected_len = sum(sizes) + assert ( + len(values) == expected_len + ), f"Values has length {len(values)} but expect length {expected_len}." + ret = [] + for k in range(len(sizes)): + begin, end = sum(sizes[:k]), sum(sizes[: k + 1]) + ret.append(values[begin:end]) + return ret + + +@dataclass +class ListSchema(Schema): + schemas: List[Schema] # the schemas that define how to flatten each element in the list + sizes: List[int] # the flattened length of each element + + def __call__(self, values): + values = self._split(values, self.sizes) + if len(values) != len(self.schemas): + raise ValueError( + f"Values has length {len(values)} but schemas " f"has length {len(self.schemas)}!" + ) + values = [m(v) for m, v in zip(self.schemas, values)] + return list(values) + + @classmethod + def flatten(cls, obj): + res = [flatten_to_tuple(k) for k in obj] + values, sizes = cls._concat([k[0] for k in res]) + return values, cls([k[1] for k in res], sizes) + + +@dataclass +class TupleSchema(ListSchema): + def __call__(self, values): + return tuple(super().__call__(values)) + + +@dataclass +class IdentitySchema(Schema): + def __call__(self, values): + return values[0] + + @classmethod + def flatten(cls, obj): + return (obj,), cls() + + +@dataclass +class DictSchema(ListSchema): + keys: List[str] + + def __call__(self, values): + values = super().__call__(values) + return dict(zip(self.keys, values)) + + @classmethod + def flatten(cls, obj): + for k in obj.keys(): + if not isinstance(k, str): + raise KeyError("Only support flattening dictionaries if keys are str.") + keys = sorted(obj.keys()) + values = [obj[k] for k in keys] + ret, schema = ListSchema.flatten(values) + return ret, cls(schema.schemas, schema.sizes, keys) + + +@dataclass +class InstancesSchema(DictSchema): + def __call__(self, values): + image_size, fields = values[-1], values[:-1] + fields = super().__call__(fields) + return Instances(image_size, **fields) + + @classmethod + def flatten(cls, obj): + ret, schema = super().flatten(obj.get_fields()) + size = obj.image_size + if not isinstance(size, torch.Tensor): + size = torch.tensor(size) + return ret + (size,), schema + + +@dataclass +class TensorWrapSchema(Schema): + """ + For classes that are simple wrapper of tensors, e.g. + Boxes, RotatedBoxes, BitMasks + """ + + class_name: str + + def __call__(self, values): + return locate(self.class_name)(values[0]) + + @classmethod + def flatten(cls, obj): + return (obj.tensor,), cls(_convert_target_to_string(type(obj))) + + +# if more custom structures needed in the future, can allow +# passing in extra schemas for custom types +def flatten_to_tuple(obj): + """ + Flatten an object so it can be used for PyTorch tracing. + Also returns how to rebuild the original object from the flattened outputs. + + Returns: + res (tuple): the flattened results that can be used as tracing outputs + schema: an object with a ``__call__`` method such that ``schema(res) == obj``. + It is a pure dataclass that can be serialized. + """ + schemas = [ + ((str, bytes), IdentitySchema), + (list, ListSchema), + (tuple, TupleSchema), + (collections.abc.Mapping, DictSchema), + (Instances, InstancesSchema), + ((Boxes, ROIMasks), TensorWrapSchema), + ] + for klass, schema in schemas: + if isinstance(obj, klass): + F = schema + break + else: + F = IdentitySchema + + return F.flatten(obj) + + +class TracingAdapter(nn.Module): + """ + A model may take rich input/output format (e.g. dict or custom classes), + but `torch.jit.trace` requires tuple of tensors as input/output. + This adapter flattens input/output format of a model so it becomes traceable. + + It also records the necessary schema to rebuild model's inputs/outputs from flattened + inputs/outputs. + + Example: + :: + outputs = model(inputs) # inputs/outputs may be rich structure + adapter = TracingAdapter(model, inputs) + + # can now trace the model, with adapter.flattened_inputs, or another + # tuple of tensors with the same length and meaning + traced = torch.jit.trace(adapter, adapter.flattened_inputs) + + # traced model can only produce flattened outputs (tuple of tensors) + flattened_outputs = traced(*adapter.flattened_inputs) + # adapter knows the schema to convert it back (new_outputs == outputs) + new_outputs = adapter.outputs_schema(flattened_outputs) + """ + + flattened_inputs: Tuple[torch.Tensor] = None + """ + Flattened version of inputs given to this class's constructor. + """ + + inputs_schema: Schema = None + """ + Schema of the inputs given to this class's constructor. + """ + + outputs_schema: Schema = None + """ + Schema of the output produced by calling the given model with inputs. + """ + + def __init__( + self, + model: nn.Module, + inputs, + inference_func: Optional[Callable] = None, + allow_non_tensor: bool = False, + ): + """ + Args: + model: an nn.Module + inputs: An input argument or a tuple of input arguments used to call model. + After flattening, it has to only consist of tensors. + inference_func: a callable that takes (model, *inputs), calls the + model with inputs, and return outputs. By default it + is ``lambda model, *inputs: model(*inputs)``. Can be override + if you need to call the model differently. + allow_non_tensor: allow inputs/outputs to contain non-tensor objects. + This option will filter out non-tensor objects to make the + model traceable, but ``inputs_schema``/``outputs_schema`` cannot be + used anymore because inputs/outputs cannot be rebuilt from pure tensors. + This is useful when you're only interested in the single trace of + execution (e.g. for flop count), but not interested in + generalizing the traced graph to new inputs. + """ + super().__init__() + if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): + model = model.module + self.model = model + if not isinstance(inputs, tuple): + inputs = (inputs,) + self.inputs = inputs + self.allow_non_tensor = allow_non_tensor + + if inference_func is None: + inference_func = lambda model, *inputs: model(*inputs) # noqa + self.inference_func = inference_func + + self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs) + + if all(isinstance(x, torch.Tensor) for x in self.flattened_inputs): + return + if self.allow_non_tensor: + self.flattened_inputs = tuple( + [x for x in self.flattened_inputs if isinstance(x, torch.Tensor)] + ) + self.inputs_schema = None + else: + for input in self.flattened_inputs: + if not isinstance(input, torch.Tensor): + raise ValueError( + "Inputs for tracing must only contain tensors. " + f"Got a {type(input)} instead." + ) + + def forward(self, *args: torch.Tensor): + with torch.no_grad(), patch_builtin_len(): + if self.inputs_schema is not None: + inputs_orig_format = self.inputs_schema(args) + else: + if len(args) != len(self.flattened_inputs) or any( + x is not y for x, y in zip(args, self.flattened_inputs) + ): + raise ValueError( + "TracingAdapter does not contain valid inputs_schema." + " So it cannot generalize to other inputs and must be" + " traced with `.flattened_inputs`." + ) + inputs_orig_format = self.inputs + + outputs = self.inference_func(self.model, *inputs_orig_format) + flattened_outputs, schema = flatten_to_tuple(outputs) + + flattened_output_tensors = tuple( + [x for x in flattened_outputs if isinstance(x, torch.Tensor)] + ) + if len(flattened_output_tensors) < len(flattened_outputs): + if self.allow_non_tensor: + flattened_outputs = flattened_output_tensors + self.outputs_schema = None + else: + raise ValueError( + "Model cannot be traced because some model outputs " + "cannot flatten to tensors." + ) + else: # schema is valid + if self.outputs_schema is None: + self.outputs_schema = schema + else: + assert self.outputs_schema == schema, ( + "Model should always return outputs with the same " + "structure so it can be traced!" + ) + return flattened_outputs + + def _create_wrapper(self, traced_model): + """ + Return a function that has an input/output interface the same as the + original model, but it calls the given traced model under the hood. + """ + + def forward(*args): + flattened_inputs, _ = flatten_to_tuple(args) + flattened_outputs = traced_model(*flattened_inputs) + return self.outputs_schema(flattened_outputs) + + return forward diff --git a/custom_detectron2/export/shared.py b/custom_detectron2/export/shared.py new file mode 100644 index 0000000000000000000000000000000000000000..53ba9335e26819f9381115eba17bbbe3816b469c --- /dev/null +++ b/custom_detectron2/export/shared.py @@ -0,0 +1,1039 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import collections +import copy +import functools +import logging +import numpy as np +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest import mock +import caffe2.python.utils as putils +import torch +import torch.nn.functional as F +from caffe2.proto import caffe2_pb2 +from caffe2.python import core, net_drawer, workspace +from torch.nn.functional import interpolate as interp + +logger = logging.getLogger(__name__) + + +# ==== torch/utils_toffee/cast.py ======================================= + + +def to_device(t, device_str): + """ + This function is a replacement of .to(another_device) such that it allows the + casting to be traced properly by explicitly calling the underlying copy ops. + It also avoids introducing unncessary op when casting to the same device. + """ + src = t.device + dst = torch.device(device_str) + + if src == dst: + return t + elif src.type == "cuda" and dst.type == "cpu": + return torch.ops._caffe2.CopyGPUToCPU(t) + elif src.type == "cpu" and dst.type == "cuda": + return torch.ops._caffe2.CopyCPUToGPU(t) + else: + raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst)) + + +# ==== torch/utils_toffee/interpolate.py ======================================= + + +# Note: borrowed from vision/detection/fair/detectron/detectron/modeling/detector.py +def BilinearInterpolation(tensor_in, up_scale): + assert up_scale % 2 == 0, "Scale should be even" + + def upsample_filt(size): + factor = (size + 1) // 2 + if size % 2 == 1: + center = factor - 1 + else: + center = factor - 0.5 + + og = np.ogrid[:size, :size] + return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor) + + kernel_size = int(up_scale) * 2 + bil_filt = upsample_filt(kernel_size) + + dim = int(tensor_in.shape[1]) + kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32) + kernel[range(dim), range(dim), :, :] = bil_filt + + tensor_out = F.conv_transpose2d( + tensor_in, + weight=to_device(torch.Tensor(kernel), tensor_in.device), + bias=None, + stride=int(up_scale), + padding=int(up_scale / 2), + ) + + return tensor_out + + +# NOTE: ONNX is incompatible with traced torch.nn.functional.interpolate if +# using dynamic `scale_factor` rather than static `size`. (T43166860) +# NOTE: Caffe2 Int8 conversion might not be able to quantize `size` properly. +def onnx_compatibale_interpolate( + input, size=None, scale_factor=None, mode="nearest", align_corners=None +): + # NOTE: The input dimensions are interpreted in the form: + # `mini-batch x channels x [optional depth] x [optional height] x width`. + if size is None and scale_factor is not None: + if input.dim() == 4: + if isinstance(scale_factor, (int, float)): + height_scale, width_scale = (scale_factor, scale_factor) + else: + assert isinstance(scale_factor, (tuple, list)) + assert len(scale_factor) == 2 + height_scale, width_scale = scale_factor + + assert not align_corners, "No matching C2 op for align_corners == True" + if mode == "nearest": + return torch.ops._caffe2.ResizeNearest( + input, order="NCHW", width_scale=width_scale, height_scale=height_scale + ) + elif mode == "bilinear": + logger.warning( + "Use F.conv_transpose2d for bilinear interpolate" + " because there's no such C2 op, this may cause significant" + " slowdown and the boundary pixels won't be as same as" + " using F.interpolate due to padding." + ) + assert height_scale == width_scale + return BilinearInterpolation(input, up_scale=height_scale) + logger.warning("Output size is not static, it might cause ONNX conversion issue") + + return interp(input, size, scale_factor, mode, align_corners) + + +def mock_torch_nn_functional_interpolate(): + def decorator(func): + @functools.wraps(func) + def _mock_torch_nn_functional_interpolate(*args, **kwargs): + if torch.onnx.is_in_onnx_export(): + with mock.patch( + "torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate + ): + return func(*args, **kwargs) + else: + return func(*args, **kwargs) + + return _mock_torch_nn_functional_interpolate + + return decorator + + +# ==== torch/utils_caffe2/ws_utils.py ========================================== + + +class ScopedWS(object): + def __init__(self, ws_name, is_reset, is_cleanup=False): + self.ws_name = ws_name + self.is_reset = is_reset + self.is_cleanup = is_cleanup + self.org_ws = "" + + def __enter__(self): + self.org_ws = workspace.CurrentWorkspace() + if self.ws_name is not None: + workspace.SwitchWorkspace(self.ws_name, True) + if self.is_reset: + workspace.ResetWorkspace() + + return workspace + + def __exit__(self, *args): + if self.is_cleanup: + workspace.ResetWorkspace() + if self.ws_name is not None: + workspace.SwitchWorkspace(self.org_ws) + + +def fetch_any_blob(name): + bb = None + try: + bb = workspace.FetchBlob(name) + except TypeError: + bb = workspace.FetchInt8Blob(name) + except Exception as e: + logger.error("Get blob {} error: {}".format(name, e)) + + return bb + + +# ==== torch/utils_caffe2/protobuf.py ========================================== + + +def get_pb_arg(pb, arg_name): + for x in pb.arg: + if x.name == arg_name: + return x + return None + + +def get_pb_arg_valf(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.f if arg is not None else default_val + + +def get_pb_arg_floats(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(map(float, arg.floats)) if arg is not None else default_val + + +def get_pb_arg_ints(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(map(int, arg.ints)) if arg is not None else default_val + + +def get_pb_arg_vali(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.i if arg is not None else default_val + + +def get_pb_arg_vals(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return arg.s if arg is not None else default_val + + +def get_pb_arg_valstrings(pb, arg_name, default_val): + arg = get_pb_arg(pb, arg_name) + return list(arg.strings) if arg is not None else default_val + + +def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False): + arg = get_pb_arg(pb, arg_name) + if arg is None: + arg = putils.MakeArgument(arg_name, arg_value) + assert hasattr(arg, arg_attr) + pb.arg.extend([arg]) + if allow_override and getattr(arg, arg_attr) != arg_value: + logger.warning( + "Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value) + ) + setattr(arg, arg_attr, arg_value) + else: + assert arg is not None + assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format( + getattr(arg, arg_attr), arg_value + ) + + +def _create_const_fill_op_from_numpy(name, tensor, device_option=None): + assert type(tensor) == np.ndarray + kTypeNameMapper = { + np.dtype("float32"): "GivenTensorFill", + np.dtype("int32"): "GivenTensorIntFill", + np.dtype("int64"): "GivenTensorInt64Fill", + np.dtype("uint8"): "GivenTensorStringFill", + } + + args_dict = {} + if tensor.dtype == np.dtype("uint8"): + args_dict.update({"values": [str(tensor.data)], "shape": [1]}) + else: + args_dict.update({"values": tensor, "shape": tensor.shape}) + + if device_option is not None: + args_dict["device_option"] = device_option + + return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict) + + +def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor): + assert type(int8_tensor) == workspace.Int8Tensor + kTypeNameMapper = { + np.dtype("int32"): "Int8GivenIntTensorFill", + np.dtype("uint8"): "Int8GivenTensorFill", + } + + tensor = int8_tensor.data + assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")] + values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor + + return core.CreateOperator( + kTypeNameMapper[tensor.dtype], + [], + [name], + values=values, + shape=tensor.shape, + Y_scale=int8_tensor.scale, + Y_zero_point=int8_tensor.zero_point, + ) + + +def create_const_fill_op( + name: str, + blob: Union[np.ndarray, workspace.Int8Tensor], + device_option: Optional[caffe2_pb2.DeviceOption] = None, +) -> caffe2_pb2.OperatorDef: + """ + Given a blob object, return the Caffe2 operator that creates this blob + as constant. Currently support NumPy tensor and Caffe2 Int8Tensor. + """ + + tensor_type = type(blob) + assert tensor_type in [ + np.ndarray, + workspace.Int8Tensor, + ], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format( + name, type(blob) + ) + + if tensor_type == np.ndarray: + return _create_const_fill_op_from_numpy(name, blob, device_option) + elif tensor_type == workspace.Int8Tensor: + assert device_option is None + return _create_const_fill_op_from_c2_int8_tensor(name, blob) + + +def construct_init_net_from_params( + params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None +) -> caffe2_pb2.NetDef: + """ + Construct the init_net from params dictionary + """ + init_net = caffe2_pb2.NetDef() + device_options = device_options or {} + for name, blob in params.items(): + if isinstance(blob, str): + logger.warning( + ( + "Blob {} with type {} is not supported in generating init net," + " skipped.".format(name, type(blob)) + ) + ) + continue + init_net.op.extend( + [create_const_fill_op(name, blob, device_option=device_options.get(name, None))] + ) + init_net.external_output.append(name) + return init_net + + +def get_producer_map(ssa): + """ + Return dict from versioned blob to (i, j), + where i is index of producer op, j is the index of output of that op. + """ + producer_map = {} + for i in range(len(ssa)): + outputs = ssa[i][1] + for j, outp in enumerate(outputs): + producer_map[outp] = (i, j) + return producer_map + + +def get_consumer_map(ssa): + """ + Return dict from versioned blob to list of (i, j), + where i is index of consumer op, j is the index of input of that op. + """ + consumer_map = collections.defaultdict(list) + for i in range(len(ssa)): + inputs = ssa[i][0] + for j, inp in enumerate(inputs): + consumer_map[inp].append((i, j)) + return consumer_map + + +def get_params_from_init_net( + init_net: caffe2_pb2.NetDef, +) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]: + """ + Take the output blobs from init_net by running it. + Outputs: + params: dict from blob name to numpy array + device_options: dict from blob name to the device option of its creating op + """ + # NOTE: this assumes that the params is determined by producer op with the + # only exception be CopyGPUToCPU which is CUDA op but returns CPU tensor. + def _get_device_option(producer_op): + if producer_op.type == "CopyGPUToCPU": + return caffe2_pb2.DeviceOption() + else: + return producer_op.device_option + + with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws: + ws.RunNetOnce(init_net) + params = {b: fetch_any_blob(b) for b in init_net.external_output} + ssa, versions = core.get_ssa(init_net) + producer_map = get_producer_map(ssa) + device_options = { + b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]]) + for b in init_net.external_output + } + return params, device_options + + +def _updater_raise(op, input_types, output_types): + raise RuntimeError( + "Failed to apply updater for op {} given input_types {} and" + " output_types {}".format(op, input_types, output_types) + ) + + +def _generic_status_identifier( + predict_net: caffe2_pb2.NetDef, + status_updater: Callable, + known_status: Dict[Tuple[str, int], Any], +) -> Dict[Tuple[str, int], Any]: + """ + Statically infer the status of each blob, the status can be such as device type + (CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here + is versioned blob (Tuple[str, int]) in the format compatible with ssa. + Inputs: + predict_net: the caffe2 network + status_updater: a callable, given an op and the status of its input/output, + it returns the updated status of input/output. `None` is used for + representing unknown status. + known_status: a dict containing known status, used as initialization. + Outputs: + A dict mapping from versioned blob to its status + """ + ssa, versions = core.get_ssa(predict_net) + versioned_ext_input = [(b, 0) for b in predict_net.external_input] + versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output] + all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa]) + + allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output) + assert all(k in allowed_vbs for k in known_status) + assert all(v is not None for v in known_status.values()) + _known_status = copy.deepcopy(known_status) + + def _check_and_update(key, value): + assert value is not None + if key in _known_status: + if not _known_status[key] == value: + raise RuntimeError( + "Confilict status for {}, existing status {}, new status {}".format( + key, _known_status[key], value + ) + ) + _known_status[key] = value + + def _update_i(op, ssa_i): + versioned_inputs = ssa_i[0] + versioned_outputs = ssa_i[1] + + inputs_status = [_known_status.get(b, None) for b in versioned_inputs] + outputs_status = [_known_status.get(b, None) for b in versioned_outputs] + + new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status) + + for versioned_blob, status in zip( + versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status + ): + if status is not None: + _check_and_update(versioned_blob, status) + + for op, ssa_i in zip(predict_net.op, ssa): + _update_i(op, ssa_i) + for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)): + _update_i(op, ssa_i) + + # NOTE: This strictly checks all the blob from predict_net must be assgined + # a known status. However sometimes it's impossible (eg. having deadend op), + # we may relax this constraint if + for k in all_versioned_blobs: + if k not in _known_status: + raise NotImplementedError( + "Can not infer the status for {}. Currently only support the case where" + " a single forward and backward pass can identify status for all blobs.".format(k) + ) + + return _known_status + + +def infer_device_type( + predict_net: caffe2_pb2.NetDef, + known_status: Dict[Tuple[str, int], Any], + device_name_style: str = "caffe2", +) -> Dict[Tuple[str, int], str]: + """Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob""" + + assert device_name_style in ["caffe2", "pytorch"] + _CPU_STR = "cpu" + _GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda" + + def _copy_cpu_to_gpu_updater(op, input_types, output_types): + if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR: + _updater_raise(op, input_types, output_types) + return ([_CPU_STR], [_GPU_STR]) + + def _copy_gpu_to_cpu_updater(op, input_types, output_types): + if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR: + _updater_raise(op, input_types, output_types) + return ([_GPU_STR], [_CPU_STR]) + + def _other_ops_updater(op, input_types, output_types): + non_none_types = [x for x in input_types + output_types if x is not None] + if len(non_none_types) > 0: + the_type = non_none_types[0] + if not all(x == the_type for x in non_none_types): + _updater_raise(op, input_types, output_types) + else: + the_type = None + return ([the_type for _ in op.input], [the_type for _ in op.output]) + + def _device_updater(op, *args, **kwargs): + return { + "CopyCPUToGPU": _copy_cpu_to_gpu_updater, + "CopyGPUToCPU": _copy_gpu_to_cpu_updater, + }.get(op.type, _other_ops_updater)(op, *args, **kwargs) + + return _generic_status_identifier(predict_net, _device_updater, known_status) + + +# ==== torch/utils_caffe2/vis.py =============================================== + + +def _modify_blob_names(ops, blob_rename_f): + ret = [] + + def _replace_list(blob_list, replaced_list): + del blob_list[:] + blob_list.extend(replaced_list) + + for x in ops: + cur = copy.deepcopy(x) + _replace_list(cur.input, list(map(blob_rename_f, cur.input))) + _replace_list(cur.output, list(map(blob_rename_f, cur.output))) + ret.append(cur) + + return ret + + +def _rename_blob(name, blob_sizes, blob_ranges): + def _list_to_str(bsize): + ret = ", ".join([str(x) for x in bsize]) + ret = "[" + ret + "]" + return ret + + ret = name + if blob_sizes is not None and name in blob_sizes: + ret += "\n" + _list_to_str(blob_sizes[name]) + if blob_ranges is not None and name in blob_ranges: + ret += "\n" + _list_to_str(blob_ranges[name]) + + return ret + + +# graph_name could not contain word 'graph' +def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None): + blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges) + return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f) + + +def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None): + graph = None + ops = net.op + if blob_rename_func is not None: + ops = _modify_blob_names(ops, blob_rename_func) + if not op_only: + graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB") + else: + graph = net_drawer.GetPydotGraphMinimal( + ops, graph_name, rankdir="TB", minimal_dependency=True + ) + + try: + par_dir = os.path.dirname(file_name) + if not os.path.exists(par_dir): + os.makedirs(par_dir) + + format = os.path.splitext(os.path.basename(file_name))[-1] + if format == ".png": + graph.write_png(file_name) + elif format == ".pdf": + graph.write_pdf(file_name) + elif format == ".svg": + graph.write_svg(file_name) + else: + print("Incorrect format {}".format(format)) + except Exception as e: + print("Error when writing graph to image {}".format(e)) + + return graph + + +# ==== torch/utils_toffee/aten_to_caffe2.py ==================================== + + +def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef): + """ + For ONNX exported model, GroupNorm will be represented as ATen op, + this can be a drop in replacement from ATen to GroupNorm + """ + count = 0 + for op in predict_net.op: + if op.type == "ATen": + op_name = get_pb_arg_vals(op, "operator", None) # return byte in py3 + if op_name and op_name.decode() == "group_norm": + op.arg.remove(get_pb_arg(op, "operator")) + + if get_pb_arg_vali(op, "cudnn_enabled", None): + op.arg.remove(get_pb_arg(op, "cudnn_enabled")) + + num_groups = get_pb_arg_vali(op, "num_groups", None) + if num_groups is not None: + op.arg.remove(get_pb_arg(op, "num_groups")) + check_set_pb_arg(op, "group", "i", num_groups) + + op.type = "GroupNorm" + count += 1 + if count > 1: + logger.info("Replaced {} ATen operator to GroupNormOp".format(count)) + + +# ==== torch/utils_toffee/alias.py ============================================= + + +def alias(x, name, is_backward=False): + if not torch.onnx.is_in_onnx_export(): + return x + assert isinstance(x, torch.Tensor) + return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward) + + +def fuse_alias_placeholder(predict_net, init_net): + """Remove AliasWithName placeholder and rename the input/output of it""" + # First we finish all the re-naming + for i, op in enumerate(predict_net.op): + if op.type == "AliasWithName": + assert len(op.input) == 1 + assert len(op.output) == 1 + name = get_pb_arg_vals(op, "name", None).decode() + is_backward = bool(get_pb_arg_vali(op, "is_backward", 0)) + rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward) + rename_op_output(predict_net, i, 0, name) + + # Remove AliasWithName, should be very safe since it's a non-op + new_ops = [] + for op in predict_net.op: + if op.type != "AliasWithName": + new_ops.append(op) + else: + # safety check + assert op.input == op.output + assert op.input[0] == op.arg[0].s.decode() + del predict_net.op[:] + predict_net.op.extend(new_ops) + + +# ==== torch/utils_caffe2/graph_transform.py =================================== + + +class IllegalGraphTransformError(ValueError): + """When a graph transform function call can't be executed.""" + + +def _rename_versioned_blob_in_proto( + proto: caffe2_pb2.NetDef, + old_name: str, + new_name: str, + version: int, + ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]], + start_versions: Dict[str, int], + end_versions: Dict[str, int], +): + """In given proto, rename all blobs with matched version""" + # Operater list + for op, i_th_ssa in zip(proto.op, ssa): + versioned_inputs, versioned_outputs = i_th_ssa + for i in range(len(op.input)): + if versioned_inputs[i] == (old_name, version): + op.input[i] = new_name + for i in range(len(op.output)): + if versioned_outputs[i] == (old_name, version): + op.output[i] = new_name + # external_input + if start_versions.get(old_name, 0) == version: + for i in range(len(proto.external_input)): + if proto.external_input[i] == old_name: + proto.external_input[i] = new_name + # external_output + if end_versions.get(old_name, 0) == version: + for i in range(len(proto.external_output)): + if proto.external_output[i] == old_name: + proto.external_output[i] = new_name + + +def rename_op_input( + predict_net: caffe2_pb2.NetDef, + init_net: caffe2_pb2.NetDef, + op_id: int, + input_id: int, + new_name: str, + from_producer: bool = False, +): + """ + Rename the op_id-th operator in predict_net, change it's input_id-th input's + name to the new_name. It also does automatic re-route and change + external_input and init_net if necessary. + - It requires the input is only consumed by this op. + - This function modifies predict_net and init_net in-place. + - When from_producer is enable, this also updates other operators that consumes + the same input. Be cautious because may trigger unintended behavior. + """ + assert isinstance(predict_net, caffe2_pb2.NetDef) + assert isinstance(init_net, caffe2_pb2.NetDef) + + init_net_ssa, init_net_versions = core.get_ssa(init_net) + predict_net_ssa, predict_net_versions = core.get_ssa( + predict_net, copy.deepcopy(init_net_versions) + ) + + versioned_inputs, versioned_outputs = predict_net_ssa[op_id] + old_name, version = versioned_inputs[input_id] + + if from_producer: + producer_map = get_producer_map(predict_net_ssa) + if not (old_name, version) in producer_map: + raise NotImplementedError( + "Can't find producer, the input {} is probably from" + " init_net, this is not supported yet.".format(old_name) + ) + producer = producer_map[(old_name, version)] + rename_op_output(predict_net, producer[0], producer[1], new_name) + return + + def contain_targets(op_ssa): + return (old_name, version) in op_ssa[0] + + is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa] + if sum(is_consumer) > 1: + raise IllegalGraphTransformError( + ( + "Input '{}' of operator(#{}) are consumed by other ops, please use" + + " rename_op_output on the producer instead. Offending op: \n{}" + ).format(old_name, op_id, predict_net.op[op_id]) + ) + + # update init_net + _rename_versioned_blob_in_proto( + init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions + ) + # update predict_net + _rename_versioned_blob_in_proto( + predict_net, + old_name, + new_name, + version, + predict_net_ssa, + init_net_versions, + predict_net_versions, + ) + + +def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str): + """ + Rename the op_id-th operator in predict_net, change it's output_id-th input's + name to the new_name. It also does automatic re-route and change + external_output and if necessary. + - It allows multiple consumers of its output. + - This function modifies predict_net in-place, doesn't need init_net. + """ + assert isinstance(predict_net, caffe2_pb2.NetDef) + + ssa, blob_versions = core.get_ssa(predict_net) + + versioned_inputs, versioned_outputs = ssa[op_id] + old_name, version = versioned_outputs[output_id] + + # update predict_net + _rename_versioned_blob_in_proto( + predict_net, old_name, new_name, version, ssa, {}, blob_versions + ) + + +def get_sub_graph_external_input_output( + predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int] +) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]: + """ + Return the list of external input/output of sub-graph, + each element is tuple of the name and corresponding version in predict_net. + + external input/output is defined the same way as caffe2 NetDef. + """ + ssa, versions = core.get_ssa(predict_net) + + all_inputs = [] + all_outputs = [] + for op_id in sub_graph_op_indices: + all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs] + all_outputs += list(ssa[op_id][1]) # ssa output won't repeat + + # for versioned blobs, external inputs are just those blob in all_inputs + # but not in all_outputs + ext_inputs = [inp for inp in all_inputs if inp not in all_outputs] + + # external outputs are essentially outputs of this subgraph that are used + # outside of this sub-graph (including predict_net.external_output) + all_other_inputs = sum( + (ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices), + [(outp, versions[outp]) for outp in predict_net.external_output], + ) + ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)] + + return ext_inputs, ext_outputs + + +class DiGraph: + """A DAG representation of caffe2 graph, each vertice is a versioned blob.""" + + def __init__(self): + self.vertices = set() + self.graph = collections.defaultdict(list) + + def add_edge(self, u, v): + self.graph[u].append(v) + self.vertices.add(u) + self.vertices.add(v) + + # grab from https://www.geeksforgeeks.org/find-paths-given-source-destination/ + def get_all_paths(self, s, d): + visited = {k: False for k in self.vertices} + path = [] + all_paths = [] + + def _get_all_paths_util(graph, u, d, visited, path): + visited[u] = True + path.append(u) + if u == d: + all_paths.append(copy.deepcopy(path)) + else: + for i in graph[u]: + if not visited[i]: + _get_all_paths_util(graph, i, d, visited, path) + path.pop() + visited[u] = False + + _get_all_paths_util(self.graph, s, d, visited, path) + return all_paths + + @staticmethod + def from_ssa(ssa): + graph = DiGraph() + for op_id in range(len(ssa)): + for inp in ssa[op_id][0]: + for outp in ssa[op_id][1]: + graph.add_edge(inp, outp) + return graph + + +def _get_dependency_chain(ssa, versioned_target, versioned_source): + """ + Return the index list of relevant operator to produce target blob from source blob, + if there's no dependency, return empty list. + """ + + # finding all paths between nodes can be O(N!), thus we can only search + # in the subgraph using the op starting from the first consumer of source blob + # to the producer of the target blob. + consumer_map = get_consumer_map(ssa) + producer_map = get_producer_map(ssa) + start_op = min(x[0] for x in consumer_map[versioned_source]) - 15 + end_op = ( + producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op + ) + sub_graph_ssa = ssa[start_op : end_op + 1] + if len(sub_graph_ssa) > 30: + logger.warning( + "Subgraph bebetween {} and {} is large (from op#{} to op#{}), it" + " might take non-trival time to find all paths between them.".format( + versioned_source, versioned_target, start_op, end_op + ) + ) + + dag = DiGraph.from_ssa(sub_graph_ssa) + paths = dag.get_all_paths(versioned_source, versioned_target) # include two ends + ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths] + return sorted(set().union(*[set(ops) for ops in ops_in_paths])) + + +def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]: + """ + Idenfity the reshape sub-graph in a protobuf. + The reshape sub-graph is defined as matching the following pattern: + + (input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -─┐ + └-------------------------------------------> Reshape -> (output_blob) + + Return: + List of sub-graphs, each sub-graph is represented as a list of indices + of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape] + """ + + ssa, _ = core.get_ssa(predict_net) + + ret = [] + for i, op in enumerate(predict_net.op): + if op.type == "Reshape": + assert len(op.input) == 2 + input_ssa = ssa[i][0] + data_source = input_ssa[0] + shape_source = input_ssa[1] + op_indices = _get_dependency_chain(ssa, shape_source, data_source) + ret.append(op_indices + [i]) + return ret + + +def remove_reshape_for_fc(predict_net, params): + """ + In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape + a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping + doesn't work well with ONNX and Int8 tools, and cause using extra + ops (eg. ExpandDims) that might not be available on mobile. + Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape + after exporting ONNX model. + """ + from caffe2.python import core + + # find all reshape sub-graph that can be removed, which is now all Reshape + # sub-graph whose output is only consumed by FC. + # TODO: to make it safer, we may need the actually value to better determine + # if a Reshape before FC is removable. + reshape_sub_graphs = identify_reshape_sub_graph(predict_net) + sub_graphs_to_remove = [] + for reshape_sub_graph in reshape_sub_graphs: + reshape_op_id = reshape_sub_graph[-1] + assert predict_net.op[reshape_op_id].type == "Reshape" + ssa, _ = core.get_ssa(predict_net) + reshape_output = ssa[reshape_op_id][1][0] + consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]] + if all(predict_net.op[consumer].type == "FC" for consumer in consumers): + # safety check if the sub-graph is isolated, for this reshape sub-graph, + # it means it has one non-param external input and one external output. + ext_inputs, ext_outputs = get_sub_graph_external_input_output( + predict_net, reshape_sub_graph + ) + non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] + if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1: + sub_graphs_to_remove.append(reshape_sub_graph) + + # perform removing subgraph by: + # 1: rename the Reshape's output to its input, then the graph can be + # seen as in-place itentify, meaning whose external input/output are the same. + # 2: simply remove those ops. + remove_op_ids = [] + params_to_remove = [] + for sub_graph in sub_graphs_to_remove: + logger.info( + "Remove Reshape sub-graph:\n{}".format( + "".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph]) + ) + ) + reshape_op_id = sub_graph[-1] + new_reshap_output = predict_net.op[reshape_op_id].input[0] + rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output) + ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph) + non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0] + params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0] + assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1 + assert ext_outputs[0][0] == non_params_ext_inputs[0][0] + assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1 + remove_op_ids.extend(sub_graph) + params_to_remove.extend(params_ext_inputs) + + predict_net = copy.deepcopy(predict_net) + new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids] + del predict_net.op[:] + predict_net.op.extend(new_ops) + for versioned_params in params_to_remove: + name = versioned_params[0] + logger.info("Remove params: {} from init_net and predict_net.external_input".format(name)) + del params[name] + predict_net.external_input.remove(name) + + return predict_net, params + + +def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef): + """ + In-place fuse extra copy ops between cpu/gpu for the following case: + a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1 + -CopyBToA> c2 -NextOp2-> d2 + The fused network will look like: + a -NextOp1-> d1 + -NextOp2-> d2 + """ + + _COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"] + + def _fuse_once(predict_net): + ssa, blob_versions = core.get_ssa(predict_net) + consumer_map = get_consumer_map(ssa) + versioned_external_output = [ + (name, blob_versions[name]) for name in predict_net.external_output + ] + + for op_id, op in enumerate(predict_net.op): + if op.type in _COPY_OPS: + fw_copy_versioned_output = ssa[op_id][1][0] + consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]] + reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)] + + is_fusable = ( + len(consumer_ids) > 0 + and fw_copy_versioned_output not in versioned_external_output + and all( + predict_net.op[_op_id].type == reverse_op_type + and ssa[_op_id][1][0] not in versioned_external_output + for _op_id in consumer_ids + ) + ) + + if is_fusable: + for rv_copy_op_id in consumer_ids: + # making each NextOp uses "a" directly and removing Copy ops + rs_copy_versioned_output = ssa[rv_copy_op_id][1][0] + next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0] + predict_net.op[next_op_id].input[inp_id] = op.input[0] + # remove CopyOps + new_ops = [ + op + for i, op in enumerate(predict_net.op) + if i != op_id and i not in consumer_ids + ] + del predict_net.op[:] + predict_net.op.extend(new_ops) + return True + + return False + + # _fuse_once returns False is nothing can be fused + while _fuse_once(predict_net): + pass + + +def remove_dead_end_ops(net_def: caffe2_pb2.NetDef): + """remove ops if its output is not used or not in external_output""" + ssa, versions = core.get_ssa(net_def) + versioned_external_output = [(name, versions[name]) for name in net_def.external_output] + consumer_map = get_consumer_map(ssa) + removed_op_ids = set() + + def _is_dead_end(versioned_blob): + return not ( + versioned_blob in versioned_external_output + or ( + len(consumer_map[versioned_blob]) > 0 + and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob]) + ) + ) + + for i, ssa_i in reversed(list(enumerate(ssa))): + versioned_outputs = ssa_i[1] + if all(_is_dead_end(outp) for outp in versioned_outputs): + removed_op_ids.add(i) + + # simply removing those deadend ops should have no effect to external_output + new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids] + del net_def.op[:] + net_def.op.extend(new_ops) diff --git a/custom_detectron2/export/torchscript.py b/custom_detectron2/export/torchscript.py new file mode 100644 index 0000000000000000000000000000000000000000..19965ad967cf9686a10728c98dcccf9e0fb7c447 --- /dev/null +++ b/custom_detectron2/export/torchscript.py @@ -0,0 +1,132 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import torch + +from custom_detectron2.utils.file_io import PathManager + +from .torchscript_patch import freeze_training_mode, patch_instances + +__all__ = ["scripting_with_instances", "dump_torchscript_IR"] + + +def scripting_with_instances(model, fields): + """ + Run :func:`torch.jit.script` on a model that uses the :class:`Instances` class. Since + attributes of :class:`Instances` are "dynamically" added in eager mode,it is difficult + for scripting to support it out of the box. This function is made to support scripting + a model that uses :class:`Instances`. It does the following: + + 1. Create a scriptable ``new_Instances`` class which behaves similarly to ``Instances``, + but with all attributes been "static". + The attributes need to be statically declared in the ``fields`` argument. + 2. Register ``new_Instances``, and force scripting compiler to + use it when trying to compile ``Instances``. + + After this function, the process will be reverted. User should be able to script another model + using different fields. + + Example: + Assume that ``Instances`` in the model consist of two attributes named + ``proposal_boxes`` and ``objectness_logits`` with type :class:`Boxes` and + :class:`Tensor` respectively during inference. You can call this function like: + :: + fields = {"proposal_boxes": Boxes, "objectness_logits": torch.Tensor} + torchscipt_model = scripting_with_instances(model, fields) + + Note: + It only support models in evaluation mode. + + Args: + model (nn.Module): The input model to be exported by scripting. + fields (Dict[str, type]): Attribute names and corresponding type that + ``Instances`` will use in the model. Note that all attributes used in ``Instances`` + need to be added, regardless of whether they are inputs/outputs of the model. + Data type not defined in detectron2 is not supported for now. + + Returns: + torch.jit.ScriptModule: the model in torchscript format + """ + assert ( + not model.training + ), "Currently we only support exporting models in evaluation mode to torchscript" + + with freeze_training_mode(model), patch_instances(fields): + scripted_model = torch.jit.script(model) + return scripted_model + + +# alias for old name +export_torchscript_with_instances = scripting_with_instances + + +def dump_torchscript_IR(model, dir): + """ + Dump IR of a TracedModule/ScriptModule/Function in various format (code, graph, + inlined graph). Useful for debugging. + + Args: + model (TracedModule/ScriptModule/ScriptFUnction): traced or scripted module + dir (str): output directory to dump files. + """ + dir = os.path.expanduser(dir) + PathManager.mkdirs(dir) + + def _get_script_mod(mod): + if isinstance(mod, torch.jit.TracedModule): + return mod._actual_script_module + return mod + + # Dump pretty-printed code: https://pytorch.org/docs/stable/jit.html#inspecting-code + with PathManager.open(os.path.join(dir, "model_ts_code.txt"), "w") as f: + + def get_code(mod): + # Try a few ways to get code using private attributes. + try: + # This contains more information than just `mod.code` + return _get_script_mod(mod)._c.code + except AttributeError: + pass + try: + return mod.code + except AttributeError: + return None + + def dump_code(prefix, mod): + code = get_code(mod) + name = prefix or "root model" + if code is None: + f.write(f"Could not found code for {name} (type={mod.original_name})\n") + f.write("\n") + else: + f.write(f"\nCode for {name}, type={mod.original_name}:\n") + f.write(code) + f.write("\n") + f.write("-" * 80) + + for name, m in mod.named_children(): + dump_code(prefix + "." + name, m) + + if isinstance(model, torch.jit.ScriptFunction): + f.write(get_code(model)) + else: + dump_code("", model) + + def _get_graph(model): + try: + # Recursively dump IR of all modules + return _get_script_mod(model)._c.dump_to_str(True, False, False) + except AttributeError: + return model.graph.str() + + with PathManager.open(os.path.join(dir, "model_ts_IR.txt"), "w") as f: + f.write(_get_graph(model)) + + # Dump IR of the entire graph (all submodules inlined) + with PathManager.open(os.path.join(dir, "model_ts_IR_inlined.txt"), "w") as f: + f.write(str(model.inlined_graph)) + + if not isinstance(model, torch.jit.ScriptFunction): + # Dump the model structure in pytorch style + with PathManager.open(os.path.join(dir, "model.txt"), "w") as f: + f.write(str(model)) diff --git a/custom_detectron2/export/torchscript_patch.py b/custom_detectron2/export/torchscript_patch.py new file mode 100644 index 0000000000000000000000000000000000000000..9c031d75579addc30ddbc9c7f46280724078c328 --- /dev/null +++ b/custom_detectron2/export/torchscript_patch.py @@ -0,0 +1,406 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import os +import sys +import tempfile +from contextlib import ExitStack, contextmanager +from copy import deepcopy +from unittest import mock +import torch +from torch import nn + +# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964 +import custom_detectron2 # noqa F401 +from custom_detectron2.structures import Boxes, Instances +from custom_detectron2.utils.env import _import_file + +_counter = 0 + + +def _clear_jit_cache(): + from torch.jit._recursive import concrete_type_store + from torch.jit._state import _jit_caching_layer + + concrete_type_store.type_store.clear() # for modules + _jit_caching_layer.clear() # for free functions + + +def _add_instances_conversion_methods(newInstances): + """ + Add from_instances methods to the scripted Instances class. + """ + cls_name = newInstances.__name__ + + @torch.jit.unused + def from_instances(instances: Instances): + """ + Create scripted Instances from original Instances + """ + fields = instances.get_fields() + image_size = instances.image_size + ret = newInstances(image_size) + for name, val in fields.items(): + assert hasattr(ret, f"_{name}"), f"No attribute named {name} in {cls_name}" + setattr(ret, name, deepcopy(val)) + return ret + + newInstances.from_instances = from_instances + + +@contextmanager +def patch_instances(fields): + """ + A contextmanager, under which the Instances class in detectron2 is replaced + by a statically-typed scriptable class, defined by `fields`. + See more in `scripting_with_instances`. + """ + + with tempfile.TemporaryDirectory(prefix="detectron2") as dir, tempfile.NamedTemporaryFile( + mode="w", encoding="utf-8", suffix=".py", dir=dir, delete=False + ) as f: + try: + # Objects that use Instances should not reuse previously-compiled + # results in cache, because `Instances` could be a new class each time. + _clear_jit_cache() + + cls_name, s = _gen_instance_module(fields) + f.write(s) + f.flush() + f.close() + + module = _import(f.name) + new_instances = getattr(module, cls_name) + _ = torch.jit.script(new_instances) + # let torchscript think Instances was scripted already + Instances.__torch_script_class__ = True + # let torchscript find new_instances when looking for the jit type of Instances + Instances._jit_override_qualname = torch._jit_internal._qualified_name(new_instances) + + _add_instances_conversion_methods(new_instances) + yield new_instances + finally: + try: + del Instances.__torch_script_class__ + del Instances._jit_override_qualname + except AttributeError: + pass + sys.modules.pop(module.__name__) + + +def _gen_instance_class(fields): + """ + Args: + fields (dict[name: type]) + """ + + class _FieldType: + def __init__(self, name, type_): + assert isinstance(name, str), f"Field name must be str, got {name}" + self.name = name + self.type_ = type_ + self.annotation = f"{type_.__module__}.{type_.__name__}" + + fields = [_FieldType(k, v) for k, v in fields.items()] + + def indent(level, s): + return " " * 4 * level + s + + lines = [] + + global _counter + _counter += 1 + + cls_name = "ScriptedInstances{}".format(_counter) + + field_names = tuple(x.name for x in fields) + extra_args = ", ".join([f"{f.name}: Optional[{f.annotation}] = None" for f in fields]) + lines.append( + f""" +class {cls_name}: + def __init__(self, image_size: Tuple[int, int], {extra_args}): + self.image_size = image_size + self._field_names = {field_names} +""" + ) + + for f in fields: + lines.append( + indent(2, f"self._{f.name} = torch.jit.annotate(Optional[{f.annotation}], {f.name})") + ) + + for f in fields: + lines.append( + f""" + @property + def {f.name}(self) -> {f.annotation}: + # has to use a local for type refinement + # https://pytorch.org/docs/stable/jit_language_reference.html#optional-type-refinement + t = self._{f.name} + assert t is not None, "{f.name} is None and cannot be accessed!" + return t + + @{f.name}.setter + def {f.name}(self, value: {f.annotation}) -> None: + self._{f.name} = value +""" + ) + + # support method `__len__` + lines.append( + """ + def __len__(self) -> int: +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + return len(t) +""" + ) + lines.append( + """ + raise NotImplementedError("Empty Instances does not support __len__!") +""" + ) + + # support method `has` + lines.append( + """ + def has(self, name: str) -> bool: +""" + ) + for f in fields: + lines.append( + f""" + if name == "{f.name}": + return self._{f.name} is not None +""" + ) + lines.append( + """ + return False +""" + ) + + # support method `to` + none_args = ", None" * len(fields) + lines.append( + f""" + def to(self, device: torch.device) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + if hasattr(f.type_, "to"): + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret._{f.name} = t.to(device) +""" + ) + else: + # For now, ignore fields that cannot be moved to devices. + # Maybe can support other tensor-like classes (e.g. __torch_function__) + pass + lines.append( + """ + return ret +""" + ) + + # support method `getitem` + none_args = ", None" * len(fields) + lines.append( + f""" + def __getitem__(self, item) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret._{f.name} = t[item] +""" + ) + lines.append( + """ + return ret +""" + ) + + # support method `cat` + # this version does not contain checks that all instances have same size and fields + none_args = ", None" * len(fields) + lines.append( + f""" + def cat(self, instances: List["{cls_name}"]) -> "{cls_name}": + ret = {cls_name}(self.image_size{none_args}) +""" + ) + for f in fields: + lines.append( + f""" + t = self._{f.name} + if t is not None: + values: List[{f.annotation}] = [x.{f.name} for x in instances] + if torch.jit.isinstance(t, torch.Tensor): + ret._{f.name} = torch.cat(values, dim=0) + else: + ret._{f.name} = t.cat(values) +""" + ) + lines.append( + """ + return ret""" + ) + + # support method `get_fields()` + lines.append( + """ + def get_fields(self) -> Dict[str, Tensor]: + ret = {} + """ + ) + for f in fields: + if f.type_ == Boxes: + stmt = "t.tensor" + elif f.type_ == torch.Tensor: + stmt = "t" + else: + stmt = f'assert False, "unsupported type {str(f.type_)}"' + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret["{f.name}"] = {stmt} + """ + ) + lines.append( + """ + return ret""" + ) + return cls_name, os.linesep.join(lines) + + +def _gen_instance_module(fields): + # TODO: find a more automatic way to enable import of other classes + s = """ +from copy import deepcopy +import torch +from torch import Tensor +import typing +from typing import * + +import custom_detectron2 +from custom_detectron2.structures import Boxes, Instances + +""" + + cls_name, cls_def = _gen_instance_class(fields) + s += cls_def + return cls_name, s + + +def _import(path): + return _import_file( + "{}{}".format(sys.modules[__name__].__name__, _counter), path, make_importable=True + ) + + +@contextmanager +def patch_builtin_len(modules=()): + """ + Patch the builtin len() function of a few detectron2 modules + to use __len__ instead, because __len__ does not convert values to + integers and therefore is friendly to tracing. + + Args: + modules (list[stsr]): names of extra modules to patch len(), in + addition to those in detectron2. + """ + + def _new_len(obj): + return obj.__len__() + + with ExitStack() as stack: + MODULES = [ + "detectron2.modeling.roi_heads.fast_rcnn", + "detectron2.modeling.roi_heads.mask_head", + "detectron2.modeling.roi_heads.keypoint_head", + ] + list(modules) + ctxs = [stack.enter_context(mock.patch(mod + ".len")) for mod in MODULES] + for m in ctxs: + m.side_effect = _new_len + yield + + +def patch_nonscriptable_classes(): + """ + Apply patches on a few nonscriptable detectron2 classes. + Should not have side-effects on eager usage. + """ + # __prepare_scriptable__ can also be added to models for easier maintenance. + # But it complicates the clean model code. + + from custom_detectron2.modeling.backbone import ResNet, FPN + + # Due to https://github.com/pytorch/pytorch/issues/36061, + # we change backbone to use ModuleList for scripting. + # (note: this changes param names in state_dict) + + def prepare_resnet(self): + ret = deepcopy(self) + ret.stages = nn.ModuleList(ret.stages) + for k in self.stage_names: + delattr(ret, k) + return ret + + ResNet.__prepare_scriptable__ = prepare_resnet + + def prepare_fpn(self): + ret = deepcopy(self) + ret.lateral_convs = nn.ModuleList(ret.lateral_convs) + ret.output_convs = nn.ModuleList(ret.output_convs) + for name, _ in self.named_children(): + if name.startswith("fpn_"): + delattr(ret, name) + return ret + + FPN.__prepare_scriptable__ = prepare_fpn + + # Annotate some attributes to be constants for the purpose of scripting, + # even though they are not constants in eager mode. + from custom_detectron2.modeling.roi_heads import StandardROIHeads + + if hasattr(StandardROIHeads, "__annotations__"): + # copy first to avoid editing annotations of base class + StandardROIHeads.__annotations__ = deepcopy(StandardROIHeads.__annotations__) + StandardROIHeads.__annotations__["mask_on"] = torch.jit.Final[bool] + StandardROIHeads.__annotations__["keypoint_on"] = torch.jit.Final[bool] + + +# These patches are not supposed to have side-effects. +patch_nonscriptable_classes() + + +@contextmanager +def freeze_training_mode(model): + """ + A context manager that annotates the "training" attribute of every submodule + to constant, so that the training codepath in these modules can be + meta-compiled away. Upon exiting, the annotations are reverted. + """ + classes = {type(x) for x in model.modules()} + # __constants__ is the old way to annotate constants and not compatible + # with __annotations__ . + classes = {x for x in classes if not hasattr(x, "__constants__")} + for cls in classes: + cls.__annotations__["training"] = torch.jit.Final[bool] + yield + for cls in classes: + cls.__annotations__["training"] = bool diff --git a/custom_detectron2/layers/__init__.py b/custom_detectron2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..761a3d1c7afa049e9779ee9fc4d299e9aae38cad --- /dev/null +++ b/custom_detectron2/layers/__init__.py @@ -0,0 +1,26 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .batch_norm import FrozenBatchNorm2d, get_norm, NaiveSyncBatchNorm, CycleBatchNormList +from .deform_conv import DeformConv, ModulatedDeformConv +from .mask_ops import paste_masks_in_image +from .nms import batched_nms, batched_nms_rotated, nms, nms_rotated +from .roi_align import ROIAlign, roi_align +from .roi_align_rotated import ROIAlignRotated, roi_align_rotated +from .shape_spec import ShapeSpec +from .wrappers import ( + BatchNorm2d, + Conv2d, + ConvTranspose2d, + cat, + interpolate, + Linear, + nonzero_tuple, + cross_entropy, + empty_input_loss_func_wrapper, + shapes_to_tensor, + move_device_like, +) +from .blocks import CNNBlockBase, DepthwiseSeparableConv2d +from .aspp import ASPP +from .losses import ciou_loss, diou_loss + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/layers/aspp.py b/custom_detectron2/layers/aspp.py new file mode 100644 index 0000000000000000000000000000000000000000..14861aa9ede4fea6a69a49f189bcab997b558148 --- /dev/null +++ b/custom_detectron2/layers/aspp.py @@ -0,0 +1,144 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from copy import deepcopy +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from .batch_norm import get_norm +from .blocks import DepthwiseSeparableConv2d +from .wrappers import Conv2d + + +class ASPP(nn.Module): + """ + Atrous Spatial Pyramid Pooling (ASPP). + """ + + def __init__( + self, + in_channels, + out_channels, + dilations, + *, + norm, + activation, + pool_kernel_size=None, + dropout: float = 0.0, + use_depthwise_separable_conv=False, + ): + """ + Args: + in_channels (int): number of input channels for ASPP. + out_channels (int): number of output channels. + dilations (list): a list of 3 dilations in ASPP. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. norm is + applied to all conv layers except the conv following + global average pooling. + activation (callable): activation function. + pool_kernel_size (tuple, list): the average pooling size (kh, kw) + for image pooling layer in ASPP. If set to None, it always + performs global average pooling. If not None, it must be + divisible by the shape of inputs in forward(). It is recommended + to use a fixed input feature size in training, and set this + option to match this size, so that it performs global average + pooling in training, and the size of the pooling window stays + consistent in inference. + dropout (float): apply dropout on the output of ASPP. It is used in + the official DeepLab implementation with a rate of 0.1: + https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/model.py#L532 # noqa + use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d + for 3x3 convs in ASPP, proposed in :paper:`DeepLabV3+`. + """ + super(ASPP, self).__init__() + assert len(dilations) == 3, "ASPP expects 3 dilations, got {}".format(len(dilations)) + self.pool_kernel_size = pool_kernel_size + self.dropout = dropout + use_bias = norm == "" + self.convs = nn.ModuleList() + # conv 1x1 + self.convs.append( + Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + ) + weight_init.c2_xavier_fill(self.convs[-1]) + # atrous convs + for dilation in dilations: + if use_depthwise_separable_conv: + self.convs.append( + DepthwiseSeparableConv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + norm1=norm, + activation1=deepcopy(activation), + norm2=norm, + activation2=deepcopy(activation), + ) + ) + else: + self.convs.append( + Conv2d( + in_channels, + out_channels, + kernel_size=3, + padding=dilation, + dilation=dilation, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + ) + weight_init.c2_xavier_fill(self.convs[-1]) + # image pooling + # We do not add BatchNorm because the spatial resolution is 1x1, + # the original TF implementation has BatchNorm. + if pool_kernel_size is None: + image_pooling = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), + ) + else: + image_pooling = nn.Sequential( + nn.AvgPool2d(kernel_size=pool_kernel_size, stride=1), + Conv2d(in_channels, out_channels, 1, bias=True, activation=deepcopy(activation)), + ) + weight_init.c2_xavier_fill(image_pooling[1]) + self.convs.append(image_pooling) + + self.project = Conv2d( + 5 * out_channels, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + activation=deepcopy(activation), + ) + weight_init.c2_xavier_fill(self.project) + + def forward(self, x): + size = x.shape[-2:] + if self.pool_kernel_size is not None: + if size[0] % self.pool_kernel_size[0] or size[1] % self.pool_kernel_size[1]: + raise ValueError( + "`pool_kernel_size` must be divisible by the shape of inputs. " + "Input size: {} `pool_kernel_size`: {}".format(size, self.pool_kernel_size) + ) + res = [] + for conv in self.convs: + res.append(conv(x)) + res[-1] = F.interpolate(res[-1], size=size, mode="bilinear", align_corners=False) + res = torch.cat(res, dim=1) + res = self.project(res) + res = F.dropout(res, self.dropout, training=self.training) if self.dropout > 0 else res + return res diff --git a/custom_detectron2/layers/batch_norm.py b/custom_detectron2/layers/batch_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..24899c56420a3e8db3793f580566ee9d8d44f84c --- /dev/null +++ b/custom_detectron2/layers/batch_norm.py @@ -0,0 +1,300 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +import torch.distributed as dist +from fvcore.nn.distributed import differentiable_all_reduce +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.utils import comm, env + +from .wrappers import BatchNorm2d + + +class FrozenBatchNorm2d(nn.Module): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + + It contains non-trainable buffers called + "weight" and "bias", "running_mean", "running_var", + initialized to perform identity transformation. + + The pre-trained backbone models from Caffe2 only contain "weight" and "bias", + which are computed from the original four parameters of BN. + The affine transform `x * weight + bias` will perform the equivalent + computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. + When loading a backbone model from Caffe2, "running_mean" and "running_var" + will be left unchanged as identity transformation. + + Other pre-trained backbone models may contain all 4 parameters. + + The forward is implemented by `F.batch_norm(..., training=False)`. + """ + + _version = 3 + + def __init__(self, num_features, eps=1e-5): + super().__init__() + self.num_features = num_features + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + out_dtype = x.dtype # may be half + return x * scale.to(out_dtype) + bias.to(out_dtype) + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # No running_mean/var in early versions + # This will silent the warnings + if prefix + "running_mean" not in state_dict: + state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) + if prefix + "running_var" not in state_dict: + state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def __repr__(self): + return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. + + Args: + module (torch.nn.Module): + + Returns: + If module is BatchNorm/SyncBatchNorm, returns a new module. + Otherwise, in-place convert module and return it. + + Similar to convert_sync_batchnorm in + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py + """ + bn_module = nn.modules.batchnorm + bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) + res = module + if isinstance(module, bn_module): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def get_norm(norm, out_channels): + """ + Args: + norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; + or a callable that takes a channel number and returns + the normalization layer as a nn.Module. + + Returns: + nn.Module or None: the normalization layer + """ + if norm is None: + return None + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "BN": BatchNorm2d, + # Fixed in https://github.com/pytorch/pytorch/pull/36382 + "SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm, + "FrozenBN": FrozenBatchNorm2d, + "GN": lambda channels: nn.GroupNorm(32, channels), + # for debugging: + "nnSyncBN": nn.SyncBatchNorm, + "naiveSyncBN": NaiveSyncBatchNorm, + # expose stats_mode N as an option to caller, required for zero-len inputs + "naiveSyncBN_N": lambda channels: NaiveSyncBatchNorm(channels, stats_mode="N"), + "LN": lambda channels: LayerNorm(channels), + }[norm] + return norm(out_channels) + + +class NaiveSyncBatchNorm(BatchNorm2d): + """ + In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient + when the batch size on each worker is different. + (e.g., when scale augmentation is used, or when it is applied to mask head). + + This is a slower but correct alternative to `nn.SyncBatchNorm`. + + Note: + There isn't a single definition of Sync BatchNorm. + + When ``stats_mode==""``, this module computes overall statistics by using + statistics of each worker with equal weight. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (N, H, W). This mode does not support inputs with zero batch size. + + When ``stats_mode=="N"``, this module computes overall statistics by weighting + the statistics of each worker by their ``N``. The result is true statistics + of all samples (as if they are all on one worker) only when all workers + have the same (H, W). It is slower than ``stats_mode==""``. + + Even though the result of this module may not be the true statistics of all samples, + it may still be reasonable because it might be preferrable to assign equal weights + to all workers, regardless of their (H, W) dimension, instead of putting larger weight + on larger images. From preliminary experiments, little difference is found between such + a simplified implementation and an accurate computation of overall mean & variance. + """ + + def __init__(self, *args, stats_mode="", **kwargs): + super().__init__(*args, **kwargs) + assert stats_mode in ["", "N"] + self._stats_mode = stats_mode + + def forward(self, input): + if comm.get_world_size() == 1 or not self.training: + return super().forward(input) + + B, C = input.shape[0], input.shape[1] + + half_input = input.dtype == torch.float16 + if half_input: + # fp16 does not have good enough numerics for the reduction here + input = input.float() + mean = torch.mean(input, dim=[0, 2, 3]) + meansqr = torch.mean(input * input, dim=[0, 2, 3]) + + if self._stats_mode == "": + assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' + vec = torch.cat([mean, meansqr], dim=0) + vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) + mean, meansqr = torch.split(vec, C) + momentum = self.momentum + else: + if B == 0: + vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) + vec = vec + input.sum() # make sure there is gradient w.r.t input + else: + vec = torch.cat( + [mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0 + ) + vec = differentiable_all_reduce(vec * B) + + total_batch = vec[-1].detach() + momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0 + mean, meansqr, _ = torch.split(vec / total_batch.clamp(min=1), C) # avoid div-by-zero + + var = meansqr - mean * mean + invstd = torch.rsqrt(var + self.eps) + scale = self.weight * invstd + bias = self.bias - mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + + self.running_mean += momentum * (mean.detach() - self.running_mean) + self.running_var += momentum * (var.detach() - self.running_var) + ret = input * scale + bias + if half_input: + ret = ret.half() + return ret + + +class CycleBatchNormList(nn.ModuleList): + """ + Implement domain-specific BatchNorm by cycling. + + When a BatchNorm layer is used for multiple input domains or input + features, it might need to maintain a separate test-time statistics + for each domain. See Sec 5.2 in :paper:`rethinking-batchnorm`. + + This module implements it by using N separate BN layers + and it cycles through them every time a forward() is called. + + NOTE: The caller of this module MUST guarantee to always call + this module by multiple of N times. Otherwise its test-time statistics + will be incorrect. + """ + + def __init__(self, length: int, bn_class=nn.BatchNorm2d, **kwargs): + """ + Args: + length: number of BatchNorm layers to cycle. + bn_class: the BatchNorm class to use + kwargs: arguments of the BatchNorm class, such as num_features. + """ + self._affine = kwargs.pop("affine", True) + super().__init__([bn_class(**kwargs, affine=False) for k in range(length)]) + if self._affine: + # shared affine, domain-specific BN + channels = self[0].num_features + self.weight = nn.Parameter(torch.ones(channels)) + self.bias = nn.Parameter(torch.zeros(channels)) + self._pos = 0 + + def forward(self, x): + ret = self[self._pos](x) + self._pos = (self._pos + 1) % len(self) + + if self._affine: + w = self.weight.reshape(1, -1, 1, 1) + b = self.bias.reshape(1, -1, 1, 1) + return ret * w + b + else: + return ret + + def extra_repr(self): + return f"affine={self._affine}" + + +class LayerNorm(nn.Module): + """ + A LayerNorm variant, popularized by Transformers, that performs point-wise mean and + variance normalization over the channel dimension for inputs that have shape + (batch_size, channels, height, width). + https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 + """ + + def __init__(self, normalized_shape, eps=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x diff --git a/custom_detectron2/layers/blocks.py b/custom_detectron2/layers/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..1995a4bf7339e8deb7eaaffda4f819dda55e7ac7 --- /dev/null +++ b/custom_detectron2/layers/blocks.py @@ -0,0 +1,111 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import fvcore.nn.weight_init as weight_init +from torch import nn + +from .batch_norm import FrozenBatchNorm2d, get_norm +from .wrappers import Conv2d + + +""" +CNN building blocks. +""" + + +class CNNBlockBase(nn.Module): + """ + A CNN block is assumed to have input channels, output channels and a stride. + The input and output of `forward()` method must be NCHW tensors. + The method can perform arbitrary computation but must match the given + channels and stride specification. + + Attribute: + in_channels (int): + out_channels (int): + stride (int): + """ + + def __init__(self, in_channels, out_channels, stride): + """ + The `__init__` method of any subclass should also contain these arguments. + + Args: + in_channels (int): + out_channels (int): + stride (int): + """ + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + + def freeze(self): + """ + Make this block not trainable. + This method sets all parameters to `requires_grad=False`, + and convert all BatchNorm layers to FrozenBatchNorm + + Returns: + the block itself + """ + for p in self.parameters(): + p.requires_grad = False + FrozenBatchNorm2d.convert_frozen_batchnorm(self) + return self + + +class DepthwiseSeparableConv2d(nn.Module): + """ + A kxk depthwise convolution + a 1x1 convolution. + + In :paper:`xception`, norm & activation are applied on the second conv. + :paper:`mobilenet` uses norm & activation on both convs. + """ + + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + padding=1, + dilation=1, + *, + norm1=None, + activation1=None, + norm2=None, + activation2=None, + ): + """ + Args: + norm1, norm2 (str or callable): normalization for the two conv layers. + activation1, activation2 (callable(Tensor) -> Tensor): activation + function for the two conv layers. + """ + super().__init__() + self.depthwise = Conv2d( + in_channels, + in_channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + groups=in_channels, + bias=not norm1, + norm=get_norm(norm1, in_channels), + activation=activation1, + ) + self.pointwise = Conv2d( + in_channels, + out_channels, + kernel_size=1, + bias=not norm2, + norm=get_norm(norm2, out_channels), + activation=activation2, + ) + + # default initialization + weight_init.c2_msra_fill(self.depthwise) + weight_init.c2_msra_fill(self.pointwise) + + def forward(self, x): + return self.pointwise(self.depthwise(x)) diff --git a/custom_detectron2/layers/csrc/README.md b/custom_detectron2/layers/csrc/README.md new file mode 100644 index 0000000000000000000000000000000000000000..778ed3da0bae89820831bcd8a72ff7b9cad8d4dd --- /dev/null +++ b/custom_detectron2/layers/csrc/README.md @@ -0,0 +1,7 @@ + + +To add a new Op: + +1. Create a new directory +2. Implement new ops there +3. Delcare its Python interface in `vision.cpp`. diff --git a/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h new file mode 100644 index 0000000000000000000000000000000000000000..03f4211003f42f601f0cfcf4a690f5da4a0a1f67 --- /dev/null +++ b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated.h @@ -0,0 +1,115 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor ROIAlignRotated_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlignRotated_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor ROIAlignRotated_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio); + +at::Tensor ROIAlignRotated_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio); +#endif + +// Interface for Python +inline at::Tensor ROIAlignRotated_forward( + const at::Tensor& input, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t sampling_ratio) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return ROIAlignRotated_forward_cuda( + input, + rois, + spatial_scale, + pooled_height, + pooled_width, + sampling_ratio); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + return ROIAlignRotated_forward_cpu( + input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio); +} + +inline at::Tensor ROIAlignRotated_backward( + const at::Tensor& grad, + const at::Tensor& rois, + const double spatial_scale, + const int64_t pooled_height, + const int64_t pooled_width, + const int64_t batch_size, + const int64_t channels, + const int64_t height, + const int64_t width, + const int64_t sampling_ratio) { + if (grad.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return ROIAlignRotated_backward_cuda( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + return ROIAlignRotated_backward_cpu( + grad, + rois, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width, + sampling_ratio); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2a3d3056cc71a4acaafb570739a9dd247a7eb1ed --- /dev/null +++ b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cpu.cpp @@ -0,0 +1,522 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include "ROIAlignRotated.h" + +// Note: this implementation originates from the Caffe2 ROIAlignRotated Op +// and PyTorch ROIAlign (non-rotated) Op implementations. +// The key difference between this implementation and those ones is +// we don't do "legacy offset" in this version, as there aren't many previous +// works, if any, using the "legacy" ROIAlignRotated Op. +// This would make the interface a bit cleaner. + +namespace detectron2 { + +namespace { +template +struct PreCalc { + int pos1; + int pos2; + int pos3; + int pos4; + T w1; + T w2; + T w3; + T w4; +}; + +template +void pre_calc_for_bilinear_interpolate( + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int iy_upper, + const int ix_upper, + T roi_start_h, + T roi_start_w, + T bin_size_h, + T bin_size_w, + int roi_bin_grid_h, + int roi_bin_grid_w, + T roi_center_h, + T roi_center_w, + T cos_theta, + T sin_theta, + std::vector>& pre_calc) { + int pre_calc_index = 0; + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + for (int iy = 0; iy < iy_upper; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < ix_upper; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + // In image space, (y, x) is the order for Right Handed System, + // and this is essentially multiplying the point by a rotation matrix + // to rotate it counterclockwise through angle theta. + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + // deal with: inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + PreCalc pc; + pc.pos1 = 0; + pc.pos2 = 0; + pc.pos3 = 0; + pc.pos4 = 0; + pc.w1 = 0; + pc.w2 = 0; + pc.w3 = 0; + pc.w4 = 0; + pre_calc[pre_calc_index] = pc; + pre_calc_index += 1; + continue; + } + + if (y < 0) { + y = 0; + } + if (x < 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + // save weights and indices + PreCalc pc; + pc.pos1 = y_low * width + x_low; + pc.pos2 = y_low * width + x_high; + pc.pos3 = y_high * width + x_low; + pc.pos4 = y_high * width + x_high; + pc.w1 = w1; + pc.w2 = w2; + pc.w3 = w3; + pc.w4 = w4; + pre_calc[pre_calc_index] = pc; + + pre_calc_index += 1; + } + } + } + } +} + +template +void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +template +inline void add(T* address, const T& val) { + *address += val; +} + +} // namespace + +template +void ROIAlignRotatedForward( + const int nthreads, + const T* input, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + T* output) { + int n_rois = nthreads / channels / pooled_width / pooled_height; + // (n, c, ph, pw) is an element in the pooled output + // can be parallelized using omp + // #pragma omp parallel for num_threads(32) + for (int n = 0; n < n_rois; n++) { + int index_n = n * channels * pooled_width * pooled_height; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + AT_ASSERTM( + roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlignRotated do not have non-negative size!"); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // We do average (integral) pooling inside a bin + const T count = std::max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + // we want to precalculate indices and weights shared by all channels, + // this is the key point of optimization + std::vector> pre_calc( + roi_bin_grid_h * roi_bin_grid_w * pooled_width * pooled_height); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + pre_calc_for_bilinear_interpolate( + height, + width, + pooled_height, + pooled_width, + roi_bin_grid_h, + roi_bin_grid_w, + roi_start_h, + roi_start_w, + bin_size_h, + bin_size_w, + roi_bin_grid_h, + roi_bin_grid_w, + roi_center_h, + roi_center_w, + cos_theta, + sin_theta, + pre_calc); + + for (int c = 0; c < channels; c++) { + int index_n_c = index_n + c * pooled_width * pooled_height; + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + int pre_calc_index = 0; + + for (int ph = 0; ph < pooled_height; ph++) { + for (int pw = 0; pw < pooled_width; pw++) { + int index = index_n_c + ph * pooled_width + pw; + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + PreCalc pc = pre_calc[pre_calc_index]; + output_val += pc.w1 * offset_input[pc.pos1] + + pc.w2 * offset_input[pc.pos2] + + pc.w3 * offset_input[pc.pos3] + pc.w4 * offset_input[pc.pos4]; + + pre_calc_index += 1; + } + } + output_val /= count; + + output[index] = output_val; + } // for pw + } // for ph + } // for c + } // for n +} + +template +void ROIAlignRotatedBackward( + const int nthreads, + // may not be contiguous. should index using n_stride, etc + const T* grad_output, + const T& spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* grad_input, + const T* rois, + const int n_stride, + const int c_stride, + const int h_stride, + const int w_stride) { + for (int index = 0; index < nthreads; index++) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + AT_ASSERTM( + roi_width >= 0 && roi_height >= 0, + "ROIs in ROIAlignRotated do not have non-negative size!"); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_grad_input = + grad_input + ((roi_batch_ind * channels + c) * height * width); + + int output_offset = n * n_stride + c * c_stride; + const T* offset_grad_output = grad_output + output_offset; + const T grad_output_this_bin = + offset_grad_output[ph * h_stride + pw * w_stride]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high); + + T g1 = grad_output_this_bin * w1 / count; + T g2 = grad_output_this_bin * w2 / count; + T g3 = grad_output_this_bin * w3 / count; + T g4 = grad_output_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + // atomic add is not needed for now since it is single threaded + add(offset_grad_input + y_low * width + x_low, static_cast(g1)); + add(offset_grad_input + y_low * width + x_high, static_cast(g2)); + add(offset_grad_input + y_high * width + x_low, static_cast(g3)); + add(offset_grad_input + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // for +} // ROIAlignRotatedBackward + +at::Tensor ROIAlignRotated_forward_cpu( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.device().is_cpu(), "input must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlign_forward_cpu"; + at::checkAllSameType(c, {input_t, rois_t}); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + at::Tensor output = at::zeros( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + + auto output_size = num_rois * pooled_height * pooled_width * channels; + + if (output.numel() == 0) { + return output; + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.scalar_type(), "ROIAlignRotated_forward", [&] { + ROIAlignRotatedForward( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + output.data_ptr()); + }); + return output; +} + +at::Tensor ROIAlignRotated_backward_cpu( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cpu(), "grad must be a CPU tensor"); + AT_ASSERTM(rois.device().is_cpu(), "rois must be a CPU tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignRotated_backward_cpu"; + at::checkAllSameType(c, {grad_t, rois_t}); + + at::Tensor grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + // handle possibly empty gradients + if (grad.numel() == 0) { + return grad_input; + } + + // get stride values to ensure indexing into gradients is correct. + int n_stride = grad.stride(0); + int c_stride = grad.stride(1); + int h_stride = grad.stride(2); + int w_stride = grad.stride(3); + + auto rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + grad.scalar_type(), "ROIAlignRotated_forward", [&] { + ROIAlignRotatedBackward( + grad.numel(), + grad.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois_.data_ptr(), + n_stride, + c_stride, + h_stride, + w_stride); + }); + return grad_input; +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..fca186519143b168a912c880a4cf495a0a5a9322 --- /dev/null +++ b/custom_detectron2/layers/csrc/ROIAlignRotated/ROIAlignRotated_cuda.cu @@ -0,0 +1,443 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include + +// TODO make it in a common file +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < n; \ + i += blockDim.x * gridDim.x) + +// Note: this implementation originates from the Caffe2 ROIAlignRotated Op +// and PyTorch ROIAlign (non-rotated) Op implementations. +// The key difference between this implementation and those ones is +// we don't do "legacy offset" in this version, as there aren't many previous +// works, if any, using the "legacy" ROIAlignRotated Op. +// This would make the interface a bit cleaner. + +namespace detectron2 { + +namespace { + +template +__device__ T bilinear_interpolate( + const T* input, + const int height, + const int width, + T y, + T x) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + return 0; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + int y_low = (int)y; + int x_low = (int)x; + int y_high; + int x_high; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + // do bilinear interpolation + T v1 = input[y_low * width + x_low]; + T v2 = input[y_low * width + x_high]; + T v3 = input[y_high * width + x_low]; + T v4 = input[y_high * width + x_high]; + T w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + return val; +} + +template +__device__ void bilinear_interpolate_gradient( + const int height, + const int width, + T y, + T x, + T& w1, + T& w2, + T& w3, + T& w4, + int& x_low, + int& x_high, + int& y_low, + int& y_high) { + // deal with cases that inverse elements are out of feature map boundary + if (y < -1.0 || y > height || x < -1.0 || x > width) { + // empty + w1 = w2 = w3 = w4 = 0.; + x_low = x_high = y_low = y_high = -1; + return; + } + + if (y < 0) { + y = 0; + } + + if (x < 0) { + x = 0; + } + + y_low = (int)y; + x_low = (int)x; + + if (y_low >= height - 1) { + y_high = y_low = height - 1; + y = (T)y_low; + } else { + y_high = y_low + 1; + } + + if (x_low >= width - 1) { + x_high = x_low = width - 1; + x = (T)x_low; + } else { + x_high = x_low + 1; + } + + T ly = y - y_low; + T lx = x - x_low; + T hy = 1. - ly, hx = 1. - lx; + + // reference in forward + // T v1 = input[y_low * width + x_low]; + // T v2 = input[y_low * width + x_high]; + // T v3 = input[y_high * width + x_low]; + // T v4 = input[y_high * width + x_high]; + // T val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + + w1 = hy * hx, w2 = hy * lx, w3 = ly * hx, w4 = ly * lx; + + return; +} + +} // namespace + +template +__global__ void RoIAlignRotatedForward( + const int nthreads, + const T* input, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + const T* rois, + T* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + const T* offset_input = + input + (roi_batch_ind * channels + c) * height * width; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (inte gral) pooling inside a bin + const T count = max(roi_bin_grid_h * roi_bin_grid_w, 1); // e.g. = 4 + + T output_val = 0.; + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T val = bilinear_interpolate(offset_input, height, width, y, x); + output_val += val; + } + } + output_val /= count; + + top_data[index] = output_val; + } +} + +template +__global__ void RoIAlignRotatedBackwardFeature( + const int nthreads, + const T* top_diff, + const int num_rois, + const T spatial_scale, + const int channels, + const int height, + const int width, + const int pooled_height, + const int pooled_width, + const int sampling_ratio, + T* bottom_diff, + const T* rois) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // (n, c, ph, pw) is an element in the pooled output + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + + const T* current_roi = rois + n * 6; + int roi_batch_ind = current_roi[0]; + + // Do not use rounding; this implementation detail is critical + // ROIAlignRotated supports align == true, i.e., continuous coordinate + // by default, thus the 0.5 offset + T offset = (T)0.5; + T roi_center_w = current_roi[1] * spatial_scale - offset; + T roi_center_h = current_roi[2] * spatial_scale - offset; + T roi_width = current_roi[3] * spatial_scale; + T roi_height = current_roi[4] * spatial_scale; + T theta = current_roi[5] * M_PI / 180.0; + T cos_theta = cos(theta); + T sin_theta = sin(theta); + + T bin_size_h = static_cast(roi_height) / static_cast(pooled_height); + T bin_size_w = static_cast(roi_width) / static_cast(pooled_width); + + T* offset_bottom_diff = + bottom_diff + (roi_batch_ind * channels + c) * height * width; + + int top_offset = (n * channels + c) * pooled_height * pooled_width; + const T* offset_top_diff = top_diff + top_offset; + const T top_diff_this_bin = offset_top_diff[ph * pooled_width + pw]; + + // We use roi_bin_grid to sample the grid and mimic integral + int roi_bin_grid_h = (sampling_ratio > 0) + ? sampling_ratio + : ceil(roi_height / pooled_height); // e.g., = 2 + int roi_bin_grid_w = + (sampling_ratio > 0) ? sampling_ratio : ceil(roi_width / pooled_width); + + // roi_start_h and roi_start_w are computed wrt the center of RoI (x, y). + // Appropriate translation needs to be applied after. + T roi_start_h = -roi_height / 2.0; + T roi_start_w = -roi_width / 2.0; + + // We do average (integral) pooling inside a bin + const T count = roi_bin_grid_h * roi_bin_grid_w; // e.g. = 4 + + for (int iy = 0; iy < roi_bin_grid_h; iy++) // e.g., iy = 0, 1 + { + const T yy = roi_start_h + ph * bin_size_h + + static_cast(iy + .5f) * bin_size_h / + static_cast(roi_bin_grid_h); // e.g., 0.5, 1.5 + for (int ix = 0; ix < roi_bin_grid_w; ix++) { + const T xx = roi_start_w + pw * bin_size_w + + static_cast(ix + .5f) * bin_size_w / + static_cast(roi_bin_grid_w); + + // Rotate by theta around the center and translate + T y = yy * cos_theta - xx * sin_theta + roi_center_h; + T x = yy * sin_theta + xx * cos_theta + roi_center_w; + + T w1, w2, w3, w4; + int x_low, x_high, y_low, y_high; + + bilinear_interpolate_gradient( + height, width, y, x, w1, w2, w3, w4, x_low, x_high, y_low, y_high); + + T g1 = top_diff_this_bin * w1 / count; + T g2 = top_diff_this_bin * w2 / count; + T g3 = top_diff_this_bin * w3 / count; + T g4 = top_diff_this_bin * w4 / count; + + if (x_low >= 0 && x_high >= 0 && y_low >= 0 && y_high >= 0) { + atomicAdd( + offset_bottom_diff + y_low * width + x_low, static_cast(g1)); + atomicAdd( + offset_bottom_diff + y_low * width + x_high, static_cast(g2)); + atomicAdd( + offset_bottom_diff + y_high * width + x_low, static_cast(g3)); + atomicAdd( + offset_bottom_diff + y_high * width + x_high, static_cast(g4)); + } // if + } // ix + } // iy + } // CUDA_1D_KERNEL_LOOP +} // RoIAlignRotatedBackward + +at::Tensor ROIAlignRotated_forward_cuda( + const at::Tensor& input, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int sampling_ratio) { + AT_ASSERTM(input.device().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + at::TensorArg input_t{input, "input", 1}, rois_t{rois, "rois", 2}; + + at::CheckedFrom c = "ROIAlignRotated_forward_cuda"; + at::checkAllSameGPU(c, {input_t, rois_t}); + at::checkAllSameType(c, {input_t, rois_t}); + at::cuda::CUDAGuard device_guard(input.device()); + + auto num_rois = rois.size(0); + auto channels = input.size(1); + auto height = input.size(2); + auto width = input.size(3); + + auto output = at::empty( + {num_rois, channels, pooled_height, pooled_width}, input.options()); + auto output_size = num_rois * pooled_height * pooled_width * channels; + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(output_size), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + if (output.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return output; + } + + auto input_ = input.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "ROIAlignRotated_forward", [&] { + RoIAlignRotatedForward<<>>( + output_size, + input_.data_ptr(), + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + rois_.data_ptr(), + output.data_ptr()); + }); + cudaDeviceSynchronize(); + AT_CUDA_CHECK(cudaGetLastError()); + return output; +} + +// TODO remove the dependency on input and use instead its sizes -> save memory +at::Tensor ROIAlignRotated_backward_cuda( + const at::Tensor& grad, + const at::Tensor& rois, + const float spatial_scale, + const int pooled_height, + const int pooled_width, + const int batch_size, + const int channels, + const int height, + const int width, + const int sampling_ratio) { + AT_ASSERTM(grad.device().is_cuda(), "grad must be a CUDA tensor"); + AT_ASSERTM(rois.device().is_cuda(), "rois must be a CUDA tensor"); + + at::TensorArg grad_t{grad, "grad", 1}, rois_t{rois, "rois", 2}; + at::CheckedFrom c = "ROIAlign_backward_cuda"; + at::checkAllSameGPU(c, {grad_t, rois_t}); + at::checkAllSameType(c, {grad_t, rois_t}); + at::cuda::CUDAGuard device_guard(grad.device()); + + auto num_rois = rois.size(0); + auto grad_input = + at::zeros({batch_size, channels, height, width}, grad.options()); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + dim3 grid(std::min( + at::cuda::ATenCeilDiv( + static_cast(grad.numel()), static_cast(512)), + static_cast(4096))); + dim3 block(512); + + // handle possibly empty gradients + if (grad.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; + } + + auto grad_ = grad.contiguous(), rois_ = rois.contiguous(); + AT_DISPATCH_FLOATING_TYPES( + grad.scalar_type(), "ROIAlignRotated_backward", [&] { + RoIAlignRotatedBackwardFeature<<>>( + grad.numel(), + grad_.data_ptr(), + num_rois, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + sampling_ratio, + grad_input.data_ptr(), + rois_.data_ptr()); + }); + AT_CUDA_CHECK(cudaGetLastError()); + return grad_input; +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h new file mode 100644 index 0000000000000000000000000000000000000000..3bf383b8ed9b358b5313d433a9682c294dfb77e4 --- /dev/null +++ b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated.h @@ -0,0 +1,35 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor box_iou_rotated_cpu( + const at::Tensor& boxes1, + const at::Tensor& boxes2); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor box_iou_rotated_cuda( + const at::Tensor& boxes1, + const at::Tensor& boxes2); +#endif + +// Interface for Python +// inline is needed to prevent multiple function definitions when this header is +// included by different cpps +inline at::Tensor box_iou_rotated( + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + assert(boxes1.device().is_cuda() == boxes2.device().is_cuda()); + if (boxes1.device().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return box_iou_rotated_cuda(boxes1.contiguous(), boxes2.contiguous()); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + + return box_iou_rotated_cpu(boxes1.contiguous(), boxes2.contiguous()); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c843487b5fa4e8077dd27402ec99009266ddda8d --- /dev/null +++ b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cpu.cpp @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "box_iou_rotated.h" +#include "box_iou_rotated_utils.h" + +namespace detectron2 { + +template +void box_iou_rotated_cpu_kernel( + const at::Tensor& boxes1, + const at::Tensor& boxes2, + at::Tensor& ious) { + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + + for (int i = 0; i < num_boxes1; i++) { + for (int j = 0; j < num_boxes2; j++) { + ious[i * num_boxes2 + j] = single_box_iou_rotated( + boxes1[i].data_ptr(), boxes2[j].data_ptr()); + } + } +} + +at::Tensor box_iou_rotated_cpu( + // input must be contiguous: + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + at::Tensor ious = + at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); + + box_iou_rotated_cpu_kernel(boxes1, boxes2, ious); + + // reshape from 1d array to 2d array + auto shape = std::vector{num_boxes1, num_boxes2}; + return ious.reshape(shape); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..952710e53041187907fbd113f8d0d0fa24134a86 --- /dev/null +++ b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_cuda.cu @@ -0,0 +1,130 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include +#include "box_iou_rotated_utils.h" + +namespace detectron2 { + +// 2D block with 32 * 16 = 512 threads per block +const int BLOCK_DIM_X = 32; +const int BLOCK_DIM_Y = 16; + +template +__global__ void box_iou_rotated_cuda_kernel( + const int n_boxes1, + const int n_boxes2, + const T* dev_boxes1, + const T* dev_boxes2, + T* dev_ious) { + const int row_start = blockIdx.x * blockDim.x; + const int col_start = blockIdx.y * blockDim.y; + + const int row_size = min(n_boxes1 - row_start, blockDim.x); + const int col_size = min(n_boxes2 - col_start, blockDim.y); + + __shared__ float block_boxes1[BLOCK_DIM_X * 5]; + __shared__ float block_boxes2[BLOCK_DIM_Y * 5]; + + // It's safe to copy using threadIdx.x since BLOCK_DIM_X >= BLOCK_DIM_Y + if (threadIdx.x < row_size && threadIdx.y == 0) { + block_boxes1[threadIdx.x * 5 + 0] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 0]; + block_boxes1[threadIdx.x * 5 + 1] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 1]; + block_boxes1[threadIdx.x * 5 + 2] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 2]; + block_boxes1[threadIdx.x * 5 + 3] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 3]; + block_boxes1[threadIdx.x * 5 + 4] = + dev_boxes1[(row_start + threadIdx.x) * 5 + 4]; + } + + if (threadIdx.x < col_size && threadIdx.y == 0) { + block_boxes2[threadIdx.x * 5 + 0] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 0]; + block_boxes2[threadIdx.x * 5 + 1] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 1]; + block_boxes2[threadIdx.x * 5 + 2] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 2]; + block_boxes2[threadIdx.x * 5 + 3] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 3]; + block_boxes2[threadIdx.x * 5 + 4] = + dev_boxes2[(col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size && threadIdx.y < col_size) { + int offset = (row_start + threadIdx.x) * n_boxes2 + col_start + threadIdx.y; + dev_ious[offset] = single_box_iou_rotated( + block_boxes1 + threadIdx.x * 5, block_boxes2 + threadIdx.y * 5); + } +} + +at::Tensor box_iou_rotated_cuda( + // input must be contiguous + const at::Tensor& boxes1, + const at::Tensor& boxes2) { + using scalar_t = float; + AT_ASSERTM( + boxes1.scalar_type() == at::kFloat, "boxes1 must be a float tensor"); + AT_ASSERTM( + boxes2.scalar_type() == at::kFloat, "boxes2 must be a float tensor"); + AT_ASSERTM(boxes1.is_cuda(), "boxes1 must be a CUDA tensor"); + AT_ASSERTM(boxes2.is_cuda(), "boxes2 must be a CUDA tensor"); + at::cuda::CUDAGuard device_guard(boxes1.device()); + + auto num_boxes1 = boxes1.size(0); + auto num_boxes2 = boxes2.size(0); + + at::Tensor ious = + at::empty({num_boxes1 * num_boxes2}, boxes1.options().dtype(at::kFloat)); + + bool transpose = false; + if (num_boxes1 > 0 && num_boxes2 > 0) { + scalar_t *data1 = boxes1.data_ptr(), + *data2 = boxes2.data_ptr(); + + if (num_boxes2 > 65535 * BLOCK_DIM_Y) { + AT_ASSERTM( + num_boxes1 <= 65535 * BLOCK_DIM_Y, + "Too many boxes for box_iou_rotated_cuda!"); + // x dim is allowed to be large, but y dim cannot, + // so we transpose the two to avoid "invalid configuration argument" + // error. We assume one of them is small. Otherwise the result is hard to + // fit in memory anyway. + std::swap(num_boxes1, num_boxes2); + std::swap(data1, data2); + transpose = true; + } + + const int blocks_x = + at::cuda::ATenCeilDiv(static_cast(num_boxes1), BLOCK_DIM_X); + const int blocks_y = + at::cuda::ATenCeilDiv(static_cast(num_boxes2), BLOCK_DIM_Y); + + dim3 blocks(blocks_x, blocks_y); + dim3 threads(BLOCK_DIM_X, BLOCK_DIM_Y); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + box_iou_rotated_cuda_kernel<<>>( + num_boxes1, + num_boxes2, + data1, + data2, + (scalar_t*)ious.data_ptr()); + + AT_CUDA_CHECK(cudaGetLastError()); + } + + // reshape from 1d array to 2d array + auto shape = std::vector{num_boxes1, num_boxes2}; + if (transpose) { + return ious.view(shape).t(); + } else { + return ious.view(shape); + } +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..b54a5dde2ca11a74d29c4d8adb7fe1634f5baf9c --- /dev/null +++ b/custom_detectron2/layers/csrc/box_iou_rotated/box_iou_rotated_utils.h @@ -0,0 +1,370 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once + +#include +#include + +#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1 +// Designates functions callable from the host (CPU) and the device (GPU) +#define HOST_DEVICE __host__ __device__ +#define HOST_DEVICE_INLINE HOST_DEVICE __forceinline__ +#else +#include +#define HOST_DEVICE +#define HOST_DEVICE_INLINE HOST_DEVICE inline +#endif + +namespace detectron2 { + +namespace { + +template +struct RotatedBox { + T x_ctr, y_ctr, w, h, a; +}; + +template +struct Point { + T x, y; + HOST_DEVICE_INLINE Point(const T& px = 0, const T& py = 0) : x(px), y(py) {} + HOST_DEVICE_INLINE Point operator+(const Point& p) const { + return Point(x + p.x, y + p.y); + } + HOST_DEVICE_INLINE Point& operator+=(const Point& p) { + x += p.x; + y += p.y; + return *this; + } + HOST_DEVICE_INLINE Point operator-(const Point& p) const { + return Point(x - p.x, y - p.y); + } + HOST_DEVICE_INLINE Point operator*(const T coeff) const { + return Point(x * coeff, y * coeff); + } +}; + +template +HOST_DEVICE_INLINE T dot_2d(const Point& A, const Point& B) { + return A.x * B.x + A.y * B.y; +} + +// R: result type. can be different from input type +template +HOST_DEVICE_INLINE R cross_2d(const Point& A, const Point& B) { + return static_cast(A.x) * static_cast(B.y) - + static_cast(B.x) * static_cast(A.y); +} + +template +HOST_DEVICE_INLINE void get_rotated_vertices( + const RotatedBox& box, + Point (&pts)[4]) { + // M_PI / 180. == 0.01745329251 + double theta = box.a * 0.01745329251; + T cosTheta2 = (T)cos(theta) * 0.5f; + T sinTheta2 = (T)sin(theta) * 0.5f; + + // y: top --> down; x: left --> right + pts[0].x = box.x_ctr + sinTheta2 * box.h + cosTheta2 * box.w; + pts[0].y = box.y_ctr + cosTheta2 * box.h - sinTheta2 * box.w; + pts[1].x = box.x_ctr - sinTheta2 * box.h + cosTheta2 * box.w; + pts[1].y = box.y_ctr - cosTheta2 * box.h - sinTheta2 * box.w; + pts[2].x = 2 * box.x_ctr - pts[0].x; + pts[2].y = 2 * box.y_ctr - pts[0].y; + pts[3].x = 2 * box.x_ctr - pts[1].x; + pts[3].y = 2 * box.y_ctr - pts[1].y; +} + +template +HOST_DEVICE_INLINE int get_intersection_points( + const Point (&pts1)[4], + const Point (&pts2)[4], + Point (&intersections)[24]) { + // Line vector + // A line from p1 to p2 is: p1 + (p2-p1)*t, t=[0,1] + Point vec1[4], vec2[4]; + for (int i = 0; i < 4; i++) { + vec1[i] = pts1[(i + 1) % 4] - pts1[i]; + vec2[i] = pts2[(i + 1) % 4] - pts2[i]; + } + + // When computing the intersection area, it doesn't hurt if we have + // more (duplicated/approximate) intersections/vertices than needed, + // while it can cause drastic difference if we miss an intersection/vertex. + // Therefore, we add an epsilon to relax the comparisons between + // the float point numbers that decide the intersection points. + double EPS = 1e-5; + + // Line test - test all line combos for intersection + int num = 0; // number of intersections + for (int i = 0; i < 4; i++) { + for (int j = 0; j < 4; j++) { + // Solve for 2x2 Ax=b + T det = cross_2d(vec2[j], vec1[i]); + + // This takes care of parallel lines + if (fabs(det) <= 1e-14) { + continue; + } + + auto vec12 = pts2[j] - pts1[i]; + + T t1 = cross_2d(vec2[j], vec12) / det; + T t2 = cross_2d(vec1[i], vec12) / det; + + if (t1 > -EPS && t1 < 1.0f + EPS && t2 > -EPS && t2 < 1.0f + EPS) { + intersections[num++] = pts1[i] + vec1[i] * t1; + } + } + } + + // Check for vertices of rect1 inside rect2 + { + const auto& AB = vec2[0]; + const auto& DA = vec2[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + // assume ABCD is the rectangle, and P is the point to be judged + // P is inside ABCD iff. P's projection on AB lies within AB + // and P's projection on AD lies within AD + + auto AP = pts1[i] - pts2[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) && + (APdotAD < ADdotAD + EPS)) { + intersections[num++] = pts1[i]; + } + } + } + + // Reverse the check - check for vertices of rect2 inside rect1 + { + const auto& AB = vec1[0]; + const auto& DA = vec1[3]; + auto ABdotAB = dot_2d(AB, AB); + auto ADdotAD = dot_2d(DA, DA); + for (int i = 0; i < 4; i++) { + auto AP = pts2[i] - pts1[0]; + + auto APdotAB = dot_2d(AP, AB); + auto APdotAD = -dot_2d(AP, DA); + + if ((APdotAB > -EPS) && (APdotAD > -EPS) && (APdotAB < ABdotAB + EPS) && + (APdotAD < ADdotAD + EPS)) { + intersections[num++] = pts2[i]; + } + } + } + + return num; +} + +template +HOST_DEVICE_INLINE int convex_hull_graham( + const Point (&p)[24], + const int& num_in, + Point (&q)[24], + bool shift_to_zero = false) { + assert(num_in >= 2); + + // Step 1: + // Find point with minimum y + // if more than 1 points have the same minimum y, + // pick the one with the minimum x. + int t = 0; + for (int i = 1; i < num_in; i++) { + if (p[i].y < p[t].y || (p[i].y == p[t].y && p[i].x < p[t].x)) { + t = i; + } + } + auto& start = p[t]; // starting point + + // Step 2: + // Subtract starting point from every points (for sorting in the next step) + for (int i = 0; i < num_in; i++) { + q[i] = p[i] - start; + } + + // Swap the starting point to position 0 + auto tmp = q[0]; + q[0] = q[t]; + q[t] = tmp; + + // Step 3: + // Sort point 1 ~ num_in according to their relative cross-product values + // (essentially sorting according to angles) + // If the angles are the same, sort according to their distance to origin + T dist[24]; +#if defined(__CUDACC__) || __HCC__ == 1 || __HIP__ == 1 + // compute distance to origin before sort, and sort them together with the + // points + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } + + // CUDA version + // In the future, we can potentially use thrust + // for sorting here to improve speed (though not guaranteed) + for (int i = 1; i < num_in - 1; i++) { + for (int j = i + 1; j < num_in; j++) { + T crossProduct = cross_2d(q[i], q[j]); + if ((crossProduct < -1e-6) || + (fabs(crossProduct) < 1e-6 && dist[i] > dist[j])) { + auto q_tmp = q[i]; + q[i] = q[j]; + q[j] = q_tmp; + auto dist_tmp = dist[i]; + dist[i] = dist[j]; + dist[j] = dist_tmp; + } + } + } +#else + // CPU version + std::sort( + q + 1, q + num_in, [](const Point& A, const Point& B) -> bool { + T temp = cross_2d(A, B); + if (fabs(temp) < 1e-6) { + return dot_2d(A, A) < dot_2d(B, B); + } else { + return temp > 0; + } + }); + // compute distance to origin after sort, since the points are now different. + for (int i = 0; i < num_in; i++) { + dist[i] = dot_2d(q[i], q[i]); + } +#endif + + // Step 4: + // Make sure there are at least 2 points (that don't overlap with each other) + // in the stack + int k; // index of the non-overlapped second point + for (k = 1; k < num_in; k++) { + if (dist[k] > 1e-8) { + break; + } + } + if (k == num_in) { + // We reach the end, which means the convex hull is just one point + q[0] = p[t]; + return 1; + } + q[1] = q[k]; + int m = 2; // 2 points in the stack + // Step 5: + // Finally we can start the scanning process. + // When a non-convex relationship between the 3 points is found + // (either concave shape or duplicated points), + // we pop the previous point from the stack + // until the 3-point relationship is convex again, or + // until the stack only contains two points + for (int i = k + 1; i < num_in; i++) { + while (m > 1) { + auto q1 = q[i] - q[m - 2], q2 = q[m - 1] - q[m - 2]; + // cross_2d() uses FMA and therefore computes round(round(q1.x*q2.y) - + // q2.x*q1.y) So it may not return 0 even when q1==q2. Therefore we + // compare round(q1.x*q2.y) and round(q2.x*q1.y) directly. (round means + // round to nearest floating point). + if (q1.x * q2.y >= q2.x * q1.y) + m--; + else + break; + } + // Using double also helps, but float can solve the issue for now. + // while (m > 1 && cross_2d(q[i] - q[m - 2], q[m - 1] - q[m - 2]) + // >= 0) { + // m--; + // } + q[m++] = q[i]; + } + + // Step 6 (Optional): + // In general sense we need the original coordinates, so we + // need to shift the points back (reverting Step 2) + // But if we're only interested in getting the area/perimeter of the shape + // We can simply return. + if (!shift_to_zero) { + for (int i = 0; i < m; i++) { + q[i] += start; + } + } + + return m; +} + +template +HOST_DEVICE_INLINE T polygon_area(const Point (&q)[24], const int& m) { + if (m <= 2) { + return 0; + } + + T area = 0; + for (int i = 1; i < m - 1; i++) { + area += fabs(cross_2d(q[i] - q[0], q[i + 1] - q[0])); + } + + return area / 2.0; +} + +template +HOST_DEVICE_INLINE T rotated_boxes_intersection( + const RotatedBox& box1, + const RotatedBox& box2) { + // There are up to 4 x 4 + 4 + 4 = 24 intersections (including dups) returned + // from rotated_rect_intersection_pts + Point intersectPts[24], orderedPts[24]; + + Point pts1[4]; + Point pts2[4]; + get_rotated_vertices(box1, pts1); + get_rotated_vertices(box2, pts2); + + int num = get_intersection_points(pts1, pts2, intersectPts); + + if (num <= 2) { + return 0.0; + } + + // Convex Hull to order the intersection points in clockwise order and find + // the contour area. + int num_convex = convex_hull_graham(intersectPts, num, orderedPts, true); + return polygon_area(orderedPts, num_convex); +} + +} // namespace + +template +HOST_DEVICE_INLINE T +single_box_iou_rotated(T const* const box1_raw, T const* const box2_raw) { + // shift center to the middle point to achieve higher precision in result + RotatedBox box1, box2; + auto center_shift_x = (box1_raw[0] + box2_raw[0]) / 2.0; + auto center_shift_y = (box1_raw[1] + box2_raw[1]) / 2.0; + box1.x_ctr = box1_raw[0] - center_shift_x; + box1.y_ctr = box1_raw[1] - center_shift_y; + box1.w = box1_raw[2]; + box1.h = box1_raw[3]; + box1.a = box1_raw[4]; + box2.x_ctr = box2_raw[0] - center_shift_x; + box2.y_ctr = box2_raw[1] - center_shift_y; + box2.w = box2_raw[2]; + box2.h = box2_raw[3]; + box2.a = box2_raw[4]; + + T area1 = box1.w * box1.h; + T area2 = box2.w * box2.h; + if (area1 < 1e-14 || area2 < 1e-14) { + return 0.f; + } + + T intersection = rotated_boxes_intersection(box1, box2); + T iou = intersection / (area1 + area2 - intersection); + return iou; +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/cocoeval/cocoeval.cpp b/custom_detectron2/layers/csrc/cocoeval/cocoeval.cpp new file mode 100644 index 0000000000000000000000000000000000000000..0a5b7b907c06720fefc77b0dfd921b8ec3ecf2be --- /dev/null +++ b/custom_detectron2/layers/csrc/cocoeval/cocoeval.cpp @@ -0,0 +1,507 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "cocoeval.h" +#include +#include +#include +#include + +using namespace pybind11::literals; + +namespace detectron2 { + +namespace COCOeval { + +// Sort detections from highest score to lowest, such that +// detection_instances[detection_sorted_indices[t]] >= +// detection_instances[detection_sorted_indices[t+1]]. Use stable_sort to match +// original COCO API +void SortInstancesByDetectionScore( + const std::vector& detection_instances, + std::vector* detection_sorted_indices) { + detection_sorted_indices->resize(detection_instances.size()); + std::iota( + detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); + std::stable_sort( + detection_sorted_indices->begin(), + detection_sorted_indices->end(), + [&detection_instances](size_t j1, size_t j2) { + return detection_instances[j1].score > detection_instances[j2].score; + }); +} + +// Partition the ground truth objects based on whether or not to ignore them +// based on area +void SortInstancesByIgnore( + const std::array& area_range, + const std::vector& ground_truth_instances, + std::vector* ground_truth_sorted_indices, + std::vector* ignores) { + ignores->clear(); + ignores->reserve(ground_truth_instances.size()); + for (auto o : ground_truth_instances) { + ignores->push_back( + o.ignore || o.area < area_range[0] || o.area > area_range[1]); + } + + ground_truth_sorted_indices->resize(ground_truth_instances.size()); + std::iota( + ground_truth_sorted_indices->begin(), + ground_truth_sorted_indices->end(), + 0); + std::stable_sort( + ground_truth_sorted_indices->begin(), + ground_truth_sorted_indices->end(), + [&ignores](size_t j1, size_t j2) { + return (int)(*ignores)[j1] < (int)(*ignores)[j2]; + }); +} + +// For each IOU threshold, greedily match each detected instance to a ground +// truth instance (if possible) and store the results +void MatchDetectionsToGroundTruth( + const std::vector& detection_instances, + const std::vector& detection_sorted_indices, + const std::vector& ground_truth_instances, + const std::vector& ground_truth_sorted_indices, + const std::vector& ignores, + const std::vector>& ious, + const std::vector& iou_thresholds, + const std::array& area_range, + ImageEvaluation* results) { + // Initialize memory to store return data matches and ignore + const int num_iou_thresholds = iou_thresholds.size(); + const int num_ground_truth = ground_truth_sorted_indices.size(); + const int num_detections = detection_sorted_indices.size(); + std::vector ground_truth_matches( + num_iou_thresholds * num_ground_truth, 0); + std::vector& detection_matches = results->detection_matches; + std::vector& detection_ignores = results->detection_ignores; + std::vector& ground_truth_ignores = results->ground_truth_ignores; + detection_matches.resize(num_iou_thresholds * num_detections, 0); + detection_ignores.resize(num_iou_thresholds * num_detections, false); + ground_truth_ignores.resize(num_ground_truth); + for (auto g = 0; g < num_ground_truth; ++g) { + ground_truth_ignores[g] = ignores[ground_truth_sorted_indices[g]]; + } + + for (auto t = 0; t < num_iou_thresholds; ++t) { + for (auto d = 0; d < num_detections; ++d) { + // information about best match so far (match=-1 -> unmatched) + double best_iou = std::min(iou_thresholds[t], 1 - 1e-10); + int match = -1; + for (auto g = 0; g < num_ground_truth; ++g) { + // if this ground truth instance is already matched and not a + // crowd, it cannot be matched to another detection + if (ground_truth_matches[t * num_ground_truth + g] > 0 && + !ground_truth_instances[ground_truth_sorted_indices[g]].is_crowd) { + continue; + } + + // if detected instance matched to a regular ground truth + // instance, we can break on the first ground truth instance + // tagged as ignore (because they are sorted by the ignore tag) + if (match >= 0 && !ground_truth_ignores[match] && + ground_truth_ignores[g]) { + break; + } + + // if IOU overlap is the best so far, store the match appropriately + if (ious[d][ground_truth_sorted_indices[g]] >= best_iou) { + best_iou = ious[d][ground_truth_sorted_indices[g]]; + match = g; + } + } + // if match was made, store id of match for both detection and + // ground truth + if (match >= 0) { + detection_ignores[t * num_detections + d] = ground_truth_ignores[match]; + detection_matches[t * num_detections + d] = + ground_truth_instances[ground_truth_sorted_indices[match]].id; + ground_truth_matches[t * num_ground_truth + match] = + detection_instances[detection_sorted_indices[d]].id; + } + + // set unmatched detections outside of area range to ignore + const InstanceAnnotation& detection = + detection_instances[detection_sorted_indices[d]]; + detection_ignores[t * num_detections + d] = + detection_ignores[t * num_detections + d] || + (detection_matches[t * num_detections + d] == 0 && + (detection.area < area_range[0] || detection.area > area_range[1])); + } + } + + // store detection score results + results->detection_scores.resize(detection_sorted_indices.size()); + for (size_t d = 0; d < detection_sorted_indices.size(); ++d) { + results->detection_scores[d] = + detection_instances[detection_sorted_indices[d]].score; + } +} + +std::vector EvaluateImages( + const std::vector>& area_ranges, + int max_detections, + const std::vector& iou_thresholds, + const ImageCategoryInstances>& image_category_ious, + const ImageCategoryInstances& + image_category_ground_truth_instances, + const ImageCategoryInstances& + image_category_detection_instances) { + const int num_area_ranges = area_ranges.size(); + const int num_images = image_category_ground_truth_instances.size(); + const int num_categories = + image_category_ious.size() > 0 ? image_category_ious[0].size() : 0; + std::vector detection_sorted_indices; + std::vector ground_truth_sorted_indices; + std::vector ignores; + std::vector results_all( + num_images * num_area_ranges * num_categories); + + // Store results for each image, category, and area range combination. Results + // for each IOU threshold are packed into the same ImageEvaluation object + for (auto i = 0; i < num_images; ++i) { + for (auto c = 0; c < num_categories; ++c) { + const std::vector& ground_truth_instances = + image_category_ground_truth_instances[i][c]; + const std::vector& detection_instances = + image_category_detection_instances[i][c]; + + SortInstancesByDetectionScore( + detection_instances, &detection_sorted_indices); + if ((int)detection_sorted_indices.size() > max_detections) { + detection_sorted_indices.resize(max_detections); + } + + for (size_t a = 0; a < area_ranges.size(); ++a) { + SortInstancesByIgnore( + area_ranges[a], + ground_truth_instances, + &ground_truth_sorted_indices, + &ignores); + + MatchDetectionsToGroundTruth( + detection_instances, + detection_sorted_indices, + ground_truth_instances, + ground_truth_sorted_indices, + ignores, + image_category_ious[i][c], + iou_thresholds, + area_ranges[a], + &results_all + [c * num_area_ranges * num_images + a * num_images + i]); + } + } + } + + return results_all; +} + +// Convert a python list to a vector +template +std::vector list_to_vec(const py::list& l) { + std::vector v(py::len(l)); + for (int i = 0; i < (int)py::len(l); ++i) { + v[i] = l[i].cast(); + } + return v; +} + +// Helper function to Accumulate() +// Considers the evaluation results applicable to a particular category, area +// range, and max_detections parameter setting, which begin at +// evaluations[evaluation_index]. Extracts a sorted list of length n of all +// applicable detection instances concatenated across all images in the dataset, +// which are represented by the outputs evaluation_indices, detection_scores, +// image_detection_indices, and detection_sorted_indices--all of which are +// length n. evaluation_indices[i] stores the applicable index into +// evaluations[] for instance i, which has detection score detection_score[i], +// and is the image_detection_indices[i]'th of the list of detections +// for the image containing i. detection_sorted_indices[] defines a sorted +// permutation of the 3 other outputs +int BuildSortedDetectionList( + const std::vector& evaluations, + const int64_t evaluation_index, + const int64_t num_images, + const int max_detections, + std::vector* evaluation_indices, + std::vector* detection_scores, + std::vector* detection_sorted_indices, + std::vector* image_detection_indices) { + assert(evaluations.size() >= evaluation_index + num_images); + + // Extract a list of object instances of the applicable category, area + // range, and max detections requirements such that they can be sorted + image_detection_indices->clear(); + evaluation_indices->clear(); + detection_scores->clear(); + image_detection_indices->reserve(num_images * max_detections); + evaluation_indices->reserve(num_images * max_detections); + detection_scores->reserve(num_images * max_detections); + int num_valid_ground_truth = 0; + for (auto i = 0; i < num_images; ++i) { + const ImageEvaluation& evaluation = evaluations[evaluation_index + i]; + + for (int d = 0; + d < (int)evaluation.detection_scores.size() && d < max_detections; + ++d) { // detected instances + evaluation_indices->push_back(evaluation_index + i); + image_detection_indices->push_back(d); + detection_scores->push_back(evaluation.detection_scores[d]); + } + for (auto ground_truth_ignore : evaluation.ground_truth_ignores) { + if (!ground_truth_ignore) { + ++num_valid_ground_truth; + } + } + } + + // Sort detections by decreasing score, using stable sort to match + // python implementation + detection_sorted_indices->resize(detection_scores->size()); + std::iota( + detection_sorted_indices->begin(), detection_sorted_indices->end(), 0); + std::stable_sort( + detection_sorted_indices->begin(), + detection_sorted_indices->end(), + [&detection_scores](size_t j1, size_t j2) { + return (*detection_scores)[j1] > (*detection_scores)[j2]; + }); + + return num_valid_ground_truth; +} + +// Helper function to Accumulate() +// Compute a precision recall curve given a sorted list of detected instances +// encoded in evaluations, evaluation_indices, detection_scores, +// detection_sorted_indices, image_detection_indices (see +// BuildSortedDetectionList()). Using vectors precisions and recalls +// and temporary storage, output the results into precisions_out, recalls_out, +// and scores_out, which are large buffers containing many precion/recall curves +// for all possible parameter settings, with precisions_out_index and +// recalls_out_index defining the applicable indices to store results. +void ComputePrecisionRecallCurve( + const int64_t precisions_out_index, + const int64_t precisions_out_stride, + const int64_t recalls_out_index, + const std::vector& recall_thresholds, + const int iou_threshold_index, + const int num_iou_thresholds, + const int num_valid_ground_truth, + const std::vector& evaluations, + const std::vector& evaluation_indices, + const std::vector& detection_scores, + const std::vector& detection_sorted_indices, + const std::vector& image_detection_indices, + std::vector* precisions, + std::vector* recalls, + std::vector* precisions_out, + std::vector* scores_out, + std::vector* recalls_out) { + assert(recalls_out->size() > recalls_out_index); + + // Compute precision/recall for each instance in the sorted list of detections + int64_t true_positives_sum = 0, false_positives_sum = 0; + precisions->clear(); + recalls->clear(); + precisions->reserve(detection_sorted_indices.size()); + recalls->reserve(detection_sorted_indices.size()); + assert(!evaluations.empty() || detection_sorted_indices.empty()); + for (auto detection_sorted_index : detection_sorted_indices) { + const ImageEvaluation& evaluation = + evaluations[evaluation_indices[detection_sorted_index]]; + const auto num_detections = + evaluation.detection_matches.size() / num_iou_thresholds; + const auto detection_index = iou_threshold_index * num_detections + + image_detection_indices[detection_sorted_index]; + assert(evaluation.detection_matches.size() > detection_index); + assert(evaluation.detection_ignores.size() > detection_index); + const int64_t detection_match = + evaluation.detection_matches[detection_index]; + const bool detection_ignores = + evaluation.detection_ignores[detection_index]; + const auto true_positive = detection_match > 0 && !detection_ignores; + const auto false_positive = detection_match == 0 && !detection_ignores; + if (true_positive) { + ++true_positives_sum; + } + if (false_positive) { + ++false_positives_sum; + } + + const double recall = + static_cast(true_positives_sum) / num_valid_ground_truth; + recalls->push_back(recall); + const int64_t num_valid_detections = + true_positives_sum + false_positives_sum; + const double precision = num_valid_detections > 0 + ? static_cast(true_positives_sum) / num_valid_detections + : 0.0; + precisions->push_back(precision); + } + + (*recalls_out)[recalls_out_index] = !recalls->empty() ? recalls->back() : 0; + + for (int64_t i = static_cast(precisions->size()) - 1; i > 0; --i) { + if ((*precisions)[i] > (*precisions)[i - 1]) { + (*precisions)[i - 1] = (*precisions)[i]; + } + } + + // Sample the per instance precision/recall list at each recall threshold + for (size_t r = 0; r < recall_thresholds.size(); ++r) { + // first index in recalls >= recall_thresholds[r] + std::vector::iterator low = std::lower_bound( + recalls->begin(), recalls->end(), recall_thresholds[r]); + size_t precisions_index = low - recalls->begin(); + + const auto results_ind = precisions_out_index + r * precisions_out_stride; + assert(results_ind < precisions_out->size()); + assert(results_ind < scores_out->size()); + if (precisions_index < precisions->size()) { + (*precisions_out)[results_ind] = (*precisions)[precisions_index]; + (*scores_out)[results_ind] = + detection_scores[detection_sorted_indices[precisions_index]]; + } else { + (*precisions_out)[results_ind] = 0; + (*scores_out)[results_ind] = 0; + } + } +} +py::dict Accumulate( + const py::object& params, + const std::vector& evaluations) { + const std::vector recall_thresholds = + list_to_vec(params.attr("recThrs")); + const std::vector max_detections = + list_to_vec(params.attr("maxDets")); + const int num_iou_thresholds = py::len(params.attr("iouThrs")); + const int num_recall_thresholds = py::len(params.attr("recThrs")); + const int num_categories = params.attr("useCats").cast() == 1 + ? py::len(params.attr("catIds")) + : 1; + const int num_area_ranges = py::len(params.attr("areaRng")); + const int num_max_detections = py::len(params.attr("maxDets")); + const int num_images = py::len(params.attr("imgIds")); + + std::vector precisions_out( + num_iou_thresholds * num_recall_thresholds * num_categories * + num_area_ranges * num_max_detections, + -1); + std::vector recalls_out( + num_iou_thresholds * num_categories * num_area_ranges * + num_max_detections, + -1); + std::vector scores_out( + num_iou_thresholds * num_recall_thresholds * num_categories * + num_area_ranges * num_max_detections, + -1); + + // Consider the list of all detected instances in the entire dataset in one + // large list. evaluation_indices, detection_scores, + // image_detection_indices, and detection_sorted_indices all have the same + // length as this list, such that each entry corresponds to one detected + // instance + std::vector evaluation_indices; // indices into evaluations[] + std::vector detection_scores; // detection scores of each instance + std::vector detection_sorted_indices; // sorted indices of all + // instances in the dataset + std::vector + image_detection_indices; // indices into the list of detected instances in + // the same image as each instance + std::vector precisions, recalls; + + for (auto c = 0; c < num_categories; ++c) { + for (auto a = 0; a < num_area_ranges; ++a) { + for (auto m = 0; m < num_max_detections; ++m) { + // The COCO PythonAPI assumes evaluations[] (the return value of + // COCOeval::EvaluateImages() is one long list storing results for each + // combination of category, area range, and image id, with categories in + // the outermost loop and images in the innermost loop. + const int64_t evaluations_index = + c * num_area_ranges * num_images + a * num_images; + int num_valid_ground_truth = BuildSortedDetectionList( + evaluations, + evaluations_index, + num_images, + max_detections[m], + &evaluation_indices, + &detection_scores, + &detection_sorted_indices, + &image_detection_indices); + + if (num_valid_ground_truth == 0) { + continue; + } + + for (auto t = 0; t < num_iou_thresholds; ++t) { + // recalls_out is a flattened vectors representing a + // num_iou_thresholds X num_categories X num_area_ranges X + // num_max_detections matrix + const int64_t recalls_out_index = + t * num_categories * num_area_ranges * num_max_detections + + c * num_area_ranges * num_max_detections + + a * num_max_detections + m; + + // precisions_out and scores_out are flattened vectors + // representing a num_iou_thresholds X num_recall_thresholds X + // num_categories X num_area_ranges X num_max_detections matrix + const int64_t precisions_out_stride = + num_categories * num_area_ranges * num_max_detections; + const int64_t precisions_out_index = t * num_recall_thresholds * + num_categories * num_area_ranges * num_max_detections + + c * num_area_ranges * num_max_detections + + a * num_max_detections + m; + + ComputePrecisionRecallCurve( + precisions_out_index, + precisions_out_stride, + recalls_out_index, + recall_thresholds, + t, + num_iou_thresholds, + num_valid_ground_truth, + evaluations, + evaluation_indices, + detection_scores, + detection_sorted_indices, + image_detection_indices, + &precisions, + &recalls, + &precisions_out, + &scores_out, + &recalls_out); + } + } + } + } + + time_t rawtime; + struct tm local_time; + std::array buffer; + time(&rawtime); +#ifdef _WIN32 + localtime_s(&local_time, &rawtime); +#else + localtime_r(&rawtime, &local_time); +#endif + strftime( + buffer.data(), 200, "%Y-%m-%d %H:%num_max_detections:%S", &local_time); + return py::dict( + "params"_a = params, + "counts"_a = std::vector( + {num_iou_thresholds, + num_recall_thresholds, + num_categories, + num_area_ranges, + num_max_detections}), + "date"_a = buffer, + "precision"_a = precisions_out, + "recall"_a = recalls_out, + "scores"_a = scores_out); +} + +} // namespace COCOeval + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/cocoeval/cocoeval.h b/custom_detectron2/layers/csrc/cocoeval/cocoeval.h new file mode 100644 index 0000000000000000000000000000000000000000..db246e49a026b7cd989b305f4d3d98100be3c912 --- /dev/null +++ b/custom_detectron2/layers/csrc/cocoeval/cocoeval.h @@ -0,0 +1,88 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once + +#include +#include +#include +#include +#include + +namespace py = pybind11; + +namespace detectron2 { + +namespace COCOeval { + +// Annotation data for a single object instance in an image +struct InstanceAnnotation { + InstanceAnnotation( + uint64_t id, + double score, + double area, + bool is_crowd, + bool ignore) + : id{id}, score{score}, area{area}, is_crowd{is_crowd}, ignore{ignore} {} + uint64_t id; + double score = 0.; + double area = 0.; + bool is_crowd = false; + bool ignore = false; +}; + +// Stores intermediate results for evaluating detection results for a single +// image that has D detected instances and G ground truth instances. This stores +// matches between detected and ground truth instances +struct ImageEvaluation { + // For each of the D detected instances, the id of the matched ground truth + // instance, or 0 if unmatched + std::vector detection_matches; + + // The detection score of each of the D detected instances + std::vector detection_scores; + + // Marks whether or not each of G instances was ignored from evaluation (e.g., + // because it's outside area_range) + std::vector ground_truth_ignores; + + // Marks whether or not each of D instances was ignored from evaluation (e.g., + // because it's outside aRng) + std::vector detection_ignores; +}; + +template +using ImageCategoryInstances = std::vector>>; + +// C++ implementation of COCO API cocoeval.py::COCOeval.evaluateImg(). For each +// combination of image, category, area range settings, and IOU thresholds to +// evaluate, it matches detected instances to ground truth instances and stores +// the results into a vector of ImageEvaluation results, which will be +// interpreted by the COCOeval::Accumulate() function to produce precion-recall +// curves. The parameters of nested vectors have the following semantics: +// image_category_ious[i][c][d][g] is the intersection over union of the d'th +// detected instance and g'th ground truth instance of +// category category_ids[c] in image image_ids[i] +// image_category_ground_truth_instances[i][c] is a vector of ground truth +// instances in image image_ids[i] of category category_ids[c] +// image_category_detection_instances[i][c] is a vector of detected +// instances in image image_ids[i] of category category_ids[c] +std::vector EvaluateImages( + const std::vector>& area_ranges, // vector of 2-tuples + int max_detections, + const std::vector& iou_thresholds, + const ImageCategoryInstances>& image_category_ious, + const ImageCategoryInstances& + image_category_ground_truth_instances, + const ImageCategoryInstances& + image_category_detection_instances); + +// C++ implementation of COCOeval.accumulate(), which generates precision +// recall curves for each set of category, IOU threshold, detection area range, +// and max number of detections parameters. It is assumed that the parameter +// evaluations is the return value of the functon COCOeval::EvaluateImages(), +// which was called with the same parameter settings params +py::dict Accumulate( + const py::object& params, + const std::vector& evalutations); + +} // namespace COCOeval +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/cuda_version.cu b/custom_detectron2/layers/csrc/cuda_version.cu new file mode 100644 index 0000000000000000000000000000000000000000..6dfe1b90c1f65c443681813fd3e3386c9faa3360 --- /dev/null +++ b/custom_detectron2/layers/csrc/cuda_version.cu @@ -0,0 +1,26 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include + +namespace detectron2 { +int get_cudart_version() { +// Not a ROCM platform: Either HIP is not used, or +// it is used, but platform is not ROCM (i.e. it is CUDA) +#if !defined(__HIP_PLATFORM_HCC__) + return CUDART_VERSION; +#else + int version = 0; + +#if HIP_VERSION_MAJOR != 0 + // Create a convention similar to that of CUDA, as assumed by other + // parts of the code. + + version = HIP_VERSION_MINOR; + version += (HIP_VERSION_MAJOR * 100); +#else + hipRuntimeGetVersion(&version); +#endif + return version; +#endif +} +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/deformable/deform_conv.h b/custom_detectron2/layers/csrc/deformable/deform_conv.h new file mode 100644 index 0000000000000000000000000000000000000000..965c1bfd47b58f9802d1c3fd69a5962517b2da61 --- /dev/null +++ b/custom_detectron2/layers/csrc/deformable/deform_conv.h @@ -0,0 +1,377 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +#if defined(WITH_CUDA) || defined(WITH_HIP) +int deform_conv_forward_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step); + +int deform_conv_backward_input_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step); + +int deform_conv_backward_parameters_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step); + +void modulated_deform_conv_cuda_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias); + +void modulated_deform_conv_cuda_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias); + +#endif + +inline int deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_forward_cuda( + input, + weight, + offset, + output, + columns, + ones, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline int deform_conv_backward_input( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + if (gradOutput.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_backward_input_cuda( + input, + offset, + gradOutput, + gradInput, + gradOffset, + weight, + columns, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline int deform_conv_backward_filter( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) { + if (gradOutput.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return deform_conv_backward_parameters_cuda( + input, + offset, + gradOutput, + gradWeight, + columns, + ones, + kW, + kH, + dW, + dH, + padW, + padH, + dilationW, + dilationH, + group, + deformable_group, + scale, + im2col_step); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline void modulated_deform_conv_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) { + if (input.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return modulated_deform_conv_cuda_forward( + input, + weight, + bias, + ones, + offset, + mask, + output, + columns, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group, + with_bias); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +inline void modulated_deform_conv_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) { + if (grad_output.is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + TORCH_CHECK(input.is_cuda(), "input tensor is not on GPU!"); + TORCH_CHECK(weight.is_cuda(), "weight tensor is not on GPU!"); + TORCH_CHECK(bias.is_cuda(), "bias tensor is not on GPU!"); + TORCH_CHECK(offset.is_cuda(), "offset tensor is not on GPU!"); + return modulated_deform_conv_cuda_backward( + input, + weight, + bias, + ones, + offset, + mask, + columns, + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group, + with_bias); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + AT_ERROR("This operator is not implemented on CPU"); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/deformable/deform_conv_cuda.cu b/custom_detectron2/layers/csrc/deformable/deform_conv_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..2072bb856ec40b61c3826cead2fb7bb7c971a089 --- /dev/null +++ b/custom_detectron2/layers/csrc/deformable/deform_conv_cuda.cu @@ -0,0 +1,1223 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda.cpp +// Original license: Apache 2.0 + +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda.c +// Original license: Apache 2.0 + +#include + +#include "deform_conv.h" + +#include +#include + +namespace detectron2 { + +void deformable_im2col( + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor data_col); + +void deformable_col2im( + const at::Tensor data_col, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_im); + +void deformable_col2im_coord( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_offset); + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor data_col); + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_im); + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, + at::Tensor grad_mask); + +void shape_check( + at::Tensor input, + at::Tensor offset, + at::Tensor* gradOutput, + at::Tensor weight, + int kH, + int kW, + int dH, + int dW, + int padH, + int padW, + int dilationH, + int dilationW, + int group, + int deformable_group) { + TORCH_CHECK( + weight.ndimension() == 4, + "4D weight tensor (nOutputPlane,nInputPlane,kH,kW) expected, " + "but got: %s", + weight.ndimension()); + + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + TORCH_CHECK( + kW > 0 && kH > 0, + "kernel size should be greater than zero, but got kH: %d kW: %d", + kH, + kW); + + TORCH_CHECK( + (weight.size(2) == kH && weight.size(3) == kW), + "kernel size should be consistent with weight, ", + "but got kH: %d kW: %d weight.size(2): %d, weight.size(3): %d", + kH, + kW, + weight.size(2), + weight.size(3)); + + TORCH_CHECK( + dW > 0 && dH > 0, + "stride should be greater than zero, but got dH: %d dW: %d", + dH, + dW); + + TORCH_CHECK( + dilationW > 0 && dilationH > 0, + "dilation should be greater than 0, but got dilationH: %d dilationW: %d", + dilationH, + dilationW); + + int ndim = input.ndimension(); + int dimf = 0; + int dimh = 1; + int dimw = 2; + + if (ndim == 4) { + dimf++; + dimh++; + dimw++; + } + + TORCH_CHECK( + ndim == 3 || ndim == 4, + "3D or 4D input tensor expected but got: %s", + ndim); + + long nInputPlane = weight.size(1) * group; + long inputHeight = input.size(dimh); + long inputWidth = input.size(dimw); + long nOutputPlane = weight.size(0); + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + + TORCH_CHECK( + nInputPlane % deformable_group == 0, + "input channels must divide deformable group size"); + + if (outputWidth < 1 || outputHeight < 1) + AT_ERROR( + "Given input size: (%ld x %ld x %ld). " + "Calculated output size: (%ld x %ld x %ld). Output size is too small", + nInputPlane, + inputHeight, + inputWidth, + nOutputPlane, + outputHeight, + outputWidth); + + TORCH_CHECK( + input.size(1) == nInputPlane, + "invalid number of input planes, expected: %d, but got: %d", + nInputPlane, + input.size(1)); + + TORCH_CHECK( + (inputHeight + 2 * padH >= kH && inputWidth + 2 * padW >= kW), + "input image is smaller than kernel"); + + TORCH_CHECK( + (offset.size(2) == outputHeight && offset.size(3) == outputWidth), + "invalid spatial size of offset, expected height: %d width: %d, but " + "got height: %d width: %d", + outputHeight, + outputWidth, + offset.size(2), + offset.size(3)); + + TORCH_CHECK( + (offset.size(1) == deformable_group * 2 * kH * kW), + "invalid number of channels of offset"); + + if (gradOutput != NULL) { + TORCH_CHECK( + gradOutput->size(dimf) == nOutputPlane, + "invalid number of gradOutput planes, expected: %d, but got: %d", + nOutputPlane, + gradOutput->size(dimf)); + + TORCH_CHECK( + (gradOutput->size(dimh) == outputHeight && + gradOutput->size(dimw) == outputWidth), + "invalid size of gradOutput, expected height: %d width: %d , but " + "got height: %d width: %d", + outputHeight, + outputWidth, + gradOutput->size(dimh), + gradOutput->size(dimw)); + } +} + +int deform_conv_forward_cuda( + at::Tensor input, + at::Tensor weight, + at::Tensor offset, + at::Tensor output, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + // todo: resize columns to include im2col: done + // todo: add im2col_step as input + // todo: add new output buffer and transpose it to output (or directly + // transpose output) todo: possibly change data indexing because of + // parallel_imgs + + shape_check( + input, + offset, + NULL, + weight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input.unsqueeze_(0); + offset.unsqueeze_(0); + } + + // todo: assert batchsize dividable by im2col_step + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + output = output.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < outputHeight * outputWidth) { + ones = at::ones({outputHeight, outputWidth}, input.options()); + } + + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + at::Tensor output_buffer = at::zeros( + {batchSize / im2col_step, + nOutputPlane, + im2col_step * outputHeight, + outputWidth}, + output.options()); + + output_buffer = output_buffer.view( + {output_buffer.size(0), + group, + output_buffer.size(1) / group, + output_buffer.size(2), + output_buffer.size(3)}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col( + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + + for (int g = 0; g < group; g++) { + output_buffer[elt][g] = output_buffer[elt][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output_buffer[elt][g]); + } + } + + output_buffer = output_buffer.view( + {output_buffer.size(0), + output_buffer.size(1) * output_buffer.size(2), + output_buffer.size(3), + output_buffer.size(4)}); + + output_buffer = output_buffer.view( + {batchSize / im2col_step, + nOutputPlane, + im2col_step, + outputHeight, + outputWidth}); + output_buffer.transpose_(1, 2); + output.copy_(output_buffer); + output = output.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + output = output.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_input_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradInput, + at::Tensor gradOffset, + at::Tensor weight, + at::Tensor columns, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + int im2col_step) { + shape_check( + input, + offset, + &gradOutput, + weight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + weight = weight.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view({1, input.size(0), input.size(1), input.size(2)}); + offset = offset.view({1, offset.size(0), offset.size(1), offset.size(2)}); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = weight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), 3, "invalid batch size of offset"); + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + // change order of grad output + gradOutput = gradOutput.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + gradOutput.transpose_(1, 2); + + gradInput = gradInput.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + gradOffset = gradOffset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + // divide into groups + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), + group, + gradOutput.size(1) / group, + gradOutput.size(2), + gradOutput.size(3), + gradOutput.size(4)}); + + for (int g = 0; g < group; g++) { + columns[g] = columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), + gradOutput[elt][g].flatten(1), + 0.0f, + 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradOutput = gradOutput.view( + {gradOutput.size(0), + gradOutput.size(1) * gradOutput.size(2), + gradOutput.size(3), + gradOutput.size(4), + gradOutput.size(5)}); + + deformable_col2im_coord( + columns, + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + gradOffset[elt]); + + deformable_col2im( + columns, + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + gradInput[elt]); + } + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + gradInput = gradInput.view({batchSize, nInputPlane, inputHeight, inputWidth}); + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + gradOffset = gradOffset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + gradInput = gradInput.view({nInputPlane, inputHeight, inputWidth}); + offset = offset.view({offset.size(1), offset.size(2), offset.size(3)}); + gradOffset = + gradOffset.view({offset.size(1), offset.size(2), offset.size(3)}); + } + + return 1; +} + +int deform_conv_backward_parameters_cuda( + at::Tensor input, + at::Tensor offset, + at::Tensor gradOutput, + at::Tensor gradWeight, // at::Tensor gradBias, + at::Tensor columns, + at::Tensor ones, + int kW, + int kH, + int dW, + int dH, + int padW, + int padH, + int dilationW, + int dilationH, + int group, + int deformable_group, + float scale, + int im2col_step) { + // todo: transpose and reshape outGrad + // todo: reshape columns + // todo: add im2col_step as input + + shape_check( + input, + offset, + &gradOutput, + gradWeight, + kH, + kW, + dH, + dW, + padH, + padW, + dilationH, + dilationW, + group, + deformable_group); + + input = input.contiguous(); + offset = offset.contiguous(); + gradOutput = gradOutput.contiguous(); + + int batch = 1; + + if (input.ndimension() == 3) { + // Force batch + batch = 0; + input = input.view( + at::IntList({1, input.size(0), input.size(1), input.size(2)})); + gradOutput = gradOutput.view( + {1, gradOutput.size(0), gradOutput.size(1), gradOutput.size(2)}); + } + + long batchSize = input.size(0); + long nInputPlane = input.size(1); + long inputHeight = input.size(2); + long inputWidth = input.size(3); + + long nOutputPlane = gradWeight.size(0); + + long outputWidth = + (inputWidth + 2 * padW - (dilationW * (kW - 1) + 1)) / dW + 1; + long outputHeight = + (inputHeight + 2 * padH - (dilationH * (kH - 1) + 1)) / dH + 1; + + TORCH_CHECK((offset.size(0) == batchSize), "invalid batch size of offset"); + + columns = at::zeros( + {nInputPlane * kW * kH, im2col_step * outputHeight * outputWidth}, + input.options()); + + gradOutput = gradOutput.view( + {batchSize / im2col_step, + im2col_step, + nOutputPlane, + outputHeight, + outputWidth}); + gradOutput.transpose_(1, 2); + + at::Tensor gradOutputBuffer = at::zeros_like(gradOutput); + gradOutputBuffer = gradOutputBuffer.view( + {batchSize / im2col_step, + nOutputPlane, + im2col_step, + outputHeight, + outputWidth}); + gradOutputBuffer.copy_(gradOutput); + // gradOutput is not contiguous, so we do reshape (instead of view) next + gradOutputBuffer = gradOutputBuffer.reshape( + {batchSize / im2col_step, + nOutputPlane, + im2col_step * outputHeight, + outputWidth}); + + gradOutput.transpose_(1, 2); + gradOutput = + gradOutput.view({batchSize, nOutputPlane, outputHeight, outputWidth}); + + input = input.view( + {batchSize / im2col_step, + im2col_step, + nInputPlane, + inputHeight, + inputWidth}); + offset = offset.view( + {batchSize / im2col_step, + im2col_step, + deformable_group * 2 * kH * kW, + outputHeight, + outputWidth}); + + for (int elt = 0; elt < batchSize / im2col_step; elt++) { + deformable_im2col( + input[elt], + offset[elt], + nInputPlane, + inputHeight, + inputWidth, + kH, + kW, + padH, + padW, + dH, + dW, + dilationH, + dilationW, + im2col_step, + deformable_group, + columns); + + // divide into group + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + group, + gradOutputBuffer.size(1) / group, + gradOutputBuffer.size(2), + gradOutputBuffer.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + gradWeight = gradWeight.view( + {group, + gradWeight.size(0) / group, + gradWeight.size(1), + gradWeight.size(2), + gradWeight.size(3)}); + + for (int g = 0; g < group; g++) { + gradWeight[g] = gradWeight[g] + .flatten(1) + .addmm_( + gradOutputBuffer[elt][g].flatten(1), + columns[g].transpose(1, 0), + 1.0, + scale) + .view_as(gradWeight[g]); + } + gradOutputBuffer = gradOutputBuffer.view( + {gradOutputBuffer.size(0), + gradOutputBuffer.size(1) * gradOutputBuffer.size(2), + gradOutputBuffer.size(3), + gradOutputBuffer.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + gradWeight = gradWeight.view( + {gradWeight.size(0) * gradWeight.size(1), + gradWeight.size(2), + gradWeight.size(3), + gradWeight.size(4)}); + } + + input = input.view({batchSize, nInputPlane, inputHeight, inputWidth}); + offset = offset.view( + {batchSize, deformable_group * 2 * kH * kW, outputHeight, outputWidth}); + + if (batch == 0) { + gradOutput = gradOutput.view({nOutputPlane, outputHeight, outputWidth}); + input = input.view({nInputPlane, inputHeight, inputWidth}); + } + + return 1; +} + +void modulated_deform_conv_cuda_forward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor output, + at::Tensor columns, + int kernel_h, + int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int group, + const int deformable_group, + const bool with_bias) { + shape_check( + input, + offset, + NULL, + weight, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group); + + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR( + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, + kernel_w, + kernel_h_, + kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR( + "Input shape and kernel channels wont match: (%d vs %d).", + channels, + channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // mask shape check + TORCH_CHECK( + (mask.size(2) == height_out && mask.size(3) == width_out), + "invalid spatial size of mask, expected height: %d width: %d, but " + "got height: %d width: %d", + height_out, + width_out, + mask.size(2), + mask.size(3)); + + TORCH_CHECK( + (mask.size(1) == deformable_group * kernel_h * kernel_w), + "invalid number of channels of mask"); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + // resize output + output = output.view({batch, channels_out, height_out, width_out}).zero_(); + // resize temporary columns + columns = at::zeros( + {channels * kernel_h * kernel_w, 1 * height_out * width_out}, + input.options()); + + output = output.view( + {output.size(0), + group, + output.size(1) / group, + output.size(2), + output.size(3)}); + + for (int b = 0; b < batch; b++) { + modulated_deformable_im2col_cuda( + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns); + + // divide into group + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + + for (int g = 0; g < group; g++) { + output[b][g] = output[b][g] + .flatten(1) + .addmm_(weight[g].flatten(1), columns[g]) + .view_as(output[b][g]); + } + + weight = weight.view( + {weight.size(0) * weight.size(1), + weight.size(2), + weight.size(3), + weight.size(4)}); + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + } + + output = output.view( + {output.size(0), + output.size(1) * output.size(2), + output.size(3), + output.size(4)}); + + if (with_bias) { + output += bias.view({1, bias.size(0), 1, 1}); + } +} + +void modulated_deform_conv_cuda_backward( + at::Tensor input, + at::Tensor weight, + at::Tensor bias, + at::Tensor ones, + at::Tensor offset, + at::Tensor mask, + at::Tensor columns, + at::Tensor grad_input, + at::Tensor grad_weight, + at::Tensor grad_bias, + at::Tensor grad_offset, + at::Tensor grad_mask, + at::Tensor grad_output, + int kernel_h, + int kernel_w, + int stride_h, + int stride_w, + int pad_h, + int pad_w, + int dilation_h, + int dilation_w, + int group, + int deformable_group, + const bool with_bias) { + shape_check( + input, + offset, + &grad_output, + weight, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + group, + deformable_group); + + TORCH_CHECK(input.is_contiguous(), "input tensor has to be contiguous"); + TORCH_CHECK(weight.is_contiguous(), "weight tensor has to be contiguous"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + AT_ERROR( + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, + kernel_w, + kernel_h_, + kernel_w_); + if (channels != channels_kernel * group) + AT_ERROR( + "Input shape and kernel channels wont match: (%d vs %d).", + channels, + channels_kernel * group); + + const int height_out = + (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = + (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + // mask shape check + TORCH_CHECK( + (mask.size(2) == height_out && mask.size(3) == width_out), + "invalid spatial size of mask, expected height: %d width: %d, but " + "got height: %d width: %d", + height_out, + width_out, + mask.size(2), + mask.size(3)); + + TORCH_CHECK( + (mask.size(1) == deformable_group * kernel_h * kernel_w), + "invalid number of channels of mask"); + + if (ones.ndimension() != 2 || + ones.size(0) * ones.size(1) < height_out * width_out) { + // Resize plane and fill with ones... + ones = at::ones({height_out, width_out}, input.options()); + } + + grad_input = grad_input.view({batch, channels, height, width}); + columns = at::zeros( + {channels * kernel_h * kernel_w, height_out * width_out}, + input.options()); + + grad_output = grad_output.view( + {grad_output.size(0), + group, + grad_output.size(1) / group, + grad_output.size(2), + grad_output.size(3)}); + + for (int b = 0; b < batch; b++) { + // divide int group + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + weight = weight.view( + {group, + weight.size(0) / group, + weight.size(1), + weight.size(2), + weight.size(3)}); + + for (int g = 0; g < group; g++) { + columns[g].addmm_( + weight[g].flatten(1).transpose(0, 1), + grad_output[b][g].flatten(1), + 0.0f, + 1.0f); + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + weight = weight.view( + {weight.size(0) * weight.size(1), + weight.size(2), + weight.size(3), + weight.size(4)}); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda( + columns, + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + grad_offset[b], + grad_mask[b]); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda( + columns, + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + grad_input[b]); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and + // group + modulated_deformable_im2col_cuda( + input[b], + offset[b], + mask[b], + 1, + channels, + height, + width, + height_out, + width_out, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + deformable_group, + columns); + + columns = columns.view({group, columns.size(0) / group, columns.size(1)}); + grad_weight = grad_weight.view( + {group, + grad_weight.size(0) / group, + grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3)}); + if (with_bias) + grad_bias = grad_bias.view({group, grad_bias.size(0) / group}); + + for (int g = 0; g < group; g++) { + grad_weight[g] = + grad_weight[g] + .flatten(1) + .addmm_(grad_output[b][g].flatten(1), columns[g].transpose(0, 1)) + .view_as(grad_weight[g]); + if (with_bias) { + grad_bias[g] = + grad_bias[g] + .view({-1, 1}) + .addmm_(grad_output[b][g].flatten(1), ones.view({-1, 1})) + .view(-1); + } + } + + columns = + columns.view({columns.size(0) * columns.size(1), columns.size(2)}); + grad_weight = grad_weight.view( + {grad_weight.size(0) * grad_weight.size(1), + grad_weight.size(2), + grad_weight.size(3), + grad_weight.size(4)}); + if (with_bias) + grad_bias = grad_bias.view({grad_bias.size(0) * grad_bias.size(1)}); + } + grad_output = grad_output.view( + {grad_output.size(0) * grad_output.size(1), + grad_output.size(2), + grad_output.size(3), + grad_output.size(4)}); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu b/custom_detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..f299c7add116685e9c87a187a85ea63f9f808867 --- /dev/null +++ b/custom_detectron2/layers/csrc/deformable/deform_conv_cuda_kernel.cu @@ -0,0 +1,1288 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +// modified from +// https://github.com/open-mmlab/mmdetection/blob/master/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu +// Original license: Apache 2.0 +// clang-format off + +// modify from +// https://github.com/chengdazhi/Deformable-Convolution-V2-PyTorch/blob/mmdetection/mmdet/ops/dcn/src/deform_conv_cuda_kernel.cu + +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer ***************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE + *FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + *DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + *SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + *CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + *OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + *OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ********************* + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.cuh + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1703.06211 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu, Dazhi Cheng + */ + +#include +#include +#include +#include +#include +#include + +using namespace at; + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ + i += blockDim.x * gridDim.x) + + +namespace { + +const int CUDA_NUM_THREADS = 1024; +const int kMaxGridNum = 65535; + +inline int GET_BLOCKS(const int N) { + return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS); +} + +} + +template +__device__ scalar_t deformable_im2col_bilinear( + const scalar_t* bottom_data, + const int data_width, + const int height, + const int width, + scalar_t h, + scalar_t w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t get_gradient_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int h, + const int w, + const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t get_coordinate_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int height, + const int width, + const scalar_t* im_data, + const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void deformable_im2col_gpu_kernel( + const int n, + const scalar_t* data_im, + const scalar_t* data_offset, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* data_col) { + CUDA_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + scalar_t* data_col_ptr = data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const scalar_t* data_im_ptr = data_im + ((b_col * num_channels + c_im) * + // height + h_in) * width + w_in; + const scalar_t* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + // const scalar_t map_h = i * dilation_h + offset_h; + // const scalar_t map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = deformable_im2col_bilinear(data_im_ptr, width, cur_height, + // cur_width, map_h, map_w); + val = deformable_im2col_bilinear( + data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val; + data_col_ptr += batch_size * height_col * width_col; + } + } + } +} + + +template +__global__ void deformable_col2im_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_offset, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_im) { + CUDA_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index]; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = get_gradient_weight( + cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + + +template +__global__ void deformable_col2im_coord_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_im, + const scalar_t* data_offset, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_offset) { + CUDA_KERNEL_LOOP(index, n) { + scalar_t val = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t* data_col_ptr = data_col + + deformable_group_index * channel_per_deformable_group * batch_size * + width_col * height_col; + const scalar_t* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } + const scalar_t weight = get_coordinate_weight( + inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + val += weight * data_col_ptr[col_pos]; + cnt += 1; + } + + grad_offset[index] = val; + } +} + + +namespace detectron2 { + +void deformable_im2col( + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor data_col) { + // num_axes should be smaller than block size + // todo: check parallel_imgs is correctly passed in + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = channels * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + at::cuda::CUDAGuard device_guard(data_im.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "deformable_im2col_gpu", ([&] { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_im_, + data_offset_, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + channels, + deformable_group, + height_col, + width_col, + data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in deformable_im2col: %s\n", cudaGetErrorString(err)); + } +} + + +void deformable_col2im( + const at::Tensor data_col, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_im) { + // todo: make sure parallel_imgs is passed in correctly + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = + channels * ksize_h * ksize_w * height_col * width_col * parallel_imgs; + int channel_per_deformable_group = channels / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* grad_im_ = grad_im.data_ptr(); + + deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_offset_, + channels, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + deformable_group, + height_col, + width_col, + grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf("error in deformable_col2im: %s\n", cudaGetErrorString(err)); + } +} + + +void deformable_col2im_coord( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const int channels, + const int height, + const int width, + const int ksize_h, + const int ksize_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int parallel_imgs, + const int deformable_group, + at::Tensor grad_offset) { + int height_col = + (height + 2 * pad_h - (dilation_h * (ksize_h - 1) + 1)) / stride_h + 1; + int width_col = + (width + 2 * pad_w - (dilation_w * (ksize_w - 1) + 1)) / stride_w + 1; + int num_kernels = height_col * width_col * 2 * ksize_h * ksize_w * + deformable_group * parallel_imgs; + int channel_per_deformable_group = + channels * ksize_h * ksize_w / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "deformable_col2im_coord_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + scalar_t* grad_offset_ = grad_offset.data_ptr(); + + deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_im_, + data_offset_, + channels, + height, + width, + ksize_h, + ksize_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + parallel_imgs, + 2 * ksize_h * ksize_w * deformable_group, + deformable_group, + height_col, + width_col, + grad_offset_); + })); +} + +} // namespace detectron2 + + +template +__device__ scalar_t dmcn_im2col_bilinear( + const scalar_t* bottom_data, + const int data_width, + const int height, + const int width, + scalar_t h, + scalar_t w) { + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + scalar_t lh = h - h_low; + scalar_t lw = w - w_low; + scalar_t hh = 1 - lh, hw = 1 - lw; + + scalar_t v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + scalar_t v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + scalar_t v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + scalar_t v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + scalar_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + scalar_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +template +__device__ scalar_t dmcn_get_gradient_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int h, + const int w, + const int height, + const int width) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +template +__device__ scalar_t dmcn_get_coordinate_weight( + scalar_t argmax_h, + scalar_t argmax_w, + const int height, + const int width, + const scalar_t* im_data, + const int data_width, + const int bp_dir) { + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || + argmax_w >= width) { + // empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + scalar_t weight = 0; + + if (bp_dir == 0) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } else if (bp_dir == 1) { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * + im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * + im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +template +__global__ void modulated_deformable_im2col_gpu_kernel( + const int n, + const scalar_t* data_im, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int num_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* data_col) { + CUDA_KERNEL_LOOP(index, n) { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + scalar_t* data_col_ptr = data_col + + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + // const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * + // height + h_in) * width + w_in; + const scalar_t* data_im_ptr = + data_im + (b_col * num_channels + c_im) * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + + const scalar_t* data_mask_ptr = data_mask + + (b_col * deformable_group + deformable_group_index) * kernel_h * + kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + + w_col; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t val = static_cast(0); + const scalar_t h_im = h_in + i * dilation_h + offset_h; + const scalar_t w_im = w_in + j * dilation_w + offset_w; + // if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) { + // const float map_h = i * dilation_h + offset_h; + // const float map_w = j * dilation_w + offset_w; + // const int cur_height = height - h_in; + // const int cur_width = width - w_in; + // val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, + // cur_width, map_h, map_w); + val = dmcn_im2col_bilinear( + data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + // data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void modulated_deformable_col2im_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_im) { + CUDA_KERNEL_LOOP(index, n) { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = + (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = + index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const scalar_t* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * + height_col * width_col; + const int data_offset_h_ptr = + ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = + ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = + ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + const scalar_t cur_inv_h_data = h_in + i * dilation_h + offset_h; + const scalar_t cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const scalar_t cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) { + for (int dx = -2; dx <= 2; dx++) { + if (cur_h + dy >= 0 && cur_h + dy < height && cur_w + dx >= 0 && + cur_w + dx < width && abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) { + int cur_bottom_grad_pos = + ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + scalar_t weight = dmcn_get_gradient_weight( + cur_inv_h_data, + cur_inv_w_data, + cur_h + dy, + cur_w + dx, + height, + width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +template +__global__ void modulated_deformable_col2im_coord_gpu_kernel( + const int n, + const scalar_t* data_col, + const scalar_t* data_im, + const scalar_t* data_offset, + const scalar_t* data_mask, + const int channels, + const int height, + const int width, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, + const int offset_channels, + const int deformable_group, + const int height_col, + const int width_col, + scalar_t* grad_offset, + scalar_t* grad_mask) { + CUDA_KERNEL_LOOP(index, n) { + scalar_t val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const scalar_t* data_col_ptr = data_col + + deformable_group_index * channel_per_deformable_group * batch_size * + width_col * height_col; + const scalar_t* data_im_ptr = data_im + + (b * deformable_group + deformable_group_index) * + channel_per_deformable_group / kernel_h / kernel_w * height * width; + const scalar_t* data_offset_ptr = data_offset + + (b * deformable_group + deformable_group_index) * 2 * kernel_h * + kernel_w * height_col * width_col; + const scalar_t* data_mask_ptr = data_mask + + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * + height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; + col_c += col_step) { + const int col_pos = + (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = + (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = + (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = + (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + + w_out); + const int data_mask_hw_ptr = + (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const scalar_t offset_h = data_offset_ptr[data_offset_h_ptr]; + const scalar_t offset_w = data_offset_ptr[data_offset_w_ptr]; + const scalar_t mask = data_mask_ptr[data_mask_hw_ptr]; + scalar_t inv_h = h_in + i * dilation_h + offset_h; + scalar_t inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) { + inv_h = inv_w = -2; + } else { + mval += data_col_ptr[col_pos] * + dmcn_im2col_bilinear( + data_im_ptr + cnt * height * width, + width, + height, + width, + inv_h, + inv_w); + } + const scalar_t weight = dmcn_get_coordinate_weight( + inv_h, + inv_w, + height, + width, + data_im_ptr + cnt * height * width, + width, + bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + + // deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * + // height_col + h) * width_col + w], mask_req, mval); + grad_mask + [(((b * deformable_group + deformable_group_index) * kernel_h * + kernel_w + + offset_c / 2) * + height_col + + h) * + width_col + + w] = mval; + } +} + + +namespace detectron2 { + +void modulated_deformable_im2col_cuda( + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kenerl_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + + at::cuda::CUDAGuard device_guard(data_im.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_im.scalar_type(), "modulated_deformable_im2col_gpu", ([&] { + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* data_col_ = data_col.data_ptr(); + + modulated_deformable_im2col_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_im_, + data_offset_, + data_mask_, + height_im, + width_im, + kernel_h, + kenerl_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + channels, + deformable_group, + height_col, + width_col, + data_col_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_im2col_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda( + const at::Tensor data_col, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_im) { + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = + channels * kernel_h * kernel_w * batch_size * height_col * width_col; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* grad_im_ = grad_im.data_ptr(); + + modulated_deformable_col2im_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_offset_, + data_mask_, + channels, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + deformable_group, + height_col, + width_col, + grad_im_); + })); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_col2im_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda( + const at::Tensor data_col, + const at::Tensor data_im, + const at::Tensor data_offset, + const at::Tensor data_mask, + const int batch_size, + const int channels, + const int height_im, + const int width_im, + const int height_col, + const int width_col, + const int kernel_h, + const int kernel_w, + const int pad_h, + const int pad_w, + const int stride_h, + const int stride_w, + const int dilation_h, + const int dilation_w, + const int deformable_group, + at::Tensor grad_offset, + at::Tensor grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * + kernel_w * deformable_group; + const int channel_per_deformable_group = + channels * kernel_h * kernel_w / deformable_group; + + at::cuda::CUDAGuard device_guard(data_col.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + data_col.scalar_type(), "modulated_deformable_col2im_coord_gpu", ([&] { + const scalar_t* data_col_ = data_col.data_ptr(); + const scalar_t* data_im_ = data_im.data_ptr(); + const scalar_t* data_offset_ = data_offset.data_ptr(); + const scalar_t* data_mask_ = data_mask.data_ptr(); + scalar_t* grad_offset_ = grad_offset.data_ptr(); + scalar_t* grad_mask_ = grad_mask.data_ptr(); + + modulated_deformable_col2im_coord_gpu_kernel<<< + GET_BLOCKS(num_kernels), + CUDA_NUM_THREADS, + 0, + stream>>>( + num_kernels, + data_col_, + data_im_, + data_offset_, + data_mask_, + channels, + height_im, + width_im, + kernel_h, + kernel_w, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + channel_per_deformable_group, + batch_size, + 2 * kernel_h * kernel_w * deformable_group, + deformable_group, + height_col, + width_col, + grad_offset_, + grad_mask_); + })); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) { + printf( + "error in modulated_deformable_col2im_coord_cuda: %s\n", + cudaGetErrorString(err)); + } +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/nms_rotated/nms_rotated.h b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated.h new file mode 100644 index 0000000000000000000000000000000000000000..12aca388e47b12dafd20999f2991a9d42f4b904b --- /dev/null +++ b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated.h @@ -0,0 +1,39 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#pragma once +#include + +namespace detectron2 { + +at::Tensor nms_rotated_cpu( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold); + +#if defined(WITH_CUDA) || defined(WITH_HIP) +at::Tensor nms_rotated_cuda( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold); +#endif + +// Interface for Python +// inline is needed to prevent multiple function definitions when this header is +// included by different cpps +inline at::Tensor nms_rotated( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + assert(dets.device().is_cuda() == scores.device().is_cuda()); + if (dets.device().is_cuda()) { +#if defined(WITH_CUDA) || defined(WITH_HIP) + return nms_rotated_cuda( + dets.contiguous(), scores.contiguous(), iou_threshold); +#else + AT_ERROR("Detectron2 is not compiled with GPU support!"); +#endif + } + + return nms_rotated_cpu(dets.contiguous(), scores.contiguous(), iou_threshold); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d7556e645b604aa83d86cc702b783fd8ecedffcc --- /dev/null +++ b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cpu.cpp @@ -0,0 +1,75 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include "../box_iou_rotated/box_iou_rotated_utils.h" +#include "nms_rotated.h" + +namespace detectron2 { + +template +at::Tensor nms_rotated_cpu_kernel( + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + // nms_rotated_cpu_kernel is modified from torchvision's nms_cpu_kernel, + // however, the code in this function is much shorter because + // we delegate the IoU computation for rotated boxes to + // the single_box_iou_rotated function in box_iou_rotated_utils.h + AT_ASSERTM(dets.device().is_cpu(), "dets must be a CPU tensor"); + AT_ASSERTM(scores.device().is_cpu(), "scores must be a CPU tensor"); + AT_ASSERTM( + dets.scalar_type() == scores.scalar_type(), + "dets should have the same type as scores"); + + if (dets.numel() == 0) { + return at::empty({0}, dets.options().dtype(at::kLong)); + } + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + + auto ndets = dets.size(0); + at::Tensor suppressed_t = at::zeros({ndets}, dets.options().dtype(at::kByte)); + at::Tensor keep_t = at::zeros({ndets}, dets.options().dtype(at::kLong)); + + auto suppressed = suppressed_t.data_ptr(); + auto keep = keep_t.data_ptr(); + auto order = order_t.data_ptr(); + + int64_t num_to_keep = 0; + + for (int64_t _i = 0; _i < ndets; _i++) { + auto i = order[_i]; + if (suppressed[i] == 1) { + continue; + } + + keep[num_to_keep++] = i; + + for (int64_t _j = _i + 1; _j < ndets; _j++) { + auto j = order[_j]; + if (suppressed[j] == 1) { + continue; + } + + auto ovr = single_box_iou_rotated( + dets[i].data_ptr(), dets[j].data_ptr()); + if (ovr >= iou_threshold) { + suppressed[j] = 1; + } + } + } + return keep_t.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep); +} + +at::Tensor nms_rotated_cpu( + // input must be contiguous + const at::Tensor& dets, + const at::Tensor& scores, + const double iou_threshold) { + auto result = at::empty({0}, dets.options()); + + AT_DISPATCH_FLOATING_TYPES(dets.scalar_type(), "nms_rotated", [&] { + result = nms_rotated_cpu_kernel(dets, scores, iou_threshold); + }); + return result; +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..2a3db5c62e7a2da52ccf5bac980653c943d630fd --- /dev/null +++ b/custom_detectron2/layers/csrc/nms_rotated/nms_rotated_cuda.cu @@ -0,0 +1,145 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +#include +#include +#include +#include +#ifdef WITH_CUDA +#include "../box_iou_rotated/box_iou_rotated_utils.h" +#endif +// TODO avoid this when pytorch supports "same directory" hipification +#ifdef WITH_HIP +#include "box_iou_rotated/box_iou_rotated_utils.h" +#endif + +using namespace detectron2; + +namespace { +int const threadsPerBlock = sizeof(unsigned long long) * 8; +} + +template +__global__ void nms_rotated_cuda_kernel( + const int n_boxes, + const double iou_threshold, + const T* dev_boxes, + unsigned long long* dev_mask) { + // nms_rotated_cuda_kernel is modified from torchvision's nms_cuda_kernel + + const int row_start = blockIdx.y; + const int col_start = blockIdx.x; + + // if (row_start > col_start) return; + + const int row_size = + min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); + const int col_size = + min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); + + // Compared to nms_cuda_kernel, where each box is represented with 4 values + // (x1, y1, x2, y2), each rotated box is represented with 5 values + // (x_center, y_center, width, height, angle_degrees) here. + __shared__ T block_boxes[threadsPerBlock * 5]; + if (threadIdx.x < col_size) { + block_boxes[threadIdx.x * 5 + 0] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; + block_boxes[threadIdx.x * 5 + 1] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; + block_boxes[threadIdx.x * 5 + 2] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; + block_boxes[threadIdx.x * 5 + 3] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; + block_boxes[threadIdx.x * 5 + 4] = + dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; + } + __syncthreads(); + + if (threadIdx.x < row_size) { + const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; + const T* cur_box = dev_boxes + cur_box_idx * 5; + int i = 0; + unsigned long long t = 0; + int start = 0; + if (row_start == col_start) { + start = threadIdx.x + 1; + } + for (i = start; i < col_size; i++) { + // Instead of devIoU used by original horizontal nms, here + // we use the single_box_iou_rotated function from box_iou_rotated_utils.h + if (single_box_iou_rotated(cur_box, block_boxes + i * 5) > + iou_threshold) { + t |= 1ULL << i; + } + } + const int col_blocks = at::cuda::ATenCeilDiv(n_boxes, threadsPerBlock); + dev_mask[cur_box_idx * col_blocks + col_start] = t; + } +} + +namespace detectron2 { + +at::Tensor nms_rotated_cuda( + // input must be contiguous + const at::Tensor& dets, + const at::Tensor& scores, + double iou_threshold) { + // using scalar_t = float; + AT_ASSERTM(dets.is_cuda(), "dets must be a CUDA tensor"); + AT_ASSERTM(scores.is_cuda(), "scores must be a CUDA tensor"); + at::cuda::CUDAGuard device_guard(dets.device()); + + auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); + auto dets_sorted = dets.index_select(0, order_t); + + auto dets_num = dets.size(0); + + const int col_blocks = + at::cuda::ATenCeilDiv(static_cast(dets_num), threadsPerBlock); + + at::Tensor mask = + at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); + + dim3 blocks(col_blocks, col_blocks); + dim3 threads(threadsPerBlock); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES( + dets_sorted.scalar_type(), "nms_rotated_kernel_cuda", [&] { + nms_rotated_cuda_kernel<<>>( + dets_num, + iou_threshold, + dets_sorted.data_ptr(), + (unsigned long long*)mask.data_ptr()); + }); + + at::Tensor mask_cpu = mask.to(at::kCPU); + unsigned long long* mask_host = + (unsigned long long*)mask_cpu.data_ptr(); + + std::vector remv(col_blocks); + memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); + + at::Tensor keep = + at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); + int64_t* keep_out = keep.data_ptr(); + + int num_to_keep = 0; + for (int i = 0; i < dets_num; i++) { + int nblock = i / threadsPerBlock; + int inblock = i % threadsPerBlock; + + if (!(remv[nblock] & (1ULL << inblock))) { + keep_out[num_to_keep++] = i; + unsigned long long* p = mask_host + i * col_blocks; + for (int j = nblock; j < col_blocks; j++) { + remv[j] |= p[j]; + } + } + } + + AT_CUDA_CHECK(cudaGetLastError()); + return order_t.index( + {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) + .to(order_t.device(), keep.scalar_type())}); +} + +} // namespace detectron2 diff --git a/custom_detectron2/layers/csrc/vision.cpp b/custom_detectron2/layers/csrc/vision.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9a2cd4f20e6f58be1c5783d67c64232dd59b560 --- /dev/null +++ b/custom_detectron2/layers/csrc/vision.cpp @@ -0,0 +1,117 @@ +// Copyright (c) Facebook, Inc. and its affiliates. + +#include +#include "ROIAlignRotated/ROIAlignRotated.h" +#include "box_iou_rotated/box_iou_rotated.h" +#include "cocoeval/cocoeval.h" +#include "deformable/deform_conv.h" +#include "nms_rotated/nms_rotated.h" + +namespace detectron2 { + +#if defined(WITH_CUDA) || defined(WITH_HIP) +extern int get_cudart_version(); +#endif + +std::string get_cuda_version() { +#if defined(WITH_CUDA) || defined(WITH_HIP) + std::ostringstream oss; + +#if defined(WITH_CUDA) + oss << "CUDA "; +#else + oss << "HIP "; +#endif + + // copied from + // https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/cuda/detail/CUDAHooks.cpp#L231 + auto printCudaStyleVersion = [&](int v) { + oss << (v / 1000) << "." << (v / 10 % 100); + if (v % 10 != 0) { + oss << "." << (v % 10); + } + }; + printCudaStyleVersion(get_cudart_version()); + return oss.str(); +#else // neither CUDA nor HIP + return std::string("not available"); +#endif +} + +bool has_cuda() { +#if defined(WITH_CUDA) + return true; +#else + return false; +#endif +} + +// similar to +// https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Version.cpp +std::string get_compiler_version() { + std::ostringstream ss; +#if defined(__GNUC__) +#ifndef __clang__ + +#if ((__GNUC__ <= 4) && (__GNUC_MINOR__ <= 8)) +#error "GCC >= 4.9 is required!" +#endif + + { ss << "GCC " << __GNUC__ << "." << __GNUC_MINOR__; } +#endif +#endif + +#if defined(__clang_major__) + { + ss << "clang " << __clang_major__ << "." << __clang_minor__ << "." + << __clang_patchlevel__; + } +#endif + +#if defined(_MSC_VER) + { ss << "MSVC " << _MSC_FULL_VER; } +#endif + return ss.str(); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("get_compiler_version", &get_compiler_version, "get_compiler_version"); + m.def("get_cuda_version", &get_cuda_version, "get_cuda_version"); + m.def("has_cuda", &has_cuda, "has_cuda"); + + m.def("deform_conv_forward", &deform_conv_forward, "deform_conv_forward"); + m.def( + "deform_conv_backward_input", + &deform_conv_backward_input, + "deform_conv_backward_input"); + m.def( + "deform_conv_backward_filter", + &deform_conv_backward_filter, + "deform_conv_backward_filter"); + m.def( + "modulated_deform_conv_forward", + &modulated_deform_conv_forward, + "modulated_deform_conv_forward"); + m.def( + "modulated_deform_conv_backward", + &modulated_deform_conv_backward, + "modulated_deform_conv_backward"); + + m.def("COCOevalAccumulate", &COCOeval::Accumulate, "COCOeval::Accumulate"); + m.def( + "COCOevalEvaluateImages", + &COCOeval::EvaluateImages, + "COCOeval::EvaluateImages"); + pybind11::class_(m, "InstanceAnnotation") + .def(pybind11::init()); + pybind11::class_(m, "ImageEvaluation") + .def(pybind11::init<>()); +} + +TORCH_LIBRARY(detectron2, m) { + m.def("nms_rotated", &nms_rotated); + m.def("box_iou_rotated", &box_iou_rotated); + m.def("roi_align_rotated_forward", &ROIAlignRotated_forward); + m.def("roi_align_rotated_backward", &ROIAlignRotated_backward); +} +} // namespace detectron2 diff --git a/custom_detectron2/layers/deform_conv.py b/custom_detectron2/layers/deform_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..9a8a039649bb71634c8da8fbcd7f295356ebadd5 --- /dev/null +++ b/custom_detectron2/layers/deform_conv.py @@ -0,0 +1,514 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from functools import lru_cache +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair +from torchvision.ops import deform_conv2d + +from custom_detectron2.utils.develop import create_dummy_class, create_dummy_func + +from .wrappers import _NewEmptyTensorOp + + +class _DeformConv(Function): + @staticmethod + def forward( + ctx, + input, + offset, + weight, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + im2col_step=64, + ): + if input is not None and input.dim() != 4: + raise ValueError( + "Expected 4D tensor as input, got {}D tensor instead.".format(input.dim()) + ) + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.im2col_step = im2col_step + + ctx.save_for_backward(input, offset, weight) + + output = input.new_empty( + _DeformConv._output_size(input, weight, ctx.padding, ctx.dilation, ctx.stride) + ) + + ctx.bufs_ = [input.new_empty(0), input.new_empty(0)] # columns, ones + + if not input.is_cuda: + # TODO: let torchvision support full features of our deformconv. + if deformable_groups != 1: + raise NotImplementedError( + "Deformable Conv with deformable_groups != 1 is not supported on CPUs!" + ) + return deform_conv2d( + input, offset, weight, stride=stride, padding=padding, dilation=dilation + ) + else: + cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + _C.deform_conv_forward( + input, + weight, + offset, + output, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, weight = ctx.saved_tensors + + grad_input = grad_offset = grad_weight = None + + if not grad_output.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + else: + cur_im2col_step = _DeformConv._cal_im2col_step(input.shape[0], ctx.im2col_step) + assert (input.shape[0] % cur_im2col_step) == 0, "im2col step must divide batchsize" + + if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + _C.deform_conv_backward_input( + input, + offset, + grad_output, + grad_input, + grad_offset, + weight, + ctx.bufs_[0], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + cur_im2col_step, + ) + + if ctx.needs_input_grad[2]: + grad_weight = torch.zeros_like(weight) + _C.deform_conv_backward_filter( + input, + offset, + grad_output, + grad_weight, + ctx.bufs_[0], + ctx.bufs_[1], + weight.size(3), + weight.size(2), + ctx.stride[1], + ctx.stride[0], + ctx.padding[1], + ctx.padding[0], + ctx.dilation[1], + ctx.dilation[0], + ctx.groups, + ctx.deformable_groups, + 1, + cur_im2col_step, + ) + + return grad_input, grad_offset, grad_weight, None, None, None, None, None, None + + @staticmethod + def _output_size(input, weight, padding, dilation, stride): + channels = weight.size(0) + output_size = (input.size(0), channels) + for d in range(input.dim() - 2): + in_size = input.size(d + 2) + pad = padding[d] + kernel = dilation[d] * (weight.size(d + 2) - 1) + 1 + stride_ = stride[d] + output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,) + if not all(map(lambda s: s > 0, output_size)): + raise ValueError( + "convolution input is too small (output would be {})".format( + "x".join(map(str, output_size)) + ) + ) + return output_size + + @staticmethod + @lru_cache(maxsize=128) + def _cal_im2col_step(input_size, default_size): + """ + Calculate proper im2col step size, which should be divisible by input_size and not larger + than prefer_size. Meanwhile the step size should be as large as possible to be more + efficient. So we choose the largest one among all divisors of input_size which are smaller + than prefer_size. + :param input_size: input batch size . + :param default_size: default preferred im2col step size. + :return: the largest proper step size. + """ + if input_size <= default_size: + return input_size + best_step = 1 + for step in range(2, min(int(math.sqrt(input_size)) + 1, default_size)): + if input_size % step == 0: + if input_size // step <= default_size: + return input_size // step + best_step = step + + return best_step + + +class _ModulatedDeformConv(Function): + @staticmethod + def forward( + ctx, + input, + offset, + mask, + weight, + bias=None, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + ): + ctx.stride = stride + ctx.padding = padding + ctx.dilation = dilation + ctx.groups = groups + ctx.deformable_groups = deformable_groups + ctx.with_bias = bias is not None + if not ctx.with_bias: + bias = input.new_empty(1) # fake tensor + if not input.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + if ( + weight.requires_grad + or mask.requires_grad + or offset.requires_grad + or input.requires_grad + ): + ctx.save_for_backward(input, offset, mask, weight, bias) + output = input.new_empty(_ModulatedDeformConv._infer_shape(ctx, input, weight)) + ctx._bufs = [input.new_empty(0), input.new_empty(0)] + _C.modulated_deform_conv_forward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + output, + ctx._bufs[1], + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError("Deformable Conv is not supported on CPUs!") + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input = torch.zeros_like(input) + grad_offset = torch.zeros_like(offset) + grad_mask = torch.zeros_like(mask) + grad_weight = torch.zeros_like(weight) + grad_bias = torch.zeros_like(bias) + _C.modulated_deform_conv_backward( + input, + weight, + bias, + ctx._bufs[0], + offset, + mask, + ctx._bufs[1], + grad_input, + grad_weight, + grad_bias, + grad_offset, + grad_mask, + grad_output, + weight.shape[2], + weight.shape[3], + ctx.stride, + ctx.stride, + ctx.padding, + ctx.padding, + ctx.dilation, + ctx.dilation, + ctx.groups, + ctx.deformable_groups, + ctx.with_bias, + ) + if not ctx.with_bias: + grad_bias = None + + return ( + grad_input, + grad_offset, + grad_mask, + grad_weight, + grad_bias, + None, + None, + None, + None, + None, + ) + + @staticmethod + def _infer_shape(ctx, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = ( + height + 2 * ctx.padding - (ctx.dilation * (kernel_h - 1) + 1) + ) // ctx.stride + 1 + width_out = ( + width + 2 * ctx.padding - (ctx.dilation * (kernel_w - 1) + 1) + ) // ctx.stride + 1 + return n, channels_out, height_out, width_out + + +deform_conv = _DeformConv.apply +modulated_deform_conv = _ModulatedDeformConv.apply + + +class DeformConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=False, + norm=None, + activation=None, + ): + """ + Deformable convolution from :paper:`deformconv`. + + Arguments are similar to :class:`Conv2D`. Extra arguments: + + Args: + deformable_groups (int): number of groups used in deformable convolution. + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + """ + super(DeformConv, self).__init__() + + assert not bias + assert in_channels % groups == 0, "in_channels {} cannot be divisible by groups {}".format( + in_channels, groups + ) + assert ( + out_channels % groups == 0 + ), "out_channels {} cannot be divisible by groups {}".format(out_channels, groups) + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deformable_groups = deformable_groups + self.norm = norm + self.activation = activation + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // self.groups, *self.kernel_size) + ) + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + + def forward(self, x, offset): + if x.numel() == 0: + # When input is empty, we want to return a empty tensor with "correct" shape, + # So that the following operations will not panic + # if they check for the shape of the tensor. + # This computes the height and width of the output tensor + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + x = deform_conv( + x, + offset, + self.weight, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + def extra_repr(self): + tmpstr = "in_channels=" + str(self.in_channels) + tmpstr += ", out_channels=" + str(self.out_channels) + tmpstr += ", kernel_size=" + str(self.kernel_size) + tmpstr += ", stride=" + str(self.stride) + tmpstr += ", padding=" + str(self.padding) + tmpstr += ", dilation=" + str(self.dilation) + tmpstr += ", groups=" + str(self.groups) + tmpstr += ", deformable_groups=" + str(self.deformable_groups) + tmpstr += ", bias=False" + return tmpstr + + +class ModulatedDeformConv(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deformable_groups=1, + bias=True, + norm=None, + activation=None, + ): + """ + Modulated deformable convolution from :paper:`deformconv2`. + + Arguments are similar to :class:`Conv2D`. Extra arguments: + + Args: + deformable_groups (int): number of groups used in deformable convolution. + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + """ + super(ModulatedDeformConv, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deformable_groups = deformable_groups + self.with_bias = bias + self.norm = norm + self.activation = activation + + self.weight = nn.Parameter( + torch.Tensor(out_channels, in_channels // groups, *self.kernel_size) + ) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.bias = None + + nn.init.kaiming_uniform_(self.weight, nonlinearity="relu") + if self.bias is not None: + nn.init.constant_(self.bias, 0) + + def forward(self, x, offset, mask): + if x.numel() == 0: + output_shape = [ + (i + 2 * p - (di * (k - 1) + 1)) // s + 1 + for i, p, di, k, s in zip( + x.shape[-2:], self.padding, self.dilation, self.kernel_size, self.stride + ) + ] + output_shape = [x.shape[0], self.weight.shape[0]] + output_shape + return _NewEmptyTensorOp.apply(x, output_shape) + + x = modulated_deform_conv( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deformable_groups, + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + def extra_repr(self): + tmpstr = "in_channels=" + str(self.in_channels) + tmpstr += ", out_channels=" + str(self.out_channels) + tmpstr += ", kernel_size=" + str(self.kernel_size) + tmpstr += ", stride=" + str(self.stride) + tmpstr += ", padding=" + str(self.padding) + tmpstr += ", dilation=" + str(self.dilation) + tmpstr += ", groups=" + str(self.groups) + tmpstr += ", deformable_groups=" + str(self.deformable_groups) + tmpstr += ", bias=" + str(self.with_bias) + return tmpstr + + +try: + from custom_detectron2 import _C +except ImportError: + # TODO: register ops natively so there is no need to import _C. + _msg = "detectron2 is not compiled successfully, please build following the instructions!" + _args = ("detectron2._C", _msg) + DeformConv = create_dummy_class("DeformConv", *_args) + ModulatedDeformConv = create_dummy_class("ModulatedDeformConv", *_args) + deform_conv = create_dummy_func("deform_conv", *_args) + modulated_deform_conv = create_dummy_func("modulated_deform_conv", *_args) diff --git a/custom_detectron2/layers/losses.py b/custom_detectron2/layers/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..850a852a2f0986d4d1ce89a526d96db42c76e44f --- /dev/null +++ b/custom_detectron2/layers/losses.py @@ -0,0 +1,133 @@ +import math +import torch + + +def diou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Distance Intersection over Union Loss (Zhaohui Zheng et. al) + https://arxiv.org/abs/1911.08287 + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # TODO: use torch._assert_async() when pytorch 1.8 support is dropped + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # Eqn. (7) + loss = 1 - iou + (distance / diag_len) + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss + + +def ciou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + reduction: str = "none", + eps: float = 1e-7, +) -> torch.Tensor: + """ + Complete Intersection over Union Loss (Zhaohui Zheng et. al) + https://arxiv.org/abs/1911.08287 + Args: + boxes1, boxes2 (Tensor): box locations in XYXY format, shape (N, 4) or (4,). + reduction: 'none' | 'mean' | 'sum' + 'none': No reduction will be applied to the output. + 'mean': The output will be averaged. + 'sum': The output will be summed. + eps (float): small number to prevent division by zero + """ + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # TODO: use torch._assert_async() when pytorch 1.8 support is dropped + assert (x2 >= x1).all(), "bad box: x1 larger than x2" + assert (y2 >= y1).all(), "bad box: y1 larger than y2" + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsct = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps + iou = intsct / union + + # smallest enclosing box + xc1 = torch.min(x1, x1g) + yc1 = torch.min(y1, y1g) + xc2 = torch.max(x2, x2g) + yc2 = torch.max(y2, y2g) + diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps + + # centers of boxes + x_p = (x2 + x1) / 2 + y_p = (y2 + y1) / 2 + x_g = (x1g + x2g) / 2 + y_g = (y1g + y2g) / 2 + distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) + + # width and height of boxes + w_pred = x2 - x1 + h_pred = y2 - y1 + w_gt = x2g - x1g + h_gt = y2g - y1g + v = (4 / (math.pi**2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) + with torch.no_grad(): + alpha = v / (1 - iou + v + eps) + + # Eqn. (10) + loss = 1 - iou + (distance / diag_len) + alpha * v + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + + return loss diff --git a/custom_detectron2/layers/mask_ops.py b/custom_detectron2/layers/mask_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..990d04abbb120e40fe07a21d024dfead471bc998 --- /dev/null +++ b/custom_detectron2/layers/mask_ops.py @@ -0,0 +1,275 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Tuple +import torch +from PIL import Image +from torch.nn import functional as F + +__all__ = ["paste_masks_in_image"] + + +BYTES_PER_FLOAT = 4 +# TODO: This memory limit may be too much or too little. It would be better to +# determine it based on available resources. +GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit + + +def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): + """ + Args: + masks: N, 1, H, W + boxes: N, 4 + img_h, img_w (int): + skip_empty (bool): only paste masks within the region that + tightly bound all boxes, and returns the results this region only. + An important optimization for CPU. + + Returns: + if skip_empty == False, a mask of shape (N, img_h, img_w) + if skip_empty == True, a mask of shape (N, h', w'), and the slice + object for the corresponding region. + """ + # On GPU, paste all masks together (up to chunk size) + # by using the entire image to sample the masks + # Compared to pasting them one by one, + # this has more operations but is faster on COCO-scale dataset. + device = masks.device + + if skip_empty and not torch.jit.is_scripting(): + x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( + dtype=torch.int32 + ) + x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) + y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) + else: + x0_int, y0_int = 0, 0 + x1_int, y1_int = img_w, img_h + x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1 + + N = masks.shape[0] + + img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 + img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 + img_y = (img_y - y0) / (y1 - y0) * 2 - 1 + img_x = (img_x - x0) / (x1 - x0) * 2 - 1 + # img_x, img_y have shapes (N, w), (N, h) + + gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) + gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) + grid = torch.stack([gx, gy], dim=3) + + if not torch.jit.is_scripting(): + if not masks.dtype.is_floating_point: + masks = masks.float() + img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) + + if skip_empty and not torch.jit.is_scripting(): + return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) + else: + return img_masks[:, 0], () + + +# Annotate boxes as Tensor (but not Boxes) in order to use scripting +@torch.jit.script_if_tracing +def paste_masks_in_image( + masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5 +): + """ + Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. + The location, height, and width for pasting each mask is determined by their + corresponding bounding boxes in boxes. + + Note: + This is a complicated but more accurate implementation. In actual deployment, it is + often enough to use a faster but less accurate implementation. + See :func:`paste_mask_in_image_old` in this file for an alternative implementation. + + Args: + masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of + detected object instances in the image and Hmask, Wmask are the mask width and mask + height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1]. + boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4). + boxes[i] and masks[i] correspond to the same object instance. + image_shape (tuple): height, width + threshold (float): A threshold in [0, 1] for converting the (soft) masks to + binary masks. + + Returns: + img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the + number of detected object instances and Himage, Wimage are the image width + and height. img_masks[i] is a binary mask for object instance i. + """ + + assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported" + N = len(masks) + if N == 0: + return masks.new_empty((0,) + image_shape, dtype=torch.uint8) + if not isinstance(boxes, torch.Tensor): + boxes = boxes.tensor + device = boxes.device + assert len(boxes) == N, boxes.shape + + img_h, img_w = image_shape + + # The actual implementation split the input into chunks, + # and paste them chunk by chunk. + if device.type == "cpu" or torch.jit.is_scripting(): + # CPU is most efficient when they are pasted one by one with skip_empty=True + # so that it performs minimal number of operations. + num_chunks = N + else: + # GPU benefits from parallelism for larger chunks, but may have memory issue + # int(img_h) because shape may be tensors in tracing + num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) + assert ( + num_chunks <= N + ), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it" + chunks = torch.chunk(torch.arange(N, device=device), num_chunks) + + img_masks = torch.zeros( + N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8 + ) + for inds in chunks: + masks_chunk, spatial_inds = _do_paste_mask( + masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu" + ) + + if threshold >= 0: + masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) + else: + # for visualization and debugging + masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) + + if torch.jit.is_scripting(): # Scripting does not use the optimized codepath + img_masks[inds] = masks_chunk + else: + img_masks[(inds,) + spatial_inds] = masks_chunk + return img_masks + + +# The below are the original paste function (from Detectron1) which has +# larger quantization error. +# It is faster on CPU, while the aligned one is faster on GPU thanks to grid_sample. + + +def paste_mask_in_image_old(mask, box, img_h, img_w, threshold): + """ + Paste a single mask in an image. + This is a per-box implementation of :func:`paste_masks_in_image`. + This function has larger quantization error due to incorrect pixel + modeling and is not used any more. + + Args: + mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single + object instance. Values are in [0, 1]. + box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners + of the object instance. + img_h, img_w (int): Image height and width. + threshold (float): Mask binarization threshold in [0, 1]. + + Returns: + im_mask (Tensor): + The resized and binarized object mask pasted into the original + image plane (a tensor of shape (img_h, img_w)). + """ + # Conversion from continuous box coordinates to discrete pixel coordinates + # via truncation (cast to int32). This determines which pixels to paste the + # mask onto. + box = box.to(dtype=torch.int32) # Continuous to discrete coordinate conversion + # An example (1D) box with continuous coordinates (x0=0.7, x1=4.3) will map to + # a discrete coordinates (x0=0, x1=4). Note that box is mapped to 5 = x1 - x0 + 1 + # pixels (not x1 - x0 pixels). + samples_w = box[2] - box[0] + 1 # Number of pixel samples, *not* geometric width + samples_h = box[3] - box[1] + 1 # Number of pixel samples, *not* geometric height + + # Resample the mask from it's original grid to the new samples_w x samples_h grid + mask = Image.fromarray(mask.cpu().numpy()) + mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR) + mask = np.array(mask, copy=False) + + if threshold >= 0: + mask = np.array(mask > threshold, dtype=np.uint8) + mask = torch.from_numpy(mask) + else: + # for visualization and debugging, we also + # allow it to return an unmodified mask + mask = torch.from_numpy(mask * 255).to(torch.uint8) + + im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8) + x_0 = max(box[0], 0) + x_1 = min(box[2] + 1, img_w) + y_0 = max(box[1], 0) + y_1 = min(box[3] + 1, img_h) + + im_mask[y_0:y_1, x_0:x_1] = mask[ + (y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) + ] + return im_mask + + +# Our pixel modeling requires extrapolation for any continuous +# coordinate < 0.5 or > length - 0.5. When sampling pixels on the masks, +# we would like this extrapolation to be an interpolation between boundary values and zero, +# instead of using absolute zero or boundary values. +# Therefore `paste_mask_in_image_old` is often used with zero padding around the masks like this: +# masks, scale = pad_masks(masks[:, 0, :, :], 1) +# boxes = scale_boxes(boxes.tensor, scale) + + +def pad_masks(masks, padding): + """ + Args: + masks (tensor): A tensor of shape (B, M, M) representing B masks. + padding (int): Number of cells to pad on all sides. + + Returns: + The padded masks and the scale factor of the padding size / original size. + """ + B = masks.shape[0] + M = masks.shape[-1] + pad2 = 2 * padding + scale = float(M + pad2) / M + padded_masks = masks.new_zeros((B, M + pad2, M + pad2)) + padded_masks[:, padding:-padding, padding:-padding] = masks + return padded_masks, scale + + +def scale_boxes(boxes, scale): + """ + Args: + boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4 + coords representing the corners x0, y0, x1, y1, + scale (float): The box scaling factor. + + Returns: + Scaled boxes. + """ + w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 + h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 + x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 + y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 + + w_half *= scale + h_half *= scale + + scaled_boxes = torch.zeros_like(boxes) + scaled_boxes[:, 0] = x_c - w_half + scaled_boxes[:, 2] = x_c + w_half + scaled_boxes[:, 1] = y_c - h_half + scaled_boxes[:, 3] = y_c + h_half + return scaled_boxes + + +@torch.jit.script_if_tracing +def _paste_masks_tensor_shape( + masks: torch.Tensor, + boxes: torch.Tensor, + image_shape: Tuple[torch.Tensor, torch.Tensor], + threshold: float = 0.5, +): + """ + A wrapper of paste_masks_in_image where image_shape is Tensor. + During tracing, shapes might be tensors instead of ints. The Tensor->int + conversion should be scripted rather than traced. + """ + return paste_masks_in_image(masks, boxes, (int(image_shape[0]), int(image_shape[1])), threshold) diff --git a/custom_detectron2/layers/nms.py b/custom_detectron2/layers/nms.py new file mode 100644 index 0000000000000000000000000000000000000000..1019e7f4c8c58f2def34a019e4c3a0573c5f69bb --- /dev/null +++ b/custom_detectron2/layers/nms.py @@ -0,0 +1,144 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import torch +from torchvision.ops import boxes as box_ops +from torchvision.ops import nms # noqa . for compatibility + + +def batched_nms( + boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float +): + """ + Same as torchvision.ops.boxes.batched_nms, but with float(). + """ + assert boxes.shape[-1] == 4 + # Note: Torchvision already has a strategy (https://github.com/pytorch/vision/issues/1311) + # to decide whether to use coordinate trick or for loop to implement batched_nms. So we + # just call it directly. + # Fp16 does not have enough range for batched NMS, so adding float(). + return box_ops.batched_nms(boxes.float(), scores, idxs, iou_threshold) + + +# Note: this function (nms_rotated) might be moved into +# torchvision/ops/boxes.py in the future +def nms_rotated(boxes: torch.Tensor, scores: torch.Tensor, iou_threshold: float): + """ + Performs non-maximum suppression (NMS) on the rotated boxes according + to their intersection-over-union (IoU). + + Rotated NMS iteratively removes lower scoring rotated boxes which have an + IoU greater than iou_threshold with another (higher scoring) rotated box. + + Note that RotatedBox (5, 3, 4, 2, -90) covers exactly the same region as + RotatedBox (5, 3, 4, 2, 90) does, and their IoU will be 1. However, they + can be representing completely different objects in certain tasks, e.g., OCR. + + As for the question of whether rotated-NMS should treat them as faraway boxes + even though their IOU is 1, it depends on the application and/or ground truth annotation. + + As an extreme example, consider a single character v and the square box around it. + + If the angle is 0 degree, the object (text) would be read as 'v'; + + If the angle is 90 degrees, the object (text) would become '>'; + + If the angle is 180 degrees, the object (text) would become '^'; + + If the angle is 270/-90 degrees, the object (text) would become '<' + + All of these cases have IoU of 1 to each other, and rotated NMS that only + uses IoU as criterion would only keep one of them with the highest score - + which, practically, still makes sense in most cases because typically + only one of theses orientations is the correct one. Also, it does not matter + as much if the box is only used to classify the object (instead of transcribing + them with a sequential OCR recognition model) later. + + On the other hand, when we use IoU to filter proposals that are close to the + ground truth during training, we should definitely take the angle into account if + we know the ground truth is labeled with the strictly correct orientation (as in, + upside-down words are annotated with -180 degrees even though they can be covered + with a 0/90/-90 degree box, etc.) + + The way the original dataset is annotated also matters. For example, if the dataset + is a 4-point polygon dataset that does not enforce ordering of vertices/orientation, + we can estimate a minimum rotated bounding box to this polygon, but there's no way + we can tell the correct angle with 100% confidence (as shown above, there could be 4 different + rotated boxes, with angles differed by 90 degrees to each other, covering the exactly + same region). In that case we have to just use IoU to determine the box + proximity (as many detection benchmarks (even for text) do) unless there're other + assumptions we can make (like width is always larger than height, or the object is not + rotated by more than 90 degrees CCW/CW, etc.) + + In summary, not considering angles in rotated NMS seems to be a good option for now, + but we should be aware of its implications. + + Args: + boxes (Tensor[N, 5]): Rotated boxes to perform NMS on. They are expected to be in + (x_center, y_center, width, height, angle_degrees) format. + scores (Tensor[N]): Scores for each one of the rotated boxes + iou_threshold (float): Discards all overlapping rotated boxes with IoU < iou_threshold + + Returns: + keep (Tensor): int64 tensor with the indices of the elements that have been kept + by Rotated NMS, sorted in decreasing order of scores + """ + return torch.ops.detectron2.nms_rotated(boxes, scores, iou_threshold) + + +# Note: this function (batched_nms_rotated) might be moved into +# torchvision/ops/boxes.py in the future + + +@torch.jit.script_if_tracing +def batched_nms_rotated( + boxes: torch.Tensor, scores: torch.Tensor, idxs: torch.Tensor, iou_threshold: float +): + """ + Performs non-maximum suppression in a batched fashion. + + Each index value correspond to a category, and NMS + will not be applied between elements of different categories. + + Args: + boxes (Tensor[N, 5]): + boxes where NMS will be performed. They + are expected to be in (x_ctr, y_ctr, width, height, angle_degrees) format + scores (Tensor[N]): + scores for each one of the boxes + idxs (Tensor[N]): + indices of the categories for each one of the boxes. + iou_threshold (float): + discards all overlapping boxes + with IoU < iou_threshold + + Returns: + Tensor: + int64 tensor with the indices of the elements that have been kept + by NMS, sorted in decreasing order of scores + """ + assert boxes.shape[-1] == 5 + + if boxes.numel() == 0: + return torch.empty((0,), dtype=torch.int64, device=boxes.device) + boxes = boxes.float() # fp16 does not have enough range for batched NMS + # Strategy: in order to perform NMS independently per class, + # we add an offset to all the boxes. The offset is dependent + # only on the class idx, and is large enough so that boxes + # from different classes do not overlap + + # Note that batched_nms in torchvision/ops/boxes.py only uses max_coordinate, + # which won't handle negative coordinates correctly. + # Here by using min_coordinate we can make sure the negative coordinates are + # correctly handled. + max_coordinate = ( + torch.max(boxes[:, 0], boxes[:, 1]) + torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).max() + min_coordinate = ( + torch.min(boxes[:, 0], boxes[:, 1]) - torch.max(boxes[:, 2], boxes[:, 3]) / 2 + ).min() + offsets = idxs.to(boxes) * (max_coordinate - min_coordinate + 1) + boxes_for_nms = boxes.clone() # avoid modifying the original values in boxes + boxes_for_nms[:, :2] += offsets[:, None] + keep = nms_rotated(boxes_for_nms, scores, iou_threshold) + return keep diff --git a/custom_detectron2/layers/roi_align.py b/custom_detectron2/layers/roi_align.py new file mode 100644 index 0000000000000000000000000000000000000000..163462e1f194e1e4100da92d76d9516f7cc22e35 --- /dev/null +++ b/custom_detectron2/layers/roi_align.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from torch import nn +from torchvision.ops import roi_align + + +# NOTE: torchvision's RoIAlign has a different default aligned=False +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True): + """ + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + aligned (bool): if False, use the legacy implementation in + Detectron. If True, align the results more perfectly. + + Note: + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). But the original + roi_align (aligned=False) does not subtract the 0.5 when computing neighboring + pixel indices and therefore it uses pixels with a slightly incorrect alignment + (relative to our pixel model) when performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; see + detectron2/tests/test_roi_align.py for verification. + + The difference does not make a difference to the model's performance if + ROIAlign is used together with conv layers. + """ + super().__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + + from torchvision import __version__ + + version = tuple(int(x) for x in __version__.split(".")[:2]) + # https://github.com/pytorch/vision/pull/2438 + assert version >= (0, 7), "Require torchvision >= 0.7" + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. + """ + assert rois.dim() == 2 and rois.size(1) == 5 + if input.is_quantized: + input = input.dequantize() + return roi_align( + input, + rois.to(dtype=input.dtype), + self.output_size, + self.spatial_scale, + self.sampling_ratio, + self.aligned, + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" + return tmpstr diff --git a/custom_detectron2/layers/roi_align_rotated.py b/custom_detectron2/layers/roi_align_rotated.py new file mode 100644 index 0000000000000000000000000000000000000000..2a523992e7c736262ad5a158f209aae7875f6f0b --- /dev/null +++ b/custom_detectron2/layers/roi_align_rotated.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch import nn +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.nn.modules.utils import _pair + + +class _ROIAlignRotated(Function): + @staticmethod + def forward(ctx, input, roi, output_size, spatial_scale, sampling_ratio): + ctx.save_for_backward(roi) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input.size() + output = torch.ops.detectron2.roi_align_rotated_forward( + input, roi, spatial_scale, output_size[0], output_size[1], sampling_ratio + ) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + (rois,) = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + grad_input = torch.ops.detectron2.roi_align_rotated_backward( + grad_output, + rois, + spatial_scale, + output_size[0], + output_size[1], + bs, + ch, + h, + w, + sampling_ratio, + ) + return grad_input, None, None, None, None, None + + +roi_align_rotated = _ROIAlignRotated.apply + + +class ROIAlignRotated(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio): + """ + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + + Note: + ROIAlignRotated supports continuous coordinate by default: + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). + """ + super(ROIAlignRotated, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + + def forward(self, input, rois): + """ + Args: + input: NCHW images + rois: Bx6 boxes. First column is the index into N. + The other 5 columns are (x_ctr, y_ctr, width, height, angle_degrees). + """ + assert rois.dim() == 2 and rois.size(1) == 6 + orig_dtype = input.dtype + if orig_dtype == torch.float16: + input = input.float() + rois = rois.float() + output_size = _pair(self.output_size) + + # Scripting for Autograd is currently unsupported. + # This is a quick fix without having to rewrite code on the C++ side + if torch.jit.is_scripting() or torch.jit.is_tracing(): + return torch.ops.detectron2.roi_align_rotated_forward( + input, rois, self.spatial_scale, output_size[0], output_size[1], self.sampling_ratio + ).to(dtype=orig_dtype) + + return roi_align_rotated( + input, rois, self.output_size, self.spatial_scale, self.sampling_ratio + ).to(dtype=orig_dtype) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ")" + return tmpstr diff --git a/custom_detectron2/layers/rotated_boxes.py b/custom_detectron2/layers/rotated_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..03f73b3bb99275931a887ad9b2d8c0ac9f412bf3 --- /dev/null +++ b/custom_detectron2/layers/rotated_boxes.py @@ -0,0 +1,21 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import absolute_import, division, print_function, unicode_literals +import torch + + +def pairwise_iou_rotated(boxes1, boxes2): + """ + Return intersection-over-union (Jaccard index) of boxes. + + Both sets of boxes are expected to be in + (x_center, y_center, width, height, angle) format. + + Arguments: + boxes1 (Tensor[N, 5]) + boxes2 (Tensor[M, 5]) + + Returns: + iou (Tensor[N, M]): the NxM matrix containing the pairwise + IoU values for every element in boxes1 and boxes2 + """ + return torch.ops.detectron2.box_iou_rotated(boxes1, boxes2) diff --git a/custom_detectron2/layers/shape_spec.py b/custom_detectron2/layers/shape_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..8dac3c59b96576710656abebe9b5eac25868abbb --- /dev/null +++ b/custom_detectron2/layers/shape_spec.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ShapeSpec: + """ + A simple structure that contains basic shape specification about a tensor. + It is often used as the auxiliary inputs/outputs of models, + to complement the lack of shape inference ability among pytorch modules. + """ + + channels: Optional[int] = None + height: Optional[int] = None + width: Optional[int] = None + stride: Optional[int] = None diff --git a/custom_detectron2/layers/wrappers.py b/custom_detectron2/layers/wrappers.py new file mode 100644 index 0000000000000000000000000000000000000000..3736f54335574a90d21e4812bedf982eae6ca681 --- /dev/null +++ b/custom_detectron2/layers/wrappers.py @@ -0,0 +1,162 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Wrappers around on some nn functions, mainly to support empty tensors. + +Ideally, add support directly in PyTorch to empty tensors in those functions. + +These can be removed once https://github.com/pytorch/pytorch/issues/12013 +is implemented +""" + +import warnings +from typing import List, Optional +import torch +from torch.nn import functional as F + +from custom_detectron2.utils.env import TORCH_VERSION + + +def shapes_to_tensor(x: List[int], device: Optional[torch.device] = None) -> torch.Tensor: + """ + Turn a list of integer scalars or integer Tensor scalars into a vector, + in a way that's both traceable and scriptable. + + In tracing, `x` should be a list of scalar Tensor, so the output can trace to the inputs. + In scripting or eager, `x` should be a list of int. + """ + if torch.jit.is_scripting(): + return torch.as_tensor(x, device=device) + if torch.jit.is_tracing(): + assert all( + [isinstance(t, torch.Tensor) for t in x] + ), "Shape should be tensor during tracing!" + # as_tensor should not be used in tracing because it records a constant + ret = torch.stack(x) + if ret.device != device: # avoid recording a hard-coded device if not necessary + ret = ret.to(device=device) + return ret + return torch.as_tensor(x, device=device) + + +def check_if_dynamo_compiling(): + if TORCH_VERSION >= (1, 14): + from torch._dynamo import is_compiling + + return is_compiling() + else: + return False + + +def cat(tensors: List[torch.Tensor], dim: int = 0): + """ + Efficient version of torch.cat that avoids a copy if there is only a single element in a list + """ + assert isinstance(tensors, (list, tuple)) + if len(tensors) == 1: + return tensors[0] + return torch.cat(tensors, dim) + + +def empty_input_loss_func_wrapper(loss_func): + def wrapped_loss_func(input, target, *, reduction="mean", **kwargs): + """ + Same as `loss_func`, but returns 0 (instead of nan) for empty inputs. + """ + if target.numel() == 0 and reduction == "mean": + return input.sum() * 0.0 # connect the gradient + return loss_func(input, target, reduction=reduction, **kwargs) + + return wrapped_loss_func + + +cross_entropy = empty_input_loss_func_wrapper(F.cross_entropy) + + +class _NewEmptyTensorOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, new_shape): + ctx.shape = x.shape + return x.new_empty(new_shape) + + @staticmethod + def backward(ctx, grad): + shape = ctx.shape + return _NewEmptyTensorOp.apply(grad, shape), None + + +class Conv2d(torch.nn.Conv2d): + """ + A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. + """ + + def __init__(self, *args, **kwargs): + """ + Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: + + Args: + norm (nn.Module, optional): a normalization layer + activation (callable(Tensor) -> Tensor): a callable activation function + + It assumes that norm layer is used before activation. + """ + norm = kwargs.pop("norm", None) + activation = kwargs.pop("activation", None) + super().__init__(*args, **kwargs) + + self.norm = norm + self.activation = activation + + def forward(self, x): + # torchscript does not support SyncBatchNorm yet + # https://github.com/pytorch/pytorch/issues/40507 + # and we skip these codes in torchscript since: + # 1. currently we only support torchscript in evaluation mode + # 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or + # later version, `Conv2d` in these PyTorch versions has already supported empty inputs. + if not torch.jit.is_scripting(): + # Dynamo doesn't support context managers yet + is_dynamo_compiling = check_if_dynamo_compiling() + if not is_dynamo_compiling: + with warnings.catch_warnings(record=True): + if x.numel() == 0 and self.training: + # https://github.com/pytorch/pytorch/issues/12013 + assert not isinstance( + self.norm, torch.nn.SyncBatchNorm + ), "SyncBatchNorm does not support empty inputs!" + + x = F.conv2d( + x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups + ) + if self.norm is not None: + x = self.norm(x) + if self.activation is not None: + x = self.activation(x) + return x + + +ConvTranspose2d = torch.nn.ConvTranspose2d +BatchNorm2d = torch.nn.BatchNorm2d +interpolate = F.interpolate +Linear = torch.nn.Linear + + +def nonzero_tuple(x): + """ + A 'as_tuple=True' version of torch.nonzero to support torchscript. + because of https://github.com/pytorch/pytorch/issues/38718 + """ + if torch.jit.is_scripting(): + if x.dim() == 0: + return x.unsqueeze(0).nonzero().unbind(1) + return x.nonzero().unbind(1) + else: + return x.nonzero(as_tuple=True) + + +@torch.jit.script_if_tracing +def move_device_like(src: torch.Tensor, dst: torch.Tensor) -> torch.Tensor: + """ + Tracing friendly way to cast tensor to another tensor's device. Device will be treated + as constant during tracing, scripting the casting process as whole can workaround this issue. + """ + return src.to(dst.device) diff --git a/custom_detectron2/model_zoo/__init__.py b/custom_detectron2/model_zoo/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6204208198d813728cf6419e8eef4a733f20c18f --- /dev/null +++ b/custom_detectron2/model_zoo/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +Model Zoo API for Detectron2: a collection of functions to create common model architectures +listed in `MODEL_ZOO.md `_, +and optionally load their pre-trained weights. +""" + +from .model_zoo import get, get_config_file, get_checkpoint_url, get_config + +__all__ = ["get_checkpoint_url", "get", "get_config_file", "get_config"] diff --git a/custom_detectron2/model_zoo/model_zoo.py b/custom_detectron2/model_zoo/model_zoo.py new file mode 100644 index 0000000000000000000000000000000000000000..fdbdf300bbf638ff4e01bc05226e8fd22499a83e --- /dev/null +++ b/custom_detectron2/model_zoo/model_zoo.py @@ -0,0 +1,213 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import os +from typing import Optional +import pkg_resources +import torch + +from custom_detectron2.checkpoint import DetectionCheckpointer +from custom_detectron2.config import CfgNode, LazyConfig, get_cfg, instantiate +from custom_detectron2.modeling import build_model + + +class _ModelZooUrls(object): + """ + Mapping from names to officially released Detectron2 pre-trained models. + """ + + S3_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" + + # format: {config_path.yaml} -> model_id/model_final_{commit}.pkl + CONFIG_PATH_TO_URL_SUFFIX = { + # COCO Detection with Faster R-CNN + "COCO-Detection/faster_rcnn_R_50_C4_1x": "137257644/model_final_721ade.pkl", + "COCO-Detection/faster_rcnn_R_50_DC5_1x": "137847829/model_final_51d356.pkl", + "COCO-Detection/faster_rcnn_R_50_FPN_1x": "137257794/model_final_b275ba.pkl", + "COCO-Detection/faster_rcnn_R_50_C4_3x": "137849393/model_final_f97cb7.pkl", + "COCO-Detection/faster_rcnn_R_50_DC5_3x": "137849425/model_final_68d202.pkl", + "COCO-Detection/faster_rcnn_R_50_FPN_3x": "137849458/model_final_280758.pkl", + "COCO-Detection/faster_rcnn_R_101_C4_3x": "138204752/model_final_298dad.pkl", + "COCO-Detection/faster_rcnn_R_101_DC5_3x": "138204841/model_final_3e0943.pkl", + "COCO-Detection/faster_rcnn_R_101_FPN_3x": "137851257/model_final_f6e8b1.pkl", + "COCO-Detection/faster_rcnn_X_101_32x8d_FPN_3x": "139173657/model_final_68b088.pkl", + # COCO Detection with RetinaNet + "COCO-Detection/retinanet_R_50_FPN_1x": "190397773/model_final_bfca0b.pkl", + "COCO-Detection/retinanet_R_50_FPN_3x": "190397829/model_final_5bd44e.pkl", + "COCO-Detection/retinanet_R_101_FPN_3x": "190397697/model_final_971ab9.pkl", + # COCO Detection with RPN and Fast R-CNN + "COCO-Detection/rpn_R_50_C4_1x": "137258005/model_final_450694.pkl", + "COCO-Detection/rpn_R_50_FPN_1x": "137258492/model_final_02ce48.pkl", + "COCO-Detection/fast_rcnn_R_50_FPN_1x": "137635226/model_final_e5f7ce.pkl", + # COCO Instance Segmentation Baselines with Mask R-CNN + "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_1x": "137259246/model_final_9243eb.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_1x": "137260150/model_final_4f86c3.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "137260431/model_final_a54504.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_C4_3x": "137849525/model_final_4ce675.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_DC5_3x": "137849551/model_final_84107b.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x": "137849600/model_final_f10217.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_C4_3x": "138363239/model_final_a2914c.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_DC5_3x": "138363294/model_final_0464b7.pkl", + "COCO-InstanceSegmentation/mask_rcnn_R_101_FPN_3x": "138205316/model_final_a3ec72.pkl", + "COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x": "139653917/model_final_2d9806.pkl", # noqa + # New baselines using Large-Scale Jitter and Longer Training Schedule + "new_baselines/mask_rcnn_R_50_FPN_100ep_LSJ": "42047764/model_final_bb69de.pkl", + "new_baselines/mask_rcnn_R_50_FPN_200ep_LSJ": "42047638/model_final_89a8d3.pkl", + "new_baselines/mask_rcnn_R_50_FPN_400ep_LSJ": "42019571/model_final_14d201.pkl", + "new_baselines/mask_rcnn_R_101_FPN_100ep_LSJ": "42025812/model_final_4f7b58.pkl", + "new_baselines/mask_rcnn_R_101_FPN_200ep_LSJ": "42131867/model_final_0bb7ae.pkl", + "new_baselines/mask_rcnn_R_101_FPN_400ep_LSJ": "42073830/model_final_f96b26.pkl", + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_100ep_LSJ": "42047771/model_final_b7fbab.pkl", # noqa + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_200ep_LSJ": "42132721/model_final_5d87c1.pkl", # noqa + "new_baselines/mask_rcnn_regnetx_4gf_dds_FPN_400ep_LSJ": "42025447/model_final_f1362d.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_100ep_LSJ": "42047784/model_final_6ba57e.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_200ep_LSJ": "42047642/model_final_27b9c1.pkl", # noqa + "new_baselines/mask_rcnn_regnety_4gf_dds_FPN_400ep_LSJ": "42045954/model_final_ef3a80.pkl", # noqa + # COCO Person Keypoint Detection Baselines with Keypoint R-CNN + "COCO-Keypoints/keypoint_rcnn_R_50_FPN_1x": "137261548/model_final_04e291.pkl", + "COCO-Keypoints/keypoint_rcnn_R_50_FPN_3x": "137849621/model_final_a6e10b.pkl", + "COCO-Keypoints/keypoint_rcnn_R_101_FPN_3x": "138363331/model_final_997cc7.pkl", + "COCO-Keypoints/keypoint_rcnn_X_101_32x8d_FPN_3x": "139686956/model_final_5ad38f.pkl", + # COCO Panoptic Segmentation Baselines with Panoptic FPN + "COCO-PanopticSegmentation/panoptic_fpn_R_50_1x": "139514544/model_final_dbfeb4.pkl", + "COCO-PanopticSegmentation/panoptic_fpn_R_50_3x": "139514569/model_final_c10459.pkl", + "COCO-PanopticSegmentation/panoptic_fpn_R_101_3x": "139514519/model_final_cafdb1.pkl", + # LVIS Instance Segmentation Baselines with Mask R-CNN + "LVISv0.5-InstanceSegmentation/mask_rcnn_R_50_FPN_1x": "144219072/model_final_571f7c.pkl", # noqa + "LVISv0.5-InstanceSegmentation/mask_rcnn_R_101_FPN_1x": "144219035/model_final_824ab5.pkl", # noqa + "LVISv0.5-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_1x": "144219108/model_final_5e3439.pkl", # noqa + # Cityscapes & Pascal VOC Baselines + "Cityscapes/mask_rcnn_R_50_FPN": "142423278/model_final_af9cf5.pkl", + "PascalVOC-Detection/faster_rcnn_R_50_C4": "142202221/model_final_b1acc2.pkl", + # Other Settings + "Misc/mask_rcnn_R_50_FPN_1x_dconv_c3-c5": "138602867/model_final_65c703.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_dconv_c3-c5": "144998336/model_final_821d0b.pkl", + "Misc/cascade_mask_rcnn_R_50_FPN_1x": "138602847/model_final_e9d89b.pkl", + "Misc/cascade_mask_rcnn_R_50_FPN_3x": "144998488/model_final_480dd8.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_syncbn": "169527823/model_final_3b3c51.pkl", + "Misc/mask_rcnn_R_50_FPN_3x_gn": "138602888/model_final_dc5d9e.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_3x_gn": "138602908/model_final_01ca85.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_9x_gn": "183808979/model_final_da7b4c.pkl", + "Misc/scratch_mask_rcnn_R_50_FPN_9x_syncbn": "184226666/model_final_5ce33e.pkl", + "Misc/panoptic_fpn_R_101_dconv_cascade_gn_3x": "139797668/model_final_be35db.pkl", + "Misc/cascade_mask_rcnn_X_152_32x8d_FPN_IN5k_gn_dconv": "18131413/model_0039999_e76410.pkl", # noqa + # D1 Comparisons + "Detectron1-Comparisons/faster_rcnn_R_50_FPN_noaug_1x": "137781054/model_final_7ab50c.pkl", # noqa + "Detectron1-Comparisons/mask_rcnn_R_50_FPN_noaug_1x": "137781281/model_final_62ca52.pkl", # noqa + "Detectron1-Comparisons/keypoint_rcnn_R_50_FPN_1x": "137781195/model_final_cce136.pkl", + } + + @staticmethod + def query(config_path: str) -> Optional[str]: + """ + Args: + config_path: relative config filename + """ + name = config_path.replace(".yaml", "").replace(".py", "") + if name in _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX: + suffix = _ModelZooUrls.CONFIG_PATH_TO_URL_SUFFIX[name] + return _ModelZooUrls.S3_PREFIX + name + "/" + suffix + return None + + +def get_checkpoint_url(config_path): + """ + Returns the URL to the model trained using the given config + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + + Returns: + str: a URL to the model + """ + url = _ModelZooUrls.query(config_path) + if url is None: + raise RuntimeError("Pretrained model for {} is not available!".format(config_path)) + return url + + +def get_config_file(config_path): + """ + Returns path to a builtin config file. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + + Returns: + str: the real path to the config file. + """ + cfg_file = pkg_resources.resource_filename( + "detectron2.model_zoo", os.path.join("configs", config_path) + ) + if not os.path.exists(cfg_file): + raise RuntimeError("{} not available in Model Zoo!".format(config_path)) + return cfg_file + + +def get_config(config_path, trained: bool = False): + """ + Returns a config object for a model in model zoo. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + trained (bool): If True, will set ``MODEL.WEIGHTS`` to trained model zoo weights. + If False, the checkpoint specified in the config file's ``MODEL.WEIGHTS`` is used + instead; this will typically (though not always) initialize a subset of weights using + an ImageNet pre-trained model, while randomly initializing the other weights. + + Returns: + CfgNode or omegaconf.DictConfig: a config object + """ + cfg_file = get_config_file(config_path) + if cfg_file.endswith(".yaml"): + cfg = get_cfg() + cfg.merge_from_file(cfg_file) + if trained: + cfg.MODEL.WEIGHTS = get_checkpoint_url(config_path) + return cfg + elif cfg_file.endswith(".py"): + cfg = LazyConfig.load(cfg_file) + if trained: + url = get_checkpoint_url(config_path) + if "train" in cfg and "init_checkpoint" in cfg.train: + cfg.train.init_checkpoint = url + else: + raise NotImplementedError + return cfg + + +def get(config_path, trained: bool = False, device: Optional[str] = None): + """ + Get a model specified by relative path under Detectron2's official ``configs/`` directory. + + Args: + config_path (str): config file name relative to detectron2's "configs/" + directory, e.g., "COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml" + trained (bool): see :func:`get_config`. + device (str or None): overwrite the device in config, if given. + + Returns: + nn.Module: a detectron2 model. Will be in training mode. + + Example: + :: + from custom_detectron2 import model_zoo + model = model_zoo.get("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_1x.yaml", trained=True) + """ + cfg = get_config(config_path, trained) + if device is None and not torch.cuda.is_available(): + device = "cpu" + if device is not None and isinstance(cfg, CfgNode): + cfg.MODEL.DEVICE = device + + if isinstance(cfg, CfgNode): + model = build_model(cfg) + DetectionCheckpointer(model).load(cfg.MODEL.WEIGHTS) + else: + model = instantiate(cfg.model) + if device is not None: + model = model.to(device) + if "train" in cfg and "init_checkpoint" in cfg.train: + DetectionCheckpointer(model).load(cfg.train.init_checkpoint) + return model diff --git a/custom_detectron2/modeling/__init__.py b/custom_detectron2/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..247d457aabadb1264f7b5dc464cae5761b6a78e7 --- /dev/null +++ b/custom_detectron2/modeling/__init__.py @@ -0,0 +1,64 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from custom_detectron2.layers import ShapeSpec + +from .anchor_generator import build_anchor_generator, ANCHOR_GENERATOR_REGISTRY +from .backbone import ( + BACKBONE_REGISTRY, + FPN, + Backbone, + ResNet, + ResNetBlockBase, + build_backbone, + build_resnet_backbone, + make_stage, + ViT, + SimpleFeaturePyramid, + get_vit_lr_decay_rate, + MViT, + SwinTransformer, +) +from .meta_arch import ( + META_ARCH_REGISTRY, + SEM_SEG_HEADS_REGISTRY, + GeneralizedRCNN, + PanopticFPN, + ProposalNetwork, + RetinaNet, + SemanticSegmentor, + build_model, + build_sem_seg_head, + FCOS, +) +from .postprocessing import detector_postprocess +from .proposal_generator import ( + PROPOSAL_GENERATOR_REGISTRY, + build_proposal_generator, + RPN_HEAD_REGISTRY, + build_rpn_head, +) +from .roi_heads import ( + ROI_BOX_HEAD_REGISTRY, + ROI_HEADS_REGISTRY, + ROI_KEYPOINT_HEAD_REGISTRY, + ROI_MASK_HEAD_REGISTRY, + ROIHeads, + StandardROIHeads, + BaseMaskRCNNHead, + BaseKeypointRCNNHead, + FastRCNNOutputLayers, + build_box_head, + build_keypoint_head, + build_mask_head, + build_roi_heads, +) +from .test_time_augmentation import DatasetMapperTTA, GeneralizedRCNNWithTTA +from .mmdet_wrapper import MMDetBackbone, MMDetDetector + +_EXCLUDE = {"ShapeSpec"} +__all__ = [k for k in globals().keys() if k not in _EXCLUDE and not k.startswith("_")] + + +from custom_detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/custom_detectron2/modeling/anchor_generator.py b/custom_detectron2/modeling/anchor_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..6768f7d06415fee4cb707092dddc7bbc764d59b0 --- /dev/null +++ b/custom_detectron2/modeling/anchor_generator.py @@ -0,0 +1,386 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import collections +import math +from typing import List +import torch +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ShapeSpec, move_device_like +from custom_detectron2.structures import Boxes, RotatedBoxes +from custom_detectron2.utils.registry import Registry + +ANCHOR_GENERATOR_REGISTRY = Registry("ANCHOR_GENERATOR") +ANCHOR_GENERATOR_REGISTRY.__doc__ = """ +Registry for modules that creates object detection anchors for feature maps. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +class BufferList(nn.Module): + """ + Similar to nn.ParameterList, but for buffers + """ + + def __init__(self, buffers): + super().__init__() + for i, buffer in enumerate(buffers): + # Use non-persistent buffer so the values are not saved in checkpoint + self.register_buffer(str(i), buffer, persistent=False) + + def __len__(self): + return len(self._buffers) + + def __iter__(self): + return iter(self._buffers.values()) + + +def _create_grid_offsets( + size: List[int], stride: int, offset: float, target_device_tensor: torch.Tensor +): + grid_height, grid_width = size + shifts_x = move_device_like( + torch.arange(offset * stride, grid_width * stride, step=stride, dtype=torch.float32), + target_device_tensor, + ) + shifts_y = move_device_like( + torch.arange(offset * stride, grid_height * stride, step=stride, dtype=torch.float32), + target_device_tensor, + ) + + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_x = shift_x.reshape(-1) + shift_y = shift_y.reshape(-1) + return shift_x, shift_y + + +def _broadcast_params(params, num_features, name): + """ + If one size (or aspect ratio) is specified and there are multiple feature + maps, we "broadcast" anchors of that single size (or aspect ratio) + over all feature maps. + + If params is list[float], or list[list[float]] with len(params) == 1, repeat + it num_features time. + + Returns: + list[list[float]]: param for each feature + """ + assert isinstance( + params, collections.abc.Sequence + ), f"{name} in anchor generator has to be a list! Got {params}." + assert len(params), f"{name} in anchor generator cannot be empty!" + if not isinstance(params[0], collections.abc.Sequence): # params is list[float] + return [params] * num_features + if len(params) == 1: + return list(params) * num_features + assert len(params) == num_features, ( + f"Got {name} of length {len(params)} in anchor generator, " + f"but the number of input features is {num_features}!" + ) + return params + + +@ANCHOR_GENERATOR_REGISTRY.register() +class DefaultAnchorGenerator(nn.Module): + """ + Compute anchors in the standard ways described in + "Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks". + """ + + box_dim: torch.jit.Final[int] = 4 + """ + the dimension of each anchor box. + """ + + @configurable + def __init__(self, *, sizes, aspect_ratios, strides, offset=0.5): + """ + This interface is experimental. + + Args: + sizes (list[list[float]] or list[float]): + If ``sizes`` is list[list[float]], ``sizes[i]`` is the list of anchor sizes + (i.e. sqrt of anchor area) to use for the i-th feature map. + If ``sizes`` is list[float], ``sizes`` is used for all feature maps. + Anchor sizes are given in absolute lengths in units of + the input image; they do not dynamically scale if the input image size changes. + aspect_ratios (list[list[float]] or list[float]): list of aspect ratios + (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies. + strides (list[int]): stride of each input feature. + offset (float): Relative offset between the center of the first anchor and the top-left + corner of the image. Value has to be in [0, 1). + Recommend to use 0.5, which means half stride. + """ + super().__init__() + + self.strides = strides + self.num_features = len(self.strides) + sizes = _broadcast_params(sizes, self.num_features, "sizes") + aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios") + self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios) + + self.offset = offset + assert 0.0 <= self.offset < 1.0, self.offset + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + return { + "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES, + "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS, + "strides": [x.stride for x in input_shape], + "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET, + } + + def _calculate_anchors(self, sizes, aspect_ratios): + cell_anchors = [ + self.generate_cell_anchors(s, a).float() for s, a in zip(sizes, aspect_ratios) + ] + return BufferList(cell_anchors) + + @property + @torch.jit.unused + def num_cell_anchors(self): + """ + Alias of `num_anchors`. + """ + return self.num_anchors + + @property + @torch.jit.unused + def num_anchors(self): + """ + Returns: + list[int]: Each int is the number of anchors at every pixel + location, on that feature map. + For example, if at every pixel we use anchors of 3 aspect + ratios and 5 sizes, the number of anchors is 15. + (See also ANCHOR_GENERATOR.SIZES and ANCHOR_GENERATOR.ASPECT_RATIOS in config) + + In standard RPN models, `num_anchors` on every feature map is the same. + """ + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def _grid_anchors(self, grid_sizes: List[List[int]]): + """ + Returns: + list[Tensor]: #featuremap tensors, each is (#locations x #cell_anchors) x 4 + """ + anchors = [] + # buffers() not supported by torchscript. use named_buffers() instead + buffers: List[torch.Tensor] = [x[1] for x in self.cell_anchors.named_buffers()] + for size, stride, base_anchors in zip(grid_sizes, self.strides, buffers): + shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors) + shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1) + + anchors.append((shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)) + + return anchors + + def generate_cell_anchors(self, sizes=(32, 64, 128, 256, 512), aspect_ratios=(0.5, 1, 2)): + """ + Generate a tensor storing canonical anchor boxes, which are all anchor + boxes of different sizes and aspect_ratios centered at (0, 0). + We can later build the set of anchors for a full feature map by + shifting and tiling these tensors (see `meth:_grid_anchors`). + + Args: + sizes (tuple[float]): + aspect_ratios (tuple[float]]): + + Returns: + Tensor of shape (len(sizes) * len(aspect_ratios), 4) storing anchor boxes + in XYXY format. + """ + + # This is different from the anchor generator defined in the original Faster R-CNN + # code or Detectron. They yield the same AP, however the old version defines cell + # anchors in a less natural way with a shift relative to the feature grid and + # quantization that results in slightly different sizes for different aspect ratios. + # See also https://github.com/facebookresearch/Detectron/issues/227 + + anchors = [] + for size in sizes: + area = size**2.0 + for aspect_ratio in aspect_ratios: + # s * s = w * h + # a = h / w + # ... some algebra ... + # w = sqrt(s * s / a) + # h = a * w + w = math.sqrt(area / aspect_ratio) + h = aspect_ratio * w + x0, y0, x1, y1 = -w / 2.0, -h / 2.0, w / 2.0, h / 2.0 + anchors.append([x0, y0, x1, y1]) + return torch.tensor(anchors) + + def forward(self, features: List[torch.Tensor]): + """ + Args: + features (list[Tensor]): list of backbone feature maps on which to generate anchors. + + Returns: + list[Boxes]: a list of Boxes containing all the anchors for each feature map + (i.e. the cell anchors repeated over all locations in the feature map). + The number of anchors of each feature map is Hi x Wi x num_cell_anchors, + where Hi, Wi are resolution of the feature map divided by anchor stride. + """ + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) + return [Boxes(x) for x in anchors_over_all_feature_maps] + + +@ANCHOR_GENERATOR_REGISTRY.register() +class RotatedAnchorGenerator(nn.Module): + """ + Compute rotated anchors used by Rotated RPN (RRPN), described in + "Arbitrary-Oriented Scene Text Detection via Rotation Proposals". + """ + + box_dim: int = 5 + """ + the dimension of each anchor box. + """ + + @configurable + def __init__(self, *, sizes, aspect_ratios, strides, angles, offset=0.5): + """ + This interface is experimental. + + Args: + sizes (list[list[float]] or list[float]): + If sizes is list[list[float]], sizes[i] is the list of anchor sizes + (i.e. sqrt of anchor area) to use for the i-th feature map. + If sizes is list[float], the sizes are used for all feature maps. + Anchor sizes are given in absolute lengths in units of + the input image; they do not dynamically scale if the input image size changes. + aspect_ratios (list[list[float]] or list[float]): list of aspect ratios + (i.e. height / width) to use for anchors. Same "broadcast" rule for `sizes` applies. + strides (list[int]): stride of each input feature. + angles (list[list[float]] or list[float]): list of angles (in degrees CCW) + to use for anchors. Same "broadcast" rule for `sizes` applies. + offset (float): Relative offset between the center of the first anchor and the top-left + corner of the image. Value has to be in [0, 1). + Recommend to use 0.5, which means half stride. + """ + super().__init__() + + self.strides = strides + self.num_features = len(self.strides) + sizes = _broadcast_params(sizes, self.num_features, "sizes") + aspect_ratios = _broadcast_params(aspect_ratios, self.num_features, "aspect_ratios") + angles = _broadcast_params(angles, self.num_features, "angles") + self.cell_anchors = self._calculate_anchors(sizes, aspect_ratios, angles) + + self.offset = offset + assert 0.0 <= self.offset < 1.0, self.offset + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + return { + "sizes": cfg.MODEL.ANCHOR_GENERATOR.SIZES, + "aspect_ratios": cfg.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS, + "strides": [x.stride for x in input_shape], + "offset": cfg.MODEL.ANCHOR_GENERATOR.OFFSET, + "angles": cfg.MODEL.ANCHOR_GENERATOR.ANGLES, + } + + def _calculate_anchors(self, sizes, aspect_ratios, angles): + cell_anchors = [ + self.generate_cell_anchors(size, aspect_ratio, angle).float() + for size, aspect_ratio, angle in zip(sizes, aspect_ratios, angles) + ] + return BufferList(cell_anchors) + + @property + def num_cell_anchors(self): + """ + Alias of `num_anchors`. + """ + return self.num_anchors + + @property + def num_anchors(self): + """ + Returns: + list[int]: Each int is the number of anchors at every pixel + location, on that feature map. + For example, if at every pixel we use anchors of 3 aspect + ratios, 2 sizes and 5 angles, the number of anchors is 30. + (See also ANCHOR_GENERATOR.SIZES, ANCHOR_GENERATOR.ASPECT_RATIOS + and ANCHOR_GENERATOR.ANGLES in config) + + In standard RRPN models, `num_anchors` on every feature map is the same. + """ + return [len(cell_anchors) for cell_anchors in self.cell_anchors] + + def _grid_anchors(self, grid_sizes): + anchors = [] + for size, stride, base_anchors in zip(grid_sizes, self.strides, self.cell_anchors): + shift_x, shift_y = _create_grid_offsets(size, stride, self.offset, base_anchors) + zeros = torch.zeros_like(shift_x) + shifts = torch.stack((shift_x, shift_y, zeros, zeros, zeros), dim=1) + + anchors.append((shifts.view(-1, 1, 5) + base_anchors.view(1, -1, 5)).reshape(-1, 5)) + + return anchors + + def generate_cell_anchors( + self, + sizes=(32, 64, 128, 256, 512), + aspect_ratios=(0.5, 1, 2), + angles=(-90, -60, -30, 0, 30, 60, 90), + ): + """ + Generate a tensor storing canonical anchor boxes, which are all anchor + boxes of different sizes, aspect_ratios, angles centered at (0, 0). + We can later build the set of anchors for a full feature map by + shifting and tiling these tensors (see `meth:_grid_anchors`). + + Args: + sizes (tuple[float]): + aspect_ratios (tuple[float]]): + angles (tuple[float]]): + + Returns: + Tensor of shape (len(sizes) * len(aspect_ratios) * len(angles), 5) + storing anchor boxes in (x_ctr, y_ctr, w, h, angle) format. + """ + anchors = [] + for size in sizes: + area = size**2.0 + for aspect_ratio in aspect_ratios: + # s * s = w * h + # a = h / w + # ... some algebra ... + # w = sqrt(s * s / a) + # h = a * w + w = math.sqrt(area / aspect_ratio) + h = aspect_ratio * w + anchors.extend([0, 0, w, h, a] for a in angles) + + return torch.tensor(anchors) + + def forward(self, features): + """ + Args: + features (list[Tensor]): list of backbone feature maps on which to generate anchors. + + Returns: + list[RotatedBoxes]: a list of Boxes containing all the anchors for each feature map + (i.e. the cell anchors repeated over all locations in the feature map). + The number of anchors of each feature map is Hi x Wi x num_cell_anchors, + where Hi, Wi are resolution of the feature map divided by anchor stride. + """ + grid_sizes = [feature_map.shape[-2:] for feature_map in features] + anchors_over_all_feature_maps = self._grid_anchors(grid_sizes) + return [RotatedBoxes(x) for x in anchors_over_all_feature_maps] + + +def build_anchor_generator(cfg, input_shape): + """ + Built an anchor generator from `cfg.MODEL.ANCHOR_GENERATOR.NAME`. + """ + anchor_generator = cfg.MODEL.ANCHOR_GENERATOR.NAME + return ANCHOR_GENERATOR_REGISTRY.get(anchor_generator)(cfg, input_shape) diff --git a/custom_detectron2/modeling/backbone/__init__.py b/custom_detectron2/modeling/backbone/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5b3358a4061b143c78eba8e7bf81fe9f7ffac1aa --- /dev/null +++ b/custom_detectron2/modeling/backbone/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import build_backbone, BACKBONE_REGISTRY # noqa F401 isort:skip + +from .backbone import Backbone +from .fpn import FPN +from .regnet import RegNet +from .resnet import ( + BasicStem, + ResNet, + ResNetBlockBase, + build_resnet_backbone, + make_stage, + BottleneckBlock, +) +from .vit import ViT, SimpleFeaturePyramid, get_vit_lr_decay_rate +from .mvit import MViT +from .swin import SwinTransformer + +__all__ = [k for k in globals().keys() if not k.startswith("_")] +# TODO can expose more resnet blocks after careful consideration diff --git a/custom_detectron2/modeling/backbone/backbone.py b/custom_detectron2/modeling/backbone/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac190f7e16ba41c73e8d2a335442b200319b203 --- /dev/null +++ b/custom_detectron2/modeling/backbone/backbone.py @@ -0,0 +1,74 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from abc import ABCMeta, abstractmethod +from typing import Dict +import torch.nn as nn + +from custom_detectron2.layers import ShapeSpec + +__all__ = ["Backbone"] + + +class Backbone(nn.Module, metaclass=ABCMeta): + """ + Abstract base class for network backbones. + """ + + def __init__(self): + """ + The `__init__` method of any subclass can specify its own set of arguments. + """ + super().__init__() + + @abstractmethod + def forward(self): + """ + Subclasses must override this method, but adhere to the same return type. + + Returns: + dict[str->Tensor]: mapping from feature name (e.g., "res2") to tensor + """ + pass + + @property + def size_divisibility(self) -> int: + """ + Some backbones require the input height and width to be divisible by a + specific integer. This is typically true for encoder / decoder type networks + with lateral connection (e.g., FPN) for which feature maps need to match + dimension in the "bottom up" and "top down" paths. Set to 0 if no specific + input size divisibility is required. + """ + return 0 + + @property + def padding_constraints(self) -> Dict[str, int]: + """ + This property is a generalization of size_divisibility. Some backbones and training + recipes require specific padding constraints, such as enforcing divisibility by a specific + integer (e.g., FPN) or padding to a square (e.g., ViTDet with large-scale jitter + in :paper:vitdet). `padding_constraints` contains these optional items like: + { + "size_divisibility": int, + "square_size": int, + # Future options are possible + } + `size_divisibility` will read from here if presented and `square_size` indicates the + square padding size if `square_size` > 0. + + TODO: use type of Dict[str, int] to avoid torchscipt issues. The type of padding_constraints + could be generalized as TypedDict (Python 3.8+) to support more types in the future. + """ + return {} + + def output_shape(self): + """ + Returns: + dict[str->ShapeSpec] + """ + # this is a backward-compatible default + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } diff --git a/custom_detectron2/modeling/backbone/build.py b/custom_detectron2/modeling/backbone/build.py new file mode 100644 index 0000000000000000000000000000000000000000..378d6dd9dbb8d1d19646f7ef23b443024dbfeec1 --- /dev/null +++ b/custom_detectron2/modeling/backbone/build.py @@ -0,0 +1,33 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from custom_detectron2.layers import ShapeSpec +from custom_detectron2.utils.registry import Registry + +from .backbone import Backbone + +BACKBONE_REGISTRY = Registry("BACKBONE") +BACKBONE_REGISTRY.__doc__ = """ +Registry for backbones, which extract feature maps from images + +The registered object must be a callable that accepts two arguments: + +1. A :class:`detectron2.config.CfgNode` +2. A :class:`detectron2.layers.ShapeSpec`, which contains the input shape specification. + +Registered object must return instance of :class:`Backbone`. +""" + + +def build_backbone(cfg, input_shape=None): + """ + Build a backbone from `cfg.MODEL.BACKBONE.NAME`. + + Returns: + an instance of :class:`Backbone` + """ + if input_shape is None: + input_shape = ShapeSpec(channels=len(cfg.MODEL.PIXEL_MEAN)) + + backbone_name = cfg.MODEL.BACKBONE.NAME + backbone = BACKBONE_REGISTRY.get(backbone_name)(cfg, input_shape) + assert isinstance(backbone, Backbone) + return backbone diff --git a/custom_detectron2/modeling/backbone/fpn.py b/custom_detectron2/modeling/backbone/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..476241a28c1876adbdcb0434d3ab7a6060dd4c25 --- /dev/null +++ b/custom_detectron2/modeling/backbone/fpn.py @@ -0,0 +1,268 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + +from custom_detectron2.layers import Conv2d, ShapeSpec, get_norm + +from .backbone import Backbone +from .build import BACKBONE_REGISTRY +from .resnet import build_resnet_backbone + +__all__ = ["build_resnet_fpn_backbone", "build_retinanet_resnet_fpn_backbone", "FPN"] + + +class FPN(Backbone): + """ + This module implements :paper:`FPN`. + It creates pyramid features built on top of some input feature maps. + """ + + _fuse_type: torch.jit.Final[str] + + def __init__( + self, + bottom_up, + in_features, + out_channels, + norm="", + top_block=None, + fuse_type="sum", + square_pad=0, + ): + """ + Args: + bottom_up (Backbone): module representing the bottom up subnetwork. + Must be a subclass of :class:`Backbone`. The multi-scale feature + maps generated by the bottom up network, and listed in `in_features`, + are used to generate FPN levels. + in_features (list[str]): names of the input feature maps coming + from the backbone to which FPN is attached. For example, if the + backbone produces ["res2", "res3", "res4"], any *contiguous* sublist + of these may be used; order must be from high to low resolution. + out_channels (int): number of channels in the output feature maps. + norm (str): the normalization to use. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + FPN output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra FPN levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + fuse_type (str): types for fusing the top down features and the lateral + ones. It can be "sum" (default), which sums up element-wise; or "avg", + which takes the element-wise mean of the two. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(FPN, self).__init__() + assert isinstance(bottom_up, Backbone) + assert in_features, in_features + + # Feature map strides and channels from the bottom up network (e.g. ResNet) + input_shapes = bottom_up.output_shape() + strides = [input_shapes[f].stride for f in in_features] + in_channels_per_feature = [input_shapes[f].channels for f in in_features] + + _assert_strides_are_log2_contiguous(strides) + lateral_convs = [] + output_convs = [] + + use_bias = norm == "" + for idx, in_channels in enumerate(in_channels_per_feature): + lateral_norm = get_norm(norm, out_channels) + output_norm = get_norm(norm, out_channels) + + lateral_conv = Conv2d( + in_channels, out_channels, kernel_size=1, bias=use_bias, norm=lateral_norm + ) + output_conv = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=use_bias, + norm=output_norm, + ) + weight_init.c2_xavier_fill(lateral_conv) + weight_init.c2_xavier_fill(output_conv) + stage = int(math.log2(strides[idx])) + self.add_module("fpn_lateral{}".format(stage), lateral_conv) + self.add_module("fpn_output{}".format(stage), output_conv) + + lateral_convs.append(lateral_conv) + output_convs.append(output_conv) + # Place convs into top-down order (from low to high resolution) + # to make the top-down computation in forward clearer. + self.lateral_convs = lateral_convs[::-1] + self.output_convs = output_convs[::-1] + self.top_block = top_block + self.in_features = tuple(in_features) + self.bottom_up = bottom_up + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + assert fuse_type in {"avg", "sum"} + self._fuse_type = fuse_type + + @property + def size_divisibility(self): + return self._size_divisibility + + @property + def padding_constraints(self): + return {"square_size": self._square_pad} + + def forward(self, x): + """ + Args: + input (dict[str->Tensor]): mapping feature map name (e.g., "res5") to + feature map tensor for each feature level in high to low resolution order. + + Returns: + dict[str->Tensor]: + mapping from feature map name to FPN feature map tensor + in high to low resolution order. Returned feature names follow the FPN + paper convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + bottom_up_features = self.bottom_up(x) + results = [] + prev_features = self.lateral_convs[0](bottom_up_features[self.in_features[-1]]) + results.append(self.output_convs[0](prev_features)) + + # Reverse feature maps into top-down order (from low to high resolution) + for idx, (lateral_conv, output_conv) in enumerate( + zip(self.lateral_convs, self.output_convs) + ): + # Slicing of ModuleList is not supported https://github.com/pytorch/pytorch/issues/47336 + # Therefore we loop over all modules but skip the first one + if idx > 0: + features = self.in_features[-idx - 1] + features = bottom_up_features[features] + top_down_features = F.interpolate(prev_features, scale_factor=2.0, mode="nearest") + lateral_features = lateral_conv(features) + prev_features = lateral_features + top_down_features + if self._fuse_type == "avg": + prev_features /= 2 + results.insert(0, output_conv(prev_features)) + + if self.top_block is not None: + if self.top_block.in_feature in bottom_up_features: + top_block_in_feature = bottom_up_features[self.top_block.in_feature] + else: + top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] + results.extend(self.top_block(top_block_in_feature)) + assert len(self._out_features) == len(results) + return {f: res for f, res in zip(self._out_features, results)} + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + +def _assert_strides_are_log2_contiguous(strides): + """ + Assert that each stride is 2x times its preceding stride, i.e. "contiguous in log2". + """ + for i, stride in enumerate(strides[1:], 1): + assert stride == 2 * strides[i - 1], "Strides {} {} are not log2 contiguous".format( + stride, strides[i - 1] + ) + + +class LastLevelMaxPool(nn.Module): + """ + This module is used in the original FPN to generate a downsampled + P6 feature from P5. + """ + + def __init__(self): + super().__init__() + self.num_levels = 1 + self.in_feature = "p5" + + def forward(self, x): + return [F.max_pool2d(x, kernel_size=1, stride=2, padding=0)] + + +class LastLevelP6P7(nn.Module): + """ + This module is used in RetinaNet to generate extra layers, P6 and P7 from + C5 feature. + """ + + def __init__(self, in_channels, out_channels, in_feature="res5"): + super().__init__() + self.num_levels = 2 + self.in_feature = in_feature + self.p6 = nn.Conv2d(in_channels, out_channels, 3, 2, 1) + self.p7 = nn.Conv2d(out_channels, out_channels, 3, 2, 1) + for module in [self.p6, self.p7]: + weight_init.c2_xavier_fill(module) + + def forward(self, c5): + p6 = self.p6(c5) + p7 = self.p7(F.relu(p6)) + return [p6, p7] + + +@BACKBONE_REGISTRY.register() +def build_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelMaxPool(), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone + + +@BACKBONE_REGISTRY.register() +def build_retinanet_resnet_fpn_backbone(cfg, input_shape: ShapeSpec): + """ + Args: + cfg: a detectron2 CfgNode + + Returns: + backbone (Backbone): backbone module, must be a subclass of :class:`Backbone`. + """ + bottom_up = build_resnet_backbone(cfg, input_shape) + in_features = cfg.MODEL.FPN.IN_FEATURES + out_channels = cfg.MODEL.FPN.OUT_CHANNELS + in_channels_p6p7 = bottom_up.output_shape()["res5"].channels + backbone = FPN( + bottom_up=bottom_up, + in_features=in_features, + out_channels=out_channels, + norm=cfg.MODEL.FPN.NORM, + top_block=LastLevelP6P7(in_channels_p6p7, out_channels), + fuse_type=cfg.MODEL.FPN.FUSE_TYPE, + ) + return backbone diff --git a/custom_detectron2/modeling/backbone/mvit.py b/custom_detectron2/modeling/backbone/mvit.py new file mode 100644 index 0000000000000000000000000000000000000000..a12e00ff1819a09f90a02715f6c56aa35cbd6d63 --- /dev/null +++ b/custom_detectron2/modeling/backbone/mvit.py @@ -0,0 +1,448 @@ +import logging +import numpy as np +import torch +import torch.nn as nn + +from .backbone import Backbone +from .utils import ( + PatchEmbed, + add_decomposed_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) + +logger = logging.getLogger(__name__) + + +__all__ = ["MViT"] + + +def attention_pool(x, pool, norm=None): + # (B, H, W, C) -> (B, C, H, W) + x = x.permute(0, 3, 1, 2) + x = pool(x) + # (B, C, H1, W1) -> (B, H1, W1, C) + x = x.permute(0, 2, 3, 1) + if norm: + x = norm(x) + + return x + + +class MultiScaleAttention(nn.Module): + """Multiscale Multi-head Attention block.""" + + def __init__( + self, + dim, + dim_out, + num_heads, + qkv_bias=True, + norm_layer=nn.LayerNorm, + pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + residual_pooling=True, + window_size=0, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + norm_layer (nn.Module): Normalization layer. + pool_kernel (tuple): kernel size for qkv pooling layers. + stride_q (int): stride size for q pooling layer. + stride_kv (int): stride size for kv pooling layer. + residual_pooling (bool): If true, enable residual pooling. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim_out // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim_out * 3, bias=qkv_bias) + self.proj = nn.Linear(dim_out, dim_out) + + # qkv pooling + pool_padding = [k // 2 for k in pool_kernel] + dim_conv = dim_out // num_heads + self.pool_q = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_q, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_q = norm_layer(dim_conv) + self.pool_k = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_kv, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_k = norm_layer(dim_conv) + self.pool_v = nn.Conv2d( + dim_conv, + dim_conv, + pool_kernel, + stride=stride_kv, + padding=pool_padding, + groups=dim_conv, + bias=False, + ) + self.norm_v = norm_layer(dim_conv) + + self.window_size = window_size + if window_size: + self.q_win_size = window_size // stride_q + self.kv_win_size = window_size // stride_kv + self.residual_pooling = residual_pooling + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + assert input_size[0] == input_size[1] + size = input_size[0] + rel_dim = 2 * max(size // stride_q, size // stride_kv) - 1 + self.rel_pos_h = nn.Parameter(torch.zeros(rel_dim, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(rel_dim, head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H, W, C) + qkv = self.qkv(x).reshape(B, H, W, 3, self.num_heads, -1).permute(3, 0, 4, 1, 2, 5) + # q, k, v with shape (B * nHead, H, W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H, W, -1).unbind(0) + + q = attention_pool(q, self.pool_q, self.norm_q) + k = attention_pool(k, self.pool_k, self.norm_k) + v = attention_pool(v, self.pool_v, self.norm_v) + + ori_q = q + if self.window_size: + q, q_hw_pad = window_partition(q, self.q_win_size) + k, kv_hw_pad = window_partition(k, self.kv_win_size) + v, _ = window_partition(v, self.kv_win_size) + q_hw = (self.q_win_size, self.q_win_size) + kv_hw = (self.kv_win_size, self.kv_win_size) + else: + q_hw = q.shape[1:3] + kv_hw = k.shape[1:3] + + q = q.view(q.shape[0], np.prod(q_hw), -1) + k = k.view(k.shape[0], np.prod(kv_hw), -1) + v = v.view(v.shape[0], np.prod(kv_hw), -1) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, q_hw, kv_hw) + + attn = attn.softmax(dim=-1) + x = attn @ v + + x = x.view(x.shape[0], q_hw[0], q_hw[1], -1) + + if self.window_size: + x = window_unpartition(x, self.q_win_size, q_hw_pad, ori_q.shape[1:3]) + + if self.residual_pooling: + x += ori_q + + H, W = x.shape[1], x.shape[2] + x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class MultiScaleBlock(nn.Module): + """Multiscale Transformer blocks""" + + def __init__( + self, + dim, + dim_out, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + qkv_pool_kernel=(3, 3), + stride_q=1, + stride_kv=1, + residual_pooling=True, + window_size=0, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + dim_out (int): Number of output channels. + num_heads (int): Number of attention heads in the MViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + stride_q (int): stride size for q pooling layer. + stride_kv (int): stride size for kv pooling layer. + residual_pooling (bool): If true, enable residual pooling. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = MultiScaleAttention( + dim, + dim_out, + num_heads=num_heads, + qkv_bias=qkv_bias, + norm_layer=norm_layer, + pool_kernel=qkv_pool_kernel, + stride_q=stride_q, + stride_kv=stride_kv, + residual_pooling=residual_pooling, + window_size=window_size, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size, + ) + + from custom_timm.models.layers import DropPath, Mlp + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim_out) + self.mlp = Mlp( + in_features=dim_out, + hidden_features=int(dim_out * mlp_ratio), + out_features=dim_out, + act_layer=act_layer, + ) + + if dim != dim_out: + self.proj = nn.Linear(dim, dim_out) + + if stride_q > 1: + kernel_skip = stride_q + 1 + padding_skip = int(kernel_skip // 2) + self.pool_skip = nn.MaxPool2d(kernel_skip, stride_q, padding_skip, ceil_mode=False) + + def forward(self, x): + x_norm = self.norm1(x) + x_block = self.attn(x_norm) + + if hasattr(self, "proj"): + x = self.proj(x_norm) + if hasattr(self, "pool_skip"): + x = attention_pool(x, self.pool_skip) + + x = x + self.drop_path(x_block) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class MViT(Backbone): + """ + This module implements Multiscale Vision Transformer (MViT) backbone in :paper:'mvitv2'. + """ + + def __init__( + self, + img_size=224, + patch_kernel=(7, 7), + patch_stride=(4, 4), + patch_padding=(3, 3), + in_chans=3, + embed_dim=96, + depth=16, + num_heads=1, + last_block_indexes=(0, 2, 11, 15), + qkv_pool_kernel=(3, 3), + adaptive_kv_stride=4, + adaptive_window_size=56, + residual_pooling=True, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=False, + use_rel_pos=True, + rel_pos_zero_init=True, + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_features=("scale2", "scale3", "scale4", "scale5"), + ): + """ + Args: + img_size (int): Input image size. + patch_kernel (tuple): kernel size for patch embedding. + patch_stride (tuple): stride size for patch embedding. + patch_padding (tuple): padding size for patch embedding. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of MViT. + num_heads (int): Number of base attention heads in each MViT block. + last_block_indexes (tuple): Block indexes for last blocks in each stage. + qkv_pool_kernel (tuple): kernel size for qkv pooling layers. + adaptive_kv_stride (int): adaptive stride size for kv pooling. + adaptive_window_size (int): adaptive window size for window attention blocks. + residual_pooling (bool): If true, enable residual pooling. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative postional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + out_features (tuple): name of the feature maps from each stage. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_embed = PatchEmbed( + kernel_size=patch_kernel, + stride=patch_stride, + padding=patch_padding, + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absoluate positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_stride[0]) * ( + pretrain_img_size // patch_stride[1] + ) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + dim_out = embed_dim + stride_kv = adaptive_kv_stride + window_size = adaptive_window_size + input_size = (img_size // patch_stride[0], img_size // patch_stride[1]) + stage = 2 + stride = patch_stride[0] + self._out_feature_strides = {} + self._out_feature_channels = {} + self.blocks = nn.ModuleList() + for i in range(depth): + # Multiply stride_kv by 2 if it's the last block of stage2 and stage3. + if i == last_block_indexes[1] or i == last_block_indexes[2]: + stride_kv_ = stride_kv * 2 + else: + stride_kv_ = stride_kv + # hybrid window attention: global attention in last three stages. + window_size_ = 0 if i in last_block_indexes[1:] else window_size + block = MultiScaleBlock( + dim=embed_dim, + dim_out=dim_out, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + qkv_pool_kernel=qkv_pool_kernel, + stride_q=2 if i - 1 in last_block_indexes else 1, + stride_kv=stride_kv_, + residual_pooling=residual_pooling, + window_size=window_size_, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size, + ) + if use_act_checkpoint: + # TODO: use torch.utils.checkpoint + from fairscale.nn.checkpoint import checkpoint_wrapper + + block = checkpoint_wrapper(block) + self.blocks.append(block) + + embed_dim = dim_out + if i in last_block_indexes: + name = f"scale{stage}" + if name in out_features: + self._out_feature_channels[name] = dim_out + self._out_feature_strides[name] = stride + self.add_module(f"{name}_norm", norm_layer(dim_out)) + + dim_out *= 2 + num_heads *= 2 + stride_kv = max(stride_kv // 2, 1) + stride *= 2 + stage += 1 + if i - 1 in last_block_indexes: + window_size = window_size // 2 + input_size = [s // 2 for s in input_size] + + self._out_features = out_features + self._last_block_indexes = last_block_indexes + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + + if self.pos_embed is not None: + x = x + get_abs_pos(self.pos_embed, self.pretrain_use_cls_token, x.shape[1:3]) + + outputs = {} + stage = 2 + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in self._last_block_indexes: + name = f"scale{stage}" + if name in self._out_features: + x_out = getattr(self, f"{name}_norm")(x) + outputs[name] = x_out.permute(0, 3, 1, 2) + stage += 1 + + return outputs diff --git a/custom_detectron2/modeling/backbone/regnet.py b/custom_detectron2/modeling/backbone/regnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7283fa12d365349eda70efc8994e396af4dc6389 --- /dev/null +++ b/custom_detectron2/modeling/backbone/regnet.py @@ -0,0 +1,452 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Implementation of RegNet models from :paper:`dds` and :paper:`scaling`. + +This code is adapted from https://github.com/facebookresearch/pycls with minimal modifications. +Some code duplication exists between RegNet and ResNets (e.g., ResStem) in order to simplify +model loading. +""" + +import numpy as np +from torch import nn + +from custom_detectron2.layers import CNNBlockBase, ShapeSpec, get_norm + +from .backbone import Backbone + +__all__ = [ + "AnyNet", + "RegNet", + "ResStem", + "SimpleStem", + "VanillaBlock", + "ResBasicBlock", + "ResBottleneckBlock", +] + + +def conv2d(w_in, w_out, k, *, stride=1, groups=1, bias=False): + """Helper for building a conv2d layer.""" + assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." + s, p, g, b = stride, (k - 1) // 2, groups, bias + return nn.Conv2d(w_in, w_out, k, stride=s, padding=p, groups=g, bias=b) + + +def gap2d(): + """Helper for building a global average pooling layer.""" + return nn.AdaptiveAvgPool2d((1, 1)) + + +def pool2d(k, *, stride=1): + """Helper for building a pool2d layer.""" + assert k % 2 == 1, "Only odd size kernels supported to avoid padding issues." + return nn.MaxPool2d(k, stride=stride, padding=(k - 1) // 2) + + +def init_weights(m): + """Performs ResNet-style weight initialization.""" + if isinstance(m, nn.Conv2d): + # Note that there is no bias due to BN + fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels + m.weight.data.normal_(mean=0.0, std=np.sqrt(2.0 / fan_out)) + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(mean=0.0, std=0.01) + m.bias.data.zero_() + + +class ResStem(CNNBlockBase): + """ResNet stem for ImageNet: 7x7, BN, AF, MaxPool.""" + + def __init__(self, w_in, w_out, norm, activation_class): + super().__init__(w_in, w_out, 4) + self.conv = conv2d(w_in, w_out, 7, stride=2) + self.bn = get_norm(norm, w_out) + self.af = activation_class() + self.pool = pool2d(3, stride=2) + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class SimpleStem(CNNBlockBase): + """Simple stem for ImageNet: 3x3, BN, AF.""" + + def __init__(self, w_in, w_out, norm, activation_class): + super().__init__(w_in, w_out, 2) + self.conv = conv2d(w_in, w_out, 3, stride=2) + self.bn = get_norm(norm, w_out) + self.af = activation_class() + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class SE(nn.Module): + """Squeeze-and-Excitation (SE) block: AvgPool, FC, Act, FC, Sigmoid.""" + + def __init__(self, w_in, w_se, activation_class): + super().__init__() + self.avg_pool = gap2d() + self.f_ex = nn.Sequential( + conv2d(w_in, w_se, 1, bias=True), + activation_class(), + conv2d(w_se, w_in, 1, bias=True), + nn.Sigmoid(), + ) + + def forward(self, x): + return x * self.f_ex(self.avg_pool(x)) + + +class VanillaBlock(CNNBlockBase): + """Vanilla block: [3x3 conv, BN, Relu] x2.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, _params): + super().__init__(w_in, w_out, stride) + self.a = conv2d(w_in, w_out, 3, stride=stride) + self.a_bn = get_norm(norm, w_out) + self.a_af = activation_class() + self.b = conv2d(w_out, w_out, 3) + self.b_bn = get_norm(norm, w_out) + self.b_af = activation_class() + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class BasicTransform(nn.Module): + """Basic transformation: [3x3 conv, BN, Relu] x2.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, _params): + super().__init__() + self.a = conv2d(w_in, w_out, 3, stride=stride) + self.a_bn = get_norm(norm, w_out) + self.a_af = activation_class() + self.b = conv2d(w_out, w_out, 3) + self.b_bn = get_norm(norm, w_out) + self.b_bn.final_bn = True + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class ResBasicBlock(CNNBlockBase): + """Residual basic block: x + f(x), f = basic transform.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__(w_in, w_out, stride) + self.proj, self.bn = None, None + if (w_in != w_out) or (stride != 1): + self.proj = conv2d(w_in, w_out, 1, stride=stride) + self.bn = get_norm(norm, w_out) + self.f = BasicTransform(w_in, w_out, stride, norm, activation_class, params) + self.af = activation_class() + + def forward(self, x): + x_p = self.bn(self.proj(x)) if self.proj else x + return self.af(x_p + self.f(x)) + + +class BottleneckTransform(nn.Module): + """Bottleneck transformation: 1x1, 3x3 [+SE], 1x1.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__() + w_b = int(round(w_out * params["bot_mul"])) + w_se = int(round(w_in * params["se_r"])) + groups = w_b // params["group_w"] + self.a = conv2d(w_in, w_b, 1) + self.a_bn = get_norm(norm, w_b) + self.a_af = activation_class() + self.b = conv2d(w_b, w_b, 3, stride=stride, groups=groups) + self.b_bn = get_norm(norm, w_b) + self.b_af = activation_class() + self.se = SE(w_b, w_se, activation_class) if w_se else None + self.c = conv2d(w_b, w_out, 1) + self.c_bn = get_norm(norm, w_out) + self.c_bn.final_bn = True + + def forward(self, x): + for layer in self.children(): + x = layer(x) + return x + + +class ResBottleneckBlock(CNNBlockBase): + """Residual bottleneck block: x + f(x), f = bottleneck transform.""" + + def __init__(self, w_in, w_out, stride, norm, activation_class, params): + super().__init__(w_in, w_out, stride) + self.proj, self.bn = None, None + if (w_in != w_out) or (stride != 1): + self.proj = conv2d(w_in, w_out, 1, stride=stride) + self.bn = get_norm(norm, w_out) + self.f = BottleneckTransform(w_in, w_out, stride, norm, activation_class, params) + self.af = activation_class() + + def forward(self, x): + x_p = self.bn(self.proj(x)) if self.proj else x + return self.af(x_p + self.f(x)) + + +class AnyStage(nn.Module): + """AnyNet stage (sequence of blocks w/ the same output shape).""" + + def __init__(self, w_in, w_out, stride, d, block_class, norm, activation_class, params): + super().__init__() + for i in range(d): + block = block_class(w_in, w_out, stride, norm, activation_class, params) + self.add_module("b{}".format(i + 1), block) + stride, w_in = 1, w_out + + def forward(self, x): + for block in self.children(): + x = block(x) + return x + + +class AnyNet(Backbone): + """AnyNet model. See :paper:`dds`.""" + + def __init__( + self, + *, + stem_class, + stem_width, + block_class, + depths, + widths, + group_widths, + strides, + bottleneck_ratios, + se_ratio, + activation_class, + freeze_at=0, + norm="BN", + out_features=None, + ): + """ + Args: + stem_class (callable): A callable taking 4 arguments (channels in, channels out, + normalization, callable returning an activation function) that returns another + callable implementing the stem module. + stem_width (int): The number of output channels that the stem produces. + block_class (callable): A callable taking 6 arguments (channels in, channels out, + stride, normalization, callable returning an activation function, a dict of + block-specific parameters) that returns another callable implementing the repeated + block module. + depths (list[int]): Number of blocks in each stage. + widths (list[int]): For each stage, the number of output channels of each block. + group_widths (list[int]): For each stage, the number of channels per group in group + convolution, if the block uses group convolution. + strides (list[int]): The stride that each network stage applies to its input. + bottleneck_ratios (list[float]): For each stage, the ratio of the number of bottleneck + channels to the number of block input channels (or, equivalently, output channels), + if the block uses a bottleneck. + se_ratio (float): The ratio of the number of channels used inside the squeeze-excitation + (SE) module to it number of input channels, if SE the block uses SE. + activation_class (callable): A callable taking no arguments that returns another + callable implementing an activation function. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. RegNet's use "stem" and "s1", "s2", etc for the stages after + the stem. If None, will return the output of the last layer. + """ + super().__init__() + self.stem = stem_class(3, stem_width, norm, activation_class) + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + self.stages_and_names = [] + prev_w = stem_width + + for i, (d, w, s, b, g) in enumerate( + zip(depths, widths, strides, bottleneck_ratios, group_widths) + ): + params = {"bot_mul": b, "group_w": g, "se_r": se_ratio} + stage = AnyStage(prev_w, w, s, d, block_class, norm, activation_class, params) + name = "s{}".format(i + 1) + self.add_module(name, stage) + self.stages_and_names.append((stage, name)) + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in stage.children()]) + ) + self._out_feature_channels[name] = list(stage.children())[-1].out_channels + prev_w = w + + self.apply(init_weights) + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {} does not include {}".format( + ", ".join(children), out_feature + ) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"Model takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for stage, name in self.stages_and_names: + x = stage(x) + if name in self._out_features: + outputs[name] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the model. Commonly used in fine-tuning. + + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + + Returns: + nn.Module: this model itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, (stage, _) in enumerate(self.stages_and_names, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + +def adjust_block_compatibility(ws, bs, gs): + """Adjusts the compatibility of widths, bottlenecks, and groups.""" + assert len(ws) == len(bs) == len(gs) + assert all(w > 0 and b > 0 and g > 0 for w, b, g in zip(ws, bs, gs)) + vs = [int(max(1, w * b)) for w, b in zip(ws, bs)] + gs = [int(min(g, v)) for g, v in zip(gs, vs)] + ms = [np.lcm(g, b) if b > 1 else g for g, b in zip(gs, bs)] + vs = [max(m, int(round(v / m) * m)) for v, m in zip(vs, ms)] + ws = [int(v / b) for v, b in zip(vs, bs)] + assert all(w * b % g == 0 for w, b, g in zip(ws, bs, gs)) + return ws, bs, gs + + +def generate_regnet_parameters(w_a, w_0, w_m, d, q=8): + """Generates per stage widths and depths from RegNet parameters.""" + assert w_a >= 0 and w_0 > 0 and w_m > 1 and w_0 % q == 0 + # Generate continuous per-block ws + ws_cont = np.arange(d) * w_a + w_0 + # Generate quantized per-block ws + ks = np.round(np.log(ws_cont / w_0) / np.log(w_m)) + ws_all = w_0 * np.power(w_m, ks) + ws_all = np.round(np.divide(ws_all, q)).astype(int) * q + # Generate per stage ws and ds (assumes ws_all are sorted) + ws, ds = np.unique(ws_all, return_counts=True) + # Compute number of actual stages and total possible stages + num_stages, total_stages = len(ws), ks.max() + 1 + # Convert numpy arrays to lists and return + ws, ds, ws_all, ws_cont = (x.tolist() for x in (ws, ds, ws_all, ws_cont)) + return ws, ds, num_stages, total_stages, ws_all, ws_cont + + +class RegNet(AnyNet): + """RegNet model. See :paper:`dds`.""" + + def __init__( + self, + *, + stem_class, + stem_width, + block_class, + depth, + w_a, + w_0, + w_m, + group_width, + stride=2, + bottleneck_ratio=1.0, + se_ratio=0.0, + activation_class=None, + freeze_at=0, + norm="BN", + out_features=None, + ): + """ + Build a RegNet from the parameterization described in :paper:`dds` Section 3.3. + + Args: + See :class:`AnyNet` for arguments that are not listed here. + depth (int): Total number of blocks in the RegNet. + w_a (float): Factor by which block width would increase prior to quantizing block widths + by stage. See :paper:`dds` Section 3.3. + w_0 (int): Initial block width. See :paper:`dds` Section 3.3. + w_m (float): Parameter controlling block width quantization. + See :paper:`dds` Section 3.3. + group_width (int): Number of channels per group in group convolution, if the block uses + group convolution. + bottleneck_ratio (float): The ratio of the number of bottleneck channels to the number + of block input channels (or, equivalently, output channels), if the block uses a + bottleneck. + stride (int): The stride that each network stage applies to its input. + """ + ws, ds = generate_regnet_parameters(w_a, w_0, w_m, depth)[0:2] + ss = [stride for _ in ws] + bs = [bottleneck_ratio for _ in ws] + gs = [group_width for _ in ws] + ws, bs, gs = adjust_block_compatibility(ws, bs, gs) + + def default_activation_class(): + return nn.ReLU(inplace=True) + + super().__init__( + stem_class=stem_class, + stem_width=stem_width, + block_class=block_class, + depths=ds, + widths=ws, + strides=ss, + group_widths=gs, + bottleneck_ratios=bs, + se_ratio=se_ratio, + activation_class=default_activation_class + if activation_class is None + else activation_class, + freeze_at=freeze_at, + norm=norm, + out_features=out_features, + ) diff --git a/custom_detectron2/modeling/backbone/resnet.py b/custom_detectron2/modeling/backbone/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c373db499750c537ff884a7d2d3d834e57297050 --- /dev/null +++ b/custom_detectron2/modeling/backbone/resnet.py @@ -0,0 +1,694 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn.functional as F +from torch import nn + +from custom_detectron2.layers import ( + CNNBlockBase, + Conv2d, + DeformConv, + ModulatedDeformConv, + ShapeSpec, + get_norm, +) + +from .backbone import Backbone +from .build import BACKBONE_REGISTRY + +__all__ = [ + "ResNetBlockBase", + "BasicBlock", + "BottleneckBlock", + "DeformBottleneckBlock", + "BasicStem", + "ResNet", + "make_stage", + "build_resnet_backbone", +] + + +class BasicBlock(CNNBlockBase): + """ + The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`, + with two 3x3 conv layers and a projection shortcut if needed. + """ + + def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + stride (int): Stride for the first conv. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + self.conv2 = Conv2d( + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + out = self.conv2(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block used by ResNet-50, 101 and 152 + defined in :paper:`ResNet`. It contains 3 conv layers with kernels + 1x1, 3x3, 1x1, and a projection shortcut if needed. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + ): + """ + Args: + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + num_groups (int): number of groups for the 3x3 conv layer. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + stride_in_1x1 (bool): when stride>1, whether to put stride in the + first 1x1 convolution or the bottleneck 3x3 convolution. + dilation (int): the dilation rate of the 3x3 conv layer. + """ + super().__init__(in_channels, out_channels, stride) + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + # The original MSRA ResNet models have stride in the first 1x1 conv + # The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have + # stride in the 3x3 conv + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + # Zero-initialize the last normalization in each residual branch, + # so that at the beginning, the residual branch starts with zeros, + # and each residual block behaves like an identity. + # See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "For BN layers, the learnable scaling coefficient γ is initialized + # to be 1, except for each residual block's last BN + # where γ is initialized to be 0." + + # nn.init.constant_(self.conv3.norm.weight, 0) + # TODO this somehow hurts performance when training GN models from scratch. + # Add it as an option when we need to use this code to train a backbone. + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + out = self.conv2(out) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class DeformBottleneckBlock(CNNBlockBase): + """ + Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv ` + in the 3x3 convolution. + """ + + def __init__( + self, + in_channels, + out_channels, + *, + bottleneck_channels, + stride=1, + num_groups=1, + norm="BN", + stride_in_1x1=False, + dilation=1, + deform_modulated=False, + deform_num_groups=1, + ): + super().__init__(in_channels, out_channels, stride) + self.deform_modulated = deform_modulated + + if in_channels != out_channels: + self.shortcut = Conv2d( + in_channels, + out_channels, + kernel_size=1, + stride=stride, + bias=False, + norm=get_norm(norm, out_channels), + ) + else: + self.shortcut = None + + stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride) + + self.conv1 = Conv2d( + in_channels, + bottleneck_channels, + kernel_size=1, + stride=stride_1x1, + bias=False, + norm=get_norm(norm, bottleneck_channels), + ) + + if deform_modulated: + deform_conv_op = ModulatedDeformConv + # offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size + offset_channels = 27 + else: + deform_conv_op = DeformConv + offset_channels = 18 + + self.conv2_offset = Conv2d( + bottleneck_channels, + offset_channels * deform_num_groups, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + dilation=dilation, + ) + self.conv2 = deform_conv_op( + bottleneck_channels, + bottleneck_channels, + kernel_size=3, + stride=stride_3x3, + padding=1 * dilation, + bias=False, + groups=num_groups, + dilation=dilation, + deformable_groups=deform_num_groups, + norm=get_norm(norm, bottleneck_channels), + ) + + self.conv3 = Conv2d( + bottleneck_channels, + out_channels, + kernel_size=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + + for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]: + if layer is not None: # shortcut can be None + weight_init.c2_msra_fill(layer) + + nn.init.constant_(self.conv2_offset.weight, 0) + nn.init.constant_(self.conv2_offset.bias, 0) + + def forward(self, x): + out = self.conv1(x) + out = F.relu_(out) + + if self.deform_modulated: + offset_mask = self.conv2_offset(out) + offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((offset_x, offset_y), dim=1) + mask = mask.sigmoid() + out = self.conv2(out, offset, mask) + else: + offset = self.conv2_offset(out) + out = self.conv2(out, offset) + out = F.relu_(out) + + out = self.conv3(out) + + if self.shortcut is not None: + shortcut = self.shortcut(x) + else: + shortcut = x + + out += shortcut + out = F.relu_(out) + return out + + +class BasicStem(CNNBlockBase): + """ + The standard ResNet stem (layers before the first residual block), + with a conv, relu and max_pool. + """ + + def __init__(self, in_channels=3, out_channels=64, norm="BN"): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = Conv2d( + in_channels, + out_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +class ResNet(Backbone): + """ + Implement :paper:`ResNet`. + """ + + def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0): + """ + Args: + stem (nn.Module): a stem module + stages (list[list[CNNBlockBase]]): several (typically 4) stages, + each contains multiple :class:`CNNBlockBase`. + num_classes (None or int): if None, will not perform classification. + Otherwise, will create a linear layer. + out_features (list[str]): name of the layers whose outputs should + be returned in forward. Can be anything in "stem", "linear", or "res2" ... + If None, will return the output of the last layer. + freeze_at (int): The number of stages at the beginning to freeze. + see :meth:`freeze` for detailed explanation. + """ + super().__init__() + self.stem = stem + self.num_classes = num_classes + + current_stride = self.stem.stride + self._out_feature_strides = {"stem": current_stride} + self._out_feature_channels = {"stem": self.stem.out_channels} + + self.stage_names, self.stages = [], [] + + if out_features is not None: + # Avoid keeping unused layers in this module. They consume extra memory + # and may cause allreduce to fail + num_stages = max( + [{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features] + ) + stages = stages[:num_stages] + for i, blocks in enumerate(stages): + assert len(blocks) > 0, len(blocks) + for block in blocks: + assert isinstance(block, CNNBlockBase), block + + name = "res" + str(i + 2) + stage = nn.Sequential(*blocks) + + self.add_module(name, stage) + self.stage_names.append(name) + self.stages.append(stage) + + self._out_feature_strides[name] = current_stride = int( + current_stride * np.prod([k.stride for k in blocks]) + ) + self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels + self.stage_names = tuple(self.stage_names) # Make it static for scripting + + if num_classes is not None: + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.linear = nn.Linear(curr_channels, num_classes) + + # Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour": + # "The 1000-way fully-connected layer is initialized by + # drawing weights from a zero-mean Gaussian with standard deviation of 0.01." + nn.init.normal_(self.linear.weight, std=0.01) + name = "linear" + + if out_features is None: + out_features = [name] + self._out_features = out_features + assert len(self._out_features) + children = [x[0] for x in self.named_children()] + for out_feature in self._out_features: + assert out_feature in children, "Available children: {}".format(", ".join(children)) + self.freeze(freeze_at) + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: names and the corresponding features + """ + assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!" + outputs = {} + x = self.stem(x) + if "stem" in self._out_features: + outputs["stem"] = x + for name, stage in zip(self.stage_names, self.stages): + x = stage(x) + if name in self._out_features: + outputs[name] = x + if self.num_classes is not None: + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.linear(x) + if "linear" in self._out_features: + outputs["linear"] = x + return outputs + + def output_shape(self): + return { + name: ShapeSpec( + channels=self._out_feature_channels[name], stride=self._out_feature_strides[name] + ) + for name in self._out_features + } + + def freeze(self, freeze_at=0): + """ + Freeze the first several stages of the ResNet. Commonly used in + fine-tuning. + + Layers that produce the same feature map spatial size are defined as one + "stage" by :paper:`FPN`. + + Args: + freeze_at (int): number of stages to freeze. + `1` means freezing the stem. `2` means freezing the stem and + one residual stage, etc. + + Returns: + nn.Module: this ResNet itself + """ + if freeze_at >= 1: + self.stem.freeze() + for idx, stage in enumerate(self.stages, start=2): + if freeze_at >= idx: + for block in stage.children(): + block.freeze() + return self + + @staticmethod + def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs): + """ + Create a list of blocks of the same type that forms one ResNet stage. + + Args: + block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this + stage. A module of this type must not change spatial resolution of inputs unless its + stride != 1. + num_blocks (int): number of blocks in this stage + in_channels (int): input channels of the entire stage. + out_channels (int): output channels of **every block** in the stage. + kwargs: other arguments passed to the constructor of + `block_class`. If the argument name is "xx_per_block", the + argument is a list of values to be passed to each block in the + stage. Otherwise, the same argument is passed to every block + in the stage. + + Returns: + list[CNNBlockBase]: a list of block module. + + Examples: + :: + stage = ResNet.make_stage( + BottleneckBlock, 3, in_channels=16, out_channels=64, + bottleneck_channels=16, num_groups=1, + stride_per_block=[2, 1, 1], + dilations_per_block=[1, 1, 2] + ) + + Usually, layers that produce the same feature map spatial size are defined as one + "stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should + all be 1. + """ + blocks = [] + for i in range(num_blocks): + curr_kwargs = {} + for k, v in kwargs.items(): + if k.endswith("_per_block"): + assert len(v) == num_blocks, ( + f"Argument '{k}' of make_stage should have the " + f"same length as num_blocks={num_blocks}." + ) + newk = k[: -len("_per_block")] + assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!" + curr_kwargs[newk] = v[i] + else: + curr_kwargs[k] = v + + blocks.append( + block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs) + ) + in_channels = out_channels + return blocks + + @staticmethod + def make_default_stages(depth, block_class=None, **kwargs): + """ + Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152). + If it doesn't create the ResNet variant you need, please use :meth:`make_stage` + instead for fine-grained customization. + + Args: + depth (int): depth of ResNet + block_class (type): the CNN block class. Has to accept + `bottleneck_channels` argument for depth > 50. + By default it is BasicBlock or BottleneckBlock, based on the + depth. + kwargs: + other arguments to pass to `make_stage`. Should not contain + stride and channels, as they are predefined for each depth. + + Returns: + list[list[CNNBlockBase]]: modules in all stages; see arguments of + :class:`ResNet.__init__`. + """ + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + if block_class is None: + block_class = BasicBlock if depth < 50 else BottleneckBlock + if depth < 50: + in_channels = [64, 64, 128, 256] + out_channels = [64, 128, 256, 512] + else: + in_channels = [64, 256, 512, 1024] + out_channels = [256, 512, 1024, 2048] + ret = [] + for (n, s, i, o) in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels): + if depth >= 50: + kwargs["bottleneck_channels"] = o // 4 + ret.append( + ResNet.make_stage( + block_class=block_class, + num_blocks=n, + stride_per_block=[s] + [1] * (n - 1), + in_channels=i, + out_channels=o, + **kwargs, + ) + ) + return ret + + +ResNetBlockBase = CNNBlockBase +""" +Alias for backward compatibiltiy. +""" + + +def make_stage(*args, **kwargs): + """ + Deprecated alias for backward compatibiltiy. + """ + return ResNet.make_stage(*args, **kwargs) + + +@BACKBONE_REGISTRY.register() +def build_resnet_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + # fmt: on + assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation) + + num_blocks_per_stage = { + 18: [2, 2, 2, 2], + 34: [3, 4, 6, 3], + 50: [3, 4, 6, 3], + 101: [3, 4, 23, 3], + 152: [3, 8, 36, 3], + }[depth] + + if depth in [18, 34]: + assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34" + assert not any( + deform_on_per_stage + ), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34" + assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34" + assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34" + + stages = [] + + for idx, stage_idx in enumerate(range(2, 6)): + # res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper + dilation = res5_dilation if stage_idx == 5 else 1 + first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + # Use BasicBlock for R18 and R34. + if depth in [18, 34]: + stage_kargs["block_class"] = BasicBlock + else: + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at) diff --git a/custom_detectron2/modeling/backbone/swin.py b/custom_detectron2/modeling/backbone/swin.py new file mode 100644 index 0000000000000000000000000000000000000000..9794139c1aa84bd447f9fd0ac11448535bb858ee --- /dev/null +++ b/custom_detectron2/modeling/backbone/swin.py @@ -0,0 +1,695 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +""" +Implementation of Swin models from :paper:`swin`. + +This code is adapted from https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/master/mmdet/models/backbones/swin_transformer.py with minimal modifications. # noqa +-------------------------------------------------------- +Swin Transformer +Copyright (c) 2021 Microsoft +Licensed under The MIT License [see LICENSE for details] +Written by Ze Liu, Yutong Lin, Yixuan Wei +-------------------------------------------------------- +LICENSE: https://github.com/SwinTransformer/Swin-Transformer-Object-Detection/blob/461e003166a8083d0b620beacd4662a2df306bd6/LICENSE +""" + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint + +from custom_detectron2.modeling.backbone.backbone import Backbone + +_to_2tuple = nn.modules.utils._ntuple(2) + + +class Mlp(nn.Module): + """Multilayer perceptron.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. + Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + nn.init.trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1 + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + """Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=_to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + if drop_path > 0.0: + from custom_timm.models.layers import DropPath + + self.drop_path = DropPath(drop_path) + else: + self.drop_path = nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + + self.H = None + self.W = None + + def forward(self, x, mask_matrix): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + mask_matrix: Attention mask for cyclic shift. + """ + B, L, C = x.shape + H, W = self.H, self.W + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size - W % self.window_size) % self.window_size + pad_b = (self.window_size - H % self.window_size) % self.window_size + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + """Patch Merging Layer + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + x = x.view(B, H, W, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) + if pad_input: + x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. + Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + depth, + num_heads, + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + ): + super().__init__() + self.window_size = window_size + self.shift_size = window_size // 2 + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + """Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size)) * self.window_size + Wp = int(np.ceil(W / self.window_size)) * self.window_size + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x, H, W, x_down, Wh, Ww + else: + return x, H, W, x, H, W + + +class PatchEmbed(nn.Module): + """Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = _to_2tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + +class SwinTransformer(Backbone): + """Swin Transformer backbone. + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted + Windows` - https://arxiv.org/pdf/2103.14030 + Args: + pretrain_img_size (int): Input image size for training the pretrained model, + used in absolute postion embedding. Default 224. + patch_size (int | tuple(int)): Patch size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + depths (tuple[int]): Depths of each Swin Transformer stage. + num_heads (tuple[int]): Number of attention head of each stage. + window_size (int): Window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. + drop_rate (float): Dropout rate. + attn_drop_rate (float): Attention dropout rate. Default: 0. + drop_path_rate (float): Stochastic depth rate. Default: 0.2. + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False. + patch_norm (bool): If True, add normalization after patch embedding. Default: True. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=(2, 2, 6, 2), + num_heads=(3, 6, 12, 24), + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + ): + super().__init__() + + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + + # absolute position embedding + if self.ape: + pretrain_img_size = _to_2tuple(pretrain_img_size) + patch_size = _to_2tuple(patch_size) + patches_resolution = [ + pretrain_img_size[0] // patch_size[0], + pretrain_img_size[1] // patch_size[1], + ] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]) + ) + nn.init.trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + ) + self.layers.append(layer) + + num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f"norm{i_layer}" + self.add_module(layer_name, layer) + + self._freeze_stages() + self._out_features = ["p{}".format(i) for i in self.out_indices] + self._out_feature_channels = { + "p{}".format(i): self.embed_dim * 2**i for i in self.out_indices + } + self._out_feature_strides = {"p{}".format(i): 2 ** (i + 2) for i in self.out_indices} + self._size_devisibility = 32 + + self.apply(self._init_weights) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @property + def size_divisibility(self): + return self._size_divisibility + + def forward(self, x): + """Forward function.""" + x = self.patch_embed(x) + + Wh, Ww = x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, size=(Wh, Ww), mode="bicubic" + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = {} + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww) + + if i in self.out_indices: + norm_layer = getattr(self, f"norm{i}") + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous() + outs["p{}".format(i)] = out + + return outs diff --git a/custom_detectron2/modeling/backbone/utils.py b/custom_detectron2/modeling/backbone/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2b89a4c3fbe079a77fd0cef947cf9ada787fc55d --- /dev/null +++ b/custom_detectron2/modeling/backbone/utils.py @@ -0,0 +1,186 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved +import math +import torch +import torch.nn as nn +import torch.nn.functional as F + +__all__ = [ + "window_partition", + "window_unpartition", + "add_decomposed_rel_pos", + "get_abs_pos", + "PatchEmbed", +] + + +def window_partition(x, window_size): + """ + Partition into non-overlapping windows with padding if needed. + Args: + x (tensor): input tokens with [B, H, W, C]. + window_size (int): window size. + + Returns: + windows: windows after partition with [B * num_windows, window_size, window_size, C]. + (Hp, Wp): padded height and width before partition + """ + B, H, W, C = x.shape + + pad_h = (window_size - H % window_size) % window_size + pad_w = (window_size - W % window_size) % window_size + if pad_h > 0 or pad_w > 0: + x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) + Hp, Wp = H + pad_h, W + pad_w + + x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) + windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + return windows, (Hp, Wp) + + +def window_unpartition(windows, window_size, pad_hw, hw): + """ + Window unpartition into original sequences and removing padding. + Args: + x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. + window_size (int): window size. + pad_hw (Tuple): padded height and width (Hp, Wp). + hw (Tuple): original height and width (H, W) before padding. + + Returns: + x: unpartitioned sequences with [B, H, W, C]. + """ + Hp, Wp = pad_hw + H, W = hw + B = windows.shape[0] // (Hp * Wp // window_size // window_size) + x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) + + if Hp > H or Wp > W: + x = x[:, :H, :W, :].contiguous() + return x + + +def get_rel_pos(q_size, k_size, rel_pos): + """ + Get relative positional embeddings according to the relative positions of + query and key sizes. + Args: + q_size (int): size of query q. + k_size (int): size of key k. + rel_pos (Tensor): relative position embeddings (L, C). + + Returns: + Extracted positional embeddings according to relative positions. + """ + max_rel_dist = int(2 * max(q_size, k_size) - 1) + # Interpolate rel pos if needed. + if rel_pos.shape[0] != max_rel_dist: + # Interpolate rel pos. + rel_pos_resized = F.interpolate( + rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), + size=max_rel_dist, + mode="linear", + ) + rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) + else: + rel_pos_resized = rel_pos + + # Scale the coords with short length if shapes for q and k are different. + q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) + k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) + relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) + + return rel_pos_resized[relative_coords.long()] + + +def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): + """ + Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. + https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 + Args: + attn (Tensor): attention map. + q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). + rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. + rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. + q_size (Tuple): spatial sequence size of query q with (q_h, q_w). + k_size (Tuple): spatial sequence size of key k with (k_h, k_w). + + Returns: + attn (Tensor): attention map with added relative positional embeddings. + """ + q_h, q_w = q_size + k_h, k_w = k_size + Rh = get_rel_pos(q_h, k_h, rel_pos_h) + Rw = get_rel_pos(q_w, k_w, rel_pos_w) + + B, _, dim = q.shape + r_q = q.reshape(B, q_h, q_w, dim) + rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) + rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) + + attn = ( + attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] + ).view(B, q_h * q_w, k_h * k_w) + + return attn + + +def get_abs_pos(abs_pos, has_cls_token, hw): + """ + Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token + dimension for the original embeddings. + Args: + abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). + has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. + hw (Tuple): size of input image tokens. + + Returns: + Absolute positional embeddings after processing with shape (1, H, W, C) + """ + h, w = hw + if has_cls_token: + abs_pos = abs_pos[:, 1:] + xy_num = abs_pos.shape[1] + size = int(math.sqrt(xy_num)) + assert size * size == xy_num + + if size != h or size != w: + new_abs_pos = F.interpolate( + abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2), + size=(h, w), + mode="bicubic", + align_corners=False, + ) + + return new_abs_pos.permute(0, 2, 3, 1) + else: + return abs_pos.reshape(1, h, w, -1) + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding. + """ + + def __init__( + self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 + ): + """ + Args: + kernel_size (Tuple): kernel size of the projection layer. + stride (Tuple): stride of the projection layer. + padding (Tuple): padding size of the projection layer. + in_chans (int): Number of input image channels. + embed_dim (int): embed_dim (int): Patch embedding dimension. + """ + super().__init__() + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding + ) + + def forward(self, x): + x = self.proj(x) + # B C H W -> B H W C + x = x.permute(0, 2, 3, 1) + return x diff --git a/custom_detectron2/modeling/backbone/vit.py b/custom_detectron2/modeling/backbone/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..3505fc095a67f39e17c4bea1a4c52463912c3c13 --- /dev/null +++ b/custom_detectron2/modeling/backbone/vit.py @@ -0,0 +1,524 @@ +import logging +import math +import fvcore.nn.weight_init as weight_init +import torch +import torch.nn as nn + +from custom_detectron2.layers import CNNBlockBase, Conv2d, get_norm +from custom_detectron2.modeling.backbone.fpn import _assert_strides_are_log2_contiguous + +from .backbone import Backbone +from .utils import ( + PatchEmbed, + add_decomposed_rel_pos, + get_abs_pos, + window_partition, + window_unpartition, +) + +logger = logging.getLogger(__name__) + + +__all__ = ["ViT", "SimpleFeaturePyramid", "get_vit_lr_decay_rate"] + + +class Attention(nn.Module): + """Multi-head Attention block with relative position embeddings.""" + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=True, + use_rel_pos=False, + rel_pos_zero_init=True, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + qkv_bias (bool: If True, add a learnable bias to query, key, value. + rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.use_rel_pos = use_rel_pos + if self.use_rel_pos: + # initialize relative positional embeddings + self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, head_dim)) + self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, head_dim)) + + if not rel_pos_zero_init: + nn.init.trunc_normal_(self.rel_pos_h, std=0.02) + nn.init.trunc_normal_(self.rel_pos_w, std=0.02) + + def forward(self, x): + B, H, W, _ = x.shape + # qkv with shape (3, B, nHead, H * W, C) + qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + # q, k, v with shape (B * nHead, H * W, C) + q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0) + + attn = (q * self.scale) @ k.transpose(-2, -1) + + if self.use_rel_pos: + attn = add_decomposed_rel_pos(attn, q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W)) + + attn = attn.softmax(dim=-1) + x = (attn @ v).view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1) + x = self.proj(x) + + return x + + +class ResBottleneckBlock(CNNBlockBase): + """ + The standard bottleneck residual block without the last activation layer. + It contains 3 conv layers with kernels 1x1, 3x3, 1x1. + """ + + def __init__( + self, + in_channels, + out_channels, + bottleneck_channels, + norm="LN", + act_layer=nn.GELU, + ): + """ + Args: + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + bottleneck_channels (int): number of output channels for the 3x3 + "bottleneck" conv layers. + norm (str or callable): normalization for all conv layers. + See :func:`layers.get_norm` for supported format. + act_layer (callable): activation for all conv layers. + """ + super().__init__(in_channels, out_channels, 1) + + self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) + self.norm1 = get_norm(norm, bottleneck_channels) + self.act1 = act_layer() + + self.conv2 = Conv2d( + bottleneck_channels, + bottleneck_channels, + 3, + padding=1, + bias=False, + ) + self.norm2 = get_norm(norm, bottleneck_channels) + self.act2 = act_layer() + + self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) + self.norm3 = get_norm(norm, out_channels) + + for layer in [self.conv1, self.conv2, self.conv3]: + weight_init.c2_msra_fill(layer) + for layer in [self.norm1, self.norm2]: + layer.weight.data.fill_(1.0) + layer.bias.data.zero_() + # zero init last norm layer. + self.norm3.weight.data.zero_() + self.norm3.bias.data.zero_() + + def forward(self, x): + out = x + for layer in self.children(): + out = layer(out) + + out = x + out + return out + + +class Block(nn.Module): + """Transformer blocks with support of window attention and residual propagation blocks""" + + def __init__( + self, + dim, + num_heads, + mlp_ratio=4.0, + qkv_bias=True, + drop_path=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + use_residual_block=False, + input_size=None, + ): + """ + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. If it equals 0, then not + use window attention. + use_residual_block (bool): If True, use a residual block after the MLP block. + input_size (int or None): Input resolution for calculating the relative positional + parameter size. + """ + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + input_size=input_size if window_size == 0 else (window_size, window_size), + ) + + from custom_timm.models.layers import DropPath, Mlp + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer) + + self.window_size = window_size + + self.use_residual_block = use_residual_block + if use_residual_block: + # Use a residual block with bottleneck channel as dim // 2 + self.residual = ResBottleneckBlock( + in_channels=dim, + out_channels=dim, + bottleneck_channels=dim // 2, + norm="LN", + act_layer=act_layer, + ) + + def forward(self, x): + shortcut = x + x = self.norm1(x) + # Window partition + if self.window_size > 0: + H, W = x.shape[1], x.shape[2] + x, pad_hw = window_partition(x, self.window_size) + + x = self.attn(x) + # Reverse window partition + if self.window_size > 0: + x = window_unpartition(x, self.window_size, pad_hw, (H, W)) + + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + if self.use_residual_block: + x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) + + return x + + +class ViT(Backbone): + """ + This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. + "Exploring Plain Vision Transformer Backbones for Object Detection", + https://arxiv.org/abs/2203.16527 + """ + + def __init__( + self, + img_size=1024, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + drop_path_rate=0.0, + norm_layer=nn.LayerNorm, + act_layer=nn.GELU, + use_abs_pos=True, + use_rel_pos=False, + rel_pos_zero_init=True, + window_size=0, + window_block_indexes=(), + residual_block_indexes=(), + use_act_checkpoint=False, + pretrain_img_size=224, + pretrain_use_cls_token=True, + out_feature="last_feat", + ): + """ + Args: + img_size (int): Input image size. + patch_size (int): Patch size. + in_chans (int): Number of input image channels. + embed_dim (int): Patch embedding dimension. + depth (int): Depth of ViT. + num_heads (int): Number of attention heads in each ViT block. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool): If True, add a learnable bias to query, key, value. + drop_path_rate (float): Stochastic depth rate. + norm_layer (nn.Module): Normalization layer. + act_layer (nn.Module): Activation layer. + use_abs_pos (bool): If True, use absolute positional embeddings. + use_rel_pos (bool): If True, add relative positional embeddings to the attention map. + rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. + window_size (int): Window size for window attention blocks. + window_block_indexes (list): Indexes for blocks using window attention. + residual_block_indexes (list): Indexes for blocks using conv propagation. + use_act_checkpoint (bool): If True, use activation checkpointing. + pretrain_img_size (int): input image size for pretraining models. + pretrain_use_cls_token (bool): If True, pretrainig models use class token. + out_feature (str): name of the feature from the last block. + """ + super().__init__() + self.pretrain_use_cls_token = pretrain_use_cls_token + + self.patch_embed = PatchEmbed( + kernel_size=(patch_size, patch_size), + stride=(patch_size, patch_size), + in_chans=in_chans, + embed_dim=embed_dim, + ) + + if use_abs_pos: + # Initialize absolute positional embedding with pretrain image size. + num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) + num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches + self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) + else: + self.pos_embed = None + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] + + self.blocks = nn.ModuleList() + for i in range(depth): + block = Block( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + use_rel_pos=use_rel_pos, + rel_pos_zero_init=rel_pos_zero_init, + window_size=window_size if i in window_block_indexes else 0, + use_residual_block=i in residual_block_indexes, + input_size=(img_size // patch_size, img_size // patch_size), + ) + if use_act_checkpoint: + # TODO: use torch.utils.checkpoint + from fairscale.nn.checkpoint import checkpoint_wrapper + + block = checkpoint_wrapper(block) + self.blocks.append(block) + + self._out_feature_channels = {out_feature: embed_dim} + self._out_feature_strides = {out_feature: patch_size} + self._out_features = [out_feature] + + if self.pos_embed is not None: + nn.init.trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + x = self.patch_embed(x) + if self.pos_embed is not None: + x = x + get_abs_pos( + self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) + ) + + for blk in self.blocks: + x = blk(x) + + outputs = {self._out_features[0]: x.permute(0, 3, 1, 2)} + return outputs + + +class SimpleFeaturePyramid(Backbone): + """ + This module implements SimpleFeaturePyramid in :paper:`vitdet`. + It creates pyramid features built on top of the input feature map. + """ + + def __init__( + self, + net, + in_feature, + out_channels, + scale_factors, + top_block=None, + norm="LN", + square_pad=0, + ): + """ + Args: + net (Backbone): module representing the subnetwork backbone. + Must be a subclass of :class:`Backbone`. + in_feature (str): names of the input feature maps coming + from the net. + out_channels (int): number of channels in the output feature maps. + scale_factors (list[float]): list of scaling factors to upsample or downsample + the input features for creating pyramid features. + top_block (nn.Module or None): if provided, an extra operation will + be performed on the output of the last (smallest resolution) + pyramid output, and the result will extend the result list. The top_block + further downsamples the feature map. It must have an attribute + "num_levels", meaning the number of extra pyramid levels added by + this block, and "in_feature", which is a string representing + its input feature (e.g., p5). + norm (str): the normalization to use. + square_pad (int): If > 0, require input images to be padded to specific square size. + """ + super(SimpleFeaturePyramid, self).__init__() + assert isinstance(net, Backbone) + + self.scale_factors = scale_factors + + input_shapes = net.output_shape() + strides = [int(input_shapes[in_feature].stride / scale) for scale in scale_factors] + _assert_strides_are_log2_contiguous(strides) + + dim = input_shapes[in_feature].channels + self.stages = [] + use_bias = norm == "" + for idx, scale in enumerate(scale_factors): + out_dim = dim + if scale == 4.0: + layers = [ + nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2), + get_norm(norm, dim // 2), + nn.GELU(), + nn.ConvTranspose2d(dim // 2, dim // 4, kernel_size=2, stride=2), + ] + out_dim = dim // 4 + elif scale == 2.0: + layers = [nn.ConvTranspose2d(dim, dim // 2, kernel_size=2, stride=2)] + out_dim = dim // 2 + elif scale == 1.0: + layers = [] + elif scale == 0.5: + layers = [nn.MaxPool2d(kernel_size=2, stride=2)] + else: + raise NotImplementedError(f"scale_factor={scale} is not supported yet.") + + layers.extend( + [ + Conv2d( + out_dim, + out_channels, + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + Conv2d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, out_channels), + ), + ] + ) + layers = nn.Sequential(*layers) + + stage = int(math.log2(strides[idx])) + self.add_module(f"simfp_{stage}", layers) + self.stages.append(layers) + + self.net = net + self.in_feature = in_feature + self.top_block = top_block + # Return feature names are "p", like ["p2", "p3", ..., "p6"] + self._out_feature_strides = {"p{}".format(int(math.log2(s))): s for s in strides} + # top block output feature maps. + if self.top_block is not None: + for s in range(stage, stage + self.top_block.num_levels): + self._out_feature_strides["p{}".format(s + 1)] = 2 ** (s + 1) + + self._out_features = list(self._out_feature_strides.keys()) + self._out_feature_channels = {k: out_channels for k in self._out_features} + self._size_divisibility = strides[-1] + self._square_pad = square_pad + + @property + def padding_constraints(self): + return { + "size_divisiblity": self._size_divisibility, + "square_size": self._square_pad, + } + + def forward(self, x): + """ + Args: + x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``. + + Returns: + dict[str->Tensor]: + mapping from feature map name to pyramid feature map tensor + in high to low resolution order. Returned feature names follow the FPN + convention: "p", where stage has stride = 2 ** stage e.g., + ["p2", "p3", ..., "p6"]. + """ + bottom_up_features = self.net(x) + features = bottom_up_features[self.in_feature] + results = [] + + for stage in self.stages: + results.append(stage(features)) + + if self.top_block is not None: + if self.top_block.in_feature in bottom_up_features: + top_block_in_feature = bottom_up_features[self.top_block.in_feature] + else: + top_block_in_feature = results[self._out_features.index(self.top_block.in_feature)] + results.extend(self.top_block(top_block_in_feature)) + assert len(self._out_features) == len(results) + return {f: res for f, res in zip(self._out_features, results)} + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone"): + if ".pos_embed" in name or ".patch_embed" in name: + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) diff --git a/custom_detectron2/modeling/box_regression.py b/custom_detectron2/modeling/box_regression.py new file mode 100644 index 0000000000000000000000000000000000000000..944a41b81a022b21e4a5e52ac925558187da428a --- /dev/null +++ b/custom_detectron2/modeling/box_regression.py @@ -0,0 +1,369 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Tuple, Union +import torch +from fvcore.nn import giou_loss, smooth_l1_loss +from torch.nn import functional as F + +from custom_detectron2.layers import cat, ciou_loss, diou_loss +from custom_detectron2.structures import Boxes + +# Value for clamping large dw and dh predictions. The heuristic is that we clamp +# such that dw and dh are no larger than what would transform a 16px box into a +# 1000px box (based on a small anchor, 16px, and a typical image size, 1000px). +_DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16) + + +__all__ = ["Box2BoxTransform", "Box2BoxTransformRotated", "Box2BoxTransformLinear"] + + +@torch.jit.script +class Box2BoxTransform(object): + """ + The box-to-box transform defined in R-CNN. The transformation is parameterized + by 4 deltas: (dx, dy, dw, dh). The transformation scales the box's width and height + by exp(dw), exp(dh) and shifts a box's center by the offset (dx * width, dy * height). + """ + + def __init__( + self, weights: Tuple[float, float, float, float], scale_clamp: float = _DEFAULT_SCALE_CLAMP + ): + """ + Args: + weights (4-element tuple): Scaling factors that are applied to the + (dx, dy, dw, dh) deltas. In Fast R-CNN, these were originally set + such that the deltas have unit variance; now they are treated as + hyperparameters of the system. + scale_clamp (float): When predicting deltas, the predicted box scaling + factors (dw and dh) are clamped such that they are <= scale_clamp. + """ + self.weights = weights + self.scale_clamp = scale_clamp + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx, dy, dw, dh) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless + any delta is too large and is clamped). + + Args: + src_boxes (Tensor): source boxes, e.g., object proposals + target_boxes (Tensor): target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_widths = src_boxes[:, 2] - src_boxes[:, 0] + src_heights = src_boxes[:, 3] - src_boxes[:, 1] + src_ctr_x = src_boxes[:, 0] + 0.5 * src_widths + src_ctr_y = src_boxes[:, 1] + 0.5 * src_heights + + target_widths = target_boxes[:, 2] - target_boxes[:, 0] + target_heights = target_boxes[:, 3] - target_boxes[:, 1] + target_ctr_x = target_boxes[:, 0] + 0.5 * target_widths + target_ctr_y = target_boxes[:, 1] + 0.5 * target_heights + + wx, wy, ww, wh = self.weights + dx = wx * (target_ctr_x - src_ctr_x) / src_widths + dy = wy * (target_ctr_y - src_ctr_y) / src_heights + dw = ww * torch.log(target_widths / src_widths) + dh = wh * torch.log(target_heights / src_heights) + + deltas = torch.stack((dx, dy, dw, dh), dim=1) + assert (src_widths > 0).all().item(), "Input boxes to Box2BoxTransform are not valid!" + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx, dy, dw, dh) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. + deltas[i] represents k potentially different class-specific + box transformations for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 4) + """ + deltas = deltas.float() # ensure fp32 for decoding precision + boxes = boxes.to(deltas.dtype) + + widths = boxes[:, 2] - boxes[:, 0] + heights = boxes[:, 3] - boxes[:, 1] + ctr_x = boxes[:, 0] + 0.5 * widths + ctr_y = boxes[:, 1] + 0.5 * heights + + wx, wy, ww, wh = self.weights + dx = deltas[:, 0::4] / wx + dy = deltas[:, 1::4] / wy + dw = deltas[:, 2::4] / ww + dh = deltas[:, 3::4] / wh + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.scale_clamp) + dh = torch.clamp(dh, max=self.scale_clamp) + + pred_ctr_x = dx * widths[:, None] + ctr_x[:, None] + pred_ctr_y = dy * heights[:, None] + ctr_y[:, None] + pred_w = torch.exp(dw) * widths[:, None] + pred_h = torch.exp(dh) * heights[:, None] + + x1 = pred_ctr_x - 0.5 * pred_w + y1 = pred_ctr_y - 0.5 * pred_h + x2 = pred_ctr_x + 0.5 * pred_w + y2 = pred_ctr_y + 0.5 * pred_h + pred_boxes = torch.stack((x1, y1, x2, y2), dim=-1) + return pred_boxes.reshape(deltas.shape) + + +@torch.jit.script +class Box2BoxTransformRotated(object): + """ + The box-to-box transform defined in Rotated R-CNN. The transformation is parameterized + by 5 deltas: (dx, dy, dw, dh, da). The transformation scales the box's width and height + by exp(dw), exp(dh), shifts a box's center by the offset (dx * width, dy * height), + and rotate a box's angle by da (radians). + Note: angles of deltas are in radians while angles of boxes are in degrees. + """ + + def __init__( + self, + weights: Tuple[float, float, float, float, float], + scale_clamp: float = _DEFAULT_SCALE_CLAMP, + ): + """ + Args: + weights (5-element tuple): Scaling factors that are applied to the + (dx, dy, dw, dh, da) deltas. These are treated as + hyperparameters of the system. + scale_clamp (float): When predicting deltas, the predicted box scaling + factors (dw and dh) are clamped such that they are <= scale_clamp. + """ + self.weights = weights + self.scale_clamp = scale_clamp + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx, dy, dw, dh, da) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true (unless + any delta is too large and is clamped). + + Args: + src_boxes (Tensor): Nx5 source boxes, e.g., object proposals + target_boxes (Tensor): Nx5 target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_ctr_x, src_ctr_y, src_widths, src_heights, src_angles = torch.unbind(src_boxes, dim=1) + + target_ctr_x, target_ctr_y, target_widths, target_heights, target_angles = torch.unbind( + target_boxes, dim=1 + ) + + wx, wy, ww, wh, wa = self.weights + dx = wx * (target_ctr_x - src_ctr_x) / src_widths + dy = wy * (target_ctr_y - src_ctr_y) / src_heights + dw = ww * torch.log(target_widths / src_widths) + dh = wh * torch.log(target_heights / src_heights) + # Angles of deltas are in radians while angles of boxes are in degrees. + # the conversion to radians serve as a way to normalize the values + da = target_angles - src_angles + da = (da + 180.0) % 360.0 - 180.0 # make it in [-180, 180) + da *= wa * math.pi / 180.0 + + deltas = torch.stack((dx, dy, dw, dh, da), dim=1) + assert ( + (src_widths > 0).all().item() + ), "Input boxes to Box2BoxTransformRotated are not valid!" + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx, dy, dw, dh, da) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*5). + deltas[i] represents box transformation for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 5) + """ + assert deltas.shape[1] % 5 == 0 and boxes.shape[1] == 5 + + boxes = boxes.to(deltas.dtype).unsqueeze(2) + + ctr_x = boxes[:, 0] + ctr_y = boxes[:, 1] + widths = boxes[:, 2] + heights = boxes[:, 3] + angles = boxes[:, 4] + + wx, wy, ww, wh, wa = self.weights + + dx = deltas[:, 0::5] / wx + dy = deltas[:, 1::5] / wy + dw = deltas[:, 2::5] / ww + dh = deltas[:, 3::5] / wh + da = deltas[:, 4::5] / wa + + # Prevent sending too large values into torch.exp() + dw = torch.clamp(dw, max=self.scale_clamp) + dh = torch.clamp(dh, max=self.scale_clamp) + + pred_boxes = torch.zeros_like(deltas) + pred_boxes[:, 0::5] = dx * widths + ctr_x # x_ctr + pred_boxes[:, 1::5] = dy * heights + ctr_y # y_ctr + pred_boxes[:, 2::5] = torch.exp(dw) * widths # width + pred_boxes[:, 3::5] = torch.exp(dh) * heights # height + + # Following original RRPN implementation, + # angles of deltas are in radians while angles of boxes are in degrees. + pred_angle = da * 180.0 / math.pi + angles + pred_angle = (pred_angle + 180.0) % 360.0 - 180.0 # make it in [-180, 180) + + pred_boxes[:, 4::5] = pred_angle + + return pred_boxes + + +class Box2BoxTransformLinear(object): + """ + The linear box-to-box transform defined in FCOS. The transformation is parameterized + by the distance from the center of (square) src box to 4 edges of the target box. + """ + + def __init__(self, normalize_by_size=True): + """ + Args: + normalize_by_size: normalize deltas by the size of src (anchor) boxes. + """ + self.normalize_by_size = normalize_by_size + + def get_deltas(self, src_boxes, target_boxes): + """ + Get box regression transformation deltas (dx1, dy1, dx2, dy2) that can be used + to transform the `src_boxes` into the `target_boxes`. That is, the relation + ``target_boxes == self.apply_deltas(deltas, src_boxes)`` is true. + The center of src must be inside target boxes. + + Args: + src_boxes (Tensor): square source boxes, e.g., anchors + target_boxes (Tensor): target of the transformation, e.g., ground-truth + boxes. + """ + assert isinstance(src_boxes, torch.Tensor), type(src_boxes) + assert isinstance(target_boxes, torch.Tensor), type(target_boxes) + + src_ctr_x = 0.5 * (src_boxes[:, 0] + src_boxes[:, 2]) + src_ctr_y = 0.5 * (src_boxes[:, 1] + src_boxes[:, 3]) + + target_l = src_ctr_x - target_boxes[:, 0] + target_t = src_ctr_y - target_boxes[:, 1] + target_r = target_boxes[:, 2] - src_ctr_x + target_b = target_boxes[:, 3] - src_ctr_y + + deltas = torch.stack((target_l, target_t, target_r, target_b), dim=1) + if self.normalize_by_size: + stride_w = src_boxes[:, 2] - src_boxes[:, 0] + stride_h = src_boxes[:, 3] - src_boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + deltas = deltas / strides + + return deltas + + def apply_deltas(self, deltas, boxes): + """ + Apply transformation `deltas` (dx1, dy1, dx2, dy2) to `boxes`. + + Args: + deltas (Tensor): transformation deltas of shape (N, k*4), where k >= 1. + deltas[i] represents k potentially different class-specific + box transformations for the single box boxes[i]. + boxes (Tensor): boxes to transform, of shape (N, 4) + """ + # Ensure the output is a valid box. See Sec 2.1 of https://arxiv.org/abs/2006.09214 + deltas = F.relu(deltas) + boxes = boxes.to(deltas.dtype) + + ctr_x = 0.5 * (boxes[:, 0] + boxes[:, 2]) + ctr_y = 0.5 * (boxes[:, 1] + boxes[:, 3]) + if self.normalize_by_size: + stride_w = boxes[:, 2] - boxes[:, 0] + stride_h = boxes[:, 3] - boxes[:, 1] + strides = torch.stack([stride_w, stride_h, stride_w, stride_h], axis=1) + deltas = deltas * strides + + l = deltas[:, 0::4] + t = deltas[:, 1::4] + r = deltas[:, 2::4] + b = deltas[:, 3::4] + + pred_boxes = torch.zeros_like(deltas) + pred_boxes[:, 0::4] = ctr_x[:, None] - l # x1 + pred_boxes[:, 1::4] = ctr_y[:, None] - t # y1 + pred_boxes[:, 2::4] = ctr_x[:, None] + r # x2 + pred_boxes[:, 3::4] = ctr_y[:, None] + b # y2 + return pred_boxes + + +def _dense_box_regression_loss( + anchors: List[Union[Boxes, torch.Tensor]], + box2box_transform: Box2BoxTransform, + pred_anchor_deltas: List[torch.Tensor], + gt_boxes: List[torch.Tensor], + fg_mask: torch.Tensor, + box_reg_loss_type="smooth_l1", + smooth_l1_beta=0.0, +): + """ + Compute loss for dense multi-level box regression. + Loss is accumulated over ``fg_mask``. + + Args: + anchors: #lvl anchor boxes, each is (HixWixA, 4) + pred_anchor_deltas: #lvl predictions, each is (N, HixWixA, 4) + gt_boxes: N ground truth boxes, each has shape (R, 4) (R = sum(Hi * Wi * A)) + fg_mask: the foreground boolean mask of shape (N, R) to compute loss on + box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou", + "diou", "ciou". + smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to + use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" + """ + if isinstance(anchors[0], Boxes): + anchors = type(anchors[0]).cat(anchors).tensor # (R, 4) + else: + anchors = cat(anchors) + if box_reg_loss_type == "smooth_l1": + gt_anchor_deltas = [box2box_transform.get_deltas(anchors, k) for k in gt_boxes] + gt_anchor_deltas = torch.stack(gt_anchor_deltas) # (N, R, 4) + loss_box_reg = smooth_l1_loss( + cat(pred_anchor_deltas, dim=1)[fg_mask], + gt_anchor_deltas[fg_mask], + beta=smooth_l1_beta, + reduction="sum", + ) + elif box_reg_loss_type == "giou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = giou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + elif box_reg_loss_type == "diou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = diou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + elif box_reg_loss_type == "ciou": + pred_boxes = [ + box2box_transform.apply_deltas(k, anchors) for k in cat(pred_anchor_deltas, dim=1) + ] + loss_box_reg = ciou_loss( + torch.stack(pred_boxes)[fg_mask], torch.stack(gt_boxes)[fg_mask], reduction="sum" + ) + else: + raise ValueError(f"Invalid dense box regression loss type '{box_reg_loss_type}'") + return loss_box_reg diff --git a/custom_detectron2/modeling/matcher.py b/custom_detectron2/modeling/matcher.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1c16e3f80e1ad90899df46d7efda448c68cff8 --- /dev/null +++ b/custom_detectron2/modeling/matcher.py @@ -0,0 +1,127 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch + +from custom_detectron2.layers import nonzero_tuple + + +# TODO: the name is too general +class Matcher(object): + """ + This class assigns to each predicted "element" (e.g., a box) a ground-truth + element. Each predicted element will have exactly zero or one matches; each + ground-truth element may be matched to zero or more predicted elements. + + The matching is determined by the MxN match_quality_matrix, that characterizes + how well each (ground-truth, prediction)-pair match each other. For example, + if the elements are boxes, this matrix may contain box intersection-over-union + overlap values. + + The matcher returns (a) a vector of length N containing the index of the + ground-truth element m in [0, M) that matches to prediction n in [0, N). + (b) a vector of length N containing the labels for each prediction. + """ + + def __init__( + self, thresholds: List[float], labels: List[int], allow_low_quality_matches: bool = False + ): + """ + Args: + thresholds (list): a list of thresholds used to stratify predictions + into levels. + labels (list): a list of values to label predictions belonging at + each level. A label can be one of {-1, 0, 1} signifying + {ignore, negative class, positive class}, respectively. + allow_low_quality_matches (bool): if True, produce additional matches + for predictions with maximum match quality lower than high_threshold. + See set_low_quality_matches_ for more details. + + For example, + thresholds = [0.3, 0.5] + labels = [0, -1, 1] + All predictions with iou < 0.3 will be marked with 0 and + thus will be considered as false positives while training. + All predictions with 0.3 <= iou < 0.5 will be marked with -1 and + thus will be ignored. + All predictions with 0.5 <= iou will be marked with 1 and + thus will be considered as true positives. + """ + # Add -inf and +inf to first and last position in thresholds + thresholds = thresholds[:] + assert thresholds[0] > 0 + thresholds.insert(0, -float("inf")) + thresholds.append(float("inf")) + # Currently torchscript does not support all + generator + assert all([low <= high for (low, high) in zip(thresholds[:-1], thresholds[1:])]) + assert all([l in [-1, 0, 1] for l in labels]) + assert len(labels) == len(thresholds) - 1 + self.thresholds = thresholds + self.labels = labels + self.allow_low_quality_matches = allow_low_quality_matches + + def __call__(self, match_quality_matrix): + """ + Args: + match_quality_matrix (Tensor[float]): an MxN tensor, containing the + pairwise quality between M ground-truth elements and N predicted + elements. All elements must be >= 0 (due to the us of `torch.nonzero` + for selecting indices in :meth:`set_low_quality_matches_`). + + Returns: + matches (Tensor[int64]): a vector of length N, where matches[i] is a matched + ground-truth index in [0, M) + match_labels (Tensor[int8]): a vector of length N, where pred_labels[i] indicates + whether a prediction is a true or false positive or ignored + """ + assert match_quality_matrix.dim() == 2 + if match_quality_matrix.numel() == 0: + default_matches = match_quality_matrix.new_full( + (match_quality_matrix.size(1),), 0, dtype=torch.int64 + ) + # When no gt boxes exist, we define IOU = 0 and therefore set labels + # to `self.labels[0]`, which usually defaults to background class 0 + # To choose to ignore instead, can make labels=[-1,0,-1,1] + set appropriate thresholds + default_match_labels = match_quality_matrix.new_full( + (match_quality_matrix.size(1),), self.labels[0], dtype=torch.int8 + ) + return default_matches, default_match_labels + + assert torch.all(match_quality_matrix >= 0) + + # match_quality_matrix is M (gt) x N (predicted) + # Max over gt elements (dim 0) to find best gt candidate for each prediction + matched_vals, matches = match_quality_matrix.max(dim=0) + + match_labels = matches.new_full(matches.size(), 1, dtype=torch.int8) + + for (l, low, high) in zip(self.labels, self.thresholds[:-1], self.thresholds[1:]): + low_high = (matched_vals >= low) & (matched_vals < high) + match_labels[low_high] = l + + if self.allow_low_quality_matches: + self.set_low_quality_matches_(match_labels, match_quality_matrix) + + return matches, match_labels + + def set_low_quality_matches_(self, match_labels, match_quality_matrix): + """ + Produce additional matches for predictions that have only low-quality matches. + Specifically, for each ground-truth G find the set of predictions that have + maximum overlap with it (including ties); for each prediction in that set, if + it is unmatched, then match it to the ground-truth G. + + This function implements the RPN assignment case (i) in Sec. 3.1.2 of + :paper:`Faster R-CNN`. + """ + # For each gt, find the prediction with which it has highest quality + highest_quality_foreach_gt, _ = match_quality_matrix.max(dim=1) + # Find the highest quality match available, even if it is low, including ties. + # Note that the matches qualities must be positive due to the use of + # `torch.nonzero`. + _, pred_inds_with_highest_quality = nonzero_tuple( + match_quality_matrix == highest_quality_foreach_gt[:, None] + ) + # If an anchor was labeled positive only due to a low-quality match + # with gt_A, but it has larger overlap with gt_B, it's matched index will still be gt_B. + # This follows the implementation in Detectron, and is found to have no significant impact. + match_labels[pred_inds_with_highest_quality] = 1 diff --git a/custom_detectron2/modeling/meta_arch/__init__.py b/custom_detectron2/modeling/meta_arch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6b0668157052ce7b796ef50bc7ee85361e7605b9 --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/__init__.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +from .build import META_ARCH_REGISTRY, build_model # isort:skip + +from .panoptic_fpn import PanopticFPN + +# import all the meta_arch, so they will be registered +from .rcnn import GeneralizedRCNN, ProposalNetwork +from .dense_detector import DenseDetector +from .retinanet import RetinaNet +from .fcos import FCOS +from .semantic_seg import SEM_SEG_HEADS_REGISTRY, SemanticSegmentor, build_sem_seg_head + + +__all__ = list(globals().keys()) diff --git a/custom_detectron2/modeling/meta_arch/build.py b/custom_detectron2/modeling/meta_arch/build.py new file mode 100644 index 0000000000000000000000000000000000000000..545ebbd1c481e6f1fb6b322770afdcb661126d68 --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/build.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from custom_detectron2.utils.logger import _log_api_usage +from custom_detectron2.utils.registry import Registry + +META_ARCH_REGISTRY = Registry("META_ARCH") # noqa F401 isort:skip +META_ARCH_REGISTRY.__doc__ = """ +Registry for meta-architectures, i.e. the whole model. + +The registered object will be called with `obj(cfg)` +and expected to return a `nn.Module` object. +""" + + +def build_model(cfg): + """ + Build the whole model architecture, defined by ``cfg.MODEL.META_ARCHITECTURE``. + Note that it does not load any weights from ``cfg``. + """ + meta_arch = cfg.MODEL.META_ARCHITECTURE + model = META_ARCH_REGISTRY.get(meta_arch)(cfg) + _log_api_usage("modeling.meta_arch." + meta_arch) + return model diff --git a/custom_detectron2/modeling/meta_arch/dense_detector.py b/custom_detectron2/modeling/meta_arch/dense_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..148fcd6cf0d37068346b525b54c84af5feb731c2 --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/dense_detector.py @@ -0,0 +1,294 @@ +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import Tensor, nn + +from custom_detectron2.data.detection_utils import convert_image_to_rgb +from custom_detectron2.layers import move_device_like +from custom_detectron2.modeling import Backbone +from custom_detectron2.structures import Boxes, ImageList, Instances +from custom_detectron2.utils.events import get_event_storage + +from ..postprocessing import detector_postprocess + + +def permute_to_N_HWA_K(tensor, K: int): + """ + Transpose/reshape a tensor from (N, (Ai x K), H, W) to (N, (HxWxAi), K) + """ + assert tensor.dim() == 4, tensor.shape + N, _, H, W = tensor.shape + tensor = tensor.view(N, -1, K, H, W) + tensor = tensor.permute(0, 3, 4, 1, 2) + tensor = tensor.reshape(N, -1, K) # Size=(N,HWA,K) + return tensor + + +class DenseDetector(nn.Module): + """ + Base class for dense detector. We define a dense detector as a fully-convolutional model that + makes per-pixel (i.e. dense) predictions. + """ + + def __init__( + self, + backbone: Backbone, + head: nn.Module, + head_in_features: Optional[List[str]] = None, + *, + pixel_mean, + pixel_std, + ): + """ + Args: + backbone: backbone module + head: head module + head_in_features: backbone features to use in head. Default to all backbone features. + pixel_mean (Tuple[float]): + Values to be used for image normalization (BGR order). + To train on images of different number of channels, set different mean & std. + Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675] + pixel_std (Tuple[float]): + When using pre-trained models in Detectron1 or any MSRA models, + std has been absorbed into its conv1 weights, so the std needs to be set 1. + Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std) + """ + super().__init__() + + self.backbone = backbone + self.head = head + if head_in_features is None: + shapes = self.backbone.output_shape() + self.head_in_features = sorted(shapes.keys(), key=lambda x: shapes[x].stride) + else: + self.head_in_features = head_in_features + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def forward(self, batched_inputs: List[Dict[str, Tensor]]): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + + * image: Tensor, image in (C, H, W) format. + * instances: Instances + + Other information that's included in the original dicts, such as: + + * "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + In training, dict[str, Tensor]: mapping from a named loss to a tensor storing the + loss. Used during training only. In inference, the standard output format, described + in :doc:`/tutorials/models`. + """ + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + features = [features[f] for f in self.head_in_features] + predictions = self.head(features) + + if self.training: + assert not torch.jit.is_scripting(), "Not supported" + assert "instances" in batched_inputs[0], "Instance annotations are missing in training!" + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + return self.forward_training(images, features, predictions, gt_instances) + else: + results = self.forward_inference(images, features, predictions) + if torch.jit.is_scripting(): + return results + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + def forward_training(self, images, features, predictions, gt_instances): + raise NotImplementedError() + + def preprocess_image(self, batched_inputs: List[Dict[str, Tensor]]): + """ + Normalize, pad and batch the input images. + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + return images + + def _transpose_dense_predictions( + self, predictions: List[List[Tensor]], dims_per_anchor: List[int] + ) -> List[List[Tensor]]: + """ + Transpose the dense per-level predictions. + + Args: + predictions: a list of outputs, each is a list of per-level + predictions with shape (N, Ai x K, Hi, Wi), where N is the + number of images, Ai is the number of anchors per location on + level i, K is the dimension of predictions per anchor. + dims_per_anchor: the value of K for each predictions. e.g. 4 for + box prediction, #classes for classification prediction. + + Returns: + List[List[Tensor]]: each prediction is transposed to (N, Hi x Wi x Ai, K). + """ + assert len(predictions) == len(dims_per_anchor) + res: List[List[Tensor]] = [] + for pred, dim_per_anchor in zip(predictions, dims_per_anchor): + pred = [permute_to_N_HWA_K(x, dim_per_anchor) for x in pred] + res.append(pred) + return res + + def _ema_update(self, name: str, value: float, initial_value: float, momentum: float = 0.9): + """ + Apply EMA update to `self.name` using `value`. + + This is mainly used for loss normalizer. In Detectron1, loss is normalized by number + of foreground samples in the batch. When batch size is 1 per GPU, #foreground has a + large variance and using it lead to lower performance. Therefore we maintain an EMA of + #foreground to stabilize the normalizer. + + Args: + name: name of the normalizer + value: the new value to update + initial_value: the initial value to start with + momentum: momentum of EMA + + Returns: + float: the updated EMA value + """ + if hasattr(self, name): + old = getattr(self, name) + else: + old = initial_value + new = old * momentum + value * (1 - momentum) + setattr(self, name, new) + return new + + def _decode_per_level_predictions( + self, + anchors: Boxes, + pred_scores: Tensor, + pred_deltas: Tensor, + score_thresh: float, + topk_candidates: int, + image_size: Tuple[int, int], + ) -> Instances: + """ + Decode boxes and classification predictions of one featuer level, by + the following steps: + 1. filter the predictions based on score threshold and top K scores. + 2. transform the box regression outputs + 3. return the predicted scores, classes and boxes + + Args: + anchors: Boxes, anchor for this feature level + pred_scores: HxWxA,K + pred_deltas: HxWxA,4 + + Returns: + Instances: with field "scores", "pred_boxes", "pred_classes". + """ + # Apply two filtering to make NMS faster. + # 1. Keep boxes with confidence score higher than threshold + keep_idxs = pred_scores > score_thresh + pred_scores = pred_scores[keep_idxs] + topk_idxs = torch.nonzero(keep_idxs) # Kx2 + + # 2. Keep top k top scoring boxes only + topk_idxs_size = topk_idxs.shape[0] + if isinstance(topk_idxs_size, Tensor): + # It's a tensor in tracing + num_topk = torch.clamp(topk_idxs_size, max=topk_candidates) + else: + num_topk = min(topk_idxs_size, topk_candidates) + pred_scores, idxs = pred_scores.topk(num_topk) + topk_idxs = topk_idxs[idxs] + + anchor_idxs, classes_idxs = topk_idxs.unbind(dim=1) + + pred_boxes = self.box2box_transform.apply_deltas( + pred_deltas[anchor_idxs], anchors.tensor[anchor_idxs] + ) + return Instances( + image_size, pred_boxes=Boxes(pred_boxes), scores=pred_scores, pred_classes=classes_idxs + ) + + def _decode_multi_level_predictions( + self, + anchors: List[Boxes], + pred_scores: List[Tensor], + pred_deltas: List[Tensor], + score_thresh: float, + topk_candidates: int, + image_size: Tuple[int, int], + ) -> Instances: + """ + Run `_decode_per_level_predictions` for all feature levels and concat the results. + """ + predictions = [ + self._decode_per_level_predictions( + anchors_i, + box_cls_i, + box_reg_i, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + # Iterate over every feature level + for box_cls_i, box_reg_i, anchors_i in zip(pred_scores, pred_deltas, anchors) + ] + return predictions[0].cat(predictions) # 'Instances.cat' is not scriptale but this is + + def visualize_training(self, batched_inputs, results): + """ + A function used to visualize ground truth images and final network predictions. + It shows ground truth bounding boxes on the original image and up to 20 + predicted object bounding boxes on the original image. + + Args: + batched_inputs (list): a list that contains input to the model. + results (List[Instances]): a list of #images elements returned by forward_inference(). + """ + from custom_detectron2.utils.visualizer import Visualizer + + assert len(batched_inputs) == len( + results + ), "Cannot visualize inputs and results of different sizes" + storage = get_event_storage() + max_boxes = 20 + + image_index = 0 # only visualize a single image + img = batched_inputs[image_index]["image"] + img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) + v_gt = Visualizer(img, None) + v_gt = v_gt.overlay_instances(boxes=batched_inputs[image_index]["instances"].gt_boxes) + anno_img = v_gt.get_image() + processed_results = detector_postprocess(results[image_index], img.shape[0], img.shape[1]) + predicted_boxes = processed_results.pred_boxes.tensor.detach().cpu().numpy() + + v_pred = Visualizer(img, None) + v_pred = v_pred.overlay_instances(boxes=predicted_boxes[0:max_boxes]) + prop_img = v_pred.get_image() + vis_img = np.vstack((anno_img, prop_img)) + vis_img = vis_img.transpose(2, 0, 1) + vis_name = f"Top: GT bounding boxes; Bottom: {max_boxes} Highest Scoring Results" + storage.put_image(vis_name, vis_img) diff --git a/custom_detectron2/modeling/meta_arch/fcos.py b/custom_detectron2/modeling/meta_arch/fcos.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f8603f776e1a0ab405684c4f8380ba78cf2600 --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/fcos.py @@ -0,0 +1,328 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from typing import List, Optional, Tuple +import torch +from fvcore.nn import sigmoid_focal_loss_jit +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.layers import ShapeSpec, batched_nms +from custom_detectron2.structures import Boxes, ImageList, Instances, pairwise_point_box_distance +from custom_detectron2.utils.events import get_event_storage + +from ..anchor_generator import DefaultAnchorGenerator +from ..backbone import Backbone +from ..box_regression import Box2BoxTransformLinear, _dense_box_regression_loss +from .dense_detector import DenseDetector +from .retinanet import RetinaNetHead + +__all__ = ["FCOS"] + +logger = logging.getLogger(__name__) + + +class FCOS(DenseDetector): + """ + Implement FCOS in :paper:`fcos`. + """ + + def __init__( + self, + *, + backbone: Backbone, + head: nn.Module, + head_in_features: Optional[List[str]] = None, + box2box_transform=None, + num_classes, + center_sampling_radius: float = 1.5, + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + test_score_thresh=0.2, + test_topk_candidates=1000, + test_nms_thresh=0.6, + max_detections_per_image=100, + pixel_mean, + pixel_std, + ): + """ + Args: + center_sampling_radius: radius of the "center" of a groundtruth box, + within which all anchor points are labeled positive. + Other arguments mean the same as in :class:`RetinaNet`. + """ + super().__init__( + backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std + ) + + self.num_classes = num_classes + + # FCOS uses one anchor point per location. + # We represent the anchor point by a box whose size equals the anchor stride. + feature_shapes = backbone.output_shape() + fpn_strides = [feature_shapes[k].stride for k in self.head_in_features] + self.anchor_generator = DefaultAnchorGenerator( + sizes=[[k] for k in fpn_strides], aspect_ratios=[1.0], strides=fpn_strides + ) + + # FCOS parameterizes box regression by a linear transform, + # where predictions are normalized by anchor stride (equal to anchor size). + if box2box_transform is None: + box2box_transform = Box2BoxTransformLinear(normalize_by_size=True) + self.box2box_transform = box2box_transform + + self.center_sampling_radius = float(center_sampling_radius) + + # Loss parameters: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + + # Inference parameters: + self.test_score_thresh = test_score_thresh + self.test_topk_candidates = test_topk_candidates + self.test_nms_thresh = test_nms_thresh + self.max_detections_per_image = max_detections_per_image + + def forward_training(self, images, features, predictions, gt_instances): + # Transpose the Hi*Wi*A dimension to the middle: + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( + predictions, [self.num_classes, 4, 1] + ) + anchors = self.anchor_generator(features) + gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) + return self.losses( + anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness + ) + + @torch.no_grad() + def _match_anchors(self, gt_boxes: Boxes, anchors: List[Boxes]): + """ + Match ground-truth boxes to a set of multi-level anchors. + + Args: + gt_boxes: Ground-truth boxes from instances of an image. + anchors: List of anchors for each feature map (of different scales). + + Returns: + torch.Tensor + A tensor of shape `(M, R)`, given `M` ground-truth boxes and total + `R` anchor points from all feature levels, indicating the quality + of match between m-th box and r-th anchor. Higher value indicates + better match. + """ + # Naming convention: (M = ground-truth boxes, R = anchor points) + # Anchor points are represented as square boxes of size = stride. + num_anchors_per_level = [len(x) for x in anchors] + anchors = Boxes.cat(anchors) # (R, 4) + anchor_centers = anchors.get_centers() # (R, 2) + anchor_sizes = anchors.tensor[:, 2] - anchors.tensor[:, 0] # (R, ) + + lower_bound = anchor_sizes * 4 + lower_bound[: num_anchors_per_level[0]] = 0 + upper_bound = anchor_sizes * 8 + upper_bound[-num_anchors_per_level[-1] :] = float("inf") + + gt_centers = gt_boxes.get_centers() + + # FCOS with center sampling: anchor point must be close enough to + # ground-truth box center. + center_dists = (anchor_centers[None, :, :] - gt_centers[:, None, :]).abs_() + sampling_regions = self.center_sampling_radius * anchor_sizes[None, :] + + match_quality_matrix = center_dists.max(dim=2).values < sampling_regions + + pairwise_dist = pairwise_point_box_distance(anchor_centers, gt_boxes) + pairwise_dist = pairwise_dist.permute(1, 0, 2) # (M, R, 4) + + # The original FCOS anchor matching rule: anchor point must be inside GT. + match_quality_matrix &= pairwise_dist.min(dim=2).values > 0 + + # Multilevel anchor matching in FCOS: each anchor is only responsible + # for certain scale range. + pairwise_dist = pairwise_dist.max(dim=2).values + match_quality_matrix &= (pairwise_dist > lower_bound[None, :]) & ( + pairwise_dist < upper_bound[None, :] + ) + # Match the GT box with minimum area, if there are multiple GT matches. + gt_areas = gt_boxes.area() # (M, ) + + match_quality_matrix = match_quality_matrix.to(torch.float32) + match_quality_matrix *= 1e8 - gt_areas[:, None] + return match_quality_matrix # (M, R) + + @torch.no_grad() + def label_anchors(self, anchors: List[Boxes], gt_instances: List[Instances]): + """ + Same interface as :meth:`RetinaNet.label_anchors`, but implemented with FCOS + anchor matching rule. + + Unlike RetinaNet, there are no ignored anchors. + """ + + gt_labels, matched_gt_boxes = [], [] + + for inst in gt_instances: + if len(inst) > 0: + match_quality_matrix = self._match_anchors(inst.gt_boxes, anchors) + + # Find matched ground-truth box per anchor. Un-matched anchors are + # assigned -1. This is equivalent to using an anchor matcher as used + # in R-CNN/RetinaNet: `Matcher(thresholds=[1e-5], labels=[0, 1])` + match_quality, matched_idxs = match_quality_matrix.max(dim=0) + matched_idxs[match_quality < 1e-5] = -1 + + matched_gt_boxes_i = inst.gt_boxes.tensor[matched_idxs.clip(min=0)] + gt_labels_i = inst.gt_classes[matched_idxs.clip(min=0)] + + # Anchors with matched_idxs = -1 are labeled background. + gt_labels_i[matched_idxs < 0] = self.num_classes + else: + matched_gt_boxes_i = torch.zeros_like(Boxes.cat(anchors).tensor) + gt_labels_i = torch.full( + (len(matched_gt_boxes_i),), + fill_value=self.num_classes, + dtype=torch.long, + device=matched_gt_boxes_i.device, + ) + + gt_labels.append(gt_labels_i) + matched_gt_boxes.append(matched_gt_boxes_i) + + return gt_labels, matched_gt_boxes + + def losses( + self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes, pred_centerness + ): + """ + This method is almost identical to :meth:`RetinaNet.losses`, with an extra + "loss_centerness" in the returned dict. + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (M, R) + + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) + num_pos_anchors = pos_mask.sum().item() + get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) + normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 300) + + # classification and regression loss + gt_labels_target = F.one_hot(gt_labels, num_classes=self.num_classes + 1)[ + :, :, :-1 + ] # no loss for the last (background) class + loss_cls = sigmoid_focal_loss_jit( + torch.cat(pred_logits, dim=1), + gt_labels_target.to(pred_logits[0].dtype), + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) + + loss_box_reg = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type="giou", + ) + + ctrness_targets = self.compute_ctrness_targets(anchors, gt_boxes) # (M, R) + pred_centerness = torch.cat(pred_centerness, dim=1).squeeze(dim=2) # (M, R) + ctrness_loss = F.binary_cross_entropy_with_logits( + pred_centerness[pos_mask], ctrness_targets[pos_mask], reduction="sum" + ) + return { + "loss_fcos_cls": loss_cls / normalizer, + "loss_fcos_loc": loss_box_reg / normalizer, + "loss_fcos_ctr": ctrness_loss / normalizer, + } + + def compute_ctrness_targets(self, anchors: List[Boxes], gt_boxes: List[torch.Tensor]): + anchors = Boxes.cat(anchors).tensor # Rx4 + reg_targets = [self.box2box_transform.get_deltas(anchors, m) for m in gt_boxes] + reg_targets = torch.stack(reg_targets, dim=0) # NxRx4 + if len(reg_targets) == 0: + return reg_targets.new_zeros(len(reg_targets)) + left_right = reg_targets[:, :, [0, 2]] + top_bottom = reg_targets[:, :, [1, 3]] + ctrness = (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * ( + top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0] + ) + return torch.sqrt(ctrness) + + def forward_inference( + self, + images: ImageList, + features: List[torch.Tensor], + predictions: List[List[torch.Tensor]], + ): + pred_logits, pred_anchor_deltas, pred_centerness = self._transpose_dense_predictions( + predictions, [self.num_classes, 4, 1] + ) + anchors = self.anchor_generator(features) + + results: List[Instances] = [] + for img_idx, image_size in enumerate(images.image_sizes): + scores_per_image = [ + # Multiply and sqrt centerness & classification scores + # (See eqn. 4 in https://arxiv.org/abs/2006.09214) + torch.sqrt(x[img_idx].sigmoid_() * y[img_idx].sigmoid_()) + for x, y in zip(pred_logits, pred_centerness) + ] + deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] + results_per_image = self.inference_single_image( + anchors, scores_per_image, deltas_per_image, image_size + ) + results.append(results_per_image) + return results + + def inference_single_image( + self, + anchors: List[Boxes], + box_cls: List[torch.Tensor], + box_delta: List[torch.Tensor], + image_size: Tuple[int, int], + ): + """ + Identical to :meth:`RetinaNet.inference_single_image. + """ + pred = self._decode_multi_level_predictions( + anchors, + box_cls, + box_delta, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + keep = batched_nms( + pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh + ) + return pred[keep[: self.max_detections_per_image]] + + +class FCOSHead(RetinaNetHead): + """ + The head used in :paper:`fcos`. It adds an additional centerness + prediction branch on top of :class:`RetinaNetHead`. + """ + + def __init__(self, *, input_shape: List[ShapeSpec], conv_dims: List[int], **kwargs): + super().__init__(input_shape=input_shape, conv_dims=conv_dims, num_anchors=1, **kwargs) + # Unlike original FCOS, we do not add an additional learnable scale layer + # because it's found to have no benefits after normalizing regression targets by stride. + self._num_features = len(input_shape) + self.ctrness = nn.Conv2d(conv_dims[-1], 1, kernel_size=3, stride=1, padding=1) + torch.nn.init.normal_(self.ctrness.weight, std=0.01) + torch.nn.init.constant_(self.ctrness.bias, 0) + + def forward(self, features): + assert len(features) == self._num_features + logits = [] + bbox_reg = [] + ctrness = [] + for feature in features: + logits.append(self.cls_score(self.cls_subnet(feature))) + bbox_feature = self.bbox_subnet(feature) + bbox_reg.append(self.bbox_pred(bbox_feature)) + ctrness.append(self.ctrness(bbox_feature)) + return logits, bbox_reg, ctrness diff --git a/custom_detectron2/modeling/meta_arch/panoptic_fpn.py b/custom_detectron2/modeling/meta_arch/panoptic_fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..e3784b62574d6cff63f434c681ad4aa1ec20d0f3 --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/panoptic_fpn.py @@ -0,0 +1,269 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from typing import Dict, List +import torch +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.structures import ImageList + +from ..postprocessing import detector_postprocess, sem_seg_postprocess +from .build import META_ARCH_REGISTRY +from .rcnn import GeneralizedRCNN +from .semantic_seg import build_sem_seg_head + +__all__ = ["PanopticFPN"] + + +@META_ARCH_REGISTRY.register() +class PanopticFPN(GeneralizedRCNN): + """ + Implement the paper :paper:`PanopticFPN`. + """ + + @configurable + def __init__( + self, + *, + sem_seg_head: nn.Module, + combine_overlap_thresh: float = 0.5, + combine_stuff_area_thresh: float = 4096, + combine_instances_score_thresh: float = 0.5, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + sem_seg_head: a module for the semantic segmentation head. + combine_overlap_thresh: combine masks into one instances if + they have enough overlap + combine_stuff_area_thresh: ignore stuff areas smaller than this threshold + combine_instances_score_thresh: ignore instances whose score is + smaller than this threshold + + Other arguments are the same as :class:`GeneralizedRCNN`. + """ + super().__init__(**kwargs) + self.sem_seg_head = sem_seg_head + # options when combining instance & semantic outputs + self.combine_overlap_thresh = combine_overlap_thresh + self.combine_stuff_area_thresh = combine_stuff_area_thresh + self.combine_instances_score_thresh = combine_instances_score_thresh + + @classmethod + def from_config(cls, cfg): + ret = super().from_config(cfg) + ret.update( + { + "combine_overlap_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH, + "combine_stuff_area_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT, + "combine_instances_score_thresh": cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH, # noqa + } + ) + ret["sem_seg_head"] = build_sem_seg_head(cfg, ret["backbone"].output_shape()) + logger = logging.getLogger(__name__) + if not cfg.MODEL.PANOPTIC_FPN.COMBINE.ENABLED: + logger.warning( + "PANOPTIC_FPN.COMBINED.ENABLED is no longer used. " + " model.inference(do_postprocess=) should be used to toggle postprocessing." + ) + if cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT != 1.0: + w = cfg.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT + logger.warning( + "PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT should be replaced by weights on each ROI head." + ) + + def update_weight(x): + if isinstance(x, dict): + return {k: v * w for k, v in x.items()} + else: + return x * w + + roi_heads = ret["roi_heads"] + roi_heads.box_predictor.loss_weight = update_weight(roi_heads.box_predictor.loss_weight) + roi_heads.mask_head.loss_weight = update_weight(roi_heads.mask_head.loss_weight) + return ret + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + + For now, each item in the list is a dict that contains: + + * "image": Tensor, image in (C, H, W) format. + * "instances": Instances + * "sem_seg": semantic segmentation ground truth. + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + list[dict]: + each dict has the results for one image. The dict contains the following keys: + + * "instances": see :meth:`GeneralizedRCNN.forward` for its format. + * "sem_seg": see :meth:`SemanticSegmentor.forward` for its format. + * "panoptic_seg": See the return value of + :func:`combine_semantic_and_instance_outputs` for its format. + """ + if not self.training: + return self.inference(batched_inputs) + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + + assert "sem_seg" in batched_inputs[0] + gt_sem_seg = [x["sem_seg"].to(self.device) for x in batched_inputs] + gt_sem_seg = ImageList.from_tensors( + gt_sem_seg, + self.backbone.size_divisibility, + self.sem_seg_head.ignore_value, + self.backbone.padding_constraints, + ).tensor + sem_seg_results, sem_seg_losses = self.sem_seg_head(features, gt_sem_seg) + + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + detector_results, detector_losses = self.roi_heads( + images, features, proposals, gt_instances + ) + + losses = sem_seg_losses + losses.update(proposal_losses) + losses.update(detector_losses) + return losses + + def inference(self, batched_inputs: List[Dict[str, torch.Tensor]], do_postprocess: bool = True): + """ + Run inference on the given inputs. + + Args: + batched_inputs (list[dict]): same as in :meth:`forward` + do_postprocess (bool): whether to apply post-processing on the outputs. + + Returns: + When do_postprocess=True, see docs in :meth:`forward`. + Otherwise, returns a (list[Instances], list[Tensor]) that contains + the raw detector outputs, and raw semantic segmentation outputs. + """ + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + sem_seg_results, sem_seg_losses = self.sem_seg_head(features, None) + proposals, _ = self.proposal_generator(images, features, None) + detector_results, _ = self.roi_heads(images, features, proposals, None) + + if do_postprocess: + processed_results = [] + for sem_seg_result, detector_result, input_per_image, image_size in zip( + sem_seg_results, detector_results, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + sem_seg_r = sem_seg_postprocess(sem_seg_result, image_size, height, width) + detector_r = detector_postprocess(detector_result, height, width) + + processed_results.append({"sem_seg": sem_seg_r, "instances": detector_r}) + + panoptic_r = combine_semantic_and_instance_outputs( + detector_r, + sem_seg_r.argmax(dim=0), + self.combine_overlap_thresh, + self.combine_stuff_area_thresh, + self.combine_instances_score_thresh, + ) + processed_results[-1]["panoptic_seg"] = panoptic_r + return processed_results + else: + return detector_results, sem_seg_results + + +def combine_semantic_and_instance_outputs( + instance_results, + semantic_results, + overlap_threshold, + stuff_area_thresh, + instances_score_thresh, +): + """ + Implement a simple combining logic following + "combine_semantic_and_instance_predictions.py" in panopticapi + to produce panoptic segmentation outputs. + + Args: + instance_results: output of :func:`detector_postprocess`. + semantic_results: an (H, W) tensor, each element is the contiguous semantic + category id + + Returns: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each segment. + segments_info (list[dict]): Describe each segment in `panoptic_seg`. + Each dict contains keys "id", "category_id", "isthing". + """ + panoptic_seg = torch.zeros_like(semantic_results, dtype=torch.int32) + + # sort instance outputs by scores + sorted_inds = torch.argsort(-instance_results.scores) + + current_segment_id = 0 + segments_info = [] + + instance_masks = instance_results.pred_masks.to(dtype=torch.bool, device=panoptic_seg.device) + + # Add instances one-by-one, check for overlaps with existing ones + for inst_id in sorted_inds: + score = instance_results.scores[inst_id].item() + if score < instances_score_thresh: + break + mask = instance_masks[inst_id] # H,W + mask_area = mask.sum().item() + + if mask_area == 0: + continue + + intersect = (mask > 0) & (panoptic_seg > 0) + intersect_area = intersect.sum().item() + + if intersect_area * 1.0 / mask_area > overlap_threshold: + continue + + if intersect_area > 0: + mask = mask & (panoptic_seg == 0) + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + segments_info.append( + { + "id": current_segment_id, + "isthing": True, + "score": score, + "category_id": instance_results.pred_classes[inst_id].item(), + "instance_id": inst_id.item(), + } + ) + + # Add semantic results to remaining empty areas + semantic_labels = torch.unique(semantic_results).cpu().tolist() + for semantic_label in semantic_labels: + if semantic_label == 0: # 0 is a special "thing" class + continue + mask = (semantic_results == semantic_label) & (panoptic_seg == 0) + mask_area = mask.sum().item() + if mask_area < stuff_area_thresh: + continue + + current_segment_id += 1 + panoptic_seg[mask] = current_segment_id + segments_info.append( + { + "id": current_segment_id, + "isthing": False, + "category_id": semantic_label, + "area": mask_area, + } + ) + + return panoptic_seg, segments_info diff --git a/custom_detectron2/modeling/meta_arch/rcnn.py b/custom_detectron2/modeling/meta_arch/rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..31d54c8a0f463be6266ef3b3248c7cde53b85bba --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/rcnn.py @@ -0,0 +1,341 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.data.detection_utils import convert_image_to_rgb +from custom_detectron2.layers import move_device_like +from custom_detectron2.structures import ImageList, Instances +from custom_detectron2.utils.events import get_event_storage +from custom_detectron2.utils.logger import log_first_n + +from ..backbone import Backbone, build_backbone +from ..postprocessing import detector_postprocess +from ..proposal_generator import build_proposal_generator +from ..roi_heads import build_roi_heads +from .build import META_ARCH_REGISTRY + +__all__ = ["GeneralizedRCNN", "ProposalNetwork"] + + +@META_ARCH_REGISTRY.register() +class GeneralizedRCNN(nn.Module): + """ + Generalized R-CNN. Any models that contains the following three components: + 1. Per-image feature extraction (aka backbone) + 2. Region proposal generation + 3. Per-region feature extraction and prediction + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + proposal_generator: nn.Module, + roi_heads: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + input_format: Optional[str] = None, + vis_period: int = 0, + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + proposal_generator: a module that generates proposals using backbone features + roi_heads: a ROI head that performs per-region computation + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + input_format: describe the meaning of channels of input. Needed by visualization + vis_period: the period to run visualization. Set to 0 to disable. + """ + super().__init__() + self.backbone = backbone + self.proposal_generator = proposal_generator + self.roi_heads = roi_heads + + self.input_format = input_format + self.vis_period = vis_period + if vis_period > 0: + assert input_format is not None, "input_format is required for visualization!" + + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + assert ( + self.pixel_mean.shape == self.pixel_std.shape + ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + return { + "backbone": backbone, + "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), + "roi_heads": build_roi_heads(cfg, backbone.output_shape()), + "input_format": cfg.INPUT.FORMAT, + "vis_period": cfg.VIS_PERIOD, + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def visualize_training(self, batched_inputs, proposals): + """ + A function used to visualize images and proposals. It shows ground truth + bounding boxes on the original image and up to 20 top-scoring predicted + object proposals on the original image. Users can implement different + visualization functions for different models. + + Args: + batched_inputs (list): a list that contains input to the model. + proposals (list): a list that contains predicted proposals. Both + batched_inputs and proposals should have the same length. + """ + from custom_detectron2.utils.visualizer import Visualizer + + storage = get_event_storage() + max_vis_prop = 20 + + for input, prop in zip(batched_inputs, proposals): + img = input["image"] + img = convert_image_to_rgb(img.permute(1, 2, 0), self.input_format) + v_gt = Visualizer(img, None) + v_gt = v_gt.overlay_instances(boxes=input["instances"].gt_boxes) + anno_img = v_gt.get_image() + box_size = min(len(prop.proposal_boxes), max_vis_prop) + v_pred = Visualizer(img, None) + v_pred = v_pred.overlay_instances( + boxes=prop.proposal_boxes[0:box_size].tensor.cpu().numpy() + ) + prop_img = v_pred.get_image() + vis_img = np.concatenate((anno_img, prop_img), axis=1) + vis_img = vis_img.transpose(2, 0, 1) + vis_name = "Left: GT bounding boxes; Right: Predicted proposals" + storage.put_image(vis_name, vis_img) + break # only visualize one image in a batch + + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper` . + Each item in the list contains the inputs for one image. + For now, each item in the list is a dict that contains: + + * image: Tensor, image in (C, H, W) format. + * instances (optional): groundtruth :class:`Instances` + * proposals (optional): :class:`Instances`, precomputed proposals. + + Other information that's included in the original dicts, such as: + + * "height", "width" (int): the output resolution of the model, used in inference. + See :meth:`postprocess` for details. + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "instances" whose value is a :class:`Instances`. + The :class:`Instances` object has the following keys: + "pred_boxes", "pred_classes", "scores", "pred_masks", "pred_keypoints" + """ + if not self.training: + return self.inference(batched_inputs) + + images = self.preprocess_image(batched_inputs) + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + + features = self.backbone(images.tensor) + + if self.proposal_generator is not None: + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + else: + assert "proposals" in batched_inputs[0] + proposals = [x["proposals"].to(self.device) for x in batched_inputs] + proposal_losses = {} + + _, detector_losses = self.roi_heads(images, features, proposals, gt_instances) + if self.vis_period > 0: + storage = get_event_storage() + if storage.iter % self.vis_period == 0: + self.visualize_training(batched_inputs, proposals) + + losses = {} + losses.update(detector_losses) + losses.update(proposal_losses) + return losses + + def inference( + self, + batched_inputs: List[Dict[str, torch.Tensor]], + detected_instances: Optional[List[Instances]] = None, + do_postprocess: bool = True, + ): + """ + Run inference on the given inputs. + + Args: + batched_inputs (list[dict]): same as in :meth:`forward` + detected_instances (None or list[Instances]): if not None, it + contains an `Instances` object per image. The `Instances` + object contains "pred_boxes" and "pred_classes" which are + known boxes in the image. + The inference will then skip the detection of bounding boxes, + and only predict other per-ROI outputs. + do_postprocess (bool): whether to apply post-processing on the outputs. + + Returns: + When do_postprocess=True, same as in :meth:`forward`. + Otherwise, a list[Instances] containing raw network outputs. + """ + assert not self.training + + images = self.preprocess_image(batched_inputs) + features = self.backbone(images.tensor) + + if detected_instances is None: + if self.proposal_generator is not None: + proposals, _ = self.proposal_generator(images, features, None) + else: + assert "proposals" in batched_inputs[0] + proposals = [x["proposals"].to(self.device) for x in batched_inputs] + + results, _ = self.roi_heads(images, features, proposals, None) + else: + detected_instances = [x.to(self.device) for x in detected_instances] + results = self.roi_heads.forward_with_given_boxes(features, detected_instances) + + if do_postprocess: + assert not torch.jit.is_scripting(), "Scripting is not supported for postprocess." + return GeneralizedRCNN._postprocess(results, batched_inputs, images.image_sizes) + return results + + def preprocess_image(self, batched_inputs: List[Dict[str, torch.Tensor]]): + """ + Normalize, pad and batch the input images. + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + return images + + @staticmethod + def _postprocess(instances, batched_inputs: List[Dict[str, torch.Tensor]], image_sizes): + """ + Rescale the output instances to the target size. + """ + # note: private function; subject to changes + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + instances, batched_inputs, image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"instances": r}) + return processed_results + + +@META_ARCH_REGISTRY.register() +class ProposalNetwork(nn.Module): + """ + A meta architecture that only predicts object proposals. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + proposal_generator: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + proposal_generator: a module that generates proposals using backbone features + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.proposal_generator = proposal_generator + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + return { + "backbone": backbone, + "proposal_generator": build_proposal_generator(cfg, backbone.output_shape()), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def _move_to_current_device(self, x): + return move_device_like(x, self.pixel_mean) + + def forward(self, batched_inputs): + """ + Args: + Same as in :class:`GeneralizedRCNN.forward` + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "proposals" whose value is a + :class:`Instances` with keys "proposal_boxes" and "objectness_logits". + """ + images = [self._move_to_current_device(x["image"]) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + features = self.backbone(images.tensor) + + if "instances" in batched_inputs[0]: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + elif "targets" in batched_inputs[0]: + log_first_n( + logging.WARN, "'targets' in the model inputs is now renamed to 'instances'!", n=10 + ) + gt_instances = [x["targets"].to(self.device) for x in batched_inputs] + else: + gt_instances = None + proposals, proposal_losses = self.proposal_generator(images, features, gt_instances) + # In training, the proposals are not useful at all but we generate them anyway. + # This makes RPN-only models about 5% slower. + if self.training: + return proposal_losses + + processed_results = [] + for results_per_image, input_per_image, image_size in zip( + proposals, batched_inputs, images.image_sizes + ): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = detector_postprocess(results_per_image, height, width) + processed_results.append({"proposals": r}) + return processed_results diff --git a/custom_detectron2/modeling/meta_arch/retinanet.py b/custom_detectron2/modeling/meta_arch/retinanet.py new file mode 100644 index 0000000000000000000000000000000000000000..cdfca1fface9d53b0f22572699bdb6625978722e --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/retinanet.py @@ -0,0 +1,439 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from typing import List, Tuple +import torch +from fvcore.nn import sigmoid_focal_loss_jit +from torch import Tensor, nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.layers import CycleBatchNormList, ShapeSpec, batched_nms, cat, get_norm +from custom_detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from custom_detectron2.utils.events import get_event_storage + +from ..anchor_generator import build_anchor_generator +from ..backbone import Backbone, build_backbone +from ..box_regression import Box2BoxTransform, _dense_box_regression_loss +from ..matcher import Matcher +from .build import META_ARCH_REGISTRY +from .dense_detector import DenseDetector, permute_to_N_HWA_K # noqa + +__all__ = ["RetinaNet"] + + +logger = logging.getLogger(__name__) + + +@META_ARCH_REGISTRY.register() +class RetinaNet(DenseDetector): + """ + Implement RetinaNet in :paper:`RetinaNet`. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + head: nn.Module, + head_in_features, + anchor_generator, + box2box_transform, + anchor_matcher, + num_classes, + focal_loss_alpha=0.25, + focal_loss_gamma=2.0, + smooth_l1_beta=0.0, + box_reg_loss_type="smooth_l1", + test_score_thresh=0.05, + test_topk_candidates=1000, + test_nms_thresh=0.5, + max_detections_per_image=100, + pixel_mean, + pixel_std, + vis_period=0, + input_format="BGR", + ): + """ + NOTE: this interface is experimental. + + Args: + backbone: a backbone module, must follow detectron2's backbone interface + head (nn.Module): a module that predicts logits and regression deltas + for each level from a list of per-level features + head_in_features (Tuple[str]): Names of the input feature maps to be used in head + anchor_generator (nn.Module): a module that creates anchors from a + list of features. Usually an instance of :class:`AnchorGenerator` + box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to + instance boxes + anchor_matcher (Matcher): label the anchors by matching them with ground truth. + num_classes (int): number of classes. Used to label background proposals. + + # Loss parameters: + focal_loss_alpha (float): focal_loss_alpha + focal_loss_gamma (float): focal_loss_gamma + smooth_l1_beta (float): smooth_l1_beta + box_reg_loss_type (str): Options are "smooth_l1", "giou", "diou", "ciou" + + # Inference parameters: + test_score_thresh (float): Inference cls score threshold, only anchors with + score > INFERENCE_TH are considered for inference (to improve speed) + test_topk_candidates (int): Select topk candidates before NMS + test_nms_thresh (float): Overlap threshold used for non-maximum suppression + (suppress boxes with IoU >= this threshold) + max_detections_per_image (int): + Maximum number of detections to return per image during inference + (100 is based on the limit established for the COCO dataset). + + pixel_mean, pixel_std: see :class:`DenseDetector`. + """ + super().__init__( + backbone, head, head_in_features, pixel_mean=pixel_mean, pixel_std=pixel_std + ) + self.num_classes = num_classes + + # Anchors + self.anchor_generator = anchor_generator + self.box2box_transform = box2box_transform + self.anchor_matcher = anchor_matcher + + # Loss parameters: + self.focal_loss_alpha = focal_loss_alpha + self.focal_loss_gamma = focal_loss_gamma + self.smooth_l1_beta = smooth_l1_beta + self.box_reg_loss_type = box_reg_loss_type + # Inference parameters: + self.test_score_thresh = test_score_thresh + self.test_topk_candidates = test_topk_candidates + self.test_nms_thresh = test_nms_thresh + self.max_detections_per_image = max_detections_per_image + # Vis parameters + self.vis_period = vis_period + self.input_format = input_format + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + backbone_shape = backbone.output_shape() + feature_shapes = [backbone_shape[f] for f in cfg.MODEL.RETINANET.IN_FEATURES] + head = RetinaNetHead(cfg, feature_shapes) + anchor_generator = build_anchor_generator(cfg, feature_shapes) + return { + "backbone": backbone, + "head": head, + "anchor_generator": anchor_generator, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.RETINANET.BBOX_REG_WEIGHTS), + "anchor_matcher": Matcher( + cfg.MODEL.RETINANET.IOU_THRESHOLDS, + cfg.MODEL.RETINANET.IOU_LABELS, + allow_low_quality_matches=True, + ), + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + "num_classes": cfg.MODEL.RETINANET.NUM_CLASSES, + "head_in_features": cfg.MODEL.RETINANET.IN_FEATURES, + # Loss parameters: + "focal_loss_alpha": cfg.MODEL.RETINANET.FOCAL_LOSS_ALPHA, + "focal_loss_gamma": cfg.MODEL.RETINANET.FOCAL_LOSS_GAMMA, + "smooth_l1_beta": cfg.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA, + "box_reg_loss_type": cfg.MODEL.RETINANET.BBOX_REG_LOSS_TYPE, + # Inference parameters: + "test_score_thresh": cfg.MODEL.RETINANET.SCORE_THRESH_TEST, + "test_topk_candidates": cfg.MODEL.RETINANET.TOPK_CANDIDATES_TEST, + "test_nms_thresh": cfg.MODEL.RETINANET.NMS_THRESH_TEST, + "max_detections_per_image": cfg.TEST.DETECTIONS_PER_IMAGE, + # Vis parameters + "vis_period": cfg.VIS_PERIOD, + "input_format": cfg.INPUT.FORMAT, + } + + def forward_training(self, images, features, predictions, gt_instances): + # Transpose the Hi*Wi*A dimension to the middle: + pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( + predictions, [self.num_classes, 4] + ) + anchors = self.anchor_generator(features) + gt_labels, gt_boxes = self.label_anchors(anchors, gt_instances) + return self.losses(anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes) + + def losses(self, anchors, pred_logits, gt_labels, pred_anchor_deltas, gt_boxes): + """ + Args: + anchors (list[Boxes]): a list of #feature level Boxes + gt_labels, gt_boxes: see output of :meth:`RetinaNet.label_anchors`. + Their shapes are (N, R) and (N, R, 4), respectively, where R is + the total number of anchors across levels, i.e. sum(Hi x Wi x Ai) + pred_logits, pred_anchor_deltas: both are list[Tensor]. Each element in the + list corresponds to one level and has shape (N, Hi * Wi * Ai, K or 4). + Where K is the number of classes used in `pred_logits`. + + Returns: + dict[str, Tensor]: + mapping from a named loss to a scalar tensor storing the loss. + Used during training only. The dict keys are: "loss_cls" and "loss_box_reg" + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (N, R) + + valid_mask = gt_labels >= 0 + pos_mask = (gt_labels >= 0) & (gt_labels != self.num_classes) + num_pos_anchors = pos_mask.sum().item() + get_event_storage().put_scalar("num_pos_anchors", num_pos_anchors / num_images) + normalizer = self._ema_update("loss_normalizer", max(num_pos_anchors, 1), 100) + + # classification and regression loss + gt_labels_target = F.one_hot(gt_labels[valid_mask], num_classes=self.num_classes + 1)[ + :, :-1 + ] # no loss for the last (background) class + loss_cls = sigmoid_focal_loss_jit( + cat(pred_logits, dim=1)[valid_mask], + gt_labels_target.to(pred_logits[0].dtype), + alpha=self.focal_loss_alpha, + gamma=self.focal_loss_gamma, + reduction="sum", + ) + + loss_box_reg = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type=self.box_reg_loss_type, + smooth_l1_beta=self.smooth_l1_beta, + ) + + return { + "loss_cls": loss_cls / normalizer, + "loss_box_reg": loss_box_reg / normalizer, + } + + @torch.no_grad() + def label_anchors(self, anchors, gt_instances): + """ + Args: + anchors (list[Boxes]): A list of #feature level Boxes. + The Boxes contains anchors of this image on the specific feature level. + gt_instances (list[Instances]): a list of N `Instances`s. The i-th + `Instances` contains the ground-truth per-instance annotations + for the i-th input image. + + Returns: + list[Tensor]: List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across all feature maps (sum(Hi * Wi * A)). + Label values are in {-1, 0, ..., K}, with -1 means ignore, and K means background. + + list[Tensor]: i-th element is a Rx4 tensor, where R is the total number of anchors + across feature maps. The values are the matched gt boxes for each anchor. + Values are undefined for those anchors not labeled as foreground. + """ + anchors = Boxes.cat(anchors) # Rx4 + + gt_labels = [] + matched_gt_boxes = [] + for gt_per_image in gt_instances: + match_quality_matrix = pairwise_iou(gt_per_image.gt_boxes, anchors) + matched_idxs, anchor_labels = self.anchor_matcher(match_quality_matrix) + del match_quality_matrix + + if len(gt_per_image) > 0: + matched_gt_boxes_i = gt_per_image.gt_boxes.tensor[matched_idxs] + + gt_labels_i = gt_per_image.gt_classes[matched_idxs] + # Anchors with label 0 are treated as background. + gt_labels_i[anchor_labels == 0] = self.num_classes + # Anchors with label -1 are ignored. + gt_labels_i[anchor_labels == -1] = -1 + else: + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + gt_labels_i = torch.zeros_like(matched_idxs) + self.num_classes + + gt_labels.append(gt_labels_i) + matched_gt_boxes.append(matched_gt_boxes_i) + + return gt_labels, matched_gt_boxes + + def forward_inference( + self, images: ImageList, features: List[Tensor], predictions: List[List[Tensor]] + ): + pred_logits, pred_anchor_deltas = self._transpose_dense_predictions( + predictions, [self.num_classes, 4] + ) + anchors = self.anchor_generator(features) + + results: List[Instances] = [] + for img_idx, image_size in enumerate(images.image_sizes): + scores_per_image = [x[img_idx].sigmoid_() for x in pred_logits] + deltas_per_image = [x[img_idx] for x in pred_anchor_deltas] + results_per_image = self.inference_single_image( + anchors, scores_per_image, deltas_per_image, image_size + ) + results.append(results_per_image) + return results + + def inference_single_image( + self, + anchors: List[Boxes], + box_cls: List[Tensor], + box_delta: List[Tensor], + image_size: Tuple[int, int], + ): + """ + Single-image inference. Return bounding-box detection results by thresholding + on scores and applying non-maximum suppression (NMS). + + Arguments: + anchors (list[Boxes]): list of #feature levels. Each entry contains + a Boxes object, which contains all the anchors in that feature level. + box_cls (list[Tensor]): list of #feature levels. Each entry contains + tensor of size (H x W x A, K) + box_delta (list[Tensor]): Same shape as 'box_cls' except that K becomes 4. + image_size (tuple(H, W)): a tuple of the image height and width. + + Returns: + Same as `inference`, but for only one image. + """ + pred = self._decode_multi_level_predictions( + anchors, + box_cls, + box_delta, + self.test_score_thresh, + self.test_topk_candidates, + image_size, + ) + keep = batched_nms( # per-class NMS + pred.pred_boxes.tensor, pred.scores, pred.pred_classes, self.test_nms_thresh + ) + return pred[keep[: self.max_detections_per_image]] + + +class RetinaNetHead(nn.Module): + """ + The head used in RetinaNet for object classification and box regression. + It has two subnets for the two tasks, with a common structure but separate parameters. + """ + + @configurable + def __init__( + self, + *, + input_shape: List[ShapeSpec], + num_classes, + num_anchors, + conv_dims: List[int], + norm="", + prior_prob=0.01, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (List[ShapeSpec]): input shape + num_classes (int): number of classes. Used to label background proposals. + num_anchors (int): number of generated anchors + conv_dims (List[int]): dimensions for each convolution layer + norm (str or callable): + Normalization for conv layers except for the two output layers. + See :func:`detectron2.layers.get_norm` for supported types. + prior_prob (float): Prior weight for computing bias + """ + super().__init__() + + self._num_features = len(input_shape) + if norm == "BN" or norm == "SyncBN": + logger.info( + f"Using domain-specific {norm} in RetinaNetHead with len={self._num_features}." + ) + bn_class = nn.BatchNorm2d if norm == "BN" else nn.SyncBatchNorm + + def norm(c): + return CycleBatchNormList( + length=self._num_features, bn_class=bn_class, num_features=c + ) + + else: + norm_name = str(type(get_norm(norm, 32))) + if "BN" in norm_name: + logger.warning( + f"Shared BatchNorm (type={norm_name}) may not work well in RetinaNetHead." + ) + + cls_subnet = [] + bbox_subnet = [] + for in_channels, out_channels in zip( + [input_shape[0].channels] + list(conv_dims), conv_dims + ): + cls_subnet.append( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + ) + if norm: + cls_subnet.append(get_norm(norm, out_channels)) + cls_subnet.append(nn.ReLU()) + bbox_subnet.append( + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + ) + if norm: + bbox_subnet.append(get_norm(norm, out_channels)) + bbox_subnet.append(nn.ReLU()) + + self.cls_subnet = nn.Sequential(*cls_subnet) + self.bbox_subnet = nn.Sequential(*bbox_subnet) + self.cls_score = nn.Conv2d( + conv_dims[-1], num_anchors * num_classes, kernel_size=3, stride=1, padding=1 + ) + self.bbox_pred = nn.Conv2d( + conv_dims[-1], num_anchors * 4, kernel_size=3, stride=1, padding=1 + ) + + # Initialization + for modules in [self.cls_subnet, self.bbox_subnet, self.cls_score, self.bbox_pred]: + for layer in modules.modules(): + if isinstance(layer, nn.Conv2d): + torch.nn.init.normal_(layer.weight, mean=0, std=0.01) + torch.nn.init.constant_(layer.bias, 0) + + # Use prior in model initialization to improve stability + bias_value = -(math.log((1 - prior_prob) / prior_prob)) + torch.nn.init.constant_(self.cls_score.bias, bias_value) + + @classmethod + def from_config(cls, cfg, input_shape: List[ShapeSpec]): + num_anchors = build_anchor_generator(cfg, input_shape).num_cell_anchors + assert ( + len(set(num_anchors)) == 1 + ), "Using different number of anchors between levels is not currently supported!" + num_anchors = num_anchors[0] + + return { + "input_shape": input_shape, + "num_classes": cfg.MODEL.RETINANET.NUM_CLASSES, + "conv_dims": [input_shape[0].channels] * cfg.MODEL.RETINANET.NUM_CONVS, + "prior_prob": cfg.MODEL.RETINANET.PRIOR_PROB, + "norm": cfg.MODEL.RETINANET.NORM, + "num_anchors": num_anchors, + } + + def forward(self, features: List[Tensor]): + """ + Arguments: + features (list[Tensor]): FPN feature map tensors in high to low resolution. + Each tensor in the list correspond to different feature levels. + + Returns: + logits (list[Tensor]): #lvl tensors, each has shape (N, AxK, Hi, Wi). + The tensor predicts the classification probability + at each spatial position for each of the A anchors and K object + classes. + bbox_reg (list[Tensor]): #lvl tensors, each has shape (N, Ax4, Hi, Wi). + The tensor predicts 4-vector (dx,dy,dw,dh) box + regression values for every anchor. These values are the + relative offset between the anchor and the ground truth box. + """ + assert len(features) == self._num_features + logits = [] + bbox_reg = [] + for feature in features: + logits.append(self.cls_score(self.cls_subnet(feature))) + bbox_reg.append(self.bbox_pred(self.bbox_subnet(feature))) + return logits, bbox_reg diff --git a/custom_detectron2/modeling/meta_arch/semantic_seg.py b/custom_detectron2/modeling/meta_arch/semantic_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..e723b854f99da18850da0d7ee7914bd8a7bcd80d --- /dev/null +++ b/custom_detectron2/modeling/meta_arch/semantic_seg.py @@ -0,0 +1,267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Callable, Dict, Optional, Tuple, Union +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.layers import Conv2d, ShapeSpec, get_norm +from custom_detectron2.structures import ImageList +from custom_detectron2.utils.registry import Registry + +from ..backbone import Backbone, build_backbone +from ..postprocessing import sem_seg_postprocess +from .build import META_ARCH_REGISTRY + +__all__ = [ + "SemanticSegmentor", + "SEM_SEG_HEADS_REGISTRY", + "SemSegFPNHead", + "build_sem_seg_head", +] + + +SEM_SEG_HEADS_REGISTRY = Registry("SEM_SEG_HEADS") +SEM_SEG_HEADS_REGISTRY.__doc__ = """ +Registry for semantic segmentation heads, which make semantic segmentation predictions +from feature maps. +""" + + +@META_ARCH_REGISTRY.register() +class SemanticSegmentor(nn.Module): + """ + Main class for semantic segmentation architectures. + """ + + @configurable + def __init__( + self, + *, + backbone: Backbone, + sem_seg_head: nn.Module, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + backbone: a backbone module, must follow detectron2's backbone interface + sem_seg_head: a module that predicts semantic segmentation from backbone features + pixel_mean, pixel_std: list or tuple with #channels element, representing + the per-channel mean and std to be used to normalize the input image + """ + super().__init__() + self.backbone = backbone + self.sem_seg_head = sem_seg_head + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + + @classmethod + def from_config(cls, cfg): + backbone = build_backbone(cfg) + sem_seg_head = build_sem_seg_head(cfg, backbone.output_shape()) + return { + "backbone": backbone, + "sem_seg_head": sem_seg_head, + "pixel_mean": cfg.MODEL.PIXEL_MEAN, + "pixel_std": cfg.MODEL.PIXEL_STD, + } + + @property + def device(self): + return self.pixel_mean.device + + def forward(self, batched_inputs): + """ + Args: + batched_inputs: a list, batched outputs of :class:`DatasetMapper`. + Each item in the list contains the inputs for one image. + + For now, each item in the list is a dict that contains: + + * "image": Tensor, image in (C, H, W) format. + * "sem_seg": semantic segmentation ground truth + * Other information that's included in the original dicts, such as: + "height", "width" (int): the output resolution of the model (may be different + from input resolution), used in inference. + + + Returns: + list[dict]: + Each dict is the output for one input image. + The dict contains one key "sem_seg" whose value is a + Tensor that represents the + per-pixel segmentation prediced by the head. + The prediction has shape KxHxW that represents the logits of + each class for each pixel. + """ + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors( + images, + self.backbone.size_divisibility, + padding_constraints=self.backbone.padding_constraints, + ) + + features = self.backbone(images.tensor) + + if "sem_seg" in batched_inputs[0]: + targets = [x["sem_seg"].to(self.device) for x in batched_inputs] + targets = ImageList.from_tensors( + targets, + self.backbone.size_divisibility, + self.sem_seg_head.ignore_value, + self.backbone.padding_constraints, + ).tensor + else: + targets = None + results, losses = self.sem_seg_head(features, targets) + + if self.training: + return losses + + processed_results = [] + for result, input_per_image, image_size in zip(results, batched_inputs, images.image_sizes): + height = input_per_image.get("height", image_size[0]) + width = input_per_image.get("width", image_size[1]) + r = sem_seg_postprocess(result, image_size, height, width) + processed_results.append({"sem_seg": r}) + return processed_results + + +def build_sem_seg_head(cfg, input_shape): + """ + Build a semantic segmentation head from `cfg.MODEL.SEM_SEG_HEAD.NAME`. + """ + name = cfg.MODEL.SEM_SEG_HEAD.NAME + return SEM_SEG_HEADS_REGISTRY.get(name)(cfg, input_shape) + + +@SEM_SEG_HEADS_REGISTRY.register() +class SemSegFPNHead(nn.Module): + """ + A semantic segmentation head described in :paper:`PanopticFPN`. + It takes a list of FPN features as input, and applies a sequence of + 3x3 convs and upsampling to scale all of them to the stride defined by + ``common_stride``. Then these features are added and used to make final + predictions by another 1x1 conv layer. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + num_classes: int, + conv_dims: int, + common_stride: int, + loss_weight: float = 1.0, + norm: Optional[Union[str, Callable]] = None, + ignore_value: int = -1, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape: shapes (channels and stride) of the input features + num_classes: number of classes to predict + conv_dims: number of output channels for the intermediate conv layers. + common_stride: the common stride that all features will be upscaled to + loss_weight: loss weight + norm (str or callable): normalization for all conv layers + ignore_value: category id to be ignored during training. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + if not len(input_shape): + raise ValueError("SemSegFPNHead(input_shape=) cannot be empty!") + self.in_features = [k for k, v in input_shape] + feature_strides = [v.stride for k, v in input_shape] + feature_channels = [v.channels for k, v in input_shape] + + self.ignore_value = ignore_value + self.common_stride = common_stride + self.loss_weight = loss_weight + + self.scale_heads = [] + for in_feature, stride, channels in zip( + self.in_features, feature_strides, feature_channels + ): + head_ops = [] + head_length = max(1, int(np.log2(stride) - np.log2(self.common_stride))) + for k in range(head_length): + norm_module = get_norm(norm, conv_dims) + conv = Conv2d( + channels if k == 0 else conv_dims, + conv_dims, + kernel_size=3, + stride=1, + padding=1, + bias=not norm, + norm=norm_module, + activation=F.relu, + ) + weight_init.c2_msra_fill(conv) + head_ops.append(conv) + if stride != self.common_stride: + head_ops.append( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False) + ) + self.scale_heads.append(nn.Sequential(*head_ops)) + self.add_module(in_feature, self.scale_heads[-1]) + self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) + weight_init.c2_msra_fill(self.predictor) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + return { + "input_shape": { + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + "ignore_value": cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + "num_classes": cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + "conv_dims": cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM, + "common_stride": cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE, + "norm": cfg.MODEL.SEM_SEG_HEAD.NORM, + "loss_weight": cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + } + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = self.layers(features) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def layers(self, features): + for i, f in enumerate(self.in_features): + if i == 0: + x = self.scale_heads[i](features[f]) + else: + x = x + self.scale_heads[i](features[f]) + x = self.predictor(x) + return x + + def losses(self, predictions, targets): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + predictions = F.interpolate( + predictions, + scale_factor=self.common_stride, + mode="bilinear", + align_corners=False, + ) + loss = F.cross_entropy( + predictions, targets, reduction="mean", ignore_index=self.ignore_value + ) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses diff --git a/custom_detectron2/modeling/mmdet_wrapper.py b/custom_detectron2/modeling/mmdet_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..ae02094cbf7b2d8d664a845ad02fff5386c52e8b --- /dev/null +++ b/custom_detectron2/modeling/mmdet_wrapper.py @@ -0,0 +1,273 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +import numpy as np +from collections import OrderedDict +from collections.abc import Mapping +from typing import Dict, List, Optional, Tuple, Union +import torch +from omegaconf import DictConfig, OmegaConf +from torch import Tensor, nn + +from custom_detectron2.layers import ShapeSpec +from custom_detectron2.structures import BitMasks, Boxes, ImageList, Instances +from custom_detectron2.utils.events import get_event_storage + +from .backbone import Backbone + +logger = logging.getLogger(__name__) + + +def _to_container(cfg): + """ + mmdet will assert the type of dict/list. + So convert omegaconf objects to dict/list. + """ + if isinstance(cfg, DictConfig): + cfg = OmegaConf.to_container(cfg, resolve=True) + from custom_mmpkg.custom_mmcv.utils import ConfigDict + + return ConfigDict(cfg) + + +class MMDetBackbone(Backbone): + """ + Wrapper of mmdetection backbones to use in detectron2. + + mmdet backbones produce list/tuple of tensors, while detectron2 backbones + produce a dict of tensors. This class wraps the given backbone to produce + output in detectron2's convention, so it can be used in place of detectron2 + backbones. + """ + + def __init__( + self, + backbone: Union[nn.Module, Mapping], + neck: Union[nn.Module, Mapping, None] = None, + *, + output_shapes: List[ShapeSpec], + output_names: Optional[List[str]] = None, + ): + """ + Args: + backbone: either a backbone module or a mmdet config dict that defines a + backbone. The backbone takes a 4D image tensor and returns a + sequence of tensors. + neck: either a backbone module or a mmdet config dict that defines a + neck. The neck takes outputs of backbone and returns a + sequence of tensors. If None, no neck is used. + output_shapes: shape for every output of the backbone (or neck, if given). + stride and channels are often needed. + output_names: names for every output of the backbone (or neck, if given). + By default, will use "out0", "out1", ... + """ + super().__init__() + if isinstance(backbone, Mapping): + from mmdet.models import build_backbone + + backbone = build_backbone(_to_container(backbone)) + self.backbone = backbone + + if isinstance(neck, Mapping): + from mmdet.models import build_neck + + neck = build_neck(_to_container(neck)) + self.neck = neck + + # "Neck" weights, if any, are part of neck itself. This is the interface + # of mmdet so we follow it. Reference: + # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/two_stage.py + logger.info("Initializing mmdet backbone weights...") + self.backbone.init_weights() + # train() in mmdet modules is non-trivial, and has to be explicitly + # called. Reference: + # https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/backbones/resnet.py + self.backbone.train() + if self.neck is not None: + logger.info("Initializing mmdet neck weights ...") + if isinstance(self.neck, nn.Sequential): + for m in self.neck: + m.init_weights() + else: + self.neck.init_weights() + self.neck.train() + + self._output_shapes = output_shapes + if not output_names: + output_names = [f"out{i}" for i in range(len(output_shapes))] + self._output_names = output_names + + def forward(self, x) -> Dict[str, Tensor]: + outs = self.backbone(x) + if self.neck is not None: + outs = self.neck(outs) + assert isinstance( + outs, (list, tuple) + ), "mmdet backbone should return a list/tuple of tensors!" + if len(outs) != len(self._output_shapes): + raise ValueError( + "Length of output_shapes does not match outputs from the mmdet backbone: " + f"{len(outs)} != {len(self._output_shapes)}" + ) + return {k: v for k, v in zip(self._output_names, outs)} + + def output_shape(self) -> Dict[str, ShapeSpec]: + return {k: v for k, v in zip(self._output_names, self._output_shapes)} + + +class MMDetDetector(nn.Module): + """ + Wrapper of a mmdetection detector model, for detection and instance segmentation. + Input/output formats of this class follow detectron2's convention, so a + mmdetection model can be trained and evaluated in detectron2. + """ + + def __init__( + self, + detector: Union[nn.Module, Mapping], + *, + # Default is 32 regardless of model: + # https://github.com/open-mmlab/mmdetection/tree/master/configs/_base_/datasets + size_divisibility=32, + pixel_mean: Tuple[float], + pixel_std: Tuple[float], + ): + """ + Args: + detector: a mmdet detector, or a mmdet config dict that defines a detector. + size_divisibility: pad input images to multiple of this number + pixel_mean: per-channel mean to normalize input image + pixel_std: per-channel stddev to normalize input image + """ + super().__init__() + if isinstance(detector, Mapping): + from mmdet.models import build_detector + + detector = build_detector(_to_container(detector)) + self.detector = detector + self.detector.init_weights() + self.size_divisibility = size_divisibility + + self.register_buffer("pixel_mean", torch.tensor(pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(pixel_std).view(-1, 1, 1), False) + assert ( + self.pixel_mean.shape == self.pixel_std.shape + ), f"{self.pixel_mean} and {self.pixel_std} have different shapes!" + + def forward(self, batched_inputs: List[Dict[str, torch.Tensor]]): + images = [x["image"].to(self.device) for x in batched_inputs] + images = [(x - self.pixel_mean) / self.pixel_std for x in images] + images = ImageList.from_tensors(images, size_divisibility=self.size_divisibility).tensor + metas = [] + rescale = {"height" in x for x in batched_inputs} + if len(rescale) != 1: + raise ValueError("Some inputs have original height/width, but some don't!") + rescale = list(rescale)[0] + output_shapes = [] + for input in batched_inputs: + meta = {} + c, h, w = input["image"].shape + meta["img_shape"] = meta["ori_shape"] = (h, w, c) + if rescale: + scale_factor = np.array( + [w / input["width"], h / input["height"]] * 2, dtype="float32" + ) + ori_shape = (input["height"], input["width"]) + output_shapes.append(ori_shape) + meta["ori_shape"] = ori_shape + (c,) + else: + scale_factor = 1.0 + output_shapes.append((h, w)) + meta["scale_factor"] = scale_factor + meta["flip"] = False + padh, padw = images.shape[-2:] + meta["pad_shape"] = (padh, padw, c) + metas.append(meta) + + if self.training: + gt_instances = [x["instances"].to(self.device) for x in batched_inputs] + if gt_instances[0].has("gt_masks"): + from mmdet.core import PolygonMasks as mm_PolygonMasks, BitmapMasks as mm_BitMasks + + def convert_mask(m, shape): + # mmdet mask format + if isinstance(m, BitMasks): + return mm_BitMasks(m.tensor.cpu().numpy(), shape[0], shape[1]) + else: + return mm_PolygonMasks(m.polygons, shape[0], shape[1]) + + gt_masks = [convert_mask(x.gt_masks, x.image_size) for x in gt_instances] + losses_and_metrics = self.detector.forward_train( + images, + metas, + [x.gt_boxes.tensor for x in gt_instances], + [x.gt_classes for x in gt_instances], + gt_masks=gt_masks, + ) + else: + losses_and_metrics = self.detector.forward_train( + images, + metas, + [x.gt_boxes.tensor for x in gt_instances], + [x.gt_classes for x in gt_instances], + ) + return _parse_losses(losses_and_metrics) + else: + results = self.detector.simple_test(images, metas, rescale=rescale) + results = [ + {"instances": _convert_mmdet_result(r, shape)} + for r, shape in zip(results, output_shapes) + ] + return results + + @property + def device(self): + return self.pixel_mean.device + + +# Reference: show_result() in +# https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/base.py +def _convert_mmdet_result(result, shape: Tuple[int, int]) -> Instances: + if isinstance(result, tuple): + bbox_result, segm_result = result + if isinstance(segm_result, tuple): + segm_result = segm_result[0] + else: + bbox_result, segm_result = result, None + + bboxes = torch.from_numpy(np.vstack(bbox_result)) # Nx5 + bboxes, scores = bboxes[:, :4], bboxes[:, -1] + labels = [ + torch.full((bbox.shape[0],), i, dtype=torch.int32) for i, bbox in enumerate(bbox_result) + ] + labels = torch.cat(labels) + inst = Instances(shape) + inst.pred_boxes = Boxes(bboxes) + inst.scores = scores + inst.pred_classes = labels + + if segm_result is not None and len(labels) > 0: + segm_result = list(itertools.chain(*segm_result)) + segm_result = [torch.from_numpy(x) if isinstance(x, np.ndarray) else x for x in segm_result] + segm_result = torch.stack(segm_result, dim=0) + inst.pred_masks = segm_result + return inst + + +# reference: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/detectors/base.py +def _parse_losses(losses: Dict[str, Tensor]) -> Dict[str, Tensor]: + log_vars = OrderedDict() + for loss_name, loss_value in losses.items(): + if isinstance(loss_value, torch.Tensor): + log_vars[loss_name] = loss_value.mean() + elif isinstance(loss_value, list): + log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value) + else: + raise TypeError(f"{loss_name} is not a tensor or list of tensors") + + if "loss" not in loss_name: + # put metrics to storage; don't return them + storage = get_event_storage() + value = log_vars.pop(loss_name).cpu().item() + storage.put_scalar(loss_name, value) + return log_vars diff --git a/custom_detectron2/modeling/poolers.py b/custom_detectron2/modeling/poolers.py new file mode 100644 index 0000000000000000000000000000000000000000..7a077a9af5e9541069b834ee290118d3628c9cb2 --- /dev/null +++ b/custom_detectron2/modeling/poolers.py @@ -0,0 +1,263 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Optional +import torch +from torch import nn +from torchvision.ops import RoIPool + +from custom_detectron2.layers import ROIAlign, ROIAlignRotated, cat, nonzero_tuple, shapes_to_tensor +from custom_detectron2.structures import Boxes +from custom_detectron2.utils.tracing import assert_fx_safe, is_fx_tracing + +""" +To export ROIPooler to torchscript, in this file, variables that should be annotated with +`Union[List[Boxes], List[RotatedBoxes]]` are only annotated with `List[Boxes]`. + +TODO: Correct these annotations when torchscript support `Union`. +https://github.com/pytorch/pytorch/issues/41412 +""" + +__all__ = ["ROIPooler"] + + +def assign_boxes_to_levels( + box_lists: List[Boxes], + min_level: int, + max_level: int, + canonical_box_size: int, + canonical_level: int, +): + """ + Map each box in `box_lists` to a feature map level index and return the assignment + vector. + + Args: + box_lists (list[Boxes] | list[RotatedBoxes]): A list of N Boxes or N RotatedBoxes, + where N is the number of images in the batch. + min_level (int): Smallest feature map level index. The input is considered index 0, + the output of stage 1 is index 1, and so. + max_level (int): Largest feature map level index. + canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). + canonical_level (int): The feature map level index on which a canonically-sized box + should be placed. + + Returns: + A tensor of length M, where M is the total number of boxes aggregated over all + N batch images. The memory layout corresponds to the concatenation of boxes + from all images. Each element is the feature map index, as an offset from + `self.min_level`, for the corresponding box (so value i means the box is at + `self.min_level + i`). + """ + box_sizes = torch.sqrt(cat([boxes.area() for boxes in box_lists])) + # Eqn.(1) in FPN paper + level_assignments = torch.floor( + canonical_level + torch.log2(box_sizes / canonical_box_size + 1e-8) + ) + # clamp level to (min, max), in case the box size is too large or too small + # for the available feature maps + level_assignments = torch.clamp(level_assignments, min=min_level, max=max_level) + return level_assignments.to(torch.int64) - min_level + + +# script the module to avoid hardcoded device type +@torch.jit.script_if_tracing +def _convert_boxes_to_pooler_format(boxes: torch.Tensor, sizes: torch.Tensor) -> torch.Tensor: + sizes = sizes.to(device=boxes.device) + indices = torch.repeat_interleave( + torch.arange(len(sizes), dtype=boxes.dtype, device=boxes.device), sizes + ) + return cat([indices[:, None], boxes], dim=1) + + +def convert_boxes_to_pooler_format(box_lists: List[Boxes]): + """ + Convert all boxes in `box_lists` to the low-level format used by ROI pooling ops + (see description under Returns). + + Args: + box_lists (list[Boxes] | list[RotatedBoxes]): + A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. + + Returns: + When input is list[Boxes]: + A tensor of shape (M, 5), where M is the total number of boxes aggregated over all + N batch images. + The 5 columns are (batch index, x0, y0, x1, y1), where batch index + is the index in [0, N) identifying which batch image the box with corners at + (x0, y0, x1, y1) comes from. + When input is list[RotatedBoxes]: + A tensor of shape (M, 6), where M is the total number of boxes aggregated over all + N batch images. + The 6 columns are (batch index, x_ctr, y_ctr, width, height, angle_degrees), + where batch index is the index in [0, N) identifying which batch image the + rotated box (x_ctr, y_ctr, width, height, angle_degrees) comes from. + """ + boxes = torch.cat([x.tensor for x in box_lists], dim=0) + # __len__ returns Tensor in tracing. + sizes = shapes_to_tensor([x.__len__() for x in box_lists]) + return _convert_boxes_to_pooler_format(boxes, sizes) + + +@torch.jit.script_if_tracing +def _create_zeros( + batch_target: Optional[torch.Tensor], + channels: int, + height: int, + width: int, + like_tensor: torch.Tensor, +) -> torch.Tensor: + batches = batch_target.shape[0] if batch_target is not None else 0 + sizes = (batches, channels, height, width) + return torch.zeros(sizes, dtype=like_tensor.dtype, device=like_tensor.device) + + +class ROIPooler(nn.Module): + """ + Region of interest feature map pooler that supports pooling from one or more + feature maps. + """ + + def __init__( + self, + output_size, + scales, + sampling_ratio, + pooler_type, + canonical_box_size=224, + canonical_level=4, + ): + """ + Args: + output_size (int, tuple[int] or list[int]): output size of the pooled region, + e.g., 14 x 14. If tuple or list is given, the length must be 2. + scales (list[float]): The scale for each low-level pooling op relative to + the input image. For a feature map with stride s relative to the input + image, scale is defined as 1/s. The stride must be power of 2. + When there are multiple scales, they must form a pyramid, i.e. they must be + a monotically decreasing geometric sequence with a factor of 1/2. + sampling_ratio (int): The `sampling_ratio` parameter for the ROIAlign op. + pooler_type (string): Name of the type of pooling operation that should be applied. + For instance, "ROIPool" or "ROIAlignV2". + canonical_box_size (int): A canonical box size in pixels (sqrt(box area)). The default + is heuristically defined as 224 pixels in the FPN paper (based on ImageNet + pre-training). + canonical_level (int): The feature map level index from which a canonically-sized box + should be placed. The default is defined as level 4 (stride=16) in the FPN paper, + i.e., a box of size 224x224 will be placed on the feature with stride=16. + The box placement for all boxes will be determined from their sizes w.r.t + canonical_box_size. For example, a box whose area is 4x that of a canonical box + should be used to pool features from feature level ``canonical_level+1``. + + Note that the actual input feature maps given to this module may not have + sufficiently many levels for the input boxes. If the boxes are too large or too + small for the input feature maps, the closest level will be used. + """ + super().__init__() + + if isinstance(output_size, int): + output_size = (output_size, output_size) + assert len(output_size) == 2 + assert isinstance(output_size[0], int) and isinstance(output_size[1], int) + self.output_size = output_size + + if pooler_type == "ROIAlign": + self.level_poolers = nn.ModuleList( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=False + ) + for scale in scales + ) + elif pooler_type == "ROIAlignV2": + self.level_poolers = nn.ModuleList( + ROIAlign( + output_size, spatial_scale=scale, sampling_ratio=sampling_ratio, aligned=True + ) + for scale in scales + ) + elif pooler_type == "ROIPool": + self.level_poolers = nn.ModuleList( + RoIPool(output_size, spatial_scale=scale) for scale in scales + ) + elif pooler_type == "ROIAlignRotated": + self.level_poolers = nn.ModuleList( + ROIAlignRotated(output_size, spatial_scale=scale, sampling_ratio=sampling_ratio) + for scale in scales + ) + else: + raise ValueError("Unknown pooler type: {}".format(pooler_type)) + + # Map scale (defined as 1 / stride) to its feature map level under the + # assumption that stride is a power of 2. + min_level = -(math.log2(scales[0])) + max_level = -(math.log2(scales[-1])) + assert math.isclose(min_level, int(min_level)) and math.isclose( + max_level, int(max_level) + ), "Featuremap stride is not power of 2!" + self.min_level = int(min_level) + self.max_level = int(max_level) + assert ( + len(scales) == self.max_level - self.min_level + 1 + ), "[ROIPooler] Sizes of input featuremaps do not form a pyramid!" + assert 0 <= self.min_level and self.min_level <= self.max_level + self.canonical_level = canonical_level + assert canonical_box_size > 0 + self.canonical_box_size = canonical_box_size + + def forward(self, x: List[torch.Tensor], box_lists: List[Boxes]): + """ + Args: + x (list[Tensor]): A list of feature maps of NCHW shape, with scales matching those + used to construct this module. + box_lists (list[Boxes] | list[RotatedBoxes]): + A list of N Boxes or N RotatedBoxes, where N is the number of images in the batch. + The box coordinates are defined on the original image and + will be scaled by the `scales` argument of :class:`ROIPooler`. + + Returns: + Tensor: + A tensor of shape (M, C, output_size, output_size) where M is the total number of + boxes aggregated over all N batch images and C is the number of channels in `x`. + """ + num_level_assignments = len(self.level_poolers) + + if not is_fx_tracing(): + torch._assert( + isinstance(x, list) and isinstance(box_lists, list), + "Arguments to pooler must be lists", + ) + assert_fx_safe( + len(x) == num_level_assignments, + "unequal value, num_level_assignments={}, but x is list of {} Tensors".format( + num_level_assignments, len(x) + ), + ) + assert_fx_safe( + len(box_lists) == x[0].size(0), + "unequal value, x[0] batch dim 0 is {}, but box_list has length {}".format( + x[0].size(0), len(box_lists) + ), + ) + if len(box_lists) == 0: + return _create_zeros(None, x[0].shape[1], *self.output_size, x[0]) + + pooler_fmt_boxes = convert_boxes_to_pooler_format(box_lists) + + if num_level_assignments == 1: + return self.level_poolers[0](x[0], pooler_fmt_boxes) + + level_assignments = assign_boxes_to_levels( + box_lists, self.min_level, self.max_level, self.canonical_box_size, self.canonical_level + ) + + num_channels = x[0].shape[1] + output_size = self.output_size[0] + + output = _create_zeros(pooler_fmt_boxes, num_channels, output_size, output_size, x[0]) + + for level, pooler in enumerate(self.level_poolers): + inds = nonzero_tuple(level_assignments == level)[0] + pooler_fmt_boxes_level = pooler_fmt_boxes[inds] + # Use index_put_ instead of advance indexing, to avoid pytorch/issues/49852 + output.index_put_((inds,), pooler(x[level], pooler_fmt_boxes_level)) + + return output diff --git a/custom_detectron2/modeling/postprocessing.py b/custom_detectron2/modeling/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..97ac46bb548a3133084227db337d62aefb1c9d5f --- /dev/null +++ b/custom_detectron2/modeling/postprocessing.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +from torch.nn import functional as F + +from custom_detectron2.structures import Instances, ROIMasks + + +# perhaps should rename to "resize_instance" +def detector_postprocess( + results: Instances, output_height: int, output_width: int, mask_threshold: float = 0.5 +): + """ + Resize the output instances. + The input images are often resized when entering an object detector. + As a result, we often need the outputs of the detector in a different + resolution from its inputs. + + This function will resize the raw outputs of an R-CNN detector + to produce outputs according to the desired output resolution. + + Args: + results (Instances): the raw outputs from the detector. + `results.image_size` contains the input image resolution the detector sees. + This object might be modified in-place. + output_height, output_width: the desired output resolution. + Returns: + Instances: the resized output from the model, based on the output resolution + """ + if isinstance(output_width, torch.Tensor): + # This shape might (but not necessarily) be tensors during tracing. + # Converts integer tensors to float temporaries to ensure true + # division is performed when computing scale_x and scale_y. + output_width_tmp = output_width.float() + output_height_tmp = output_height.float() + new_size = torch.stack([output_height, output_width]) + else: + new_size = (output_height, output_width) + output_width_tmp = output_width + output_height_tmp = output_height + + scale_x, scale_y = ( + output_width_tmp / results.image_size[1], + output_height_tmp / results.image_size[0], + ) + results = Instances(new_size, **results.get_fields()) + + if results.has("pred_boxes"): + output_boxes = results.pred_boxes + elif results.has("proposal_boxes"): + output_boxes = results.proposal_boxes + else: + output_boxes = None + assert output_boxes is not None, "Predictions must contain boxes!" + + output_boxes.scale(scale_x, scale_y) + output_boxes.clip(results.image_size) + + results = results[output_boxes.nonempty()] + + if results.has("pred_masks"): + if isinstance(results.pred_masks, ROIMasks): + roi_masks = results.pred_masks + else: + # pred_masks is a tensor of shape (N, 1, M, M) + roi_masks = ROIMasks(results.pred_masks[:, 0, :, :]) + results.pred_masks = roi_masks.to_bitmasks( + results.pred_boxes, output_height, output_width, mask_threshold + ).tensor # TODO return ROIMasks/BitMask object in the future + + if results.has("pred_keypoints"): + results.pred_keypoints[:, :, 0] *= scale_x + results.pred_keypoints[:, :, 1] *= scale_y + + return results + + +def sem_seg_postprocess(result, img_size, output_height, output_width): + """ + Return semantic segmentation predictions in the original resolution. + + The input images are often resized when entering semantic segmentor. Moreover, in same + cases, they also padded inside segmentor to be divisible by maximum network stride. + As a result, we often need the predictions of the segmentor in a different + resolution from its inputs. + + Args: + result (Tensor): semantic segmentation prediction logits. A tensor of shape (C, H, W), + where C is the number of classes, and H, W are the height and width of the prediction. + img_size (tuple): image size that segmentor is taking as input. + output_height, output_width: the desired output resolution. + + Returns: + semantic segmentation prediction (Tensor): A tensor of the shape + (C, output_height, output_width) that contains per-pixel soft predictions. + """ + result = result[:, : img_size[0], : img_size[1]].expand(1, -1, -1, -1) + result = F.interpolate( + result, size=(output_height, output_width), mode="bilinear", align_corners=False + )[0] + return result diff --git a/custom_detectron2/modeling/proposal_generator/__init__.py b/custom_detectron2/modeling/proposal_generator/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3f4e4df7645c67b7a013295207b98fe70b2e574c --- /dev/null +++ b/custom_detectron2/modeling/proposal_generator/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import PROPOSAL_GENERATOR_REGISTRY, build_proposal_generator +from .rpn import RPN_HEAD_REGISTRY, build_rpn_head, RPN, StandardRPNHead + +__all__ = list(globals().keys()) diff --git a/custom_detectron2/modeling/proposal_generator/build.py b/custom_detectron2/modeling/proposal_generator/build.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc0dca8574598562d10fe66a6900666e6a0cc04 --- /dev/null +++ b/custom_detectron2/modeling/proposal_generator/build.py @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from custom_detectron2.utils.registry import Registry + +PROPOSAL_GENERATOR_REGISTRY = Registry("PROPOSAL_GENERATOR") +PROPOSAL_GENERATOR_REGISTRY.__doc__ = """ +Registry for proposal generator, which produces object proposals from feature maps. + +The registered object will be called with `obj(cfg, input_shape)`. +The call should return a `nn.Module` object. +""" + +from . import rpn, rrpn # noqa F401 isort:skip + + +def build_proposal_generator(cfg, input_shape): + """ + Build a proposal generator from `cfg.MODEL.PROPOSAL_GENERATOR.NAME`. + The name can be "PrecomputedProposals" to use no proposal generator. + """ + name = cfg.MODEL.PROPOSAL_GENERATOR.NAME + if name == "PrecomputedProposals": + return None + + return PROPOSAL_GENERATOR_REGISTRY.get(name)(cfg, input_shape) diff --git a/custom_detectron2/modeling/proposal_generator/proposal_utils.py b/custom_detectron2/modeling/proposal_generator/proposal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ead4ce29ff03a475621e44df910391ea59378f --- /dev/null +++ b/custom_detectron2/modeling/proposal_generator/proposal_utils.py @@ -0,0 +1,205 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from typing import List, Tuple, Union +import torch + +from custom_detectron2.layers import batched_nms, cat, move_device_like +from custom_detectron2.structures import Boxes, Instances + +logger = logging.getLogger(__name__) + + +def _is_tracing(): + # (fixed in TORCH_VERSION >= 1.9) + if torch.jit.is_scripting(): + # https://github.com/pytorch/pytorch/issues/47379 + return False + else: + return torch.jit.is_tracing() + + +def find_top_rpn_proposals( + proposals: List[torch.Tensor], + pred_objectness_logits: List[torch.Tensor], + image_sizes: List[Tuple[int, int]], + nms_thresh: float, + pre_nms_topk: int, + post_nms_topk: int, + min_box_size: float, + training: bool, +): + """ + For each feature map, select the `pre_nms_topk` highest scoring proposals, + apply NMS, clip proposals, and remove small boxes. Return the `post_nms_topk` + highest scoring proposals among all the feature maps for each image. + + Args: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A, 4). + All proposal predictions on the feature maps. + pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A). + image_sizes (list[tuple]): sizes (h, w) for each image + nms_thresh (float): IoU threshold to use for NMS + pre_nms_topk (int): number of top k scoring proposals to keep before applying NMS. + When RPN is run on multiple feature maps (as in FPN) this number is per + feature map. + post_nms_topk (int): number of top k scoring proposals to keep after applying NMS. + When RPN is run on multiple feature maps (as in FPN) this number is total, + over all feature maps. + min_box_size (float): minimum proposal box side length in pixels (absolute units + wrt input images). + training (bool): True if proposals are to be used in training, otherwise False. + This arg exists only to support a legacy bug; look for the "NB: Legacy bug ..." + comment. + + Returns: + list[Instances]: list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i, sorted by their + objectness score in descending order. + """ + num_images = len(image_sizes) + device = ( + proposals[0].device + if torch.jit.is_scripting() + else ("cpu" if torch.jit.is_tracing() else proposals[0].device) + ) + + # 1. Select top-k anchor for every level and every image + topk_scores = [] # #lvl Tensor, each of shape N x topk + topk_proposals = [] + level_ids = [] # #lvl Tensor, each of shape (topk,) + batch_idx = move_device_like(torch.arange(num_images, device=device), proposals[0]) + for level_id, (proposals_i, logits_i) in enumerate(zip(proposals, pred_objectness_logits)): + Hi_Wi_A = logits_i.shape[1] + if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing + num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk) + else: + num_proposals_i = min(Hi_Wi_A, pre_nms_topk) + + topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) + + # each is N x topk + topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 4 + + topk_proposals.append(topk_proposals_i) + topk_scores.append(topk_scores_i) + level_ids.append( + move_device_like( + torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device), + proposals[0], + ) + ) + + # 2. Concat all levels together + topk_scores = cat(topk_scores, dim=1) + topk_proposals = cat(topk_proposals, dim=1) + level_ids = cat(level_ids, dim=0) + + # 3. For each image, run a per-level NMS, and choose topk results. + results: List[Instances] = [] + for n, image_size in enumerate(image_sizes): + boxes = Boxes(topk_proposals[n]) + scores_per_img = topk_scores[n] + lvl = level_ids + + valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img) + if not valid_mask.all(): + if training: + raise FloatingPointError( + "Predicted boxes or scores contain Inf/NaN. Training has diverged." + ) + boxes = boxes[valid_mask] + scores_per_img = scores_per_img[valid_mask] + lvl = lvl[valid_mask] + boxes.clip(image_size) + + # filter empty boxes + keep = boxes.nonempty(threshold=min_box_size) + if _is_tracing() or keep.sum().item() != len(boxes): + boxes, scores_per_img, lvl = boxes[keep], scores_per_img[keep], lvl[keep] + + keep = batched_nms(boxes.tensor, scores_per_img, lvl, nms_thresh) + # In Detectron1, there was different behavior during training vs. testing. + # (https://github.com/facebookresearch/Detectron/issues/459) + # During training, topk is over the proposals from *all* images in the training batch. + # During testing, it is over the proposals for each image separately. + # As a result, the training behavior becomes batch-dependent, + # and the configuration "POST_NMS_TOPK_TRAIN" end up relying on the batch size. + # This bug is addressed in Detectron2 to make the behavior independent of batch size. + keep = keep[:post_nms_topk] # keep is already sorted + + res = Instances(image_size) + res.proposal_boxes = boxes[keep] + res.objectness_logits = scores_per_img[keep] + results.append(res) + return results + + +def add_ground_truth_to_proposals( + gt: Union[List[Instances], List[Boxes]], proposals: List[Instances] +) -> List[Instances]: + """ + Call `add_ground_truth_to_proposals_single_image` for all images. + + Args: + gt(Union[List[Instances], List[Boxes]): list of N elements. Element i is a Instances + representing the ground-truth for image i. + proposals (list[Instances]): list of N elements. Element i is a Instances + representing the proposals for image i. + + Returns: + list[Instances]: list of N Instances. Each is the proposals for the image, + with field "proposal_boxes" and "objectness_logits". + """ + assert gt is not None + + if len(proposals) != len(gt): + raise ValueError("proposals and gt should have the same length as the number of images!") + if len(proposals) == 0: + return proposals + + return [ + add_ground_truth_to_proposals_single_image(gt_i, proposals_i) + for gt_i, proposals_i in zip(gt, proposals) + ] + + +def add_ground_truth_to_proposals_single_image( + gt: Union[Instances, Boxes], proposals: Instances +) -> Instances: + """ + Augment `proposals` with `gt`. + + Args: + Same as `add_ground_truth_to_proposals`, but with gt and proposals + per image. + + Returns: + Same as `add_ground_truth_to_proposals`, but for only one image. + """ + if isinstance(gt, Boxes): + # convert Boxes to Instances + gt = Instances(proposals.image_size, gt_boxes=gt) + + gt_boxes = gt.gt_boxes + device = proposals.objectness_logits.device + # Assign all ground-truth boxes an objectness logit corresponding to + # P(object) = sigmoid(logit) =~ 1. + gt_logit_value = math.log((1.0 - 1e-10) / (1 - (1.0 - 1e-10))) + gt_logits = gt_logit_value * torch.ones(len(gt_boxes), device=device) + + # Concatenating gt_boxes with proposals requires them to have the same fields + gt_proposal = Instances(proposals.image_size, **gt.get_fields()) + gt_proposal.proposal_boxes = gt_boxes + gt_proposal.objectness_logits = gt_logits + + for key in proposals.get_fields().keys(): + assert gt_proposal.has( + key + ), "The attribute '{}' in `proposals` does not exist in `gt`".format(key) + + # NOTE: Instances.cat only use fields from the first item. Extra fields in latter items + # will be thrown away. + new_proposals = Instances.cat([proposals, gt_proposal]) + + return new_proposals diff --git a/custom_detectron2/modeling/proposal_generator/rpn.py b/custom_detectron2/modeling/proposal_generator/rpn.py new file mode 100644 index 0000000000000000000000000000000000000000..dff181bcd489f89cb65a1978a4bf9efac4f30d06 --- /dev/null +++ b/custom_detectron2/modeling/proposal_generator/rpn.py @@ -0,0 +1,533 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Dict, List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.layers import Conv2d, ShapeSpec, cat +from custom_detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from custom_detectron2.utils.events import get_event_storage +from custom_detectron2.utils.memory import retry_if_cuda_oom +from custom_detectron2.utils.registry import Registry + +from ..anchor_generator import build_anchor_generator +from ..box_regression import Box2BoxTransform, _dense_box_regression_loss +from ..matcher import Matcher +from ..sampling import subsample_labels +from .build import PROPOSAL_GENERATOR_REGISTRY +from .proposal_utils import find_top_rpn_proposals + +RPN_HEAD_REGISTRY = Registry("RPN_HEAD") +RPN_HEAD_REGISTRY.__doc__ = """ +Registry for RPN heads, which take feature maps and perform +objectness classification and bounding box regression for anchors. + +The registered object will be called with `obj(cfg, input_shape)`. +The call should return a `nn.Module` object. +""" + + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + L: number of feature maps per image on which RPN is run + A: number of cell anchors (must be the same for all feature maps) + Hi, Wi: height and width of the i-th feature map + B: size of the box parameterization + +Naming convention: + + objectness: refers to the binary classification of an anchor as object vs. not object. + + deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransform`), or 5d for rotated boxes. + + pred_objectness_logits: predicted objectness scores in [-inf, +inf]; use + sigmoid(pred_objectness_logits) to estimate P(object). + + gt_labels: ground-truth binary classification labels for objectness + + pred_anchor_deltas: predicted box2box transform deltas + + gt_anchor_deltas: ground-truth box2box transform deltas +""" + + +def build_rpn_head(cfg, input_shape): + """ + Build an RPN head defined by `cfg.MODEL.RPN.HEAD_NAME`. + """ + name = cfg.MODEL.RPN.HEAD_NAME + return RPN_HEAD_REGISTRY.get(name)(cfg, input_shape) + + +@RPN_HEAD_REGISTRY.register() +class StandardRPNHead(nn.Module): + """ + Standard RPN classification and regression heads described in :paper:`Faster R-CNN`. + Uses a 3x3 conv to produce a shared hidden state from which one 1x1 conv predicts + objectness logits for each anchor and a second 1x1 conv predicts bounding-box deltas + specifying how to deform each anchor into an object proposal. + """ + + @configurable + def __init__( + self, *, in_channels: int, num_anchors: int, box_dim: int = 4, conv_dims: List[int] = (-1,) + ): + """ + NOTE: this interface is experimental. + + Args: + in_channels (int): number of input feature channels. When using multiple + input features, they must have the same number of channels. + num_anchors (int): number of anchors to predict for *each spatial position* + on the feature map. The total number of anchors for each + feature map will be `num_anchors * H * W`. + box_dim (int): dimension of a box, which is also the number of box regression + predictions to make for each anchor. An axis aligned box has + box_dim=4, while a rotated box has box_dim=5. + conv_dims (list[int]): a list of integers representing the output channels + of N conv layers. Set it to -1 to use the same number of output channels + as input channels. + """ + super().__init__() + cur_channels = in_channels + # Keeping the old variable names and structure for backwards compatiblity. + # Otherwise the old checkpoints will fail to load. + if len(conv_dims) == 1: + out_channels = cur_channels if conv_dims[0] == -1 else conv_dims[0] + # 3x3 conv for the hidden representation + self.conv = self._get_rpn_conv(cur_channels, out_channels) + cur_channels = out_channels + else: + self.conv = nn.Sequential() + for k, conv_dim in enumerate(conv_dims): + out_channels = cur_channels if conv_dim == -1 else conv_dim + if out_channels <= 0: + raise ValueError( + f"Conv output channels should be greater than 0. Got {out_channels}" + ) + conv = self._get_rpn_conv(cur_channels, out_channels) + self.conv.add_module(f"conv{k}", conv) + cur_channels = out_channels + # 1x1 conv for predicting objectness logits + self.objectness_logits = nn.Conv2d(cur_channels, num_anchors, kernel_size=1, stride=1) + # 1x1 conv for predicting box2box transform deltas + self.anchor_deltas = nn.Conv2d(cur_channels, num_anchors * box_dim, kernel_size=1, stride=1) + + # Keeping the order of weights initialization same for backwards compatiblility. + for layer in self.modules(): + if isinstance(layer, nn.Conv2d): + nn.init.normal_(layer.weight, std=0.01) + nn.init.constant_(layer.bias, 0) + + def _get_rpn_conv(self, in_channels, out_channels): + return Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + activation=nn.ReLU(), + ) + + @classmethod + def from_config(cls, cfg, input_shape): + # Standard RPN is shared across levels: + in_channels = [s.channels for s in input_shape] + assert len(set(in_channels)) == 1, "Each level must have the same channel!" + in_channels = in_channels[0] + + # RPNHead should take the same input as anchor generator + # NOTE: it assumes that creating an anchor generator does not have unwanted side effect. + anchor_generator = build_anchor_generator(cfg, input_shape) + num_anchors = anchor_generator.num_anchors + box_dim = anchor_generator.box_dim + assert ( + len(set(num_anchors)) == 1 + ), "Each level must have the same number of anchors per spatial position" + return { + "in_channels": in_channels, + "num_anchors": num_anchors[0], + "box_dim": box_dim, + "conv_dims": cfg.MODEL.RPN.CONV_DIMS, + } + + def forward(self, features: List[torch.Tensor]): + """ + Args: + features (list[Tensor]): list of feature maps + + Returns: + list[Tensor]: A list of L elements. + Element i is a tensor of shape (N, A, Hi, Wi) representing + the predicted objectness logits for all anchors. A is the number of cell anchors. + list[Tensor]: A list of L elements. Element i is a tensor of shape + (N, A*box_dim, Hi, Wi) representing the predicted "deltas" used to transform anchors + to proposals. + """ + pred_objectness_logits = [] + pred_anchor_deltas = [] + for x in features: + t = self.conv(x) + pred_objectness_logits.append(self.objectness_logits(t)) + pred_anchor_deltas.append(self.anchor_deltas(t)) + return pred_objectness_logits, pred_anchor_deltas + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class RPN(nn.Module): + """ + Region Proposal Network, introduced by :paper:`Faster R-CNN`. + """ + + @configurable + def __init__( + self, + *, + in_features: List[str], + head: nn.Module, + anchor_generator: nn.Module, + anchor_matcher: Matcher, + box2box_transform: Box2BoxTransform, + batch_size_per_image: int, + positive_fraction: float, + pre_nms_topk: Tuple[float, float], + post_nms_topk: Tuple[float, float], + nms_thresh: float = 0.7, + min_box_size: float = 0.0, + anchor_boundary_thresh: float = -1.0, + loss_weight: Union[float, Dict[str, float]] = 1.0, + box_reg_loss_type: str = "smooth_l1", + smooth_l1_beta: float = 0.0, + ): + """ + NOTE: this interface is experimental. + + Args: + in_features (list[str]): list of names of input features to use + head (nn.Module): a module that predicts logits and regression deltas + for each level from a list of per-level features + anchor_generator (nn.Module): a module that creates anchors from a + list of features. Usually an instance of :class:`AnchorGenerator` + anchor_matcher (Matcher): label the anchors by matching them with ground truth. + box2box_transform (Box2BoxTransform): defines the transform from anchors boxes to + instance boxes + batch_size_per_image (int): number of anchors per image to sample for training + positive_fraction (float): fraction of foreground anchors to sample for training + pre_nms_topk (tuple[float]): (train, test) that represents the + number of top k proposals to select before NMS, in + training and testing. + post_nms_topk (tuple[float]): (train, test) that represents the + number of top k proposals to select after NMS, in + training and testing. + nms_thresh (float): NMS threshold used to de-duplicate the predicted proposals + min_box_size (float): remove proposal boxes with any side smaller than this threshold, + in the unit of input image pixels + anchor_boundary_thresh (float): legacy option + loss_weight (float|dict): weights to use for losses. Can be single float for weighting + all rpn losses together, or a dict of individual weightings. Valid dict keys are: + "loss_rpn_cls" - applied to classification loss + "loss_rpn_loc" - applied to box regression loss + box_reg_loss_type (str): Loss type to use. Supported losses: "smooth_l1", "giou". + smooth_l1_beta (float): beta parameter for the smooth L1 regression loss. Default to + use L1 loss. Only used when `box_reg_loss_type` is "smooth_l1" + """ + super().__init__() + self.in_features = in_features + self.rpn_head = head + self.anchor_generator = anchor_generator + self.anchor_matcher = anchor_matcher + self.box2box_transform = box2box_transform + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + # Map from self.training state to train/test settings + self.pre_nms_topk = {True: pre_nms_topk[0], False: pre_nms_topk[1]} + self.post_nms_topk = {True: post_nms_topk[0], False: post_nms_topk[1]} + self.nms_thresh = nms_thresh + self.min_box_size = float(min_box_size) + self.anchor_boundary_thresh = anchor_boundary_thresh + if isinstance(loss_weight, float): + loss_weight = {"loss_rpn_cls": loss_weight, "loss_rpn_loc": loss_weight} + self.loss_weight = loss_weight + self.box_reg_loss_type = box_reg_loss_type + self.smooth_l1_beta = smooth_l1_beta + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + in_features = cfg.MODEL.RPN.IN_FEATURES + ret = { + "in_features": in_features, + "min_box_size": cfg.MODEL.PROPOSAL_GENERATOR.MIN_SIZE, + "nms_thresh": cfg.MODEL.RPN.NMS_THRESH, + "batch_size_per_image": cfg.MODEL.RPN.BATCH_SIZE_PER_IMAGE, + "positive_fraction": cfg.MODEL.RPN.POSITIVE_FRACTION, + "loss_weight": { + "loss_rpn_cls": cfg.MODEL.RPN.LOSS_WEIGHT, + "loss_rpn_loc": cfg.MODEL.RPN.BBOX_REG_LOSS_WEIGHT * cfg.MODEL.RPN.LOSS_WEIGHT, + }, + "anchor_boundary_thresh": cfg.MODEL.RPN.BOUNDARY_THRESH, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS), + "box_reg_loss_type": cfg.MODEL.RPN.BBOX_REG_LOSS_TYPE, + "smooth_l1_beta": cfg.MODEL.RPN.SMOOTH_L1_BETA, + } + + ret["pre_nms_topk"] = (cfg.MODEL.RPN.PRE_NMS_TOPK_TRAIN, cfg.MODEL.RPN.PRE_NMS_TOPK_TEST) + ret["post_nms_topk"] = (cfg.MODEL.RPN.POST_NMS_TOPK_TRAIN, cfg.MODEL.RPN.POST_NMS_TOPK_TEST) + + ret["anchor_generator"] = build_anchor_generator(cfg, [input_shape[f] for f in in_features]) + ret["anchor_matcher"] = Matcher( + cfg.MODEL.RPN.IOU_THRESHOLDS, cfg.MODEL.RPN.IOU_LABELS, allow_low_quality_matches=True + ) + ret["head"] = build_rpn_head(cfg, [input_shape[f] for f in in_features]) + return ret + + def _subsample_labels(self, label): + """ + Randomly sample a subset of positive and negative examples, and overwrite + the label vector to the ignore value (-1) for all elements that are not + included in the sample. + + Args: + labels (Tensor): a vector of -1, 0, 1. Will be modified in-place and returned. + """ + pos_idx, neg_idx = subsample_labels( + label, self.batch_size_per_image, self.positive_fraction, 0 + ) + # Fill with the ignore label (-1), then set positive and negative labels + label.fill_(-1) + label.scatter_(0, pos_idx, 1) + label.scatter_(0, neg_idx, 0) + return label + + @torch.jit.unused + @torch.no_grad() + def label_and_sample_anchors( + self, anchors: List[Boxes], gt_instances: List[Instances] + ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + """ + Args: + anchors (list[Boxes]): anchors for each feature map. + gt_instances: the ground-truth instances for each image. + + Returns: + list[Tensor]: + List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across all feature maps R = sum(Hi * Wi * A). + Label values are in {-1, 0, 1}, with meanings: -1 = ignore; 0 = negative + class; 1 = positive class. + list[Tensor]: + i-th element is a Rx4 tensor. The values are the matched gt boxes for each + anchor. Values are undefined for those anchors not labeled as 1. + """ + anchors = Boxes.cat(anchors) + + gt_boxes = [x.gt_boxes for x in gt_instances] + image_sizes = [x.image_size for x in gt_instances] + del gt_instances + + gt_labels = [] + matched_gt_boxes = [] + for image_size_i, gt_boxes_i in zip(image_sizes, gt_boxes): + """ + image_size_i: (h, w) for the i-th image + gt_boxes_i: ground-truth boxes for i-th image + """ + + match_quality_matrix = retry_if_cuda_oom(pairwise_iou)(gt_boxes_i, anchors) + matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) + # Matching is memory-expensive and may result in CPU tensors. But the result is small + gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) + del match_quality_matrix + + if self.anchor_boundary_thresh >= 0: + # Discard anchors that go out of the boundaries of the image + # NOTE: This is legacy functionality that is turned off by default in Detectron2 + anchors_inside_image = anchors.inside_box(image_size_i, self.anchor_boundary_thresh) + gt_labels_i[~anchors_inside_image] = -1 + + # A vector of labels (-1, 0, 1) for each anchor + gt_labels_i = self._subsample_labels(gt_labels_i) + + if len(gt_boxes_i) == 0: + # These values won't be used anyway since the anchor is labeled as background + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + else: + # TODO wasted indexing computation for ignored boxes + matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor + + gt_labels.append(gt_labels_i) # N,AHW + matched_gt_boxes.append(matched_gt_boxes_i) + return gt_labels, matched_gt_boxes + + @torch.jit.unused + def losses( + self, + anchors: List[Boxes], + pred_objectness_logits: List[torch.Tensor], + gt_labels: List[torch.Tensor], + pred_anchor_deltas: List[torch.Tensor], + gt_boxes: List[torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """ + Return the losses from a set of RPN predictions and their associated ground-truth. + + Args: + anchors (list[Boxes or RotatedBoxes]): anchors for each feature map, each + has shape (Hi*Wi*A, B), where B is box dimension (4 or 5). + pred_objectness_logits (list[Tensor]): A list of L elements. + Element i is a tensor of shape (N, Hi*Wi*A) representing + the predicted objectness logits for all anchors. + gt_labels (list[Tensor]): Output of :meth:`label_and_sample_anchors`. + pred_anchor_deltas (list[Tensor]): A list of L elements. Element i is a tensor of shape + (N, Hi*Wi*A, 4 or 5) representing the predicted "deltas" used to transform anchors + to proposals. + gt_boxes (list[Tensor]): Output of :meth:`label_and_sample_anchors`. + + Returns: + dict[loss name -> loss value]: A dict mapping from loss name to loss value. + Loss names are: `loss_rpn_cls` for objectness classification and + `loss_rpn_loc` for proposal localization. + """ + num_images = len(gt_labels) + gt_labels = torch.stack(gt_labels) # (N, sum(Hi*Wi*Ai)) + + # Log the number of positive/negative anchors per-image that's used in training + pos_mask = gt_labels == 1 + num_pos_anchors = pos_mask.sum().item() + num_neg_anchors = (gt_labels == 0).sum().item() + storage = get_event_storage() + storage.put_scalar("rpn/num_pos_anchors", num_pos_anchors / num_images) + storage.put_scalar("rpn/num_neg_anchors", num_neg_anchors / num_images) + + localization_loss = _dense_box_regression_loss( + anchors, + self.box2box_transform, + pred_anchor_deltas, + gt_boxes, + pos_mask, + box_reg_loss_type=self.box_reg_loss_type, + smooth_l1_beta=self.smooth_l1_beta, + ) + + valid_mask = gt_labels >= 0 + objectness_loss = F.binary_cross_entropy_with_logits( + cat(pred_objectness_logits, dim=1)[valid_mask], + gt_labels[valid_mask].to(torch.float32), + reduction="sum", + ) + normalizer = self.batch_size_per_image * num_images + losses = { + "loss_rpn_cls": objectness_loss / normalizer, + # The original Faster R-CNN paper uses a slightly different normalizer + # for loc loss. But it doesn't matter in practice + "loss_rpn_loc": localization_loss / normalizer, + } + losses = {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + return losses + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + gt_instances: Optional[List[Instances]] = None, + ): + """ + Args: + images (ImageList): input images of length `N` + features (dict[str, Tensor]): input data as a mapping from feature + map name to tensor. Axis 0 represents the number of images `N` in + the input data; axes 1-3 are channels, height, and width, which may + vary between feature maps (e.g., if a feature pyramid is used). + gt_instances (list[Instances], optional): a length `N` list of `Instances`s. + Each `Instances` stores ground-truth instances for the corresponding image. + + Returns: + proposals: list[Instances]: contains fields "proposal_boxes", "objectness_logits" + loss: dict[Tensor] or None + """ + features = [features[f] for f in self.in_features] + anchors = self.anchor_generator(features) + + pred_objectness_logits, pred_anchor_deltas = self.rpn_head(features) + # Transpose the Hi*Wi*A dimension to the middle: + pred_objectness_logits = [ + # (N, A, Hi, Wi) -> (N, Hi, Wi, A) -> (N, Hi*Wi*A) + score.permute(0, 2, 3, 1).flatten(1) + for score in pred_objectness_logits + ] + pred_anchor_deltas = [ + # (N, A*B, Hi, Wi) -> (N, A, B, Hi, Wi) -> (N, Hi, Wi, A, B) -> (N, Hi*Wi*A, B) + x.view(x.shape[0], -1, self.anchor_generator.box_dim, x.shape[-2], x.shape[-1]) + .permute(0, 3, 4, 1, 2) + .flatten(1, -2) + for x in pred_anchor_deltas + ] + + if self.training: + assert gt_instances is not None, "RPN requires gt_instances in training!" + gt_labels, gt_boxes = self.label_and_sample_anchors(anchors, gt_instances) + losses = self.losses( + anchors, pred_objectness_logits, gt_labels, pred_anchor_deltas, gt_boxes + ) + else: + losses = {} + proposals = self.predict_proposals( + anchors, pred_objectness_logits, pred_anchor_deltas, images.image_sizes + ) + return proposals, losses + + def predict_proposals( + self, + anchors: List[Boxes], + pred_objectness_logits: List[torch.Tensor], + pred_anchor_deltas: List[torch.Tensor], + image_sizes: List[Tuple[int, int]], + ): + """ + Decode all the predicted box regression deltas to proposals. Find the top proposals + by applying NMS and removing boxes that are too small. + + Returns: + proposals (list[Instances]): list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i, sorted by their + objectness score in descending order. + """ + # The proposals are treated as fixed for joint training with roi heads. + # This approach ignores the derivative w.r.t. the proposal boxes’ coordinates that + # are also network responses. + with torch.no_grad(): + pred_proposals = self._decode_proposals(anchors, pred_anchor_deltas) + return find_top_rpn_proposals( + pred_proposals, + pred_objectness_logits, + image_sizes, + self.nms_thresh, + self.pre_nms_topk[self.training], + self.post_nms_topk[self.training], + self.min_box_size, + self.training, + ) + + def _decode_proposals(self, anchors: List[Boxes], pred_anchor_deltas: List[torch.Tensor]): + """ + Transform anchors into proposals by applying the predicted anchor deltas. + + Returns: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape + (N, Hi*Wi*A, B) + """ + N = pred_anchor_deltas[0].shape[0] + proposals = [] + # For each feature map + for anchors_i, pred_anchor_deltas_i in zip(anchors, pred_anchor_deltas): + B = anchors_i.tensor.size(1) + pred_anchor_deltas_i = pred_anchor_deltas_i.reshape(-1, B) + # Expand anchors to shape (N*Hi*Wi*A, B) + anchors_i = anchors_i.tensor.unsqueeze(0).expand(N, -1, -1).reshape(-1, B) + proposals_i = self.box2box_transform.apply_deltas(pred_anchor_deltas_i, anchors_i) + # Append feature map proposals with shape (N, Hi*Wi*A, B) + proposals.append(proposals_i.view(N, -1, B)) + return proposals diff --git a/custom_detectron2/modeling/proposal_generator/rrpn.py b/custom_detectron2/modeling/proposal_generator/rrpn.py new file mode 100644 index 0000000000000000000000000000000000000000..5c0d038190c1a9ec3cbd6cbb32322623aa3e7278 --- /dev/null +++ b/custom_detectron2/modeling/proposal_generator/rrpn.py @@ -0,0 +1,209 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import logging +from typing import Dict, List +import torch + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ShapeSpec, batched_nms_rotated, cat +from custom_detectron2.structures import Instances, RotatedBoxes, pairwise_iou_rotated +from custom_detectron2.utils.memory import retry_if_cuda_oom + +from ..box_regression import Box2BoxTransformRotated +from .build import PROPOSAL_GENERATOR_REGISTRY +from .proposal_utils import _is_tracing +from .rpn import RPN + +logger = logging.getLogger(__name__) + + +def find_top_rrpn_proposals( + proposals, + pred_objectness_logits, + image_sizes, + nms_thresh, + pre_nms_topk, + post_nms_topk, + min_box_size, + training, +): + """ + For each feature map, select the `pre_nms_topk` highest scoring proposals, + apply NMS, clip proposals, and remove small boxes. Return the `post_nms_topk` + highest scoring proposals among all the feature maps if `training` is True, + otherwise, returns the highest `post_nms_topk` scoring proposals for each + feature map. + + Args: + proposals (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A, 5). + All proposal predictions on the feature maps. + pred_objectness_logits (list[Tensor]): A list of L tensors. Tensor i has shape (N, Hi*Wi*A). + image_sizes (list[tuple]): sizes (h, w) for each image + nms_thresh (float): IoU threshold to use for NMS + pre_nms_topk (int): number of top k scoring proposals to keep before applying NMS. + When RRPN is run on multiple feature maps (as in FPN) this number is per + feature map. + post_nms_topk (int): number of top k scoring proposals to keep after applying NMS. + When RRPN is run on multiple feature maps (as in FPN) this number is total, + over all feature maps. + min_box_size(float): minimum proposal box side length in pixels (absolute units wrt + input images). + training (bool): True if proposals are to be used in training, otherwise False. + This arg exists only to support a legacy bug; look for the "NB: Legacy bug ..." + comment. + + Returns: + proposals (list[Instances]): list of N Instances. The i-th Instances + stores post_nms_topk object proposals for image i. + """ + num_images = len(image_sizes) + device = proposals[0].device + + # 1. Select top-k anchor for every level and every image + topk_scores = [] # #lvl Tensor, each of shape N x topk + topk_proposals = [] + level_ids = [] # #lvl Tensor, each of shape (topk,) + batch_idx = torch.arange(num_images, device=device) + for level_id, proposals_i, logits_i in zip( + itertools.count(), proposals, pred_objectness_logits + ): + Hi_Wi_A = logits_i.shape[1] + if isinstance(Hi_Wi_A, torch.Tensor): # it's a tensor in tracing + num_proposals_i = torch.clamp(Hi_Wi_A, max=pre_nms_topk) + else: + num_proposals_i = min(Hi_Wi_A, pre_nms_topk) + + topk_scores_i, topk_idx = logits_i.topk(num_proposals_i, dim=1) + + # each is N x topk + topk_proposals_i = proposals_i[batch_idx[:, None], topk_idx] # N x topk x 5 + + topk_proposals.append(topk_proposals_i) + topk_scores.append(topk_scores_i) + level_ids.append(torch.full((num_proposals_i,), level_id, dtype=torch.int64, device=device)) + + # 2. Concat all levels together + topk_scores = cat(topk_scores, dim=1) + topk_proposals = cat(topk_proposals, dim=1) + level_ids = cat(level_ids, dim=0) + + # 3. For each image, run a per-level NMS, and choose topk results. + results = [] + for n, image_size in enumerate(image_sizes): + boxes = RotatedBoxes(topk_proposals[n]) + scores_per_img = topk_scores[n] + lvl = level_ids + + valid_mask = torch.isfinite(boxes.tensor).all(dim=1) & torch.isfinite(scores_per_img) + if not valid_mask.all(): + if training: + raise FloatingPointError( + "Predicted boxes or scores contain Inf/NaN. Training has diverged." + ) + boxes = boxes[valid_mask] + scores_per_img = scores_per_img[valid_mask] + lvl = lvl[valid_mask] + boxes.clip(image_size) + + # filter empty boxes + keep = boxes.nonempty(threshold=min_box_size) + if _is_tracing() or keep.sum().item() != len(boxes): + boxes, scores_per_img, lvl = (boxes[keep], scores_per_img[keep], lvl[keep]) + + keep = batched_nms_rotated(boxes.tensor, scores_per_img, lvl, nms_thresh) + # In Detectron1, there was different behavior during training vs. testing. + # (https://github.com/facebookresearch/Detectron/issues/459) + # During training, topk is over the proposals from *all* images in the training batch. + # During testing, it is over the proposals for each image separately. + # As a result, the training behavior becomes batch-dependent, + # and the configuration "POST_NMS_TOPK_TRAIN" end up relying on the batch size. + # This bug is addressed in Detectron2 to make the behavior independent of batch size. + keep = keep[:post_nms_topk] + + res = Instances(image_size) + res.proposal_boxes = boxes[keep] + res.objectness_logits = scores_per_img[keep] + results.append(res) + return results + + +@PROPOSAL_GENERATOR_REGISTRY.register() +class RRPN(RPN): + """ + Rotated Region Proposal Network described in :paper:`RRPN`. + """ + + @configurable + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if self.anchor_boundary_thresh >= 0: + raise NotImplementedError( + "anchor_boundary_thresh is a legacy option not implemented for RRPN." + ) + + @classmethod + def from_config(cls, cfg, input_shape: Dict[str, ShapeSpec]): + ret = super().from_config(cfg, input_shape) + ret["box2box_transform"] = Box2BoxTransformRotated(weights=cfg.MODEL.RPN.BBOX_REG_WEIGHTS) + return ret + + @torch.no_grad() + def label_and_sample_anchors(self, anchors: List[RotatedBoxes], gt_instances: List[Instances]): + """ + Args: + anchors (list[RotatedBoxes]): anchors for each feature map. + gt_instances: the ground-truth instances for each image. + + Returns: + list[Tensor]: + List of #img tensors. i-th element is a vector of labels whose length is + the total number of anchors across feature maps. Label values are in {-1, 0, 1}, + with meanings: -1 = ignore; 0 = negative class; 1 = positive class. + list[Tensor]: + i-th element is a Nx5 tensor, where N is the total number of anchors across + feature maps. The values are the matched gt boxes for each anchor. + Values are undefined for those anchors not labeled as 1. + """ + anchors = RotatedBoxes.cat(anchors) + + gt_boxes = [x.gt_boxes for x in gt_instances] + del gt_instances + + gt_labels = [] + matched_gt_boxes = [] + for gt_boxes_i in gt_boxes: + """ + gt_boxes_i: ground-truth boxes for i-th image + """ + match_quality_matrix = retry_if_cuda_oom(pairwise_iou_rotated)(gt_boxes_i, anchors) + matched_idxs, gt_labels_i = retry_if_cuda_oom(self.anchor_matcher)(match_quality_matrix) + # Matching is memory-expensive and may result in CPU tensors. But the result is small + gt_labels_i = gt_labels_i.to(device=gt_boxes_i.device) + + # A vector of labels (-1, 0, 1) for each anchor + gt_labels_i = self._subsample_labels(gt_labels_i) + + if len(gt_boxes_i) == 0: + # These values won't be used anyway since the anchor is labeled as background + matched_gt_boxes_i = torch.zeros_like(anchors.tensor) + else: + # TODO wasted indexing computation for ignored boxes + matched_gt_boxes_i = gt_boxes_i[matched_idxs].tensor + + gt_labels.append(gt_labels_i) # N,AHW + matched_gt_boxes.append(matched_gt_boxes_i) + return gt_labels, matched_gt_boxes + + @torch.no_grad() + def predict_proposals(self, anchors, pred_objectness_logits, pred_anchor_deltas, image_sizes): + pred_proposals = self._decode_proposals(anchors, pred_anchor_deltas) + return find_top_rrpn_proposals( + pred_proposals, + pred_objectness_logits, + image_sizes, + self.nms_thresh, + self.pre_nms_topk[self.training], + self.post_nms_topk[self.training], + self.min_box_size, + self.training, + ) diff --git a/custom_detectron2/modeling/roi_heads/__init__.py b/custom_detectron2/modeling/roi_heads/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d13e9c57235b982f3e0645bc316de2b75755dfda --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/__init__.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .box_head import ROI_BOX_HEAD_REGISTRY, build_box_head, FastRCNNConvFCHead +from .keypoint_head import ( + ROI_KEYPOINT_HEAD_REGISTRY, + build_keypoint_head, + BaseKeypointRCNNHead, + KRCNNConvDeconvUpsampleHead, +) +from .mask_head import ( + ROI_MASK_HEAD_REGISTRY, + build_mask_head, + BaseMaskRCNNHead, + MaskRCNNConvUpsampleHead, +) +from .roi_heads import ( + ROI_HEADS_REGISTRY, + ROIHeads, + Res5ROIHeads, + StandardROIHeads, + build_roi_heads, + select_foreground_proposals, +) +from .cascade_rcnn import CascadeROIHeads +from .rotated_fast_rcnn import RROIHeads +from .fast_rcnn import FastRCNNOutputLayers + +from . import cascade_rcnn # isort:skip + +__all__ = list(globals().keys()) diff --git a/custom_detectron2/modeling/roi_heads/box_head.py b/custom_detectron2/modeling/roi_heads/box_head.py new file mode 100644 index 0000000000000000000000000000000000000000..c2312d53a6959ba5b8bb6746ea49f8a956112298 --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/box_head.py @@ -0,0 +1,118 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import List +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.layers import Conv2d, ShapeSpec, get_norm +from custom_detectron2.utils.registry import Registry + +__all__ = ["FastRCNNConvFCHead", "build_box_head", "ROI_BOX_HEAD_REGISTRY"] + +ROI_BOX_HEAD_REGISTRY = Registry("ROI_BOX_HEAD") +ROI_BOX_HEAD_REGISTRY.__doc__ = """ +Registry for box heads, which make box predictions from per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_BOX_HEAD_REGISTRY.register() +class FastRCNNConvFCHead(nn.Sequential): + """ + A head with several 3x3 conv layers (each followed by norm & relu) and then + several fc layers (each followed by relu). + """ + + @configurable + def __init__( + self, input_shape: ShapeSpec, *, conv_dims: List[int], fc_dims: List[int], conv_norm="" + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature. + conv_dims (list[int]): the output dimensions of the conv layers + fc_dims (list[int]): the output dimensions of the fc layers + conv_norm (str or callable): normalization for the conv layers. + See :func:`detectron2.layers.get_norm` for supported types. + """ + super().__init__() + assert len(conv_dims) + len(fc_dims) > 0 + + self._output_size = (input_shape.channels, input_shape.height, input_shape.width) + + self.conv_norm_relus = [] + for k, conv_dim in enumerate(conv_dims): + conv = Conv2d( + self._output_size[0], + conv_dim, + kernel_size=3, + padding=1, + bias=not conv_norm, + norm=get_norm(conv_norm, conv_dim), + activation=nn.ReLU(), + ) + self.add_module("conv{}".format(k + 1), conv) + self.conv_norm_relus.append(conv) + self._output_size = (conv_dim, self._output_size[1], self._output_size[2]) + + self.fcs = [] + for k, fc_dim in enumerate(fc_dims): + if k == 0: + self.add_module("flatten", nn.Flatten()) + fc = nn.Linear(int(np.prod(self._output_size)), fc_dim) + self.add_module("fc{}".format(k + 1), fc) + self.add_module("fc_relu{}".format(k + 1), nn.ReLU()) + self.fcs.append(fc) + self._output_size = fc_dim + + for layer in self.conv_norm_relus: + weight_init.c2_msra_fill(layer) + for layer in self.fcs: + weight_init.c2_xavier_fill(layer) + + @classmethod + def from_config(cls, cfg, input_shape): + num_conv = cfg.MODEL.ROI_BOX_HEAD.NUM_CONV + conv_dim = cfg.MODEL.ROI_BOX_HEAD.CONV_DIM + num_fc = cfg.MODEL.ROI_BOX_HEAD.NUM_FC + fc_dim = cfg.MODEL.ROI_BOX_HEAD.FC_DIM + return { + "input_shape": input_shape, + "conv_dims": [conv_dim] * num_conv, + "fc_dims": [fc_dim] * num_fc, + "conv_norm": cfg.MODEL.ROI_BOX_HEAD.NORM, + } + + def forward(self, x): + for layer in self: + x = layer(x) + return x + + @property + @torch.jit.unused + def output_shape(self): + """ + Returns: + ShapeSpec: the output feature shape + """ + o = self._output_size + if isinstance(o, int): + return ShapeSpec(channels=o) + else: + return ShapeSpec(channels=o[0], height=o[1], width=o[2]) + + +def build_box_head(cfg, input_shape): + """ + Build a box head defined by `cfg.MODEL.ROI_BOX_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_BOX_HEAD.NAME + return ROI_BOX_HEAD_REGISTRY.get(name)(cfg, input_shape) diff --git a/custom_detectron2/modeling/roi_heads/cascade_rcnn.py b/custom_detectron2/modeling/roi_heads/cascade_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..441da8e5741a5c8f008025996d39cc3a9a6b53c6 --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/cascade_rcnn.py @@ -0,0 +1,299 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch +from torch import nn +from torch.autograd.function import Function + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ShapeSpec +from custom_detectron2.structures import Boxes, Instances, pairwise_iou +from custom_detectron2.utils.events import get_event_storage + +from ..box_regression import Box2BoxTransform +from ..matcher import Matcher +from ..poolers import ROIPooler +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers, fast_rcnn_inference +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads + + +class _ScaleGradient(Function): + @staticmethod + def forward(ctx, input, scale): + ctx.scale = scale + return input + + @staticmethod + def backward(ctx, grad_output): + return grad_output * ctx.scale, None + + +@ROI_HEADS_REGISTRY.register() +class CascadeROIHeads(StandardROIHeads): + """ + The ROI heads that implement :paper:`Cascade R-CNN`. + """ + + @configurable + def __init__( + self, + *, + box_in_features: List[str], + box_pooler: ROIPooler, + box_heads: List[nn.Module], + box_predictors: List[nn.Module], + proposal_matchers: List[Matcher], + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + box_pooler (ROIPooler): pooler that extracts region features from given boxes + box_heads (list[nn.Module]): box head for each cascade stage + box_predictors (list[nn.Module]): box predictor for each cascade stage + proposal_matchers (list[Matcher]): matcher with different IoU thresholds to + match boxes with ground truth for each stage. The first matcher matches + RPN proposals with ground truth, the other matchers use boxes predicted + by the previous stage as proposals and match them with ground truth. + """ + assert "proposal_matcher" not in kwargs, ( + "CascadeROIHeads takes 'proposal_matchers=' for each stage instead " + "of one 'proposal_matcher='." + ) + # The first matcher matches RPN proposals with ground truth, done in the base class + kwargs["proposal_matcher"] = proposal_matchers[0] + num_stages = self.num_cascade_stages = len(box_heads) + box_heads = nn.ModuleList(box_heads) + box_predictors = nn.ModuleList(box_predictors) + assert len(box_predictors) == num_stages, f"{len(box_predictors)} != {num_stages}!" + assert len(proposal_matchers) == num_stages, f"{len(proposal_matchers)} != {num_stages}!" + super().__init__( + box_in_features=box_in_features, + box_pooler=box_pooler, + box_head=box_heads, + box_predictor=box_predictors, + **kwargs, + ) + self.proposal_matchers = proposal_matchers + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret.pop("proposal_matcher") + return ret + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + cascade_bbox_reg_weights = cfg.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS + cascade_ious = cfg.MODEL.ROI_BOX_CASCADE_HEAD.IOUS + assert len(cascade_bbox_reg_weights) == len(cascade_ious) + assert cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG, \ + "CascadeROIHeads only support class-agnostic regression now!" + assert cascade_ious[0] == cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS[0] + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features] + # Check all channel counts are equal + assert len(set(in_channels)) == 1, in_channels + in_channels = in_channels[0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + pooled_shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + + box_heads, box_predictors, proposal_matchers = [], [], [] + for match_iou, bbox_reg_weights in zip(cascade_ious, cascade_bbox_reg_weights): + box_head = build_box_head(cfg, pooled_shape) + box_heads.append(box_head) + box_predictors.append( + FastRCNNOutputLayers( + cfg, + box_head.output_shape, + box2box_transform=Box2BoxTransform(weights=bbox_reg_weights), + ) + ) + proposal_matchers.append(Matcher([match_iou], [0, 1], allow_low_quality_matches=False)) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_heads": box_heads, + "box_predictors": box_predictors, + "proposal_matchers": proposal_matchers, + } + + def forward(self, images, features, proposals, targets=None): + del images + if self.training: + proposals = self.label_and_sample_proposals(proposals, targets) + + if self.training: + # Need targets to box head + losses = self._forward_box(features, proposals, targets) + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def _forward_box(self, features, proposals, targets=None): + """ + Args: + features, targets: the same as in + Same as in :meth:`ROIHeads.forward`. + proposals (list[Instances]): the per-image object proposals with + their matching ground truth. + Each has fields "proposal_boxes", and "objectness_logits", + "gt_classes", "gt_boxes". + """ + features = [features[f] for f in self.box_in_features] + head_outputs = [] # (predictor, predictions, proposals) + prev_pred_boxes = None + image_sizes = [x.image_size for x in proposals] + for k in range(self.num_cascade_stages): + if k > 0: + # The output boxes of the previous stage are used to create the input + # proposals of the next stage. + proposals = self._create_proposals_from_boxes(prev_pred_boxes, image_sizes) + if self.training: + proposals = self._match_and_label_boxes(proposals, k, targets) + predictions = self._run_stage(features, proposals, k) + prev_pred_boxes = self.box_predictor[k].predict_boxes(predictions, proposals) + head_outputs.append((self.box_predictor[k], predictions, proposals)) + + if self.training: + losses = {} + storage = get_event_storage() + for stage, (predictor, predictions, proposals) in enumerate(head_outputs): + with storage.name_scope("stage{}".format(stage)): + stage_losses = predictor.losses(predictions, proposals) + losses.update({k + "_stage{}".format(stage): v for k, v in stage_losses.items()}) + return losses + else: + # Each is a list[Tensor] of length #image. Each tensor is Ri x (K+1) + scores_per_stage = [h[0].predict_probs(h[1], h[2]) for h in head_outputs] + + # Average the scores across heads + scores = [ + sum(list(scores_per_image)) * (1.0 / self.num_cascade_stages) + for scores_per_image in zip(*scores_per_stage) + ] + # Use the boxes of the last head + predictor, predictions, proposals = head_outputs[-1] + boxes = predictor.predict_boxes(predictions, proposals) + pred_instances, _ = fast_rcnn_inference( + boxes, + scores, + image_sizes, + predictor.test_score_thresh, + predictor.test_nms_thresh, + predictor.test_topk_per_image, + ) + return pred_instances + + @torch.no_grad() + def _match_and_label_boxes(self, proposals, stage, targets): + """ + Match proposals with groundtruth using the matcher at the given stage. + Label the proposals as foreground or background based on the match. + + Args: + proposals (list[Instances]): One Instances for each image, with + the field "proposal_boxes". + stage (int): the current stage + targets (list[Instances]): the ground truth instances + + Returns: + list[Instances]: the same proposals, but with fields "gt_classes" and "gt_boxes" + """ + num_fg_samples, num_bg_samples = [], [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + match_quality_matrix = pairwise_iou( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + # proposal_labels are 0 or 1 + matched_idxs, proposal_labels = self.proposal_matchers[stage](match_quality_matrix) + if len(targets_per_image) > 0: + gt_classes = targets_per_image.gt_classes[matched_idxs] + # Label unmatched proposals (0 label from matcher) as background (label=num_classes) + gt_classes[proposal_labels == 0] = self.num_classes + gt_boxes = targets_per_image.gt_boxes[matched_idxs] + else: + gt_classes = torch.zeros_like(matched_idxs) + self.num_classes + gt_boxes = Boxes( + targets_per_image.gt_boxes.tensor.new_zeros((len(proposals_per_image), 4)) + ) + proposals_per_image.gt_classes = gt_classes + proposals_per_image.gt_boxes = gt_boxes + + num_fg_samples.append((proposal_labels == 1).sum().item()) + num_bg_samples.append(proposal_labels.numel() - num_fg_samples[-1]) + + # Log the number of fg/bg samples in each stage + storage = get_event_storage() + storage.put_scalar( + "stage{}/roi_head/num_fg_samples".format(stage), + sum(num_fg_samples) / len(num_fg_samples), + ) + storage.put_scalar( + "stage{}/roi_head/num_bg_samples".format(stage), + sum(num_bg_samples) / len(num_bg_samples), + ) + return proposals + + def _run_stage(self, features, proposals, stage): + """ + Args: + features (list[Tensor]): #lvl input features to ROIHeads + proposals (list[Instances]): #image Instances, with the field "proposal_boxes" + stage (int): the current stage + + Returns: + Same output as `FastRCNNOutputLayers.forward()`. + """ + box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) + # The original implementation averages the losses among heads, + # but scale up the parameter gradients of the heads. + # This is equivalent to adding the losses among heads, + # but scale down the gradients on features. + if self.training: + box_features = _ScaleGradient.apply(box_features, 1.0 / self.num_cascade_stages) + box_features = self.box_head[stage](box_features) + return self.box_predictor[stage](box_features) + + def _create_proposals_from_boxes(self, boxes, image_sizes): + """ + Args: + boxes (list[Tensor]): per-image predicted boxes, each of shape Ri x 4 + image_sizes (list[tuple]): list of image shapes in (h, w) + + Returns: + list[Instances]: per-image proposals with the given boxes. + """ + # Just like RPN, the proposals should not have gradients + boxes = [Boxes(b.detach()) for b in boxes] + proposals = [] + for boxes_per_image, image_size in zip(boxes, image_sizes): + boxes_per_image.clip(image_size) + if self.training: + # do not filter empty boxes at inference time, + # because the scores from each stage need to be aligned and added later + boxes_per_image = boxes_per_image[boxes_per_image.nonempty()] + prop = Instances(image_size) + prop.proposal_boxes = boxes_per_image + proposals.append(prop) + return proposals diff --git a/custom_detectron2/modeling/roi_heads/fast_rcnn.py b/custom_detectron2/modeling/roi_heads/fast_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..6d2f8476f40c5fff30a16a41d9cabee21e7aad6c --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/fast_rcnn.py @@ -0,0 +1,569 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +from typing import Callable, Dict, List, Optional, Tuple, Union +import torch +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.data.detection_utils import get_fed_loss_cls_weights +from custom_detectron2.layers import ShapeSpec, batched_nms, cat, cross_entropy, nonzero_tuple +from custom_detectron2.modeling.box_regression import Box2BoxTransform, _dense_box_regression_loss +from custom_detectron2.structures import Boxes, Instances +from custom_detectron2.utils.events import get_event_storage + +__all__ = ["fast_rcnn_inference", "FastRCNNOutputLayers"] + + +logger = logging.getLogger(__name__) + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + R: number of ROIs, combined over all images, in the minibatch + Ri: number of ROIs in image i + K: number of foreground classes. E.g.,there are 80 foreground classes in COCO. + +Naming convention: + + deltas: refers to the 4-d (dx, dy, dw, dh) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransform`). + + pred_class_logits: predicted class scores in [-inf, +inf]; use + softmax(pred_class_logits) to estimate P(class). + + gt_classes: ground-truth classification labels in [0, K], where [0, K) represent + foreground object classes and K represents the background class. + + pred_proposal_deltas: predicted box2box transform deltas for transforming proposals + to detection box predictions. + + gt_proposal_deltas: ground-truth box2box transform deltas +""" + + +def fast_rcnn_inference( + boxes: List[torch.Tensor], + scores: List[torch.Tensor], + image_shapes: List[Tuple[int, int]], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, +): + """ + Call `fast_rcnn_inference_single_image` for all images. + + Args: + boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic + boxes for each image. Element i has shape (Ri, K * 4) if doing + class-specific regression, or (Ri, 4) if doing class-agnostic + regression, where Ri is the number of predicted objects for image i. + This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`. + scores (list[Tensor]): A list of Tensors of predicted class scores for each image. + Element i has shape (Ri, K + 1), where Ri is the number of predicted objects + for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. + image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch. + score_thresh (float): Only return detections with a confidence score exceeding this + threshold. + nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + topk_per_image (int): The number of top scoring detections to return. Set < 0 to return + all detections. + + Returns: + instances: (list[Instances]): A list of N instances, one for each image in the batch, + that stores the topk most confidence detections. + kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates + the corresponding boxes/scores index in [0, Ri) from the input, for image i. + """ + result_per_image = [ + fast_rcnn_inference_single_image( + boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image + ) + for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) + ] + return [x[0] for x in result_per_image], [x[1] for x in result_per_image] + + +def _log_classification_stats(pred_logits, gt_classes, prefix="fast_rcnn"): + """ + Log the classification metrics to EventStorage. + + Args: + pred_logits: Rx(K+1) logits. The last column is for background class. + gt_classes: R labels + """ + num_instances = gt_classes.numel() + if num_instances == 0: + return + pred_classes = pred_logits.argmax(dim=1) + bg_class_ind = pred_logits.shape[1] - 1 + + fg_inds = (gt_classes >= 0) & (gt_classes < bg_class_ind) + num_fg = fg_inds.nonzero().numel() + fg_gt_classes = gt_classes[fg_inds] + fg_pred_classes = pred_classes[fg_inds] + + num_false_negative = (fg_pred_classes == bg_class_ind).nonzero().numel() + num_accurate = (pred_classes == gt_classes).nonzero().numel() + fg_num_accurate = (fg_pred_classes == fg_gt_classes).nonzero().numel() + + storage = get_event_storage() + storage.put_scalar(f"{prefix}/cls_accuracy", num_accurate / num_instances) + if num_fg > 0: + storage.put_scalar(f"{prefix}/fg_cls_accuracy", fg_num_accurate / num_fg) + storage.put_scalar(f"{prefix}/false_negative", num_false_negative / num_fg) + + +def fast_rcnn_inference_single_image( + boxes, + scores, + image_shape: Tuple[int, int], + score_thresh: float, + nms_thresh: float, + topk_per_image: int, +): + """ + Single-image inference. Return bounding-box detection results by thresholding + on scores and applying non-maximum suppression (NMS). + + Args: + Same as `fast_rcnn_inference`, but with boxes, scores, and image shapes + per image. + + Returns: + Same as `fast_rcnn_inference`, but for only one image. + """ + valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) + if not valid_mask.all(): + boxes = boxes[valid_mask] + scores = scores[valid_mask] + + scores = scores[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // 4 + # Convert to Boxes to use the `clip` function ... + boxes = Boxes(boxes.reshape(-1, 4)) + boxes.clip(image_shape) + boxes = boxes.tensor.view(-1, num_bbox_reg_classes, 4) # R x C x 4 + + # 1. Filter results based on detection scores. It can make NMS more efficient + # by filtering out low-confidence detections. + filter_mask = scores > score_thresh # R x K + # R' x 2. First column contains indices of the R predictions; + # Second column contains indices of classes. + filter_inds = filter_mask.nonzero() + if num_bbox_reg_classes == 1: + boxes = boxes[filter_inds[:, 0], 0] + else: + boxes = boxes[filter_mask] + scores = scores[filter_mask] + + # 2. Apply NMS for each class independently. + keep = batched_nms(boxes, scores, filter_inds[:, 1], nms_thresh) + if topk_per_image >= 0: + keep = keep[:topk_per_image] + boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] + + result = Instances(image_shape) + result.pred_boxes = Boxes(boxes) + result.scores = scores + result.pred_classes = filter_inds[:, 1] + return result, filter_inds[:, 0] + + +class FastRCNNOutputLayers(nn.Module): + """ + Two linear layers for predicting Fast R-CNN outputs: + + 1. proposal-to-detection box regression deltas + 2. classification scores + """ + + @configurable + def __init__( + self, + input_shape: ShapeSpec, + *, + box2box_transform, + num_classes: int, + test_score_thresh: float = 0.0, + test_nms_thresh: float = 0.5, + test_topk_per_image: int = 100, + cls_agnostic_bbox_reg: bool = False, + smooth_l1_beta: float = 0.0, + box_reg_loss_type: str = "smooth_l1", + loss_weight: Union[float, Dict[str, float]] = 1.0, + use_fed_loss: bool = False, + use_sigmoid_ce: bool = False, + get_fed_loss_cls_weights: Optional[Callable] = None, + fed_loss_num_classes: int = 50, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature to this module + box2box_transform (Box2BoxTransform or Box2BoxTransformRotated): + num_classes (int): number of foreground classes + test_score_thresh (float): threshold to filter predictions results. + test_nms_thresh (float): NMS threshold for prediction results. + test_topk_per_image (int): number of top predictions to produce per image. + cls_agnostic_bbox_reg (bool): whether to use class agnostic for bbox regression + smooth_l1_beta (float): transition point from L1 to L2 loss. Only used if + `box_reg_loss_type` is "smooth_l1" + box_reg_loss_type (str): Box regression loss type. One of: "smooth_l1", "giou", + "diou", "ciou" + loss_weight (float|dict): weights to use for losses. Can be single float for weighting + all losses, or a dict of individual weightings. Valid dict keys are: + * "loss_cls": applied to classification loss + * "loss_box_reg": applied to box regression loss + use_fed_loss (bool): whether to use federated loss which samples additional negative + classes to calculate the loss + use_sigmoid_ce (bool): whether to calculate the loss using weighted average of binary + cross entropy with logits. This could be used together with federated loss + get_fed_loss_cls_weights (Callable): a callable which takes dataset name and frequency + weight power, and returns the probabilities to sample negative classes for + federated loss. The implementation can be found in + detectron2/data/detection_utils.py + fed_loss_num_classes (int): number of federated classes to keep in total + """ + super().__init__() + if isinstance(input_shape, int): # some backward compatibility + input_shape = ShapeSpec(channels=input_shape) + self.num_classes = num_classes + input_size = input_shape.channels * (input_shape.width or 1) * (input_shape.height or 1) + # prediction layer for num_classes foreground classes and one background class (hence + 1) + self.cls_score = nn.Linear(input_size, num_classes + 1) + num_bbox_reg_classes = 1 if cls_agnostic_bbox_reg else num_classes + box_dim = len(box2box_transform.weights) + self.bbox_pred = nn.Linear(input_size, num_bbox_reg_classes * box_dim) + + nn.init.normal_(self.cls_score.weight, std=0.01) + nn.init.normal_(self.bbox_pred.weight, std=0.001) + for l in [self.cls_score, self.bbox_pred]: + nn.init.constant_(l.bias, 0) + + self.box2box_transform = box2box_transform + self.smooth_l1_beta = smooth_l1_beta + self.test_score_thresh = test_score_thresh + self.test_nms_thresh = test_nms_thresh + self.test_topk_per_image = test_topk_per_image + self.box_reg_loss_type = box_reg_loss_type + if isinstance(loss_weight, float): + loss_weight = {"loss_cls": loss_weight, "loss_box_reg": loss_weight} + self.loss_weight = loss_weight + self.use_fed_loss = use_fed_loss + self.use_sigmoid_ce = use_sigmoid_ce + self.fed_loss_num_classes = fed_loss_num_classes + + if self.use_fed_loss: + assert self.use_sigmoid_ce, "Please use sigmoid cross entropy loss with federated loss" + fed_loss_cls_weights = get_fed_loss_cls_weights() + assert ( + len(fed_loss_cls_weights) == self.num_classes + ), "Please check the provided fed_loss_cls_weights. Their size should match num_classes" + self.register_buffer("fed_loss_cls_weights", fed_loss_cls_weights) + + @classmethod + def from_config(cls, cfg, input_shape): + return { + "input_shape": input_shape, + "box2box_transform": Box2BoxTransform(weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS), + # fmt: off + "num_classes" : cfg.MODEL.ROI_HEADS.NUM_CLASSES, + "cls_agnostic_bbox_reg" : cfg.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG, + "smooth_l1_beta" : cfg.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA, + "test_score_thresh" : cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST, + "test_nms_thresh" : cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, + "test_topk_per_image" : cfg.TEST.DETECTIONS_PER_IMAGE, + "box_reg_loss_type" : cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE, + "loss_weight" : {"loss_box_reg": cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT}, # noqa + "use_fed_loss" : cfg.MODEL.ROI_BOX_HEAD.USE_FED_LOSS, + "use_sigmoid_ce" : cfg.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE, + "get_fed_loss_cls_weights" : lambda: get_fed_loss_cls_weights(dataset_names=cfg.DATASETS.TRAIN, freq_weight_power=cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER), # noqa + "fed_loss_num_classes" : cfg.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES, + # fmt: on + } + + def forward(self, x): + """ + Args: + x: per-region features of shape (N, ...) for N bounding boxes to predict. + + Returns: + (Tensor, Tensor): + First tensor: shape (N,K+1), scores for each of the N box. Each row contains the + scores for K object categories and 1 background class. + + Second tensor: bounding box regression deltas for each box. Shape is shape (N,Kx4), + or (N,4) for class-agnostic regression. + """ + if x.dim() > 2: + x = torch.flatten(x, start_dim=1) + scores = self.cls_score(x) + proposal_deltas = self.bbox_pred(x) + return scores, proposal_deltas + + def losses(self, predictions, proposals): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were used + to compute predictions. The fields ``proposal_boxes``, ``gt_boxes``, + ``gt_classes`` are expected. + + Returns: + Dict[str, Tensor]: dict of losses + """ + scores, proposal_deltas = predictions + + # parse classification outputs + gt_classes = ( + cat([p.gt_classes for p in proposals], dim=0) if len(proposals) else torch.empty(0) + ) + _log_classification_stats(scores, gt_classes) + + # parse box regression outputs + if len(proposals): + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) # Nx4 + assert not proposal_boxes.requires_grad, "Proposals should not require gradients!" + # If "gt_boxes" does not exist, the proposals must be all negative and + # should not be included in regression loss computation. + # Here we just use proposal_boxes as an arbitrary placeholder because its + # value won't be used in self.box_reg_loss(). + gt_boxes = cat( + [(p.gt_boxes if p.has("gt_boxes") else p.proposal_boxes).tensor for p in proposals], + dim=0, + ) + else: + proposal_boxes = gt_boxes = torch.empty((0, 4), device=proposal_deltas.device) + + if self.use_sigmoid_ce: + loss_cls = self.sigmoid_cross_entropy_loss(scores, gt_classes) + else: + loss_cls = cross_entropy(scores, gt_classes, reduction="mean") + + losses = { + "loss_cls": loss_cls, + "loss_box_reg": self.box_reg_loss( + proposal_boxes, gt_boxes, proposal_deltas, gt_classes + ), + } + return {k: v * self.loss_weight.get(k, 1.0) for k, v in losses.items()} + + # Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/fed_loss.py # noqa + # with slight modifications + def get_fed_loss_classes(self, gt_classes, num_fed_loss_classes, num_classes, weight): + """ + Args: + gt_classes: a long tensor of shape R that contains the gt class label of each proposal. + num_fed_loss_classes: minimum number of classes to keep when calculating federated loss. + Will sample negative classes if number of unique gt_classes is smaller than this value. + num_classes: number of foreground classes + weight: probabilities used to sample negative classes + + Returns: + Tensor: + classes to keep when calculating the federated loss, including both unique gt + classes and sampled negative classes. + """ + unique_gt_classes = torch.unique(gt_classes) + prob = unique_gt_classes.new_ones(num_classes + 1).float() + prob[-1] = 0 + if len(unique_gt_classes) < num_fed_loss_classes: + prob[:num_classes] = weight.float().clone() + prob[unique_gt_classes] = 0 + sampled_negative_classes = torch.multinomial( + prob, num_fed_loss_classes - len(unique_gt_classes), replacement=False + ) + fed_loss_classes = torch.cat([unique_gt_classes, sampled_negative_classes]) + else: + fed_loss_classes = unique_gt_classes + return fed_loss_classes + + # Implementation from https://github.com/xingyizhou/CenterNet2/blob/master/projects/CenterNet2/centernet/modeling/roi_heads/custom_fast_rcnn.py#L113 # noqa + # with slight modifications + def sigmoid_cross_entropy_loss(self, pred_class_logits, gt_classes): + """ + Args: + pred_class_logits: shape (N, K+1), scores for each of the N box. Each row contains the + scores for K object categories and 1 background class + gt_classes: a long tensor of shape R that contains the gt class label of each proposal. + """ + if pred_class_logits.numel() == 0: + return pred_class_logits.new_zeros([1])[0] + + N = pred_class_logits.shape[0] + K = pred_class_logits.shape[1] - 1 + + target = pred_class_logits.new_zeros(N, K + 1) + target[range(len(gt_classes)), gt_classes] = 1 + target = target[:, :K] + + cls_loss = F.binary_cross_entropy_with_logits( + pred_class_logits[:, :-1], target, reduction="none" + ) + + if self.use_fed_loss: + fed_loss_classes = self.get_fed_loss_classes( + gt_classes, + num_fed_loss_classes=self.fed_loss_num_classes, + num_classes=K, + weight=self.fed_loss_cls_weights, + ) + fed_loss_classes_mask = fed_loss_classes.new_zeros(K + 1) + fed_loss_classes_mask[fed_loss_classes] = 1 + fed_loss_classes_mask = fed_loss_classes_mask[:K] + weight = fed_loss_classes_mask.view(1, K).expand(N, K).float() + else: + weight = 1 + + loss = torch.sum(cls_loss * weight) / N + return loss + + def box_reg_loss(self, proposal_boxes, gt_boxes, pred_deltas, gt_classes): + """ + Args: + proposal_boxes/gt_boxes are tensors with the same shape (R, 4 or 5). + pred_deltas has shape (R, 4 or 5), or (R, num_classes * (4 or 5)). + gt_classes is a long tensor of shape R, the gt class label of each proposal. + R shall be the number of proposals. + """ + box_dim = proposal_boxes.shape[1] # 4 or 5 + # Regression loss is only computed for foreground proposals (those matched to a GT) + fg_inds = nonzero_tuple((gt_classes >= 0) & (gt_classes < self.num_classes))[0] + if pred_deltas.shape[1] == box_dim: # cls-agnostic regression + fg_pred_deltas = pred_deltas[fg_inds] + else: + fg_pred_deltas = pred_deltas.view(-1, self.num_classes, box_dim)[ + fg_inds, gt_classes[fg_inds] + ] + + loss_box_reg = _dense_box_regression_loss( + [proposal_boxes[fg_inds]], + self.box2box_transform, + [fg_pred_deltas.unsqueeze(0)], + [gt_boxes[fg_inds]], + ..., + self.box_reg_loss_type, + self.smooth_l1_beta, + ) + + # The reg loss is normalized using the total number of regions (R), not the number + # of foreground regions even though the box regression loss is only defined on + # foreground regions. Why? Because doing so gives equal training influence to + # each foreground example. To see how, consider two different minibatches: + # (1) Contains a single foreground region + # (2) Contains 100 foreground regions + # If we normalize by the number of foreground regions, the single example in + # minibatch (1) will be given 100 times as much influence as each foreground + # example in minibatch (2). Normalizing by the total number of regions, R, + # means that the single example in minibatch (1) and each of the 100 examples + # in minibatch (2) are given equal influence. + return loss_box_reg / max(gt_classes.numel(), 1.0) # return 0 if empty + + def inference(self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances]): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. The ``proposal_boxes`` field is expected. + + Returns: + list[Instances]: same as `fast_rcnn_inference`. + list[Tensor]: same as `fast_rcnn_inference`. + """ + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + image_shapes = [x.image_size for x in proposals] + return fast_rcnn_inference( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + def predict_boxes_for_gt_classes(self, predictions, proposals): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were used + to compute predictions. The fields ``proposal_boxes``, ``gt_classes`` are expected. + + Returns: + list[Tensor]: + A list of Tensors of predicted boxes for GT classes in case of + class-specific box head. Element i of the list has shape (Ri, B), where Ri is + the number of proposals for image i and B is the box dimension (4 or 5) + """ + if not len(proposals): + return [] + scores, proposal_deltas = predictions + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) + N, B = proposal_boxes.shape + predict_boxes = self.box2box_transform.apply_deltas( + proposal_deltas, proposal_boxes + ) # Nx(KxB) + + K = predict_boxes.shape[1] // B + if K > 1: + gt_classes = torch.cat([p.gt_classes for p in proposals], dim=0) + # Some proposals are ignored or have a background class. Their gt_classes + # cannot be used as index. + gt_classes = gt_classes.clamp_(0, K - 1) + + predict_boxes = predict_boxes.view(N, K, B)[ + torch.arange(N, dtype=torch.long, device=predict_boxes.device), gt_classes + ] + num_prop_per_image = [len(p) for p in proposals] + return predict_boxes.split(num_prop_per_image) + + def predict_boxes( + self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances] + ): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. The ``proposal_boxes`` field is expected. + + Returns: + list[Tensor]: + A list of Tensors of predicted class-specific or class-agnostic boxes + for each image. Element i has shape (Ri, K * B) or (Ri, B), where Ri is + the number of proposals for image i and B is the box dimension (4 or 5) + """ + if not len(proposals): + return [] + _, proposal_deltas = predictions + num_prop_per_image = [len(p) for p in proposals] + proposal_boxes = cat([p.proposal_boxes.tensor for p in proposals], dim=0) + predict_boxes = self.box2box_transform.apply_deltas( + proposal_deltas, + proposal_boxes, + ) # Nx(KxB) + return predict_boxes.split(num_prop_per_image) + + def predict_probs( + self, predictions: Tuple[torch.Tensor, torch.Tensor], proposals: List[Instances] + ): + """ + Args: + predictions: return values of :meth:`forward()`. + proposals (list[Instances]): proposals that match the features that were + used to compute predictions. + + Returns: + list[Tensor]: + A list of Tensors of predicted class probabilities for each image. + Element i has shape (Ri, K + 1), where Ri is the number of proposals for image i. + """ + scores, _ = predictions + num_inst_per_image = [len(p) for p in proposals] + if self.use_sigmoid_ce: + probs = scores.sigmoid() + else: + probs = F.softmax(scores, dim=-1) + return probs.split(num_inst_per_image, dim=0) diff --git a/custom_detectron2/modeling/roi_heads/keypoint_head.py b/custom_detectron2/modeling/roi_heads/keypoint_head.py new file mode 100644 index 0000000000000000000000000000000000000000..d0af8e66c502f724ca97f3c72f19c0e54ee2d6e1 --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/keypoint_head.py @@ -0,0 +1,272 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import torch +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.layers import Conv2d, ConvTranspose2d, cat, interpolate +from custom_detectron2.structures import Instances, heatmaps_to_keypoints +from custom_detectron2.utils.events import get_event_storage +from custom_detectron2.utils.registry import Registry + +_TOTAL_SKIPPED = 0 + + +__all__ = [ + "ROI_KEYPOINT_HEAD_REGISTRY", + "build_keypoint_head", + "BaseKeypointRCNNHead", + "KRCNNConvDeconvUpsampleHead", +] + + +ROI_KEYPOINT_HEAD_REGISTRY = Registry("ROI_KEYPOINT_HEAD") +ROI_KEYPOINT_HEAD_REGISTRY.__doc__ = """ +Registry for keypoint heads, which make keypoint predictions from per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +def build_keypoint_head(cfg, input_shape): + """ + Build a keypoint head from `cfg.MODEL.ROI_KEYPOINT_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_KEYPOINT_HEAD.NAME + return ROI_KEYPOINT_HEAD_REGISTRY.get(name)(cfg, input_shape) + + +def keypoint_rcnn_loss(pred_keypoint_logits, instances, normalizer): + """ + Arguments: + pred_keypoint_logits (Tensor): A tensor of shape (N, K, S, S) where N is the total number + of instances in the batch, K is the number of keypoints, and S is the side length + of the keypoint heatmap. The values are spatial logits. + instances (list[Instances]): A list of M Instances, where M is the batch size. + These instances are predictions from the model + that are in 1:1 correspondence with pred_keypoint_logits. + Each Instances should contain a `gt_keypoints` field containing a `structures.Keypoint` + instance. + normalizer (float): Normalize the loss by this amount. + If not specified, we normalize by the number of visible keypoints in the minibatch. + + Returns a scalar tensor containing the loss. + """ + heatmaps = [] + valid = [] + + keypoint_side_len = pred_keypoint_logits.shape[2] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + keypoints = instances_per_image.gt_keypoints + heatmaps_per_image, valid_per_image = keypoints.to_heatmap( + instances_per_image.proposal_boxes.tensor, keypoint_side_len + ) + heatmaps.append(heatmaps_per_image.view(-1)) + valid.append(valid_per_image.view(-1)) + + if len(heatmaps): + keypoint_targets = cat(heatmaps, dim=0) + valid = cat(valid, dim=0).to(dtype=torch.uint8) + valid = torch.nonzero(valid).squeeze(1) + + # torch.mean (in binary_cross_entropy_with_logits) doesn't + # accept empty tensors, so handle it separately + if len(heatmaps) == 0 or valid.numel() == 0: + global _TOTAL_SKIPPED + _TOTAL_SKIPPED += 1 + storage = get_event_storage() + storage.put_scalar("kpts_num_skipped_batches", _TOTAL_SKIPPED, smoothing_hint=False) + return pred_keypoint_logits.sum() * 0 + + N, K, H, W = pred_keypoint_logits.shape + pred_keypoint_logits = pred_keypoint_logits.view(N * K, H * W) + + keypoint_loss = F.cross_entropy( + pred_keypoint_logits[valid], keypoint_targets[valid], reduction="sum" + ) + + # If a normalizer isn't specified, normalize by the number of visible keypoints in the minibatch + if normalizer is None: + normalizer = valid.numel() + keypoint_loss /= normalizer + + return keypoint_loss + + +def keypoint_rcnn_inference(pred_keypoint_logits: torch.Tensor, pred_instances: List[Instances]): + """ + Post process each predicted keypoint heatmap in `pred_keypoint_logits` into (x, y, score) + and add it to the `pred_instances` as a `pred_keypoints` field. + + Args: + pred_keypoint_logits (Tensor): A tensor of shape (R, K, S, S) where R is the total number + of instances in the batch, K is the number of keypoints, and S is the side length of + the keypoint heatmap. The values are spatial logits. + pred_instances (list[Instances]): A list of N Instances, where N is the number of images. + + Returns: + None. Each element in pred_instances will contain extra "pred_keypoints" and + "pred_keypoint_heatmaps" fields. "pred_keypoints" is a tensor of shape + (#instance, K, 3) where the last dimension corresponds to (x, y, score). + The scores are larger than 0. "pred_keypoint_heatmaps" contains the raw + keypoint logits as passed to this function. + """ + # flatten all bboxes from all images together (list[Boxes] -> Rx4 tensor) + bboxes_flat = cat([b.pred_boxes.tensor for b in pred_instances], dim=0) + + pred_keypoint_logits = pred_keypoint_logits.detach() + keypoint_results = heatmaps_to_keypoints(pred_keypoint_logits, bboxes_flat.detach()) + num_instances_per_image = [len(i) for i in pred_instances] + keypoint_results = keypoint_results[:, :, [0, 1, 3]].split(num_instances_per_image, dim=0) + heatmap_results = pred_keypoint_logits.split(num_instances_per_image, dim=0) + + for keypoint_results_per_image, heatmap_results_per_image, instances_per_image in zip( + keypoint_results, heatmap_results, pred_instances + ): + # keypoint_results_per_image is (num instances)x(num keypoints)x(x, y, score) + # heatmap_results_per_image is (num instances)x(num keypoints)x(side)x(side) + instances_per_image.pred_keypoints = keypoint_results_per_image + instances_per_image.pred_keypoint_heatmaps = heatmap_results_per_image + + +class BaseKeypointRCNNHead(nn.Module): + """ + Implement the basic Keypoint R-CNN losses and inference logic described in + Sec. 5 of :paper:`Mask R-CNN`. + """ + + @configurable + def __init__(self, *, num_keypoints, loss_weight=1.0, loss_normalizer=1.0): + """ + NOTE: this interface is experimental. + + Args: + num_keypoints (int): number of keypoints to predict + loss_weight (float): weight to multiple on the keypoint loss + loss_normalizer (float or str): + If float, divide the loss by `loss_normalizer * #images`. + If 'visible', the loss is normalized by the total number of + visible keypoints across images. + """ + super().__init__() + self.num_keypoints = num_keypoints + self.loss_weight = loss_weight + assert loss_normalizer == "visible" or isinstance(loss_normalizer, float), loss_normalizer + self.loss_normalizer = loss_normalizer + + @classmethod + def from_config(cls, cfg, input_shape): + ret = { + "loss_weight": cfg.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT, + "num_keypoints": cfg.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS, + } + normalize_by_visible = ( + cfg.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS + ) # noqa + if not normalize_by_visible: + batch_size_per_image = cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE + positive_sample_fraction = cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION + ret["loss_normalizer"] = ( + ret["num_keypoints"] * batch_size_per_image * positive_sample_fraction + ) + else: + ret["loss_normalizer"] = "visible" + return ret + + def forward(self, x, instances: List[Instances]): + """ + Args: + x: input 4D region feature(s) provided by :class:`ROIHeads`. + instances (list[Instances]): contains the boxes & labels corresponding + to the input features. + Exact format is up to its caller to decide. + Typically, this is the foreground instances in training, with + "proposal_boxes" field and other gt annotations. + In inference, it contains boxes that are already predicted. + + Returns: + A dict of losses if in training. The predicted "instances" if in inference. + """ + x = self.layers(x) + if self.training: + num_images = len(instances) + normalizer = ( + None if self.loss_normalizer == "visible" else num_images * self.loss_normalizer + ) + return { + "loss_keypoint": keypoint_rcnn_loss(x, instances, normalizer=normalizer) + * self.loss_weight + } + else: + keypoint_rcnn_inference(x, instances) + return instances + + def layers(self, x): + """ + Neural network layers that makes predictions from regional input features. + """ + raise NotImplementedError + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_KEYPOINT_HEAD_REGISTRY.register() +class KRCNNConvDeconvUpsampleHead(BaseKeypointRCNNHead, nn.Sequential): + """ + A standard keypoint head containing a series of 3x3 convs, followed by + a transpose convolution and bilinear interpolation for upsampling. + It is described in Sec. 5 of :paper:`Mask R-CNN`. + """ + + @configurable + def __init__(self, input_shape, *, num_keypoints, conv_dims, **kwargs): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + conv_dims: an iterable of output channel counts for each conv in the head + e.g. (512, 512, 512) for three convs outputting 512 channels. + """ + super().__init__(num_keypoints=num_keypoints, **kwargs) + + # default up_scale to 2.0 (this can be made an option) + up_scale = 2.0 + in_channels = input_shape.channels + + for idx, layer_channels in enumerate(conv_dims, 1): + module = Conv2d(in_channels, layer_channels, 3, stride=1, padding=1) + self.add_module("conv_fcn{}".format(idx), module) + self.add_module("conv_fcn_relu{}".format(idx), nn.ReLU()) + in_channels = layer_channels + + deconv_kernel = 4 + self.score_lowres = ConvTranspose2d( + in_channels, num_keypoints, deconv_kernel, stride=2, padding=deconv_kernel // 2 - 1 + ) + self.up_scale = up_scale + + for name, param in self.named_parameters(): + if "bias" in name: + nn.init.constant_(param, 0) + elif "weight" in name: + # Caffe2 implementation uses MSRAFill, which in fact + # corresponds to kaiming_normal_ in PyTorch + nn.init.kaiming_normal_(param, mode="fan_out", nonlinearity="relu") + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + ret["input_shape"] = input_shape + ret["conv_dims"] = cfg.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS + return ret + + def layers(self, x): + for layer in self: + x = layer(x) + x = interpolate(x, scale_factor=self.up_scale, mode="bilinear", align_corners=False) + return x diff --git a/custom_detectron2/modeling/roi_heads/mask_head.py b/custom_detectron2/modeling/roi_heads/mask_head.py new file mode 100644 index 0000000000000000000000000000000000000000..43b4318d7aab5896b4842a4e0efee3cb8f3a787f --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/mask_head.py @@ -0,0 +1,298 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import List +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.layers import Conv2d, ConvTranspose2d, ShapeSpec, cat, get_norm +from custom_detectron2.layers.wrappers import move_device_like +from custom_detectron2.structures import Instances +from custom_detectron2.utils.events import get_event_storage +from custom_detectron2.utils.registry import Registry + +__all__ = [ + "BaseMaskRCNNHead", + "MaskRCNNConvUpsampleHead", + "build_mask_head", + "ROI_MASK_HEAD_REGISTRY", +] + + +ROI_MASK_HEAD_REGISTRY = Registry("ROI_MASK_HEAD") +ROI_MASK_HEAD_REGISTRY.__doc__ = """ +Registry for mask heads, which predicts instance masks given +per-region features. + +The registered object will be called with `obj(cfg, input_shape)`. +""" + + +@torch.jit.unused +def mask_rcnn_loss(pred_mask_logits: torch.Tensor, instances: List[Instances], vis_period: int = 0): + """ + Compute the mask prediction loss defined in the Mask R-CNN paper. + + Args: + pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask) + for class-specific or class-agnostic, where B is the total number of predicted masks + in all images, C is the number of foreground classes, and Hmask, Wmask are the height + and width of the mask predictions. The values are logits. + instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. These instances are in 1:1 + correspondence with the pred_mask_logits. The ground-truth labels (class, box, mask, + ...) associated with each instance are stored in fields. + vis_period (int): the period (in steps) to dump visualization. + + Returns: + mask_loss (Tensor): A scalar tensor containing the loss. + """ + cls_agnostic_mask = pred_mask_logits.size(1) == 1 + total_num_masks = pred_mask_logits.size(0) + mask_side_len = pred_mask_logits.size(2) + assert pred_mask_logits.size(2) == pred_mask_logits.size(3), "Mask prediction must be square!" + + gt_classes = [] + gt_masks = [] + for instances_per_image in instances: + if len(instances_per_image) == 0: + continue + if not cls_agnostic_mask: + gt_classes_per_image = instances_per_image.gt_classes.to(dtype=torch.int64) + gt_classes.append(gt_classes_per_image) + + gt_masks_per_image = instances_per_image.gt_masks.crop_and_resize( + instances_per_image.proposal_boxes.tensor, mask_side_len + ).to(device=pred_mask_logits.device) + # A tensor of shape (N, M, M), N=#instances in the image; M=mask_side_len + gt_masks.append(gt_masks_per_image) + + if len(gt_masks) == 0: + return pred_mask_logits.sum() * 0 + + gt_masks = cat(gt_masks, dim=0) + + if cls_agnostic_mask: + pred_mask_logits = pred_mask_logits[:, 0] + else: + indices = torch.arange(total_num_masks) + gt_classes = cat(gt_classes, dim=0) + pred_mask_logits = pred_mask_logits[indices, gt_classes] + + if gt_masks.dtype == torch.bool: + gt_masks_bool = gt_masks + else: + # Here we allow gt_masks to be float as well (depend on the implementation of rasterize()) + gt_masks_bool = gt_masks > 0.5 + gt_masks = gt_masks.to(dtype=torch.float32) + + # Log the training accuracy (using gt classes and 0.5 threshold) + mask_incorrect = (pred_mask_logits > 0.0) != gt_masks_bool + mask_accuracy = 1 - (mask_incorrect.sum().item() / max(mask_incorrect.numel(), 1.0)) + num_positive = gt_masks_bool.sum().item() + false_positive = (mask_incorrect & ~gt_masks_bool).sum().item() / max( + gt_masks_bool.numel() - num_positive, 1.0 + ) + false_negative = (mask_incorrect & gt_masks_bool).sum().item() / max(num_positive, 1.0) + + storage = get_event_storage() + storage.put_scalar("mask_rcnn/accuracy", mask_accuracy) + storage.put_scalar("mask_rcnn/false_positive", false_positive) + storage.put_scalar("mask_rcnn/false_negative", false_negative) + if vis_period > 0 and storage.iter % vis_period == 0: + pred_masks = pred_mask_logits.sigmoid() + vis_masks = torch.cat([pred_masks, gt_masks], axis=2) + name = "Left: mask prediction; Right: mask GT" + for idx, vis_mask in enumerate(vis_masks): + vis_mask = torch.stack([vis_mask] * 3, axis=0) + storage.put_image(name + f" ({idx})", vis_mask) + + mask_loss = F.binary_cross_entropy_with_logits(pred_mask_logits, gt_masks, reduction="mean") + return mask_loss + + +def mask_rcnn_inference(pred_mask_logits: torch.Tensor, pred_instances: List[Instances]): + """ + Convert pred_mask_logits to estimated foreground probability masks while also + extracting only the masks for the predicted classes in pred_instances. For each + predicted box, the mask of the same class is attached to the instance by adding a + new "pred_masks" field to pred_instances. + + Args: + pred_mask_logits (Tensor): A tensor of shape (B, C, Hmask, Wmask) or (B, 1, Hmask, Wmask) + for class-specific or class-agnostic, where B is the total number of predicted masks + in all images, C is the number of foreground classes, and Hmask, Wmask are the height + and width of the mask predictions. The values are logits. + pred_instances (list[Instances]): A list of N Instances, where N is the number of images + in the batch. Each Instances must have field "pred_classes". + + Returns: + None. pred_instances will contain an extra "pred_masks" field storing a mask of size (Hmask, + Wmask) for predicted class. Note that the masks are returned as a soft (non-quantized) + masks the resolution predicted by the network; post-processing steps, such as resizing + the predicted masks to the original image resolution and/or binarizing them, is left + to the caller. + """ + cls_agnostic_mask = pred_mask_logits.size(1) == 1 + + if cls_agnostic_mask: + mask_probs_pred = pred_mask_logits.sigmoid() + else: + # Select masks corresponding to the predicted classes + num_masks = pred_mask_logits.shape[0] + class_pred = cat([i.pred_classes for i in pred_instances]) + device = ( + class_pred.device + if torch.jit.is_scripting() + else ("cpu" if torch.jit.is_tracing() else class_pred.device) + ) + indices = move_device_like(torch.arange(num_masks, device=device), class_pred) + mask_probs_pred = pred_mask_logits[indices, class_pred][:, None].sigmoid() + # mask_probs_pred.shape: (B, 1, Hmask, Wmask) + + num_boxes_per_image = [len(i) for i in pred_instances] + mask_probs_pred = mask_probs_pred.split(num_boxes_per_image, dim=0) + + for prob, instances in zip(mask_probs_pred, pred_instances): + instances.pred_masks = prob # (1, Hmask, Wmask) + + +class BaseMaskRCNNHead(nn.Module): + """ + Implement the basic Mask R-CNN losses and inference logic described in :paper:`Mask R-CNN` + """ + + @configurable + def __init__(self, *, loss_weight: float = 1.0, vis_period: int = 0): + """ + NOTE: this interface is experimental. + + Args: + loss_weight (float): multiplier of the loss + vis_period (int): visualization period + """ + super().__init__() + self.vis_period = vis_period + self.loss_weight = loss_weight + + @classmethod + def from_config(cls, cfg, input_shape): + return {"vis_period": cfg.VIS_PERIOD} + + def forward(self, x, instances: List[Instances]): + """ + Args: + x: input region feature(s) provided by :class:`ROIHeads`. + instances (list[Instances]): contains the boxes & labels corresponding + to the input features. + Exact format is up to its caller to decide. + Typically, this is the foreground instances in training, with + "proposal_boxes" field and other gt annotations. + In inference, it contains boxes that are already predicted. + + Returns: + A dict of losses in training. The predicted "instances" in inference. + """ + x = self.layers(x) + if self.training: + return {"loss_mask": mask_rcnn_loss(x, instances, self.vis_period) * self.loss_weight} + else: + mask_rcnn_inference(x, instances) + return instances + + def layers(self, x): + """ + Neural network layers that makes predictions from input features. + """ + raise NotImplementedError + + +# To get torchscript support, we make the head a subclass of `nn.Sequential`. +# Therefore, to add new layers in this head class, please make sure they are +# added in the order they will be used in forward(). +@ROI_MASK_HEAD_REGISTRY.register() +class MaskRCNNConvUpsampleHead(BaseMaskRCNNHead, nn.Sequential): + """ + A mask head with several conv layers, plus an upsample layer (with `ConvTranspose2d`). + Predictions are made with a final 1x1 conv layer. + """ + + @configurable + def __init__(self, input_shape: ShapeSpec, *, num_classes, conv_dims, conv_norm="", **kwargs): + """ + NOTE: this interface is experimental. + + Args: + input_shape (ShapeSpec): shape of the input feature + num_classes (int): the number of foreground classes (i.e. background is not + included). 1 if using class agnostic prediction. + conv_dims (list[int]): a list of N>0 integers representing the output dimensions + of N-1 conv layers and the last upsample layer. + conv_norm (str or callable): normalization for the conv layers. + See :func:`detectron2.layers.get_norm` for supported types. + """ + super().__init__(**kwargs) + assert len(conv_dims) >= 1, "conv_dims have to be non-empty!" + + self.conv_norm_relus = [] + + cur_channels = input_shape.channels + for k, conv_dim in enumerate(conv_dims[:-1]): + conv = Conv2d( + cur_channels, + conv_dim, + kernel_size=3, + stride=1, + padding=1, + bias=not conv_norm, + norm=get_norm(conv_norm, conv_dim), + activation=nn.ReLU(), + ) + self.add_module("mask_fcn{}".format(k + 1), conv) + self.conv_norm_relus.append(conv) + cur_channels = conv_dim + + self.deconv = ConvTranspose2d( + cur_channels, conv_dims[-1], kernel_size=2, stride=2, padding=0 + ) + self.add_module("deconv_relu", nn.ReLU()) + cur_channels = conv_dims[-1] + + self.predictor = Conv2d(cur_channels, num_classes, kernel_size=1, stride=1, padding=0) + + for layer in self.conv_norm_relus + [self.deconv]: + weight_init.c2_msra_fill(layer) + # use normal distribution initialization for mask prediction layer + nn.init.normal_(self.predictor.weight, std=0.001) + if self.predictor.bias is not None: + nn.init.constant_(self.predictor.bias, 0) + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg, input_shape) + conv_dim = cfg.MODEL.ROI_MASK_HEAD.CONV_DIM + num_conv = cfg.MODEL.ROI_MASK_HEAD.NUM_CONV + ret.update( + conv_dims=[conv_dim] * (num_conv + 1), # +1 for ConvTranspose + conv_norm=cfg.MODEL.ROI_MASK_HEAD.NORM, + input_shape=input_shape, + ) + if cfg.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK: + ret["num_classes"] = 1 + else: + ret["num_classes"] = cfg.MODEL.ROI_HEADS.NUM_CLASSES + return ret + + def layers(self, x): + for layer in self: + x = layer(x) + return x + + +def build_mask_head(cfg, input_shape): + """ + Build a mask head defined by `cfg.MODEL.ROI_MASK_HEAD.NAME`. + """ + name = cfg.MODEL.ROI_MASK_HEAD.NAME + return ROI_MASK_HEAD_REGISTRY.get(name)(cfg, input_shape) diff --git a/custom_detectron2/modeling/roi_heads/roi_heads.py b/custom_detectron2/modeling/roi_heads/roi_heads.py new file mode 100644 index 0000000000000000000000000000000000000000..1722ac72dc07f142e6b56e4996c975edf5a79973 --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/roi_heads.py @@ -0,0 +1,877 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import inspect +import logging +import numpy as np +from typing import Dict, List, Optional, Tuple +import torch +from torch import nn + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ShapeSpec, nonzero_tuple +from custom_detectron2.structures import Boxes, ImageList, Instances, pairwise_iou +from custom_detectron2.utils.events import get_event_storage +from custom_detectron2.utils.registry import Registry + +from ..backbone.resnet import BottleneckBlock, ResNet +from ..matcher import Matcher +from ..poolers import ROIPooler +from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals +from ..sampling import subsample_labels +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers +from .keypoint_head import build_keypoint_head +from .mask_head import build_mask_head + +ROI_HEADS_REGISTRY = Registry("ROI_HEADS") +ROI_HEADS_REGISTRY.__doc__ = """ +Registry for ROI heads in a generalized R-CNN model. +ROIHeads take feature maps and region proposals, and +perform per-region computation. + +The registered object will be called with `obj(cfg, input_shape)`. +The call is expected to return an :class:`ROIHeads`. +""" + +logger = logging.getLogger(__name__) + + +def build_roi_heads(cfg, input_shape): + """ + Build ROIHeads defined by `cfg.MODEL.ROI_HEADS.NAME`. + """ + name = cfg.MODEL.ROI_HEADS.NAME + return ROI_HEADS_REGISTRY.get(name)(cfg, input_shape) + + +def select_foreground_proposals( + proposals: List[Instances], bg_label: int +) -> Tuple[List[Instances], List[torch.Tensor]]: + """ + Given a list of N Instances (for N images), each containing a `gt_classes` field, + return a list of Instances that contain only instances with `gt_classes != -1 && + gt_classes != bg_label`. + + Args: + proposals (list[Instances]): A list of N Instances, where N is the number of + images in the batch. + bg_label: label index of background class. + + Returns: + list[Instances]: N Instances, each contains only the selected foreground instances. + list[Tensor]: N boolean vector, correspond to the selection mask of + each Instances object. True for selected instances. + """ + assert isinstance(proposals, (list, tuple)) + assert isinstance(proposals[0], Instances) + assert proposals[0].has("gt_classes") + fg_proposals = [] + fg_selection_masks = [] + for proposals_per_image in proposals: + gt_classes = proposals_per_image.gt_classes + fg_selection_mask = (gt_classes != -1) & (gt_classes != bg_label) + fg_idxs = fg_selection_mask.nonzero().squeeze(1) + fg_proposals.append(proposals_per_image[fg_idxs]) + fg_selection_masks.append(fg_selection_mask) + return fg_proposals, fg_selection_masks + + +def select_proposals_with_visible_keypoints(proposals: List[Instances]) -> List[Instances]: + """ + Args: + proposals (list[Instances]): a list of N Instances, where N is the + number of images. + + Returns: + proposals: only contains proposals with at least one visible keypoint. + + Note that this is still slightly different from Detectron. + In Detectron, proposals for training keypoint head are re-sampled from + all the proposals with IOU>threshold & >=1 visible keypoint. + + Here, the proposals are first sampled from all proposals with + IOU>threshold, then proposals with no visible keypoint are filtered out. + This strategy seems to make no difference on Detectron and is easier to implement. + """ + ret = [] + all_num_fg = [] + for proposals_per_image in proposals: + # If empty/unannotated image (hard negatives), skip filtering for train + if len(proposals_per_image) == 0: + ret.append(proposals_per_image) + continue + gt_keypoints = proposals_per_image.gt_keypoints.tensor + # #fg x K x 3 + vis_mask = gt_keypoints[:, :, 2] >= 1 + xs, ys = gt_keypoints[:, :, 0], gt_keypoints[:, :, 1] + proposal_boxes = proposals_per_image.proposal_boxes.tensor.unsqueeze(dim=1) # #fg x 1 x 4 + kp_in_box = ( + (xs >= proposal_boxes[:, :, 0]) + & (xs <= proposal_boxes[:, :, 2]) + & (ys >= proposal_boxes[:, :, 1]) + & (ys <= proposal_boxes[:, :, 3]) + ) + selection = (kp_in_box & vis_mask).any(dim=1) + selection_idxs = nonzero_tuple(selection)[0] + all_num_fg.append(selection_idxs.numel()) + ret.append(proposals_per_image[selection_idxs]) + + storage = get_event_storage() + storage.put_scalar("keypoint_head/num_fg_samples", np.mean(all_num_fg)) + return ret + + +class ROIHeads(torch.nn.Module): + """ + ROIHeads perform all per-region computation in an R-CNN. + + It typically contains logic to + + 1. (in training only) match proposals with ground truth and sample them + 2. crop the regions and extract per-region features using proposals + 3. make per-region predictions with different heads + + It can have many variants, implemented as subclasses of this class. + This base class contains the logic to match/sample proposals. + But it is not necessary to inherit this class if the sampling logic is not needed. + """ + + @configurable + def __init__( + self, + *, + num_classes, + batch_size_per_image, + positive_fraction, + proposal_matcher, + proposal_append_gt=True, + ): + """ + NOTE: this interface is experimental. + + Args: + num_classes (int): number of foreground classes (i.e. background is not included) + batch_size_per_image (int): number of proposals to sample for training + positive_fraction (float): fraction of positive (foreground) proposals + to sample for training. + proposal_matcher (Matcher): matcher that matches proposals and ground truth + proposal_append_gt (bool): whether to include ground truth as proposals as well + """ + super().__init__() + self.batch_size_per_image = batch_size_per_image + self.positive_fraction = positive_fraction + self.num_classes = num_classes + self.proposal_matcher = proposal_matcher + self.proposal_append_gt = proposal_append_gt + + @classmethod + def from_config(cls, cfg): + return { + "batch_size_per_image": cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, + "positive_fraction": cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION, + "num_classes": cfg.MODEL.ROI_HEADS.NUM_CLASSES, + "proposal_append_gt": cfg.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT, + # Matcher to assign box proposals to gt boxes + "proposal_matcher": Matcher( + cfg.MODEL.ROI_HEADS.IOU_THRESHOLDS, + cfg.MODEL.ROI_HEADS.IOU_LABELS, + allow_low_quality_matches=False, + ), + } + + def _sample_proposals( + self, matched_idxs: torch.Tensor, matched_labels: torch.Tensor, gt_classes: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Based on the matching between N proposals and M groundtruth, + sample the proposals and set their classification labels. + + Args: + matched_idxs (Tensor): a vector of length N, each is the best-matched + gt index in [0, M) for each proposal. + matched_labels (Tensor): a vector of length N, the matcher's label + (one of cfg.MODEL.ROI_HEADS.IOU_LABELS) for each proposal. + gt_classes (Tensor): a vector of length M. + + Returns: + Tensor: a vector of indices of sampled proposals. Each is in [0, N). + Tensor: a vector of the same length, the classification label for + each sampled proposal. Each sample is labeled as either a category in + [0, num_classes) or the background (num_classes). + """ + has_gt = gt_classes.numel() > 0 + # Get the corresponding GT for each proposal + if has_gt: + gt_classes = gt_classes[matched_idxs] + # Label unmatched proposals (0 label from matcher) as background (label=num_classes) + gt_classes[matched_labels == 0] = self.num_classes + # Label ignore proposals (-1 label) + gt_classes[matched_labels == -1] = -1 + else: + gt_classes = torch.zeros_like(matched_idxs) + self.num_classes + + sampled_fg_idxs, sampled_bg_idxs = subsample_labels( + gt_classes, self.batch_size_per_image, self.positive_fraction, self.num_classes + ) + + sampled_idxs = torch.cat([sampled_fg_idxs, sampled_bg_idxs], dim=0) + return sampled_idxs, gt_classes[sampled_idxs] + + @torch.no_grad() + def label_and_sample_proposals( + self, proposals: List[Instances], targets: List[Instances] + ) -> List[Instances]: + """ + Prepare some proposals to be used to train the ROI heads. + It performs box matching between `proposals` and `targets`, and assigns + training labels to the proposals. + It returns ``self.batch_size_per_image`` random samples from proposals and groundtruth + boxes, with a fraction of positives that is no larger than + ``self.positive_fraction``. + + Args: + See :meth:`ROIHeads.forward` + + Returns: + list[Instances]: + length `N` list of `Instances`s containing the proposals + sampled for training. Each `Instances` has the following fields: + + - proposal_boxes: the proposal boxes + - gt_boxes: the ground-truth box that the proposal is assigned to + (this is only meaningful if the proposal has a label > 0; if label = 0 + then the ground-truth box is random) + + Other fields such as "gt_classes", "gt_masks", that's included in `targets`. + """ + # Augment proposals with ground-truth boxes. + # In the case of learned proposals (e.g., RPN), when training starts + # the proposals will be low quality due to random initialization. + # It's possible that none of these initial + # proposals have high enough overlap with the gt objects to be used + # as positive examples for the second stage components (box head, + # cls head, mask head). Adding the gt boxes to the set of proposals + # ensures that the second stage components will have some positive + # examples from the start of training. For RPN, this augmentation improves + # convergence and empirically improves box AP on COCO by about 0.5 + # points (under one tested configuration). + if self.proposal_append_gt: + proposals = add_ground_truth_to_proposals(targets, proposals) + + proposals_with_gt = [] + + num_fg_samples = [] + num_bg_samples = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + has_gt = len(targets_per_image) > 0 + match_quality_matrix = pairwise_iou( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix) + sampled_idxs, gt_classes = self._sample_proposals( + matched_idxs, matched_labels, targets_per_image.gt_classes + ) + + # Set target attributes of the sampled proposals: + proposals_per_image = proposals_per_image[sampled_idxs] + proposals_per_image.gt_classes = gt_classes + + if has_gt: + sampled_targets = matched_idxs[sampled_idxs] + # We index all the attributes of targets that start with "gt_" + # and have not been added to proposals yet (="gt_classes"). + # NOTE: here the indexing waste some compute, because heads + # like masks, keypoints, etc, will filter the proposals again, + # (by foreground/background, or number of keypoints in the image, etc) + # so we essentially index the data twice. + for (trg_name, trg_value) in targets_per_image.get_fields().items(): + if trg_name.startswith("gt_") and not proposals_per_image.has(trg_name): + proposals_per_image.set(trg_name, trg_value[sampled_targets]) + # If no GT is given in the image, we don't know what a dummy gt value can be. + # Therefore the returned proposals won't have any gt_* fields, except for a + # gt_classes full of background label. + + num_bg_samples.append((gt_classes == self.num_classes).sum().item()) + num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) + proposals_with_gt.append(proposals_per_image) + + # Log the number of fg/bg samples that are selected for training ROI heads + storage = get_event_storage() + storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples)) + storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples)) + + return proposals_with_gt + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ) -> Tuple[List[Instances], Dict[str, torch.Tensor]]: + """ + Args: + images (ImageList): + features (dict[str,Tensor]): input data as a mapping from feature + map name to tensor. Axis 0 represents the number of images `N` in + the input data; axes 1-3 are channels, height, and width, which may + vary between feature maps (e.g., if a feature pyramid is used). + proposals (list[Instances]): length `N` list of `Instances`. The i-th + `Instances` contains object proposals for the i-th input image, + with fields "proposal_boxes" and "objectness_logits". + targets (list[Instances], optional): length `N` list of `Instances`. The i-th + `Instances` contains the ground-truth per-instance annotations + for the i-th input image. Specify `targets` during training only. + It may have the following fields: + + - gt_boxes: the bounding box of each instance. + - gt_classes: the label for each instance with a category ranging in [0, #class]. + - gt_masks: PolygonMasks or BitMasks, the ground-truth masks of each instance. + - gt_keypoints: NxKx3, the groud-truth keypoints for each instance. + + Returns: + list[Instances]: length `N` list of `Instances` containing the + detected instances. Returned during inference only; may be [] during training. + + dict[str->Tensor]: + mapping from a named loss to a tensor storing the loss. Used during training only. + """ + raise NotImplementedError() + + +@ROI_HEADS_REGISTRY.register() +class Res5ROIHeads(ROIHeads): + """ + The ROIHeads in a typical "C4" R-CNN model, where + the box and mask head share the cropping and + the per-region feature computation by a Res5 block. + See :paper:`ResNet` Appendix A. + """ + + @configurable + def __init__( + self, + *, + in_features: List[str], + pooler: ROIPooler, + res5: nn.Module, + box_predictor: nn.Module, + mask_head: Optional[nn.Module] = None, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + in_features (list[str]): list of backbone feature map names to use for + feature extraction + pooler (ROIPooler): pooler to extra region features from backbone + res5 (nn.Sequential): a CNN to compute per-region features, to be used by + ``box_predictor`` and ``mask_head``. Typically this is a "res5" + block from a ResNet. + box_predictor (nn.Module): make box predictions from the feature. + Should have the same interface as :class:`FastRCNNOutputLayers`. + mask_head (nn.Module): transform features to make mask predictions + """ + super().__init__(**kwargs) + self.in_features = in_features + self.pooler = pooler + if isinstance(res5, (list, tuple)): + res5 = nn.Sequential(*res5) + self.res5 = res5 + self.box_predictor = box_predictor + self.mask_on = mask_head is not None + if self.mask_on: + self.mask_head = mask_head + + @classmethod + def from_config(cls, cfg, input_shape): + # fmt: off + ret = super().from_config(cfg) + in_features = ret["in_features"] = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + pooler_scales = (1.0 / input_shape[in_features[0]].stride, ) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + mask_on = cfg.MODEL.MASK_ON + # fmt: on + assert not cfg.MODEL.KEYPOINT_ON + assert len(in_features) == 1 + + ret["pooler"] = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + + # Compatbility with old moco code. Might be useful. + # See notes in StandardROIHeads.from_config + if not inspect.ismethod(cls._build_res5_block): + logger.warning( + "The behavior of _build_res5_block may change. " + "Please do not depend on private methods." + ) + cls._build_res5_block = classmethod(cls._build_res5_block) + + ret["res5"], out_channels = cls._build_res5_block(cfg) + ret["box_predictor"] = FastRCNNOutputLayers( + cfg, ShapeSpec(channels=out_channels, height=1, width=1) + ) + + if mask_on: + ret["mask_head"] = build_mask_head( + cfg, + ShapeSpec(channels=out_channels, width=pooler_resolution, height=pooler_resolution), + ) + return ret + + @classmethod + def _build_res5_block(cls, cfg): + # fmt: off + stage_channel_factor = 2 ** 3 # res5 is 8x res2 + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group * stage_channel_factor + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS * stage_channel_factor + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + norm = cfg.MODEL.RESNETS.NORM + assert not cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE[-1], \ + "Deformable conv is not yet supported in res5 head." + # fmt: on + + blocks = ResNet.make_stage( + BottleneckBlock, + 3, + stride_per_block=[2, 1, 1], + in_channels=out_channels // 2, + bottleneck_channels=bottleneck_channels, + out_channels=out_channels, + num_groups=num_groups, + norm=norm, + stride_in_1x1=stride_in_1x1, + ) + return nn.Sequential(*blocks), out_channels + + def _shared_roi_transform(self, features: List[torch.Tensor], boxes: List[Boxes]): + x = self.pooler(features, boxes) + return self.res5(x) + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ): + """ + See :meth:`ROIHeads.forward`. + """ + del images + + if self.training: + assert targets + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + proposal_boxes = [x.proposal_boxes for x in proposals] + box_features = self._shared_roi_transform( + [features[f] for f in self.in_features], proposal_boxes + ) + predictions = self.box_predictor(box_features.mean(dim=[2, 3])) + + if self.training: + del features + losses = self.box_predictor.losses(predictions, proposals) + if self.mask_on: + proposals, fg_selection_masks = select_foreground_proposals( + proposals, self.num_classes + ) + # Since the ROI feature transform is shared between boxes and masks, + # we don't need to recompute features. The mask loss is only defined + # on foreground proposals, so we need to select out the foreground + # features. + mask_features = box_features[torch.cat(fg_selection_masks, dim=0)] + del box_features + losses.update(self.mask_head(mask_features, proposals)) + return [], losses + else: + pred_instances, _ = self.box_predictor.inference(predictions, proposals) + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def forward_with_given_boxes( + self, features: Dict[str, torch.Tensor], instances: List[Instances] + ) -> List[Instances]: + """ + Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. + + Args: + features: same as in `forward()` + instances (list[Instances]): instances to predict other outputs. Expect the keys + "pred_boxes" and "pred_classes" to exist. + + Returns: + instances (Instances): + the same `Instances` object, with extra + fields such as `pred_masks` or `pred_keypoints`. + """ + assert not self.training + assert instances[0].has("pred_boxes") and instances[0].has("pred_classes") + + if self.mask_on: + feature_list = [features[f] for f in self.in_features] + x = self._shared_roi_transform(feature_list, [x.pred_boxes for x in instances]) + return self.mask_head(x, instances) + else: + return instances + + +@ROI_HEADS_REGISTRY.register() +class StandardROIHeads(ROIHeads): + """ + It's "standard" in a sense that there is no ROI transform sharing + or feature sharing between tasks. + Each head independently processes the input features by each head's + own pooler and head. + + This class is used by most models, such as FPN and C5. + To implement more models, you can subclass it and implement a different + :meth:`forward()` or a head. + """ + + @configurable + def __init__( + self, + *, + box_in_features: List[str], + box_pooler: ROIPooler, + box_head: nn.Module, + box_predictor: nn.Module, + mask_in_features: Optional[List[str]] = None, + mask_pooler: Optional[ROIPooler] = None, + mask_head: Optional[nn.Module] = None, + keypoint_in_features: Optional[List[str]] = None, + keypoint_pooler: Optional[ROIPooler] = None, + keypoint_head: Optional[nn.Module] = None, + train_on_pred_boxes: bool = False, + **kwargs, + ): + """ + NOTE: this interface is experimental. + + Args: + box_in_features (list[str]): list of feature names to use for the box head. + box_pooler (ROIPooler): pooler to extra region features for box head + box_head (nn.Module): transform features to make box predictions + box_predictor (nn.Module): make box predictions from the feature. + Should have the same interface as :class:`FastRCNNOutputLayers`. + mask_in_features (list[str]): list of feature names to use for the mask + pooler or mask head. None if not using mask head. + mask_pooler (ROIPooler): pooler to extract region features from image features. + The mask head will then take region features to make predictions. + If None, the mask head will directly take the dict of image features + defined by `mask_in_features` + mask_head (nn.Module): transform features to make mask predictions + keypoint_in_features, keypoint_pooler, keypoint_head: similar to ``mask_*``. + train_on_pred_boxes (bool): whether to use proposal boxes or + predicted boxes from the box head to train other heads. + """ + super().__init__(**kwargs) + # keep self.in_features for backward compatibility + self.in_features = self.box_in_features = box_in_features + self.box_pooler = box_pooler + self.box_head = box_head + self.box_predictor = box_predictor + + self.mask_on = mask_in_features is not None + if self.mask_on: + self.mask_in_features = mask_in_features + self.mask_pooler = mask_pooler + self.mask_head = mask_head + + self.keypoint_on = keypoint_in_features is not None + if self.keypoint_on: + self.keypoint_in_features = keypoint_in_features + self.keypoint_pooler = keypoint_pooler + self.keypoint_head = keypoint_head + + self.train_on_pred_boxes = train_on_pred_boxes + + @classmethod + def from_config(cls, cfg, input_shape): + ret = super().from_config(cfg) + ret["train_on_pred_boxes"] = cfg.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES + # Subclasses that have not been updated to use from_config style construction + # may have overridden _init_*_head methods. In this case, those overridden methods + # will not be classmethods and we need to avoid trying to call them here. + # We test for this with ismethod which only returns True for bound methods of cls. + # Such subclasses will need to handle calling their overridden _init_*_head methods. + if inspect.ismethod(cls._init_box_head): + ret.update(cls._init_box_head(cfg, input_shape)) + if inspect.ismethod(cls._init_mask_head): + ret.update(cls._init_mask_head(cfg, input_shape)) + if inspect.ismethod(cls._init_keypoint_head): + ret.update(cls._init_keypoint_head(cfg, input_shape)) + return ret + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + # fmt: on + + # If StandardROIHeads is applied on multiple feature maps (as in FPN), + # then we share the same predictors and therefore the channel counts must be the same + in_channels = [input_shape[f].channels for f in in_features] + # Check all channel counts are equal + assert len(set(in_channels)) == 1, in_channels + in_channels = in_channels[0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + # Here we split "box head" and "box predictor", which is mainly due to historical reasons. + # They are used together so the "box predictor" layers should be part of the "box head". + # New subclasses of ROIHeads do not need "box predictor"s. + box_head = build_box_head( + cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution) + ) + box_predictor = FastRCNNOutputLayers(cfg, box_head.output_shape) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_head": box_head, + "box_predictor": box_predictor, + } + + @classmethod + def _init_mask_head(cls, cfg, input_shape): + if not cfg.MODEL.MASK_ON: + return {} + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_MASK_HEAD.POOLER_TYPE + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features][0] + + ret = {"mask_in_features": in_features} + ret["mask_pooler"] = ( + ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + if pooler_type + else None + ) + if pooler_type: + shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + else: + shape = {f: input_shape[f] for f in in_features} + ret["mask_head"] = build_mask_head(cfg, shape) + return ret + + @classmethod + def _init_keypoint_head(cls, cfg, input_shape): + if not cfg.MODEL.KEYPOINT_ON: + return {} + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) # noqa + sampling_ratio = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE + # fmt: on + + in_channels = [input_shape[f].channels for f in in_features][0] + + ret = {"keypoint_in_features": in_features} + ret["keypoint_pooler"] = ( + ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + if pooler_type + else None + ) + if pooler_type: + shape = ShapeSpec( + channels=in_channels, width=pooler_resolution, height=pooler_resolution + ) + else: + shape = {f: input_shape[f] for f in in_features} + ret["keypoint_head"] = build_keypoint_head(cfg, shape) + return ret + + def forward( + self, + images: ImageList, + features: Dict[str, torch.Tensor], + proposals: List[Instances], + targets: Optional[List[Instances]] = None, + ) -> Tuple[List[Instances], Dict[str, torch.Tensor]]: + """ + See :class:`ROIHeads.forward`. + """ + del images + if self.training: + assert targets, "'targets' argument is required during training" + proposals = self.label_and_sample_proposals(proposals, targets) + del targets + + if self.training: + losses = self._forward_box(features, proposals) + # Usually the original proposals used by the box head are used by the mask, keypoint + # heads. But when `self.train_on_pred_boxes is True`, proposals will contain boxes + # predicted by the box head. + losses.update(self._forward_mask(features, proposals)) + losses.update(self._forward_keypoint(features, proposals)) + return proposals, losses + else: + pred_instances = self._forward_box(features, proposals) + # During inference cascaded prediction is used: the mask and keypoints heads are only + # applied to the top scoring box detections. + pred_instances = self.forward_with_given_boxes(features, pred_instances) + return pred_instances, {} + + def forward_with_given_boxes( + self, features: Dict[str, torch.Tensor], instances: List[Instances] + ) -> List[Instances]: + """ + Use the given boxes in `instances` to produce other (non-box) per-ROI outputs. + + This is useful for downstream tasks where a box is known, but need to obtain + other attributes (outputs of other heads). + Test-time augmentation also uses this. + + Args: + features: same as in `forward()` + instances (list[Instances]): instances to predict other outputs. Expect the keys + "pred_boxes" and "pred_classes" to exist. + + Returns: + list[Instances]: + the same `Instances` objects, with extra + fields such as `pred_masks` or `pred_keypoints`. + """ + assert not self.training + assert instances[0].has("pred_boxes") and instances[0].has("pred_classes") + + instances = self._forward_mask(features, instances) + instances = self._forward_keypoint(features, instances) + return instances + + def _forward_box(self, features: Dict[str, torch.Tensor], proposals: List[Instances]): + """ + Forward logic of the box prediction branch. If `self.train_on_pred_boxes is True`, + the function puts predicted boxes in the `proposal_boxes` field of `proposals` argument. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + proposals (list[Instances]): the per-image object proposals with + their matching ground truth. + Each has fields "proposal_boxes", and "objectness_logits", + "gt_classes", "gt_boxes". + + Returns: + In training, a dict of losses. + In inference, a list of `Instances`, the predicted instances. + """ + features = [features[f] for f in self.box_in_features] + box_features = self.box_pooler(features, [x.proposal_boxes for x in proposals]) + box_features = self.box_head(box_features) + predictions = self.box_predictor(box_features) + del box_features + + if self.training: + losses = self.box_predictor.losses(predictions, proposals) + # proposals is modified in-place below, so losses must be computed first. + if self.train_on_pred_boxes: + with torch.no_grad(): + pred_boxes = self.box_predictor.predict_boxes_for_gt_classes( + predictions, proposals + ) + for proposals_per_image, pred_boxes_per_image in zip(proposals, pred_boxes): + proposals_per_image.proposal_boxes = Boxes(pred_boxes_per_image) + return losses + else: + pred_instances, _ = self.box_predictor.inference(predictions, proposals) + return pred_instances + + def _forward_mask(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + """ + Forward logic of the mask prediction branch. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + instances (list[Instances]): the per-image instances to train/predict masks. + In training, they can be the proposals. + In inference, they can be the boxes predicted by R-CNN box head. + + Returns: + In training, a dict of losses. + In inference, update `instances` with new fields "pred_masks" and return it. + """ + if not self.mask_on: + return {} if self.training else instances + + if self.training: + # head is only trained on positive proposals. + instances, _ = select_foreground_proposals(instances, self.num_classes) + + if self.mask_pooler is not None: + features = [features[f] for f in self.mask_in_features] + boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances] + features = self.mask_pooler(features, boxes) + else: + features = {f: features[f] for f in self.mask_in_features} + return self.mask_head(features, instances) + + def _forward_keypoint(self, features: Dict[str, torch.Tensor], instances: List[Instances]): + """ + Forward logic of the keypoint prediction branch. + + Args: + features (dict[str, Tensor]): mapping from feature map names to tensor. + Same as in :meth:`ROIHeads.forward`. + instances (list[Instances]): the per-image instances to train/predict keypoints. + In training, they can be the proposals. + In inference, they can be the boxes predicted by R-CNN box head. + + Returns: + In training, a dict of losses. + In inference, update `instances` with new fields "pred_keypoints" and return it. + """ + if not self.keypoint_on: + return {} if self.training else instances + + if self.training: + # head is only trained on positive proposals with >=1 visible keypoints. + instances, _ = select_foreground_proposals(instances, self.num_classes) + instances = select_proposals_with_visible_keypoints(instances) + + if self.keypoint_pooler is not None: + features = [features[f] for f in self.keypoint_in_features] + boxes = [x.proposal_boxes if self.training else x.pred_boxes for x in instances] + features = self.keypoint_pooler(features, boxes) + else: + features = {f: features[f] for f in self.keypoint_in_features} + return self.keypoint_head(features, instances) diff --git a/custom_detectron2/modeling/roi_heads/rotated_fast_rcnn.py b/custom_detectron2/modeling/roi_heads/rotated_fast_rcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..1e29bb9b9cd44420bb3cc4136dd51c4c848b32c1 --- /dev/null +++ b/custom_detectron2/modeling/roi_heads/rotated_fast_rcnn.py @@ -0,0 +1,271 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import numpy as np +import torch + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ShapeSpec, batched_nms_rotated +from custom_detectron2.structures import Instances, RotatedBoxes, pairwise_iou_rotated +from custom_detectron2.utils.events import get_event_storage + +from ..box_regression import Box2BoxTransformRotated +from ..poolers import ROIPooler +from ..proposal_generator.proposal_utils import add_ground_truth_to_proposals +from .box_head import build_box_head +from .fast_rcnn import FastRCNNOutputLayers +from .roi_heads import ROI_HEADS_REGISTRY, StandardROIHeads + +logger = logging.getLogger(__name__) + +""" +Shape shorthand in this module: + + N: number of images in the minibatch + R: number of ROIs, combined over all images, in the minibatch + Ri: number of ROIs in image i + K: number of foreground classes. E.g.,there are 80 foreground classes in COCO. + +Naming convention: + + deltas: refers to the 5-d (dx, dy, dw, dh, da) deltas that parameterize the box2box + transform (see :class:`box_regression.Box2BoxTransformRotated`). + + pred_class_logits: predicted class scores in [-inf, +inf]; use + softmax(pred_class_logits) to estimate P(class). + + gt_classes: ground-truth classification labels in [0, K], where [0, K) represent + foreground object classes and K represents the background class. + + pred_proposal_deltas: predicted rotated box2box transform deltas for transforming proposals + to detection box predictions. + + gt_proposal_deltas: ground-truth rotated box2box transform deltas +""" + + +def fast_rcnn_inference_rotated( + boxes, scores, image_shapes, score_thresh, nms_thresh, topk_per_image +): + """ + Call `fast_rcnn_inference_single_image_rotated` for all images. + + Args: + boxes (list[Tensor]): A list of Tensors of predicted class-specific or class-agnostic + boxes for each image. Element i has shape (Ri, K * 5) if doing + class-specific regression, or (Ri, 5) if doing class-agnostic + regression, where Ri is the number of predicted objects for image i. + This is compatible with the output of :meth:`FastRCNNOutputLayers.predict_boxes`. + scores (list[Tensor]): A list of Tensors of predicted class scores for each image. + Element i has shape (Ri, K + 1), where Ri is the number of predicted objects + for image i. Compatible with the output of :meth:`FastRCNNOutputLayers.predict_probs`. + image_shapes (list[tuple]): A list of (width, height) tuples for each image in the batch. + score_thresh (float): Only return detections with a confidence score exceeding this + threshold. + nms_thresh (float): The threshold to use for box non-maximum suppression. Value in [0, 1]. + topk_per_image (int): The number of top scoring detections to return. Set < 0 to return + all detections. + + Returns: + instances: (list[Instances]): A list of N instances, one for each image in the batch, + that stores the topk most confidence detections. + kept_indices: (list[Tensor]): A list of 1D tensor of length of N, each element indicates + the corresponding boxes/scores index in [0, Ri) from the input, for image i. + """ + result_per_image = [ + fast_rcnn_inference_single_image_rotated( + boxes_per_image, scores_per_image, image_shape, score_thresh, nms_thresh, topk_per_image + ) + for scores_per_image, boxes_per_image, image_shape in zip(scores, boxes, image_shapes) + ] + return [x[0] for x in result_per_image], [x[1] for x in result_per_image] + + +@torch.no_grad() +def fast_rcnn_inference_single_image_rotated( + boxes, scores, image_shape, score_thresh, nms_thresh, topk_per_image +): + """ + Single-image inference. Return rotated bounding-box detection results by thresholding + on scores and applying rotated non-maximum suppression (Rotated NMS). + + Args: + Same as `fast_rcnn_inference_rotated`, but with rotated boxes, scores, and image shapes + per image. + + Returns: + Same as `fast_rcnn_inference_rotated`, but for only one image. + """ + valid_mask = torch.isfinite(boxes).all(dim=1) & torch.isfinite(scores).all(dim=1) + if not valid_mask.all(): + boxes = boxes[valid_mask] + scores = scores[valid_mask] + + B = 5 # box dimension + scores = scores[:, :-1] + num_bbox_reg_classes = boxes.shape[1] // B + # Convert to Boxes to use the `clip` function ... + boxes = RotatedBoxes(boxes.reshape(-1, B)) + boxes.clip(image_shape) + boxes = boxes.tensor.view(-1, num_bbox_reg_classes, B) # R x C x B + # Filter results based on detection scores + filter_mask = scores > score_thresh # R x K + # R' x 2. First column contains indices of the R predictions; + # Second column contains indices of classes. + filter_inds = filter_mask.nonzero() + if num_bbox_reg_classes == 1: + boxes = boxes[filter_inds[:, 0], 0] + else: + boxes = boxes[filter_mask] + scores = scores[filter_mask] + + # Apply per-class Rotated NMS + keep = batched_nms_rotated(boxes, scores, filter_inds[:, 1], nms_thresh) + if topk_per_image >= 0: + keep = keep[:topk_per_image] + boxes, scores, filter_inds = boxes[keep], scores[keep], filter_inds[keep] + + result = Instances(image_shape) + result.pred_boxes = RotatedBoxes(boxes) + result.scores = scores + result.pred_classes = filter_inds[:, 1] + + return result, filter_inds[:, 0] + + +class RotatedFastRCNNOutputLayers(FastRCNNOutputLayers): + """ + Two linear layers for predicting Rotated Fast R-CNN outputs. + """ + + @classmethod + def from_config(cls, cfg, input_shape): + args = super().from_config(cfg, input_shape) + args["box2box_transform"] = Box2BoxTransformRotated( + weights=cfg.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS + ) + return args + + def inference(self, predictions, proposals): + """ + Returns: + list[Instances]: same as `fast_rcnn_inference_rotated`. + list[Tensor]: same as `fast_rcnn_inference_rotated`. + """ + boxes = self.predict_boxes(predictions, proposals) + scores = self.predict_probs(predictions, proposals) + image_shapes = [x.image_size for x in proposals] + + return fast_rcnn_inference_rotated( + boxes, + scores, + image_shapes, + self.test_score_thresh, + self.test_nms_thresh, + self.test_topk_per_image, + ) + + +@ROI_HEADS_REGISTRY.register() +class RROIHeads(StandardROIHeads): + """ + This class is used by Rotated Fast R-CNN to detect rotated boxes. + For now, it only supports box predictions but not mask or keypoints. + """ + + @configurable + def __init__(self, **kwargs): + """ + NOTE: this interface is experimental. + """ + super().__init__(**kwargs) + assert ( + not self.mask_on and not self.keypoint_on + ), "Mask/Keypoints not supported in Rotated ROIHeads." + assert not self.train_on_pred_boxes, "train_on_pred_boxes not implemented for RROIHeads!" + + @classmethod + def _init_box_head(cls, cfg, input_shape): + # fmt: off + in_features = cfg.MODEL.ROI_HEADS.IN_FEATURES + pooler_resolution = cfg.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION + pooler_scales = tuple(1.0 / input_shape[k].stride for k in in_features) + sampling_ratio = cfg.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO + pooler_type = cfg.MODEL.ROI_BOX_HEAD.POOLER_TYPE + # fmt: on + assert pooler_type in ["ROIAlignRotated"], pooler_type + # assume all channel counts are equal + in_channels = [input_shape[f].channels for f in in_features][0] + + box_pooler = ROIPooler( + output_size=pooler_resolution, + scales=pooler_scales, + sampling_ratio=sampling_ratio, + pooler_type=pooler_type, + ) + box_head = build_box_head( + cfg, ShapeSpec(channels=in_channels, height=pooler_resolution, width=pooler_resolution) + ) + # This line is the only difference v.s. StandardROIHeads + box_predictor = RotatedFastRCNNOutputLayers(cfg, box_head.output_shape) + return { + "box_in_features": in_features, + "box_pooler": box_pooler, + "box_head": box_head, + "box_predictor": box_predictor, + } + + @torch.no_grad() + def label_and_sample_proposals(self, proposals, targets): + """ + Prepare some proposals to be used to train the RROI heads. + It performs box matching between `proposals` and `targets`, and assigns + training labels to the proposals. + It returns `self.batch_size_per_image` random samples from proposals and groundtruth boxes, + with a fraction of positives that is no larger than `self.positive_sample_fraction. + + Args: + See :meth:`StandardROIHeads.forward` + + Returns: + list[Instances]: length `N` list of `Instances`s containing the proposals + sampled for training. Each `Instances` has the following fields: + - proposal_boxes: the rotated proposal boxes + - gt_boxes: the ground-truth rotated boxes that the proposal is assigned to + (this is only meaningful if the proposal has a label > 0; if label = 0 + then the ground-truth box is random) + - gt_classes: the ground-truth classification lable for each proposal + """ + if self.proposal_append_gt: + proposals = add_ground_truth_to_proposals(targets, proposals) + + proposals_with_gt = [] + + num_fg_samples = [] + num_bg_samples = [] + for proposals_per_image, targets_per_image in zip(proposals, targets): + has_gt = len(targets_per_image) > 0 + match_quality_matrix = pairwise_iou_rotated( + targets_per_image.gt_boxes, proposals_per_image.proposal_boxes + ) + matched_idxs, matched_labels = self.proposal_matcher(match_quality_matrix) + sampled_idxs, gt_classes = self._sample_proposals( + matched_idxs, matched_labels, targets_per_image.gt_classes + ) + + proposals_per_image = proposals_per_image[sampled_idxs] + proposals_per_image.gt_classes = gt_classes + + if has_gt: + sampled_targets = matched_idxs[sampled_idxs] + proposals_per_image.gt_boxes = targets_per_image.gt_boxes[sampled_targets] + + num_bg_samples.append((gt_classes == self.num_classes).sum().item()) + num_fg_samples.append(gt_classes.numel() - num_bg_samples[-1]) + proposals_with_gt.append(proposals_per_image) + + # Log the number of fg/bg samples that are selected for training ROI heads + storage = get_event_storage() + storage.put_scalar("roi_head/num_fg_samples", np.mean(num_fg_samples)) + storage.put_scalar("roi_head/num_bg_samples", np.mean(num_bg_samples)) + + return proposals_with_gt diff --git a/custom_detectron2/modeling/sampling.py b/custom_detectron2/modeling/sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..e143cf8073bd22f86b5d18de8e8c336874d3cda8 --- /dev/null +++ b/custom_detectron2/modeling/sampling.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from custom_detectron2.layers import nonzero_tuple + +__all__ = ["subsample_labels"] + + +def subsample_labels( + labels: torch.Tensor, num_samples: int, positive_fraction: float, bg_label: int +): + """ + Return `num_samples` (or fewer, if not enough found) + random samples from `labels` which is a mixture of positives & negatives. + It will try to return as many positives as possible without + exceeding `positive_fraction * num_samples`, and then try to + fill the remaining slots with negatives. + + Args: + labels (Tensor): (N, ) label vector with values: + * -1: ignore + * bg_label: background ("negative") class + * otherwise: one or more foreground ("positive") classes + num_samples (int): The total number of labels with value >= 0 to return. + Values that are not sampled will be filled with -1 (ignore). + positive_fraction (float): The number of subsampled labels with values > 0 + is `min(num_positives, int(positive_fraction * num_samples))`. The number + of negatives sampled is `min(num_negatives, num_samples - num_positives_sampled)`. + In order words, if there are not enough positives, the sample is filled with + negatives. If there are also not enough negatives, then as many elements are + sampled as is possible. + bg_label (int): label index of background ("negative") class. + + Returns: + pos_idx, neg_idx (Tensor): + 1D vector of indices. The total length of both is `num_samples` or fewer. + """ + positive = nonzero_tuple((labels != -1) & (labels != bg_label))[0] + negative = nonzero_tuple(labels == bg_label)[0] + + num_pos = int(num_samples * positive_fraction) + # protect against not enough positive examples + num_pos = min(positive.numel(), num_pos) + num_neg = num_samples - num_pos + # protect against not enough negative examples + num_neg = min(negative.numel(), num_neg) + + # randomly select positive and negative examples + perm1 = torch.randperm(positive.numel(), device=positive.device)[:num_pos] + perm2 = torch.randperm(negative.numel(), device=negative.device)[:num_neg] + + pos_idx = positive[perm1] + neg_idx = negative[perm2] + return pos_idx, neg_idx diff --git a/custom_detectron2/modeling/test_time_augmentation.py b/custom_detectron2/modeling/test_time_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..3c3539057fb519885914bfe1838ff39a384a2e96 --- /dev/null +++ b/custom_detectron2/modeling/test_time_augmentation.py @@ -0,0 +1,307 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import numpy as np +from contextlib import contextmanager +from itertools import count +from typing import List +import torch +from fvcore.transforms import HFlipTransform, NoOpTransform +from torch import nn +from torch.nn.parallel import DistributedDataParallel + +from custom_detectron2.config import configurable +from custom_detectron2.data.detection_utils import read_image +from custom_detectron2.data.transforms import ( + RandomFlip, + ResizeShortestEdge, + ResizeTransform, + apply_augmentations, +) +from custom_detectron2.structures import Boxes, Instances + +from .meta_arch import GeneralizedRCNN +from .postprocessing import detector_postprocess +from .roi_heads.fast_rcnn import fast_rcnn_inference_single_image + +__all__ = ["DatasetMapperTTA", "GeneralizedRCNNWithTTA"] + + +class DatasetMapperTTA: + """ + Implement test-time augmentation for detection data. + It is a callable which takes a dataset dict from a detection dataset, + and returns a list of dataset dicts where the images + are augmented from the input image by the transformations defined in the config. + This is used for test-time augmentation. + """ + + @configurable + def __init__(self, min_sizes: List[int], max_size: int, flip: bool): + """ + Args: + min_sizes: list of short-edge size to resize the image to + max_size: maximum height or width of resized images + flip: whether to apply flipping augmentation + """ + self.min_sizes = min_sizes + self.max_size = max_size + self.flip = flip + + @classmethod + def from_config(cls, cfg): + return { + "min_sizes": cfg.TEST.AUG.MIN_SIZES, + "max_size": cfg.TEST.AUG.MAX_SIZE, + "flip": cfg.TEST.AUG.FLIP, + } + + def __call__(self, dataset_dict): + """ + Args: + dict: a dict in standard model input format. See tutorials for details. + + Returns: + list[dict]: + a list of dicts, which contain augmented version of the input image. + The total number of dicts is ``len(min_sizes) * (2 if flip else 1)``. + Each dict has field "transforms" which is a TransformList, + containing the transforms that are used to generate this image. + """ + numpy_image = dataset_dict["image"].permute(1, 2, 0).numpy() + shape = numpy_image.shape + orig_shape = (dataset_dict["height"], dataset_dict["width"]) + if shape[:2] != orig_shape: + # It transforms the "original" image in the dataset to the input image + pre_tfm = ResizeTransform(orig_shape[0], orig_shape[1], shape[0], shape[1]) + else: + pre_tfm = NoOpTransform() + + # Create all combinations of augmentations to use + aug_candidates = [] # each element is a list[Augmentation] + for min_size in self.min_sizes: + resize = ResizeShortestEdge(min_size, self.max_size) + aug_candidates.append([resize]) # resize only + if self.flip: + flip = RandomFlip(prob=1.0) + aug_candidates.append([resize, flip]) # resize + flip + + # Apply all the augmentations + ret = [] + for aug in aug_candidates: + new_image, tfms = apply_augmentations(aug, np.copy(numpy_image)) + torch_image = torch.from_numpy(np.ascontiguousarray(new_image.transpose(2, 0, 1))) + + dic = copy.deepcopy(dataset_dict) + dic["transforms"] = pre_tfm + tfms + dic["image"] = torch_image + ret.append(dic) + return ret + + +class GeneralizedRCNNWithTTA(nn.Module): + """ + A GeneralizedRCNN with test-time augmentation enabled. + Its :meth:`__call__` method has the same interface as :meth:`GeneralizedRCNN.forward`. + """ + + def __init__(self, cfg, model, tta_mapper=None, batch_size=3): + """ + Args: + cfg (CfgNode): + model (GeneralizedRCNN): a GeneralizedRCNN to apply TTA on. + tta_mapper (callable): takes a dataset dict and returns a list of + augmented versions of the dataset dict. Defaults to + `DatasetMapperTTA(cfg)`. + batch_size (int): batch the augmented images into this batch size for inference. + """ + super().__init__() + if isinstance(model, DistributedDataParallel): + model = model.module + assert isinstance( + model, GeneralizedRCNN + ), "TTA is only supported on GeneralizedRCNN. Got a model of type {}".format(type(model)) + self.cfg = cfg.clone() + assert not self.cfg.MODEL.KEYPOINT_ON, "TTA for keypoint is not supported yet" + assert ( + not self.cfg.MODEL.LOAD_PROPOSALS + ), "TTA for pre-computed proposals is not supported yet" + + self.model = model + + if tta_mapper is None: + tta_mapper = DatasetMapperTTA(cfg) + self.tta_mapper = tta_mapper + self.batch_size = batch_size + + @contextmanager + def _turn_off_roi_heads(self, attrs): + """ + Open a context where some heads in `model.roi_heads` are temporarily turned off. + Args: + attr (list[str]): the attribute in `model.roi_heads` which can be used + to turn off a specific head, e.g., "mask_on", "keypoint_on". + """ + roi_heads = self.model.roi_heads + old = {} + for attr in attrs: + try: + old[attr] = getattr(roi_heads, attr) + except AttributeError: + # The head may not be implemented in certain ROIHeads + pass + + if len(old.keys()) == 0: + yield + else: + for attr in old.keys(): + setattr(roi_heads, attr, False) + yield + for attr in old.keys(): + setattr(roi_heads, attr, old[attr]) + + def _batch_inference(self, batched_inputs, detected_instances=None): + """ + Execute inference on a list of inputs, + using batch size = self.batch_size, instead of the length of the list. + + Inputs & outputs have the same format as :meth:`GeneralizedRCNN.inference` + """ + if detected_instances is None: + detected_instances = [None] * len(batched_inputs) + + outputs = [] + inputs, instances = [], [] + for idx, input, instance in zip(count(), batched_inputs, detected_instances): + inputs.append(input) + instances.append(instance) + if len(inputs) == self.batch_size or idx == len(batched_inputs) - 1: + outputs.extend( + self.model.inference( + inputs, + instances if instances[0] is not None else None, + do_postprocess=False, + ) + ) + inputs, instances = [], [] + return outputs + + def __call__(self, batched_inputs): + """ + Same input/output format as :meth:`GeneralizedRCNN.forward` + """ + + def _maybe_read_image(dataset_dict): + ret = copy.copy(dataset_dict) + if "image" not in ret: + image = read_image(ret.pop("file_name"), self.model.input_format) + image = torch.from_numpy(np.ascontiguousarray(image.transpose(2, 0, 1))) # CHW + ret["image"] = image + if "height" not in ret and "width" not in ret: + ret["height"] = image.shape[1] + ret["width"] = image.shape[2] + return ret + + return [self._inference_one_image(_maybe_read_image(x)) for x in batched_inputs] + + def _inference_one_image(self, input): + """ + Args: + input (dict): one dataset dict with "image" field being a CHW tensor + + Returns: + dict: one output dict + """ + orig_shape = (input["height"], input["width"]) + augmented_inputs, tfms = self._get_augmented_inputs(input) + # Detect boxes from all augmented versions + with self._turn_off_roi_heads(["mask_on", "keypoint_on"]): + # temporarily disable roi heads + all_boxes, all_scores, all_classes = self._get_augmented_boxes(augmented_inputs, tfms) + # merge all detected boxes to obtain final predictions for boxes + merged_instances = self._merge_detections(all_boxes, all_scores, all_classes, orig_shape) + + if self.cfg.MODEL.MASK_ON: + # Use the detected boxes to obtain masks + augmented_instances = self._rescale_detected_boxes( + augmented_inputs, merged_instances, tfms + ) + # run forward on the detected boxes + outputs = self._batch_inference(augmented_inputs, augmented_instances) + # Delete now useless variables to avoid being out of memory + del augmented_inputs, augmented_instances + # average the predictions + merged_instances.pred_masks = self._reduce_pred_masks(outputs, tfms) + merged_instances = detector_postprocess(merged_instances, *orig_shape) + return {"instances": merged_instances} + else: + return {"instances": merged_instances} + + def _get_augmented_inputs(self, input): + augmented_inputs = self.tta_mapper(input) + tfms = [x.pop("transforms") for x in augmented_inputs] + return augmented_inputs, tfms + + def _get_augmented_boxes(self, augmented_inputs, tfms): + # 1: forward with all augmented images + outputs = self._batch_inference(augmented_inputs) + # 2: union the results + all_boxes = [] + all_scores = [] + all_classes = [] + for output, tfm in zip(outputs, tfms): + # Need to inverse the transforms on boxes, to obtain results on original image + pred_boxes = output.pred_boxes.tensor + original_pred_boxes = tfm.inverse().apply_box(pred_boxes.cpu().numpy()) + all_boxes.append(torch.from_numpy(original_pred_boxes).to(pred_boxes.device)) + + all_scores.extend(output.scores) + all_classes.extend(output.pred_classes) + all_boxes = torch.cat(all_boxes, dim=0) + return all_boxes, all_scores, all_classes + + def _merge_detections(self, all_boxes, all_scores, all_classes, shape_hw): + # select from the union of all results + num_boxes = len(all_boxes) + num_classes = self.cfg.MODEL.ROI_HEADS.NUM_CLASSES + # +1 because fast_rcnn_inference expects background scores as well + all_scores_2d = torch.zeros(num_boxes, num_classes + 1, device=all_boxes.device) + for idx, cls, score in zip(count(), all_classes, all_scores): + all_scores_2d[idx, cls] = score + + merged_instances, _ = fast_rcnn_inference_single_image( + all_boxes, + all_scores_2d, + shape_hw, + 1e-8, + self.cfg.MODEL.ROI_HEADS.NMS_THRESH_TEST, + self.cfg.TEST.DETECTIONS_PER_IMAGE, + ) + + return merged_instances + + def _rescale_detected_boxes(self, augmented_inputs, merged_instances, tfms): + augmented_instances = [] + for input, tfm in zip(augmented_inputs, tfms): + # Transform the target box to the augmented image's coordinate space + pred_boxes = merged_instances.pred_boxes.tensor.cpu().numpy() + pred_boxes = torch.from_numpy(tfm.apply_box(pred_boxes)) + + aug_instances = Instances( + image_size=input["image"].shape[1:3], + pred_boxes=Boxes(pred_boxes), + pred_classes=merged_instances.pred_classes, + scores=merged_instances.scores, + ) + augmented_instances.append(aug_instances) + return augmented_instances + + def _reduce_pred_masks(self, outputs, tfms): + # Should apply inverse transforms on masks. + # We assume only resize & flip are used. pred_masks is a scale-invariant + # representation, so we handle flip specially + for output, tfm in zip(outputs, tfms): + if any(isinstance(t, HFlipTransform) for t in tfm.transforms): + output.pred_masks = output.pred_masks.flip(dims=[3]) + all_pred_masks = torch.stack([o.pred_masks for o in outputs], dim=0) + avg_pred_masks = torch.mean(all_pred_masks, dim=0) + return avg_pred_masks diff --git a/custom_detectron2/projects/README.md b/custom_detectron2/projects/README.md new file mode 100644 index 0000000000000000000000000000000000000000..95afe7ff8c8a9bd2f56621fcc3c1bdac11c256a9 --- /dev/null +++ b/custom_detectron2/projects/README.md @@ -0,0 +1,2 @@ + +Projects live in the [`projects` directory](../../projects) under the root of this repository, but not here. diff --git a/custom_detectron2/projects/__init__.py b/custom_detectron2/projects/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b2d0540b93ebbad78d6ff2cc0adc0fe8375816c2 --- /dev/null +++ b/custom_detectron2/projects/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib.abc +import importlib.util +from pathlib import Path + +__all__ = [] + +_PROJECTS = { + "point_rend": "PointRend", + "deeplab": "DeepLab", + "panoptic_deeplab": "Panoptic-DeepLab", +} +_PROJECT_ROOT = Path(__file__).resolve().parent.parent.parent / "projects" + +if _PROJECT_ROOT.is_dir(): + # This is true only for in-place installation (pip install -e, setup.py develop), + # where setup(package_dir=) does not work: https://github.com/pypa/setuptools/issues/230 + + class _D2ProjectsFinder(importlib.abc.MetaPathFinder): + def find_spec(self, name, path, target=None): + if not name.startswith("detectron2.projects."): + return + project_name = name.split(".")[-1] + project_dir = _PROJECTS.get(project_name) + if not project_dir: + return + target_file = _PROJECT_ROOT / f"{project_dir}/{project_name}/__init__.py" + if not target_file.is_file(): + return + return importlib.util.spec_from_file_location(name, target_file) + + import sys + + sys.meta_path.append(_D2ProjectsFinder()) diff --git a/custom_detectron2/projects/deeplab/__init__.py b/custom_detectron2/projects/deeplab/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd88ff0c09d630577e3ac9f8afb5324a80a7be4 --- /dev/null +++ b/custom_detectron2/projects/deeplab/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build_solver import build_lr_scheduler +from .config import add_deeplab_config +from .resnet import build_resnet_deeplab_backbone +from .semantic_seg import DeepLabV3Head, DeepLabV3PlusHead diff --git a/custom_detectron2/projects/deeplab/build_solver.py b/custom_detectron2/projects/deeplab/build_solver.py new file mode 100644 index 0000000000000000000000000000000000000000..bc81cb27493af340dde2bd09ac828ba47b272f70 --- /dev/null +++ b/custom_detectron2/projects/deeplab/build_solver.py @@ -0,0 +1,27 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch + +from custom_detectron2.config import CfgNode +from custom_detectron2.solver import LRScheduler +from custom_detectron2.solver import build_lr_scheduler as build_d2_lr_scheduler + +from .lr_scheduler import WarmupPolyLR + + +def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler: + """ + Build a LR scheduler from config. + """ + name = cfg.SOLVER.LR_SCHEDULER_NAME + if name == "WarmupPolyLR": + return WarmupPolyLR( + optimizer, + cfg.SOLVER.MAX_ITER, + warmup_factor=cfg.SOLVER.WARMUP_FACTOR, + warmup_iters=cfg.SOLVER.WARMUP_ITERS, + warmup_method=cfg.SOLVER.WARMUP_METHOD, + power=cfg.SOLVER.POLY_LR_POWER, + constant_ending=cfg.SOLVER.POLY_LR_CONSTANT_ENDING, + ) + else: + return build_d2_lr_scheduler(cfg, optimizer) diff --git a/custom_detectron2/projects/deeplab/config.py b/custom_detectron2/projects/deeplab/config.py new file mode 100644 index 0000000000000000000000000000000000000000..5f5e45a9124e61c12d90cfc5032b268496891a4a --- /dev/null +++ b/custom_detectron2/projects/deeplab/config.py @@ -0,0 +1,28 @@ +# -*- coding: utf-8 -*- +# Copyright (c) Facebook, Inc. and its affiliates. + + +def add_deeplab_config(cfg): + """ + Add config for DeepLab. + """ + # We retry random cropping until no single category in semantic segmentation GT occupies more + # than `SINGLE_CATEGORY_MAX_AREA` part of the crop. + cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0 + # Used for `poly` learning rate schedule. + cfg.SOLVER.POLY_LR_POWER = 0.9 + cfg.SOLVER.POLY_LR_CONSTANT_ENDING = 0.0 + # Loss type, choose from `cross_entropy`, `hard_pixel_mining`. + cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE = "hard_pixel_mining" + # DeepLab settings + cfg.MODEL.SEM_SEG_HEAD.PROJECT_FEATURES = ["res2"] + cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS = [48] + cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS = 256 + cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS = [6, 12, 18] + cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT = 0.1 + cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV = False + # Backbone new configs + cfg.MODEL.RESNETS.RES4_DILATION = 1 + cfg.MODEL.RESNETS.RES5_MULTI_GRID = [1, 2, 4] + # ResNet stem type from: `basic`, `deeplab` + cfg.MODEL.RESNETS.STEM_TYPE = "deeplab" diff --git a/custom_detectron2/projects/deeplab/loss.py b/custom_detectron2/projects/deeplab/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..3a43087b7c1a2b4d2b249fad117724dbd0f14fdd --- /dev/null +++ b/custom_detectron2/projects/deeplab/loss.py @@ -0,0 +1,40 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import torch +import torch.nn as nn + + +class DeepLabCE(nn.Module): + """ + Hard pixel mining with cross entropy loss, for semantic segmentation. + This is used in TensorFlow DeepLab frameworks. + Paper: DeeperLab: Single-Shot Image Parser + Reference: https://github.com/tensorflow/models/blob/bd488858d610e44df69da6f89277e9de8a03722c/research/deeplab/utils/train_utils.py#L33 # noqa + Arguments: + ignore_label: Integer, label to ignore. + top_k_percent_pixels: Float, the value lies in [0.0, 1.0]. When its + value < 1.0, only compute the loss for the top k percent pixels + (e.g., the top 20% pixels). This is useful for hard pixel mining. + weight: Tensor, a manual rescaling weight given to each class. + """ + + def __init__(self, ignore_label=-1, top_k_percent_pixels=1.0, weight=None): + super(DeepLabCE, self).__init__() + self.top_k_percent_pixels = top_k_percent_pixels + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, ignore_index=ignore_label, reduction="none" + ) + + def forward(self, logits, labels, weights=None): + if weights is None: + pixel_losses = self.criterion(logits, labels).contiguous().view(-1) + else: + # Apply per-pixel loss weights. + pixel_losses = self.criterion(logits, labels) * weights + pixel_losses = pixel_losses.contiguous().view(-1) + if self.top_k_percent_pixels == 1.0: + return pixel_losses.mean() + + top_k_pixels = int(self.top_k_percent_pixels * pixel_losses.numel()) + pixel_losses, _ = torch.topk(pixel_losses, top_k_pixels) + return pixel_losses.mean() diff --git a/custom_detectron2/projects/deeplab/lr_scheduler.py b/custom_detectron2/projects/deeplab/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1f9ad8fededad41d9bb0ee2314532fb2bcc03b39 --- /dev/null +++ b/custom_detectron2/projects/deeplab/lr_scheduler.py @@ -0,0 +1,62 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List +import torch + +from custom_detectron2.solver.lr_scheduler import LRScheduler, _get_warmup_factor_at_iter + +# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes +# only on epoch boundaries. We typically use iteration based schedules instead. +# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean +# "iteration" instead. + +# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating +# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. + + +class WarmupPolyLR(LRScheduler): + """ + Poly learning rate schedule used to train DeepLab. + Paper: DeepLab: Semantic Image Segmentation with Deep Convolutional Nets, + Atrous Convolution, and Fully Connected CRFs. + Reference: https://github.com/tensorflow/models/blob/21b73d22f3ed05b650e85ac50849408dd36de32e/research/deeplab/utils/train_utils.py#L337 # noqa + """ + + def __init__( + self, + optimizer: torch.optim.Optimizer, + max_iters: int, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + power: float = 0.9, + constant_ending: float = 0.0, + ): + self.max_iters = max_iters + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + self.power = power + self.constant_ending = constant_ending + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + if self.constant_ending > 0 and warmup_factor == 1.0: + # Constant ending lr. + if ( + math.pow((1.0 - self.last_epoch / self.max_iters), self.power) + < self.constant_ending + ): + return [base_lr * self.constant_ending for base_lr in self.base_lrs] + return [ + base_lr * warmup_factor * math.pow((1.0 - self.last_epoch / self.max_iters), self.power) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() diff --git a/custom_detectron2/projects/deeplab/resnet.py b/custom_detectron2/projects/deeplab/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..08bc22557b2e0d630548790610cc4c461fe04bdc --- /dev/null +++ b/custom_detectron2/projects/deeplab/resnet.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import fvcore.nn.weight_init as weight_init +import torch.nn.functional as F + +from custom_detectron2.layers import CNNBlockBase, Conv2d, get_norm +from custom_detectron2.modeling import BACKBONE_REGISTRY +from custom_detectron2.modeling.backbone.resnet import ( + BasicStem, + BottleneckBlock, + DeformBottleneckBlock, + ResNet, +) + + +class DeepLabStem(CNNBlockBase): + """ + The DeepLab ResNet stem (layers before the first residual block). + """ + + def __init__(self, in_channels=3, out_channels=128, norm="BN"): + """ + Args: + norm (str or callable): norm after the first conv layer. + See :func:`layers.get_norm` for supported format. + """ + super().__init__(in_channels, out_channels, 4) + self.in_channels = in_channels + self.conv1 = Conv2d( + in_channels, + out_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False, + norm=get_norm(norm, out_channels // 2), + ) + self.conv2 = Conv2d( + out_channels // 2, + out_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels // 2), + ) + self.conv3 = Conv2d( + out_channels // 2, + out_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False, + norm=get_norm(norm, out_channels), + ) + weight_init.c2_msra_fill(self.conv1) + weight_init.c2_msra_fill(self.conv2) + weight_init.c2_msra_fill(self.conv3) + + def forward(self, x): + x = self.conv1(x) + x = F.relu_(x) + x = self.conv2(x) + x = F.relu_(x) + x = self.conv3(x) + x = F.relu_(x) + x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1) + return x + + +@BACKBONE_REGISTRY.register() +def build_resnet_deeplab_backbone(cfg, input_shape): + """ + Create a ResNet instance from config. + Returns: + ResNet: a :class:`ResNet` instance. + """ + # need registration of new blocks/stems? + norm = cfg.MODEL.RESNETS.NORM + if cfg.MODEL.RESNETS.STEM_TYPE == "basic": + stem = BasicStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + elif cfg.MODEL.RESNETS.STEM_TYPE == "deeplab": + stem = DeepLabStem( + in_channels=input_shape.channels, + out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS, + norm=norm, + ) + else: + raise ValueError("Unknown stem type: {}".format(cfg.MODEL.RESNETS.STEM_TYPE)) + + # fmt: off + freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT + out_features = cfg.MODEL.RESNETS.OUT_FEATURES + depth = cfg.MODEL.RESNETS.DEPTH + num_groups = cfg.MODEL.RESNETS.NUM_GROUPS + width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP + bottleneck_channels = num_groups * width_per_group + in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS + out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS + stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1 + res4_dilation = cfg.MODEL.RESNETS.RES4_DILATION + res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION + deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE + deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED + deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS + res5_multi_grid = cfg.MODEL.RESNETS.RES5_MULTI_GRID + # fmt: on + assert res4_dilation in {1, 2}, "res4_dilation cannot be {}.".format(res4_dilation) + assert res5_dilation in {1, 2, 4}, "res5_dilation cannot be {}.".format(res5_dilation) + if res4_dilation == 2: + # Always dilate res5 if res4 is dilated. + assert res5_dilation == 4 + + num_blocks_per_stage = {50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3]}[depth] + + stages = [] + + # Avoid creating variables without gradients + # It consumes extra memory and may cause allreduce to fail + out_stage_idx = [{"res2": 2, "res3": 3, "res4": 4, "res5": 5}[f] for f in out_features] + max_stage_idx = max(out_stage_idx) + for idx, stage_idx in enumerate(range(2, max_stage_idx + 1)): + if stage_idx == 4: + dilation = res4_dilation + elif stage_idx == 5: + dilation = res5_dilation + else: + dilation = 1 + first_stride = 1 if idx == 0 or dilation > 1 else 2 + stage_kargs = { + "num_blocks": num_blocks_per_stage[idx], + "stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1), + "in_channels": in_channels, + "out_channels": out_channels, + "norm": norm, + } + stage_kargs["bottleneck_channels"] = bottleneck_channels + stage_kargs["stride_in_1x1"] = stride_in_1x1 + stage_kargs["dilation"] = dilation + stage_kargs["num_groups"] = num_groups + if deform_on_per_stage[idx]: + stage_kargs["block_class"] = DeformBottleneckBlock + stage_kargs["deform_modulated"] = deform_modulated + stage_kargs["deform_num_groups"] = deform_num_groups + else: + stage_kargs["block_class"] = BottleneckBlock + if stage_idx == 5: + stage_kargs.pop("dilation") + stage_kargs["dilation_per_block"] = [dilation * mg for mg in res5_multi_grid] + blocks = ResNet.make_stage(**stage_kargs) + in_channels = out_channels + out_channels *= 2 + bottleneck_channels *= 2 + stages.append(blocks) + return ResNet(stem, stages, out_features=out_features).freeze(freeze_at) diff --git a/custom_detectron2/projects/deeplab/semantic_seg.py b/custom_detectron2/projects/deeplab/semantic_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad38aa610ebc2f0ffcab914aa20558fffef9576 --- /dev/null +++ b/custom_detectron2/projects/deeplab/semantic_seg.py @@ -0,0 +1,348 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from typing import Callable, Dict, List, Optional, Tuple, Union +import fvcore.nn.weight_init as weight_init +import torch +from torch import nn +from torch.nn import functional as F + +from custom_detectron2.config import configurable +from custom_detectron2.layers import ASPP, Conv2d, DepthwiseSeparableConv2d, ShapeSpec, get_norm +from custom_detectron2.modeling import SEM_SEG_HEADS_REGISTRY + +from .loss import DeepLabCE + + +@SEM_SEG_HEADS_REGISTRY.register() +class DeepLabV3PlusHead(nn.Module): + """ + A semantic segmentation head described in :paper:`DeepLabV3+`. + """ + + @configurable + def __init__( + self, + input_shape: Dict[str, ShapeSpec], + *, + project_channels: List[int], + aspp_dilations: List[int], + aspp_dropout: float, + decoder_channels: List[int], + common_stride: int, + norm: Union[str, Callable], + train_size: Optional[Tuple], + loss_weight: float = 1.0, + loss_type: str = "cross_entropy", + ignore_value: int = -1, + num_classes: Optional[int] = None, + use_depthwise_separable_conv: bool = False, + ): + """ + NOTE: this interface is experimental. + + Args: + input_shape: shape of the input features. They will be ordered by stride + and the last one (with largest stride) is used as the input to the + decoder (i.e. the ASPP module); the rest are low-level feature for + the intermediate levels of decoder. + project_channels (list[int]): a list of low-level feature channels. + The length should be len(in_features) - 1. + aspp_dilations (list(int)): a list of 3 dilations in ASPP. + aspp_dropout (float): apply dropout on the output of ASPP. + decoder_channels (list[int]): a list of output channels of each + decoder stage. It should have the same length as "in_features" + (each element in "in_features" corresponds to one decoder stage). + common_stride (int): output stride of decoder. + norm (str or callable): normalization for all conv layers. + train_size (tuple): (height, width) of training images. + loss_weight (float): loss weight. + loss_type (str): type of loss function, 2 opptions: + (1) "cross_entropy" is the standard cross entropy loss. + (2) "hard_pixel_mining" is the loss in DeepLab that samples + top k% hardest pixels. + ignore_value (int): category to be ignored during training. + num_classes (int): number of classes, if set to None, the decoder + will not construct a predictor. + use_depthwise_separable_conv (bool): use DepthwiseSeparableConv2d + in ASPP and decoder. + """ + super().__init__() + input_shape = sorted(input_shape.items(), key=lambda x: x[1].stride) + + # fmt: off + self.in_features = [k for k, v in input_shape] # starting from "res2" to "res5" + in_channels = [x[1].channels for x in input_shape] + in_strides = [x[1].stride for x in input_shape] + aspp_channels = decoder_channels[-1] + self.ignore_value = ignore_value + self.common_stride = common_stride # output stride + self.loss_weight = loss_weight + self.loss_type = loss_type + self.decoder_only = num_classes is None + self.use_depthwise_separable_conv = use_depthwise_separable_conv + # fmt: on + + assert ( + len(project_channels) == len(self.in_features) - 1 + ), "Expected {} project_channels, got {}".format( + len(self.in_features) - 1, len(project_channels) + ) + assert len(decoder_channels) == len( + self.in_features + ), "Expected {} decoder_channels, got {}".format( + len(self.in_features), len(decoder_channels) + ) + self.decoder = nn.ModuleDict() + + use_bias = norm == "" + for idx, in_channel in enumerate(in_channels): + decoder_stage = nn.ModuleDict() + + if idx == len(self.in_features) - 1: + # ASPP module + if train_size is not None: + train_h, train_w = train_size + encoder_stride = in_strides[-1] + if train_h % encoder_stride or train_w % encoder_stride: + raise ValueError("Crop size need to be divisible by encoder stride.") + pool_h = train_h // encoder_stride + pool_w = train_w // encoder_stride + pool_kernel_size = (pool_h, pool_w) + else: + pool_kernel_size = None + project_conv = ASPP( + in_channel, + aspp_channels, + aspp_dilations, + norm=norm, + activation=F.relu, + pool_kernel_size=pool_kernel_size, + dropout=aspp_dropout, + use_depthwise_separable_conv=use_depthwise_separable_conv, + ) + fuse_conv = None + else: + project_conv = Conv2d( + in_channel, + project_channels[idx], + kernel_size=1, + bias=use_bias, + norm=get_norm(norm, project_channels[idx]), + activation=F.relu, + ) + weight_init.c2_xavier_fill(project_conv) + if use_depthwise_separable_conv: + # We use a single 5x5 DepthwiseSeparableConv2d to replace + # 2 3x3 Conv2d since they have the same receptive field, + # proposed in :paper:`Panoptic-DeepLab`. + fuse_conv = DepthwiseSeparableConv2d( + project_channels[idx] + decoder_channels[idx + 1], + decoder_channels[idx], + kernel_size=5, + padding=2, + norm1=norm, + activation1=F.relu, + norm2=norm, + activation2=F.relu, + ) + else: + fuse_conv = nn.Sequential( + Conv2d( + project_channels[idx] + decoder_channels[idx + 1], + decoder_channels[idx], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[idx]), + activation=F.relu, + ), + Conv2d( + decoder_channels[idx], + decoder_channels[idx], + kernel_size=3, + padding=1, + bias=use_bias, + norm=get_norm(norm, decoder_channels[idx]), + activation=F.relu, + ), + ) + weight_init.c2_xavier_fill(fuse_conv[0]) + weight_init.c2_xavier_fill(fuse_conv[1]) + + decoder_stage["project_conv"] = project_conv + decoder_stage["fuse_conv"] = fuse_conv + + self.decoder[self.in_features[idx]] = decoder_stage + + if not self.decoder_only: + self.predictor = Conv2d( + decoder_channels[0], num_classes, kernel_size=1, stride=1, padding=0 + ) + nn.init.normal_(self.predictor.weight, 0, 0.001) + nn.init.constant_(self.predictor.bias, 0) + + if self.loss_type == "cross_entropy": + self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value) + elif self.loss_type == "hard_pixel_mining": + self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2) + else: + raise ValueError("Unexpected loss type: %s" % self.loss_type) + + @classmethod + def from_config(cls, cfg, input_shape): + if cfg.INPUT.CROP.ENABLED: + assert cfg.INPUT.CROP.TYPE == "absolute" + train_size = cfg.INPUT.CROP.SIZE + else: + train_size = None + decoder_channels = [cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM] * ( + len(cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES) - 1 + ) + [cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS] + ret = dict( + input_shape={ + k: v for k, v in input_shape.items() if k in cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + }, + project_channels=cfg.MODEL.SEM_SEG_HEAD.PROJECT_CHANNELS, + aspp_dilations=cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS, + aspp_dropout=cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT, + decoder_channels=decoder_channels, + common_stride=cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE, + norm=cfg.MODEL.SEM_SEG_HEAD.NORM, + train_size=train_size, + loss_weight=cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT, + loss_type=cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE, + ignore_value=cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE, + num_classes=cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES, + use_depthwise_separable_conv=cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV, + ) + return ret + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + y = self.layers(features) + if self.decoder_only: + # Output from self.layers() only contains decoder feature. + return y + if self.training: + return None, self.losses(y, targets) + else: + y = F.interpolate( + y, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return y, {} + + def layers(self, features): + # Reverse feature maps into top-down order (from low to high resolution) + for f in self.in_features[::-1]: + x = features[f] + proj_x = self.decoder[f]["project_conv"](x) + if self.decoder[f]["fuse_conv"] is None: + # This is aspp module + y = proj_x + else: + # Upsample y + y = F.interpolate(y, size=proj_x.size()[2:], mode="bilinear", align_corners=False) + y = torch.cat([proj_x, y], dim=1) + y = self.decoder[f]["fuse_conv"](y) + if not self.decoder_only: + y = self.predictor(y) + return y + + def losses(self, predictions, targets): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.loss(predictions, targets) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses + + +@SEM_SEG_HEADS_REGISTRY.register() +class DeepLabV3Head(nn.Module): + """ + A semantic segmentation head described in :paper:`DeepLabV3`. + """ + + def __init__(self, cfg, input_shape: Dict[str, ShapeSpec]): + super().__init__() + + # fmt: off + self.in_features = cfg.MODEL.SEM_SEG_HEAD.IN_FEATURES + in_channels = [input_shape[f].channels for f in self.in_features] + aspp_channels = cfg.MODEL.SEM_SEG_HEAD.ASPP_CHANNELS + aspp_dilations = cfg.MODEL.SEM_SEG_HEAD.ASPP_DILATIONS + self.ignore_value = cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE + num_classes = cfg.MODEL.SEM_SEG_HEAD.NUM_CLASSES + conv_dims = cfg.MODEL.SEM_SEG_HEAD.CONVS_DIM + self.common_stride = cfg.MODEL.SEM_SEG_HEAD.COMMON_STRIDE # output stride + norm = cfg.MODEL.SEM_SEG_HEAD.NORM + self.loss_weight = cfg.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT + self.loss_type = cfg.MODEL.SEM_SEG_HEAD.LOSS_TYPE + train_crop_size = cfg.INPUT.CROP.SIZE + aspp_dropout = cfg.MODEL.SEM_SEG_HEAD.ASPP_DROPOUT + use_depthwise_separable_conv = cfg.MODEL.SEM_SEG_HEAD.USE_DEPTHWISE_SEPARABLE_CONV + # fmt: on + + assert len(self.in_features) == 1 + assert len(in_channels) == 1 + + # ASPP module + if cfg.INPUT.CROP.ENABLED: + assert cfg.INPUT.CROP.TYPE == "absolute" + train_crop_h, train_crop_w = train_crop_size + if train_crop_h % self.common_stride or train_crop_w % self.common_stride: + raise ValueError("Crop size need to be divisible by output stride.") + pool_h = train_crop_h // self.common_stride + pool_w = train_crop_w // self.common_stride + pool_kernel_size = (pool_h, pool_w) + else: + pool_kernel_size = None + self.aspp = ASPP( + in_channels[0], + aspp_channels, + aspp_dilations, + norm=norm, + activation=F.relu, + pool_kernel_size=pool_kernel_size, + dropout=aspp_dropout, + use_depthwise_separable_conv=use_depthwise_separable_conv, + ) + + self.predictor = Conv2d(conv_dims, num_classes, kernel_size=1, stride=1, padding=0) + nn.init.normal_(self.predictor.weight, 0, 0.001) + nn.init.constant_(self.predictor.bias, 0) + + if self.loss_type == "cross_entropy": + self.loss = nn.CrossEntropyLoss(reduction="mean", ignore_index=self.ignore_value) + elif self.loss_type == "hard_pixel_mining": + self.loss = DeepLabCE(ignore_label=self.ignore_value, top_k_percent_pixels=0.2) + else: + raise ValueError("Unexpected loss type: %s" % self.loss_type) + + def forward(self, features, targets=None): + """ + Returns: + In training, returns (None, dict of losses) + In inference, returns (CxHxW logits, {}) + """ + x = features[self.in_features[0]] + x = self.aspp(x) + x = self.predictor(x) + if self.training: + return None, self.losses(x, targets) + else: + x = F.interpolate( + x, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + return x, {} + + def losses(self, predictions, targets): + predictions = F.interpolate( + predictions, scale_factor=self.common_stride, mode="bilinear", align_corners=False + ) + loss = self.loss(predictions, targets) + losses = {"loss_sem_seg": loss * self.loss_weight} + return losses diff --git a/custom_detectron2/solver/__init__.py b/custom_detectron2/solver/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e36c64f60f38f41d01dd2c9fb30364489a03841 --- /dev/null +++ b/custom_detectron2/solver/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .build import build_lr_scheduler, build_optimizer, get_default_optimizer_params +from .lr_scheduler import ( + LRMultiplier, + LRScheduler, + WarmupCosineLR, + WarmupMultiStepLR, + WarmupParamScheduler, +) + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/solver/build.py b/custom_detectron2/solver/build.py new file mode 100644 index 0000000000000000000000000000000000000000..5630936e9d52102e07029a685f79668476b6b141 --- /dev/null +++ b/custom_detectron2/solver/build.py @@ -0,0 +1,310 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import logging +from collections import defaultdict +from enum import Enum +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union +import torch +from fvcore.common.param_scheduler import ( + CosineParamScheduler, + MultiStepParamScheduler, + StepWithFixedGammaParamScheduler, +) + +from custom_detectron2.config import CfgNode +from custom_detectron2.utils.env import TORCH_VERSION + +from .lr_scheduler import LRMultiplier, LRScheduler, WarmupParamScheduler + +_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]] +_GradientClipper = Callable[[_GradientClipperInput], None] + + +class GradientClipType(Enum): + VALUE = "value" + NORM = "norm" + + +def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper: + """ + Creates gradient clipping closure to clip by value or by norm, + according to the provided config. + """ + cfg = copy.deepcopy(cfg) + + def clip_grad_norm(p: _GradientClipperInput): + torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE) + + def clip_grad_value(p: _GradientClipperInput): + torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE) + + _GRADIENT_CLIP_TYPE_TO_CLIPPER = { + GradientClipType.VALUE: clip_grad_value, + GradientClipType.NORM: clip_grad_norm, + } + return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)] + + +def _generate_optimizer_class_with_gradient_clipping( + optimizer: Type[torch.optim.Optimizer], + *, + per_param_clipper: Optional[_GradientClipper] = None, + global_clipper: Optional[_GradientClipper] = None, +) -> Type[torch.optim.Optimizer]: + """ + Dynamically creates a new type that inherits the type of a given instance + and overrides the `step` method to add gradient clipping + """ + assert ( + per_param_clipper is None or global_clipper is None + ), "Not allowed to use both per-parameter clipping and global clipping" + + def optimizer_wgc_step(self, closure=None): + if per_param_clipper is not None: + for group in self.param_groups: + for p in group["params"]: + per_param_clipper(p) + else: + # global clipper for future use with detr + # (https://github.com/facebookresearch/detr/pull/287) + all_params = itertools.chain(*[g["params"] for g in self.param_groups]) + global_clipper(all_params) + super(type(self), self).step(closure) + + OptimizerWithGradientClip = type( + optimizer.__name__ + "WithGradientClip", + (optimizer,), + {"step": optimizer_wgc_step}, + ) + return OptimizerWithGradientClip + + +def maybe_add_gradient_clipping( + cfg: CfgNode, optimizer: Type[torch.optim.Optimizer] +) -> Type[torch.optim.Optimizer]: + """ + If gradient clipping is enabled through config options, wraps the existing + optimizer type to become a new dynamically created class OptimizerWithGradientClip + that inherits the given optimizer and overrides the `step` method to + include gradient clipping. + + Args: + cfg: CfgNode, configuration options + optimizer: type. A subclass of torch.optim.Optimizer + + Return: + type: either the input `optimizer` (if gradient clipping is disabled), or + a subclass of it with gradient clipping included in the `step` method. + """ + if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED: + return optimizer + if isinstance(optimizer, torch.optim.Optimizer): + optimizer_type = type(optimizer) + else: + assert issubclass(optimizer, torch.optim.Optimizer), optimizer + optimizer_type = optimizer + + grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS) + OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping( + optimizer_type, per_param_clipper=grad_clipper + ) + if isinstance(optimizer, torch.optim.Optimizer): + optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended + return optimizer + else: + return OptimizerWithGradientClip + + +def build_optimizer(cfg: CfgNode, model: torch.nn.Module) -> torch.optim.Optimizer: + """ + Build an optimizer from config. + """ + params = get_default_optimizer_params( + model, + base_lr=cfg.SOLVER.BASE_LR, + weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, + bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, + weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, + ) + sgd_args = { + "params": params, + "lr": cfg.SOLVER.BASE_LR, + "momentum": cfg.SOLVER.MOMENTUM, + "nesterov": cfg.SOLVER.NESTEROV, + "weight_decay": cfg.SOLVER.WEIGHT_DECAY, + } + if TORCH_VERSION >= (1, 12): + sgd_args["foreach"] = True + return maybe_add_gradient_clipping(cfg, torch.optim.SGD(**sgd_args)) + + +def get_default_optimizer_params( + model: torch.nn.Module, + base_lr: Optional[float] = None, + weight_decay: Optional[float] = None, + weight_decay_norm: Optional[float] = None, + bias_lr_factor: Optional[float] = 1.0, + weight_decay_bias: Optional[float] = None, + lr_factor_func: Optional[Callable] = None, + overrides: Optional[Dict[str, Dict[str, float]]] = None, +) -> List[Dict[str, Any]]: + """ + Get default param list for optimizer, with support for a few types of + overrides. If no overrides needed, this is equivalent to `model.parameters()`. + + Args: + base_lr: lr for every group by default. Can be omitted to use the one in optimizer. + weight_decay: weight decay for every group by default. Can be omitted to use the one + in optimizer. + weight_decay_norm: override weight decay for params in normalization layers + bias_lr_factor: multiplier of lr for bias parameters. + weight_decay_bias: override weight decay for bias parameters. + lr_factor_func: function to calculate lr decay rate by mapping the parameter names to + corresponding lr decay rate. Note that setting this option requires + also setting ``base_lr``. + overrides: if not `None`, provides values for optimizer hyperparameters + (LR, weight decay) for module parameters with a given name; e.g. + ``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and + weight decay values for all module parameters named `embedding`. + + For common detection models, ``weight_decay_norm`` is the only option + needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings + from Detectron1 that are not found useful. + + Example: + :: + torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), + lr=0.01, weight_decay=1e-4, momentum=0.9) + """ + if overrides is None: + overrides = {} + defaults = {} + if base_lr is not None: + defaults["lr"] = base_lr + if weight_decay is not None: + defaults["weight_decay"] = weight_decay + bias_overrides = {} + if bias_lr_factor is not None and bias_lr_factor != 1.0: + # NOTE: unlike Detectron v1, we now by default make bias hyperparameters + # exactly the same as regular weights. + if base_lr is None: + raise ValueError("bias_lr_factor requires base_lr") + bias_overrides["lr"] = base_lr * bias_lr_factor + if weight_decay_bias is not None: + bias_overrides["weight_decay"] = weight_decay_bias + if len(bias_overrides): + if "bias" in overrides: + raise ValueError("Conflicting overrides for 'bias'") + overrides["bias"] = bias_overrides + if lr_factor_func is not None: + if base_lr is None: + raise ValueError("lr_factor_func requires base_lr") + norm_module_types = ( + torch.nn.BatchNorm1d, + torch.nn.BatchNorm2d, + torch.nn.BatchNorm3d, + torch.nn.SyncBatchNorm, + # NaiveSyncBatchNorm inherits from BatchNorm2d + torch.nn.GroupNorm, + torch.nn.InstanceNorm1d, + torch.nn.InstanceNorm2d, + torch.nn.InstanceNorm3d, + torch.nn.LayerNorm, + torch.nn.LocalResponseNorm, + ) + params: List[Dict[str, Any]] = [] + memo: Set[torch.nn.parameter.Parameter] = set() + for module_name, module in model.named_modules(): + for module_param_name, value in module.named_parameters(recurse=False): + if not value.requires_grad: + continue + # Avoid duplicating parameters + if value in memo: + continue + memo.add(value) + + hyperparams = copy.copy(defaults) + if isinstance(module, norm_module_types) and weight_decay_norm is not None: + hyperparams["weight_decay"] = weight_decay_norm + if lr_factor_func is not None: + hyperparams["lr"] *= lr_factor_func(f"{module_name}.{module_param_name}") + + hyperparams.update(overrides.get(module_param_name, {})) + params.append({"params": [value], **hyperparams}) + return reduce_param_groups(params) + + +def _expand_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # Transform parameter groups into per-parameter structure. + # Later items in `params` can overwrite parameters set in previous items. + ret = defaultdict(dict) + for item in params: + assert "params" in item + cur_params = {x: y for x, y in item.items() if x != "params"} + for param in item["params"]: + ret[param].update({"params": [param], **cur_params}) + return list(ret.values()) + + +def reduce_param_groups(params: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + # Reorganize the parameter groups and merge duplicated groups. + # The number of parameter groups needs to be as small as possible in order + # to efficiently use the PyTorch multi-tensor optimizer. Therefore instead + # of using a parameter_group per single parameter, we reorganize the + # parameter groups and merge duplicated groups. This approach speeds + # up multi-tensor optimizer significantly. + params = _expand_param_groups(params) + groups = defaultdict(list) # re-group all parameter groups by their hyperparams + for item in params: + cur_params = tuple((x, y) for x, y in item.items() if x != "params") + groups[cur_params].extend(item["params"]) + ret = [] + for param_keys, param_values in groups.items(): + cur = {kv[0]: kv[1] for kv in param_keys} + cur["params"] = param_values + ret.append(cur) + return ret + + +def build_lr_scheduler(cfg: CfgNode, optimizer: torch.optim.Optimizer) -> LRScheduler: + """ + Build a LR scheduler from config. + """ + name = cfg.SOLVER.LR_SCHEDULER_NAME + + if name == "WarmupMultiStepLR": + steps = [x for x in cfg.SOLVER.STEPS if x <= cfg.SOLVER.MAX_ITER] + if len(steps) != len(cfg.SOLVER.STEPS): + logger = logging.getLogger(__name__) + logger.warning( + "SOLVER.STEPS contains values larger than SOLVER.MAX_ITER. " + "These values will be ignored." + ) + sched = MultiStepParamScheduler( + values=[cfg.SOLVER.GAMMA**k for k in range(len(steps) + 1)], + milestones=steps, + num_updates=cfg.SOLVER.MAX_ITER, + ) + elif name == "WarmupCosineLR": + end_value = cfg.SOLVER.BASE_LR_END / cfg.SOLVER.BASE_LR + assert end_value >= 0.0 and end_value <= 1.0, end_value + sched = CosineParamScheduler(1, end_value) + elif name == "WarmupStepWithFixedGammaLR": + sched = StepWithFixedGammaParamScheduler( + base_value=1.0, + gamma=cfg.SOLVER.GAMMA, + num_decays=cfg.SOLVER.NUM_DECAYS, + num_updates=cfg.SOLVER.MAX_ITER, + ) + else: + raise ValueError("Unknown LR scheduler: {}".format(name)) + + sched = WarmupParamScheduler( + sched, + cfg.SOLVER.WARMUP_FACTOR, + min(cfg.SOLVER.WARMUP_ITERS / cfg.SOLVER.MAX_ITER, 1.0), + cfg.SOLVER.WARMUP_METHOD, + cfg.SOLVER.RESCALE_INTERVAL, + ) + return LRMultiplier(optimizer, multiplier=sched, max_iter=cfg.SOLVER.MAX_ITER) diff --git a/custom_detectron2/solver/lr_scheduler.py b/custom_detectron2/solver/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..d6aed2bb20c418bf6cc5594c1244b241796d7086 --- /dev/null +++ b/custom_detectron2/solver/lr_scheduler.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import logging +import math +from bisect import bisect_right +from typing import List +import torch +from fvcore.common.param_scheduler import ( + CompositeParamScheduler, + ConstantParamScheduler, + LinearParamScheduler, + ParamScheduler, +) + +try: + from torch.optim.lr_scheduler import LRScheduler +except ImportError: + from torch.optim.lr_scheduler import _LRScheduler as LRScheduler + +logger = logging.getLogger(__name__) + + +class WarmupParamScheduler(CompositeParamScheduler): + """ + Add an initial warmup stage to another scheduler. + """ + + def __init__( + self, + scheduler: ParamScheduler, + warmup_factor: float, + warmup_length: float, + warmup_method: str = "linear", + rescale_interval: bool = False, + ): + """ + Args: + scheduler: warmup will be added at the beginning of this scheduler + warmup_factor: the factor w.r.t the initial value of ``scheduler``, e.g. 0.001 + warmup_length: the relative length (in [0, 1]) of warmup steps w.r.t the entire + training, e.g. 0.01 + warmup_method: one of "linear" or "constant" + rescale_interval: whether we will rescale the interval of the scheduler after + warmup + """ + end_value = scheduler(warmup_length) # the value to reach when warmup ends + start_value = warmup_factor * scheduler(0.0) + if warmup_method == "constant": + warmup = ConstantParamScheduler(start_value) + elif warmup_method == "linear": + warmup = LinearParamScheduler(start_value, end_value) + else: + raise ValueError("Unknown warmup method: {}".format(warmup_method)) + super().__init__( + [warmup, scheduler], + interval_scaling=["rescaled", "rescaled" if rescale_interval else "fixed"], + lengths=[warmup_length, 1 - warmup_length], + ) + + +class LRMultiplier(LRScheduler): + """ + A LRScheduler which uses fvcore :class:`ParamScheduler` to multiply the + learning rate of each param in the optimizer. + Every step, the learning rate of each parameter becomes its initial value + multiplied by the output of the given :class:`ParamScheduler`. + + The absolute learning rate value of each parameter can be different. + This scheduler can be used as long as the relative scale among them do + not change during training. + + Examples: + :: + LRMultiplier( + opt, + WarmupParamScheduler( + MultiStepParamScheduler( + [1, 0.1, 0.01], + milestones=[60000, 80000], + num_updates=90000, + ), 0.001, 100 / 90000 + ), + max_iter=90000 + ) + """ + + # NOTES: in the most general case, every LR can use its own scheduler. + # Supporting this requires interaction with the optimizer when its parameter + # group is initialized. For example, classyvision implements its own optimizer + # that allows different schedulers for every parameter group. + # To avoid this complexity, we use this class to support the most common cases + # where the relative scale among all LRs stay unchanged during training. In this + # case we only need a total of one scheduler that defines the relative LR multiplier. + + def __init__( + self, + optimizer: torch.optim.Optimizer, + multiplier: ParamScheduler, + max_iter: int, + last_iter: int = -1, + ): + """ + Args: + optimizer, last_iter: See ``torch.optim.lr_scheduler.LRScheduler``. + ``last_iter`` is the same as ``last_epoch``. + multiplier: a fvcore ParamScheduler that defines the multiplier on + every LR of the optimizer + max_iter: the total number of training iterations + """ + if not isinstance(multiplier, ParamScheduler): + raise ValueError( + "_LRMultiplier(multiplier=) must be an instance of fvcore " + f"ParamScheduler. Got {multiplier} instead." + ) + self._multiplier = multiplier + self._max_iter = max_iter + super().__init__(optimizer, last_epoch=last_iter) + + def state_dict(self): + # fvcore schedulers are stateless. Only keep pytorch scheduler states + return {"base_lrs": self.base_lrs, "last_epoch": self.last_epoch} + + def get_lr(self) -> List[float]: + multiplier = self._multiplier(self.last_epoch / self._max_iter) + return [base_lr * multiplier for base_lr in self.base_lrs] + + +""" +Content below is no longer needed! +""" + +# NOTE: PyTorch's LR scheduler interface uses names that assume the LR changes +# only on epoch boundaries. We typically use iteration based schedules instead. +# As a result, "epoch" (e.g., as in self.last_epoch) should be understood to mean +# "iteration" instead. + +# FIXME: ideally this would be achieved with a CombinedLRScheduler, separating +# MultiStepLR with WarmupLR but the current LRScheduler design doesn't allow it. + + +class WarmupMultiStepLR(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + milestones: List[int], + gamma: float = 0.1, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + ): + logger.warning( + "WarmupMultiStepLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" + ) + if not list(milestones) == sorted(milestones): + raise ValueError( + "Milestones should be a list of" " increasing integers. Got {}", milestones + ) + self.milestones = milestones + self.gamma = gamma + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + return [ + base_lr * warmup_factor * self.gamma ** bisect_right(self.milestones, self.last_epoch) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +class WarmupCosineLR(LRScheduler): + def __init__( + self, + optimizer: torch.optim.Optimizer, + max_iters: int, + warmup_factor: float = 0.001, + warmup_iters: int = 1000, + warmup_method: str = "linear", + last_epoch: int = -1, + ): + logger.warning( + "WarmupCosineLR is deprecated! Use LRMultipilier with fvcore ParamScheduler instead!" + ) + self.max_iters = max_iters + self.warmup_factor = warmup_factor + self.warmup_iters = warmup_iters + self.warmup_method = warmup_method + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + warmup_factor = _get_warmup_factor_at_iter( + self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor + ) + # Different definitions of half-cosine with warmup are possible. For + # simplicity we multiply the standard half-cosine schedule by the warmup + # factor. An alternative is to start the period of the cosine at warmup_iters + # instead of at 0. In the case that warmup_iters << max_iters the two are + # very close to each other. + return [ + base_lr + * warmup_factor + * 0.5 + * (1.0 + math.cos(math.pi * self.last_epoch / self.max_iters)) + for base_lr in self.base_lrs + ] + + def _compute_values(self) -> List[float]: + # The new interface + return self.get_lr() + + +def _get_warmup_factor_at_iter( + method: str, iter: int, warmup_iters: int, warmup_factor: float +) -> float: + """ + Return the learning rate warmup factor at a specific iteration. + See :paper:`ImageNet in 1h` for more details. + + Args: + method (str): warmup method; either "constant" or "linear". + iter (int): iteration at which to calculate the warmup factor. + warmup_iters (int): the number of warmup iterations. + warmup_factor (float): the base warmup factor (the meaning changes according + to the method used). + + Returns: + float: the effective warmup factor at the given iteration. + """ + if iter >= warmup_iters: + return 1.0 + + if method == "constant": + return warmup_factor + elif method == "linear": + alpha = iter / warmup_iters + return warmup_factor * (1 - alpha) + alpha + else: + raise ValueError("Unknown warmup method: {}".format(method)) diff --git a/custom_detectron2/structures/__init__.py b/custom_detectron2/structures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a23142782a68816b64fe0e3690f099b83008e9df --- /dev/null +++ b/custom_detectron2/structures/__init__.py @@ -0,0 +1,17 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .boxes import Boxes, BoxMode, pairwise_iou, pairwise_ioa, pairwise_point_box_distance +from .image_list import ImageList + +from .instances import Instances +from .keypoints import Keypoints, heatmaps_to_keypoints +from .masks import BitMasks, PolygonMasks, polygons_to_bitmask, ROIMasks +from .rotated_boxes import RotatedBoxes +from .rotated_boxes import pairwise_iou as pairwise_iou_rotated + +__all__ = [k for k in globals().keys() if not k.startswith("_")] + + +from custom_detectron2.utils.env import fixup_module_metadata + +fixup_module_metadata(__name__, globals(), __all__) +del fixup_module_metadata diff --git a/custom_detectron2/structures/boxes.py b/custom_detectron2/structures/boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..fd396f68645db1d6946056eed868ffcc02cd7a22 --- /dev/null +++ b/custom_detectron2/structures/boxes.py @@ -0,0 +1,425 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +import numpy as np +from enum import IntEnum, unique +from typing import List, Tuple, Union +import torch +from torch import device + +_RawBoxType = Union[List[float], Tuple[float, ...], torch.Tensor, np.ndarray] + + +@unique +class BoxMode(IntEnum): + """ + Enum of different ways to represent a box. + """ + + XYXY_ABS = 0 + """ + (x0, y0, x1, y1) in absolute floating points coordinates. + The coordinates in range [0, width or height]. + """ + XYWH_ABS = 1 + """ + (x0, y0, w, h) in absolute floating points coordinates. + """ + XYXY_REL = 2 + """ + Not yet supported! + (x0, y0, x1, y1) in range [0, 1]. They are relative to the size of the image. + """ + XYWH_REL = 3 + """ + Not yet supported! + (x0, y0, w, h) in range [0, 1]. They are relative to the size of the image. + """ + XYWHA_ABS = 4 + """ + (xc, yc, w, h, a) in absolute floating points coordinates. + (xc, yc) is the center of the rotated box, and the angle a is in degrees ccw. + """ + + @staticmethod + def convert(box: _RawBoxType, from_mode: "BoxMode", to_mode: "BoxMode") -> _RawBoxType: + """ + Args: + box: can be a k-tuple, k-list or an Nxk array/tensor, where k = 4 or 5 + from_mode, to_mode (BoxMode) + + Returns: + The converted box of the same type. + """ + if from_mode == to_mode: + return box + + original_type = type(box) + is_numpy = isinstance(box, np.ndarray) + single_box = isinstance(box, (list, tuple)) + if single_box: + assert len(box) == 4 or len(box) == 5, ( + "BoxMode.convert takes either a k-tuple/list or an Nxk array/tensor," + " where k == 4 or 5" + ) + arr = torch.tensor(box)[None, :] + else: + # avoid modifying the input box + if is_numpy: + arr = torch.from_numpy(np.asarray(box)).clone() + else: + arr = box.clone() + + assert to_mode not in [BoxMode.XYXY_REL, BoxMode.XYWH_REL] and from_mode not in [ + BoxMode.XYXY_REL, + BoxMode.XYWH_REL, + ], "Relative mode not yet supported!" + + if from_mode == BoxMode.XYWHA_ABS and to_mode == BoxMode.XYXY_ABS: + assert ( + arr.shape[-1] == 5 + ), "The last dimension of input shape must be 5 for XYWHA format" + original_dtype = arr.dtype + arr = arr.double() + + w = arr[:, 2] + h = arr[:, 3] + a = arr[:, 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + new_w = c * w + s * h + new_h = c * h + s * w + + # convert center to top-left corner + arr[:, 0] -= new_w / 2.0 + arr[:, 1] -= new_h / 2.0 + # bottom-right corner + arr[:, 2] = arr[:, 0] + new_w + arr[:, 3] = arr[:, 1] + new_h + + arr = arr[:, :4].to(dtype=original_dtype) + elif from_mode == BoxMode.XYWH_ABS and to_mode == BoxMode.XYWHA_ABS: + original_dtype = arr.dtype + arr = arr.double() + arr[:, 0] += arr[:, 2] / 2.0 + arr[:, 1] += arr[:, 3] / 2.0 + angles = torch.zeros((arr.shape[0], 1), dtype=arr.dtype) + arr = torch.cat((arr, angles), axis=1).to(dtype=original_dtype) + else: + if to_mode == BoxMode.XYXY_ABS and from_mode == BoxMode.XYWH_ABS: + arr[:, 2] += arr[:, 0] + arr[:, 3] += arr[:, 1] + elif from_mode == BoxMode.XYXY_ABS and to_mode == BoxMode.XYWH_ABS: + arr[:, 2] -= arr[:, 0] + arr[:, 3] -= arr[:, 1] + else: + raise NotImplementedError( + "Conversion from BoxMode {} to {} is not supported yet".format( + from_mode, to_mode + ) + ) + + if single_box: + return original_type(arr.flatten().tolist()) + if is_numpy: + return arr.numpy() + else: + return arr + + +class Boxes: + """ + This structure stores a list of boxes as a Nx4 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + + Attributes: + tensor (torch.Tensor): float matrix of Nx4. Each row is (x1, y1, x2, y2). + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2). + """ + if not isinstance(tensor, torch.Tensor): + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=torch.device("cpu")) + else: + tensor = tensor.to(torch.float32) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32) + assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size() + + self.tensor = tensor + + def clone(self) -> "Boxes": + """ + Clone the Boxes. + + Returns: + Boxes + """ + return Boxes(self.tensor.clone()) + + def to(self, device: torch.device): + # Boxes are assumed float32 and does not support to(dtype) + return Boxes(self.tensor.to(device=device)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = (box[:, 2] - box[:, 0]) * (box[:, 3] - box[:, 1]) + return area + + def clip(self, box_size: Tuple[int, int]) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + Args: + box_size (height, width): The clipping box's size. + """ + assert torch.isfinite(self.tensor).all(), "Box tensor contains infinite or NaN!" + h, w = box_size + x1 = self.tensor[:, 0].clamp(min=0, max=w) + y1 = self.tensor[:, 1].clamp(min=0, max=h) + x2 = self.tensor[:, 2].clamp(min=0, max=w) + y2 = self.tensor[:, 3].clamp(min=0, max=h) + self.tensor = torch.stack((x1, y1, x2, y2), dim=-1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: + a binary vector which represents whether each box is empty + (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] - box[:, 0] + heights = box[:, 3] - box[:, 1] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "Boxes": + """ + Args: + item: int, slice, or a BoolTensor + + Returns: + Boxes: Create a new :class:`Boxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `Boxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.BoolTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned Boxes might share storage with this Boxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Boxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert b.dim() == 2, "Indexing on Boxes with {} failed to return a matrix!".format(item) + return Boxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "Boxes(" + str(self.tensor) + ")" + + def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box. + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + inds_inside = ( + (self.tensor[..., 0] >= -boundary_threshold) + & (self.tensor[..., 1] >= -boundary_threshold) + & (self.tensor[..., 2] < width + boundary_threshold) + & (self.tensor[..., 3] < height + boundary_threshold) + ) + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return (self.tensor[:, :2] + self.tensor[:, 2:]) / 2 + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the box with horizontal and vertical scaling factors + """ + self.tensor[:, 0::2] *= scale_x + self.tensor[:, 1::2] *= scale_y + + @classmethod + def cat(cls, boxes_list: List["Boxes"]) -> "Boxes": + """ + Concatenates a list of Boxes into a single Boxes + + Arguments: + boxes_list (list[Boxes]) + + Returns: + Boxes: the concatenated Boxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, Boxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> device: + return self.tensor.device + + # type "Iterator[torch.Tensor]", yield, and iter() not supported by torchscript + # https://github.com/pytorch/pytorch/issues/18627 + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (4,) at a time. + """ + yield from self.tensor + + +def pairwise_intersection(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, + compute the intersection area between __all__ N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax) + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: intersection, sized [N,M]. + """ + boxes1, boxes2 = boxes1.tensor, boxes2.tensor + width_height = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) - torch.max( + boxes1[:, None, :2], boxes2[:, :2] + ) # [N,M,2] + + width_height.clamp_(min=0) # [N,M,2] + intersection = width_height.prod(dim=2) # [N,M] + return intersection + + +# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py +# with slight modifications +def pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Given two lists of boxes of size N and M, compute the IoU + (intersection over union) between **all** N x M pairs of boxes. + The box order must be (xmin, ymin, xmax, ymax). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + iou = torch.where( + inter > 0, + inter / (area1[:, None] + area2 - inter), + torch.zeros(1, dtype=inter.dtype, device=inter.device), + ) + return iou + + +def pairwise_ioa(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Similar to :func:`pariwise_iou` but compute the IoA (intersection over boxes2 area). + + Args: + boxes1,boxes2 (Boxes): two `Boxes`. Contains N & M boxes, respectively. + + Returns: + Tensor: IoA, sized [N,M]. + """ + area2 = boxes2.area() # [M] + inter = pairwise_intersection(boxes1, boxes2) + + # handle empty boxes + ioa = torch.where( + inter > 0, inter / area2, torch.zeros(1, dtype=inter.dtype, device=inter.device) + ) + return ioa + + +def pairwise_point_box_distance(points: torch.Tensor, boxes: Boxes): + """ + Pairwise distance between N points and M boxes. The distance between a + point and a box is represented by the distance from the point to 4 edges + of the box. Distances are all positive when the point is inside the box. + + Args: + points: Nx2 coordinates. Each row is (x, y) + boxes: M boxes + + Returns: + Tensor: distances of size (N, M, 4). The 4 values are distances from + the point to the left, top, right, bottom of the box. + """ + x, y = points.unsqueeze(dim=2).unbind(dim=1) # (N, 1) + x0, y0, x1, y1 = boxes.tensor.unsqueeze(dim=0).unbind(dim=2) # (1, M) + return torch.stack([x - x0, y - y0, x1 - x, y1 - y], dim=2) + + +def matched_pairwise_iou(boxes1: Boxes, boxes2: Boxes) -> torch.Tensor: + """ + Compute pairwise intersection over union (IOU) of two sets of matched + boxes that have the same number of boxes. + Similar to :func:`pairwise_iou`, but computes only diagonal elements of the matrix. + + Args: + boxes1 (Boxes): bounding boxes, sized [N,4]. + boxes2 (Boxes): same length as boxes1 + Returns: + Tensor: iou, sized [N]. + """ + assert len(boxes1) == len( + boxes2 + ), "boxlists should have the same" "number of entries, got {}, {}".format( + len(boxes1), len(boxes2) + ) + area1 = boxes1.area() # [N] + area2 = boxes2.area() # [N] + box1, box2 = boxes1.tensor, boxes2.tensor + lt = torch.max(box1[:, :2], box2[:, :2]) # [N,2] + rb = torch.min(box1[:, 2:], box2[:, 2:]) # [N,2] + wh = (rb - lt).clamp(min=0) # [N,2] + inter = wh[:, 0] * wh[:, 1] # [N] + iou = inter / (area1 + area2 - inter) # [N] + return iou diff --git a/custom_detectron2/structures/image_list.py b/custom_detectron2/structures/image_list.py new file mode 100644 index 0000000000000000000000000000000000000000..09c3d0ebf9517eaf6a651dab4f29cd549ee4784a --- /dev/null +++ b/custom_detectron2/structures/image_list.py @@ -0,0 +1,129 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from __future__ import division +from typing import Any, Dict, List, Optional, Tuple +import torch +from torch import device +from torch.nn import functional as F + +from custom_detectron2.layers.wrappers import move_device_like, shapes_to_tensor + + +class ImageList(object): + """ + Structure that holds a list of images (of possibly + varying sizes) as a single tensor. + This works by padding the images to the same size. + The original sizes of each image is stored in `image_sizes`. + + Attributes: + image_sizes (list[tuple[int, int]]): each tuple is (h, w). + During tracing, it becomes list[Tensor] instead. + """ + + def __init__(self, tensor: torch.Tensor, image_sizes: List[Tuple[int, int]]): + """ + Arguments: + tensor (Tensor): of shape (N, H, W) or (N, C_1, ..., C_K, H, W) where K >= 1 + image_sizes (list[tuple[int, int]]): Each tuple is (h, w). It can + be smaller than (H, W) due to padding. + """ + self.tensor = tensor + self.image_sizes = image_sizes + + def __len__(self) -> int: + return len(self.image_sizes) + + def __getitem__(self, idx) -> torch.Tensor: + """ + Access the individual image in its original size. + + Args: + idx: int or slice + + Returns: + Tensor: an image of shape (H, W) or (C_1, ..., C_K, H, W) where K >= 1 + """ + size = self.image_sizes[idx] + return self.tensor[idx, ..., : size[0], : size[1]] + + @torch.jit.unused + def to(self, *args: Any, **kwargs: Any) -> "ImageList": + cast_tensor = self.tensor.to(*args, **kwargs) + return ImageList(cast_tensor, self.image_sizes) + + @property + def device(self) -> device: + return self.tensor.device + + @staticmethod + def from_tensors( + tensors: List[torch.Tensor], + size_divisibility: int = 0, + pad_value: float = 0.0, + padding_constraints: Optional[Dict[str, int]] = None, + ) -> "ImageList": + """ + Args: + tensors: a tuple or list of `torch.Tensor`, each of shape (Hi, Wi) or + (C_1, ..., C_K, Hi, Wi) where K >= 1. The Tensors will be padded + to the same shape with `pad_value`. + size_divisibility (int): If `size_divisibility > 0`, add padding to ensure + the common height and width is divisible by `size_divisibility`. + This depends on the model and many models need a divisibility of 32. + pad_value (float): value to pad. + padding_constraints (optional[Dict]): If given, it would follow the format as + {"size_divisibility": int, "square_size": int}, where `size_divisibility` will + overwrite the above one if presented and `square_size` indicates the + square padding size if `square_size` > 0. + Returns: + an `ImageList`. + """ + assert len(tensors) > 0 + assert isinstance(tensors, (tuple, list)) + for t in tensors: + assert isinstance(t, torch.Tensor), type(t) + assert t.shape[:-2] == tensors[0].shape[:-2], t.shape + + image_sizes = [(im.shape[-2], im.shape[-1]) for im in tensors] + image_sizes_tensor = [shapes_to_tensor(x) for x in image_sizes] + max_size = torch.stack(image_sizes_tensor).max(0).values + + if padding_constraints is not None: + square_size = padding_constraints.get("square_size", 0) + if square_size > 0: + # pad to square. + max_size[0] = max_size[1] = square_size + if "size_divisibility" in padding_constraints: + size_divisibility = padding_constraints["size_divisibility"] + if size_divisibility > 1: + stride = size_divisibility + # the last two dims are H,W, both subject to divisibility requirement + max_size = (max_size + (stride - 1)).div(stride, rounding_mode="floor") * stride + + # handle weirdness of scripting and tracing ... + if torch.jit.is_scripting(): + max_size: List[int] = max_size.to(dtype=torch.long).tolist() + else: + if torch.jit.is_tracing(): + image_sizes = image_sizes_tensor + + if len(tensors) == 1: + # This seems slightly (2%) faster. + # TODO: check whether it's faster for multiple images as well + image_size = image_sizes[0] + padding_size = [0, max_size[-1] - image_size[1], 0, max_size[-2] - image_size[0]] + batched_imgs = F.pad(tensors[0], padding_size, value=pad_value).unsqueeze_(0) + else: + # max_size can be a tensor in tracing mode, therefore convert to list + batch_shape = [len(tensors)] + list(tensors[0].shape[:-2]) + list(max_size) + device = ( + None if torch.jit.is_scripting() else ("cpu" if torch.jit.is_tracing() else None) + ) + batched_imgs = tensors[0].new_full(batch_shape, pad_value, device=device) + batched_imgs = move_device_like(batched_imgs, tensors[0]) + for i, img in enumerate(tensors): + # Use `batched_imgs` directly instead of `img, pad_img = zip(tensors, batched_imgs)` + # Tracing mode cannot capture `copy_()` of temporary locals + batched_imgs[i, ..., : img.shape[-2], : img.shape[-1]].copy_(img) + + return ImageList(batched_imgs.contiguous(), image_sizes) diff --git a/custom_detectron2/structures/instances.py b/custom_detectron2/structures/instances.py new file mode 100644 index 0000000000000000000000000000000000000000..c9579bce2730f42e256c6eed99d9014d09304c99 --- /dev/null +++ b/custom_detectron2/structures/instances.py @@ -0,0 +1,194 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import itertools +import warnings +from typing import Any, Dict, List, Tuple, Union +import torch + + +class Instances: + """ + This class represents a list of instances in an image. + It stores the attributes of instances (e.g., boxes, masks, labels, scores) as "fields". + All fields must have the same ``__len__`` which is the number of instances. + + All other (non-field) attributes of this class are considered private: + they must start with '_' and are not modifiable by a user. + + Some basic usage: + + 1. Set/get/check a field: + + .. code-block:: python + + instances.gt_boxes = Boxes(...) + print(instances.pred_masks) # a tensor of shape (N, H, W) + print('gt_masks' in instances) + + 2. ``len(instances)`` returns the number of instances + 3. Indexing: ``instances[indices]`` will apply the indexing on all the fields + and returns a new :class:`Instances`. + Typically, ``indices`` is a integer vector of indices, + or a binary mask of length ``num_instances`` + + .. code-block:: python + + category_3_detections = instances[instances.pred_classes == 3] + confident_detections = instances[instances.scores > 0.9] + """ + + def __init__(self, image_size: Tuple[int, int], **kwargs: Any): + """ + Args: + image_size (height, width): the spatial size of the image. + kwargs: fields to add to this `Instances`. + """ + self._image_size = image_size + self._fields: Dict[str, Any] = {} + for k, v in kwargs.items(): + self.set(k, v) + + @property + def image_size(self) -> Tuple[int, int]: + """ + Returns: + tuple: height, width + """ + return self._image_size + + def __setattr__(self, name: str, val: Any) -> None: + if name.startswith("_"): + super().__setattr__(name, val) + else: + self.set(name, val) + + def __getattr__(self, name: str) -> Any: + if name == "_fields" or name not in self._fields: + raise AttributeError("Cannot find field '{}' in the given Instances!".format(name)) + return self._fields[name] + + def set(self, name: str, value: Any) -> None: + """ + Set the field named `name` to `value`. + The length of `value` must be the number of instances, + and must agree with other existing fields in this object. + """ + with warnings.catch_warnings(record=True): + data_len = len(value) + if len(self._fields): + assert ( + len(self) == data_len + ), "Adding a field of length {} to a Instances of length {}".format(data_len, len(self)) + self._fields[name] = value + + def has(self, name: str) -> bool: + """ + Returns: + bool: whether the field called `name` exists. + """ + return name in self._fields + + def remove(self, name: str) -> None: + """ + Remove the field called `name`. + """ + del self._fields[name] + + def get(self, name: str) -> Any: + """ + Returns the field called `name`. + """ + return self._fields[name] + + def get_fields(self) -> Dict[str, Any]: + """ + Returns: + dict: a dict which maps names (str) to data of the fields + + Modifying the returned dict will modify this instance. + """ + return self._fields + + # Tensor-like methods + def to(self, *args: Any, **kwargs: Any) -> "Instances": + """ + Returns: + Instances: all fields are called with a `to(device)`, if the field has this method. + """ + ret = Instances(self._image_size) + for k, v in self._fields.items(): + if hasattr(v, "to"): + v = v.to(*args, **kwargs) + ret.set(k, v) + return ret + + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Instances": + """ + Args: + item: an index-like object and will be used to index all the fields. + + Returns: + If `item` is a string, return the data in the corresponding field. + Otherwise, returns an `Instances` where all fields are indexed by `item`. + """ + if type(item) == int: + if item >= len(self) or item < -len(self): + raise IndexError("Instances index out of range!") + else: + item = slice(item, None, len(self)) + + ret = Instances(self._image_size) + for k, v in self._fields.items(): + ret.set(k, v[item]) + return ret + + def __len__(self) -> int: + for v in self._fields.values(): + # use __len__ because len() has to be int and is not friendly to tracing + return v.__len__() + raise NotImplementedError("Empty Instances does not support __len__!") + + def __iter__(self): + raise NotImplementedError("`Instances` object is not iterable!") + + @staticmethod + def cat(instance_lists: List["Instances"]) -> "Instances": + """ + Args: + instance_lists (list[Instances]) + + Returns: + Instances + """ + assert all(isinstance(i, Instances) for i in instance_lists) + assert len(instance_lists) > 0 + if len(instance_lists) == 1: + return instance_lists[0] + + image_size = instance_lists[0].image_size + if not isinstance(image_size, torch.Tensor): # could be a tensor in tracing + for i in instance_lists[1:]: + assert i.image_size == image_size + ret = Instances(image_size) + for k in instance_lists[0]._fields.keys(): + values = [i.get(k) for i in instance_lists] + v0 = values[0] + if isinstance(v0, torch.Tensor): + values = torch.cat(values, dim=0) + elif isinstance(v0, list): + values = list(itertools.chain(*values)) + elif hasattr(type(v0), "cat"): + values = type(v0).cat(values) + else: + raise ValueError("Unsupported type {} for concatenation".format(type(v0))) + ret.set(k, values) + return ret + + def __str__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={}, ".format(len(self)) + s += "image_height={}, ".format(self._image_size[0]) + s += "image_width={}, ".format(self._image_size[1]) + s += "fields=[{}])".format(", ".join((f"{k}: {v}" for k, v in self._fields.items()))) + return s + + __repr__ = __str__ diff --git a/custom_detectron2/structures/keypoints.py b/custom_detectron2/structures/keypoints.py new file mode 100644 index 0000000000000000000000000000000000000000..b93ebed4f6554e67ba9bde8d3af90e8dbb3246b6 --- /dev/null +++ b/custom_detectron2/structures/keypoints.py @@ -0,0 +1,235 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import Any, List, Tuple, Union +import torch +from torch.nn import functional as F + + +class Keypoints: + """ + Stores keypoint **annotation** data. GT Instances have a `gt_keypoints` property + containing the x,y location and visibility flag of each keypoint. This tensor has shape + (N, K, 3) where N is the number of instances and K is the number of keypoints per instance. + + The visibility flag follows the COCO format and must be one of three integers: + + * v=0: not labeled (in which case x=y=0) + * v=1: labeled but not visible + * v=2: labeled and visible + """ + + def __init__(self, keypoints: Union[torch.Tensor, np.ndarray, List[List[float]]]): + """ + Arguments: + keypoints: A Tensor, numpy array, or list of the x, y, and visibility of each keypoint. + The shape should be (N, K, 3) where N is the number of + instances, and K is the number of keypoints per instance. + """ + device = keypoints.device if isinstance(keypoints, torch.Tensor) else torch.device("cpu") + keypoints = torch.as_tensor(keypoints, dtype=torch.float32, device=device) + assert keypoints.dim() == 3 and keypoints.shape[2] == 3, keypoints.shape + self.tensor = keypoints + + def __len__(self) -> int: + return self.tensor.size(0) + + def to(self, *args: Any, **kwargs: Any) -> "Keypoints": + return type(self)(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + def to_heatmap(self, boxes: torch.Tensor, heatmap_size: int) -> torch.Tensor: + """ + Convert keypoint annotations to a heatmap of one-hot labels for training, + as described in :paper:`Mask R-CNN`. + + Arguments: + boxes: Nx4 tensor, the boxes to draw the keypoints to + + Returns: + heatmaps: + A tensor of shape (N, K), each element is integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: + A tensor of shape (N, K) containing whether each keypoint is in the roi or not. + """ + return _keypoints_to_heatmap(self.tensor, boxes, heatmap_size) + + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "Keypoints": + """ + Create a new `Keypoints` by indexing on this `Keypoints`. + + The following usage are allowed: + + 1. `new_kpts = kpts[3]`: return a `Keypoints` which contains only one instance. + 2. `new_kpts = kpts[2:10]`: return a slice of key points. + 3. `new_kpts = kpts[vector]`, where vector is a torch.ByteTensor + with `length = len(kpts)`. Nonzero elements in the vector will be selected. + + Note that the returned Keypoints might share storage with this Keypoints, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return Keypoints([self.tensor[item]]) + return Keypoints(self.tensor[item]) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @staticmethod + def cat(keypoints_list: List["Keypoints"]) -> "Keypoints": + """ + Concatenates a list of Keypoints into a single Keypoints + + Arguments: + keypoints_list (list[Keypoints]) + + Returns: + Keypoints: the concatenated Keypoints + """ + assert isinstance(keypoints_list, (list, tuple)) + assert len(keypoints_list) > 0 + assert all(isinstance(keypoints, Keypoints) for keypoints in keypoints_list) + + cat_kpts = type(keypoints_list[0])( + torch.cat([kpts.tensor for kpts in keypoints_list], dim=0) + ) + return cat_kpts + + +# TODO make this nicer, this is a direct translation from C2 (but removing the inner loop) +def _keypoints_to_heatmap( + keypoints: torch.Tensor, rois: torch.Tensor, heatmap_size: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Encode keypoint locations into a target heatmap for use in SoftmaxWithLoss across space. + + Maps keypoints from the half-open interval [x1, x2) on continuous image coordinates to the + closed interval [0, heatmap_size - 1] on discrete image coordinates. We use the + continuous-discrete conversion from Heckbert 1990 ("What is the coordinate of a pixel?"): + d = floor(c) and c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + + Arguments: + keypoints: tensor of keypoint locations in of shape (N, K, 3). + rois: Nx4 tensor of rois in xyxy format + heatmap_size: integer side length of square heatmap. + + Returns: + heatmaps: A tensor of shape (N, K) containing an integer spatial label + in the range [0, heatmap_size**2 - 1] for each keypoint in the input. + valid: A tensor of shape (N, K) containing whether each keypoint is in + the roi or not. + """ + + if rois.numel() == 0: + return rois.new().long(), rois.new().long() + offset_x = rois[:, 0] + offset_y = rois[:, 1] + scale_x = heatmap_size / (rois[:, 2] - rois[:, 0]) + scale_y = heatmap_size / (rois[:, 3] - rois[:, 1]) + + offset_x = offset_x[:, None] + offset_y = offset_y[:, None] + scale_x = scale_x[:, None] + scale_y = scale_y[:, None] + + x = keypoints[..., 0] + y = keypoints[..., 1] + + x_boundary_inds = x == rois[:, 2][:, None] + y_boundary_inds = y == rois[:, 3][:, None] + + x = (x - offset_x) * scale_x + x = x.floor().long() + y = (y - offset_y) * scale_y + y = y.floor().long() + + x[x_boundary_inds] = heatmap_size - 1 + y[y_boundary_inds] = heatmap_size - 1 + + valid_loc = (x >= 0) & (y >= 0) & (x < heatmap_size) & (y < heatmap_size) + vis = keypoints[..., 2] > 0 + valid = (valid_loc & vis).long() + + lin_ind = y * heatmap_size + x + heatmaps = lin_ind * valid + + return heatmaps, valid + + +@torch.jit.script_if_tracing +def heatmaps_to_keypoints(maps: torch.Tensor, rois: torch.Tensor) -> torch.Tensor: + """ + Extract predicted keypoint locations from heatmaps. + + Args: + maps (Tensor): (#ROIs, #keypoints, POOL_H, POOL_W). The predicted heatmap of logits for + each ROI and each keypoint. + rois (Tensor): (#ROIs, 4). The box of each ROI. + + Returns: + Tensor of shape (#ROIs, #keypoints, 4) with the last dimension corresponding to + (x, y, logit, score) for each keypoint. + + When converting discrete pixel indices in an NxN image to a continuous keypoint coordinate, + we maintain consistency with :meth:`Keypoints.to_heatmap` by using the conversion from + Heckbert 1990: c = d + 0.5, where d is a discrete coordinate and c is a continuous coordinate. + """ + + offset_x = rois[:, 0] + offset_y = rois[:, 1] + + widths = (rois[:, 2] - rois[:, 0]).clamp(min=1) + heights = (rois[:, 3] - rois[:, 1]).clamp(min=1) + widths_ceil = widths.ceil() + heights_ceil = heights.ceil() + + num_rois, num_keypoints = maps.shape[:2] + xy_preds = maps.new_zeros(rois.shape[0], num_keypoints, 4) + + width_corrections = widths / widths_ceil + height_corrections = heights / heights_ceil + + keypoints_idx = torch.arange(num_keypoints, device=maps.device) + + for i in range(num_rois): + outsize = (int(heights_ceil[i]), int(widths_ceil[i])) + roi_map = F.interpolate(maps[[i]], size=outsize, mode="bicubic", align_corners=False) + + # Although semantically equivalent, `reshape` is used instead of `squeeze` due + # to limitation during ONNX export of `squeeze` in scripting mode + roi_map = roi_map.reshape(roi_map.shape[1:]) # keypoints x H x W + + # softmax over the spatial region + max_score, _ = roi_map.view(num_keypoints, -1).max(1) + max_score = max_score.view(num_keypoints, 1, 1) + tmp_full_resolution = (roi_map - max_score).exp_() + tmp_pool_resolution = (maps[i] - max_score).exp_() + # Produce scores over the region H x W, but normalize with POOL_H x POOL_W, + # so that the scores of objects of different absolute sizes will be more comparable + roi_map_scores = tmp_full_resolution / tmp_pool_resolution.sum((1, 2), keepdim=True) + + w = roi_map.shape[2] + pos = roi_map.view(num_keypoints, -1).argmax(1) + + x_int = pos % w + y_int = (pos - x_int) // w + + assert ( + roi_map_scores[keypoints_idx, y_int, x_int] + == roi_map_scores.view(num_keypoints, -1).max(1)[0] + ).all() + + x = (x_int.float() + 0.5) * width_corrections[i] + y = (y_int.float() + 0.5) * height_corrections[i] + + xy_preds[i, :, 0] = x + offset_x[i] + xy_preds[i, :, 1] = y + offset_y[i] + xy_preds[i, :, 2] = roi_map[keypoints_idx, y_int, x_int] + xy_preds[i, :, 3] = roi_map_scores[keypoints_idx, y_int, x_int] + + return xy_preds diff --git a/custom_detectron2/structures/masks.py b/custom_detectron2/structures/masks.py new file mode 100644 index 0000000000000000000000000000000000000000..cc41656e414b801cb4cb088460caad79b40939f3 --- /dev/null +++ b/custom_detectron2/structures/masks.py @@ -0,0 +1,534 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import copy +import itertools +import numpy as np +from typing import Any, Iterator, List, Union +import custom_pycocotools.mask as mask_util +import torch +from torch import device + +from custom_detectron2.layers.roi_align import ROIAlign +from custom_detectron2.utils.memory import retry_if_cuda_oom + +from .boxes import Boxes + + +def polygon_area(x, y): + # Using the shoelace formula + # https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + return 0.5 * np.abs(np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1))) + + +def polygons_to_bitmask(polygons: List[np.ndarray], height: int, width: int) -> np.ndarray: + """ + Args: + polygons (list[ndarray]): each array has shape (Nx2,) + height, width (int) + + Returns: + ndarray: a bool mask of shape (height, width) + """ + if len(polygons) == 0: + # COCOAPI does not support empty polygons + return np.zeros((height, width)).astype(bool) + rles = mask_util.frPyObjects(polygons, height, width) + rle = mask_util.merge(rles) + return mask_util.decode(rle).astype(bool) + + +def rasterize_polygons_within_box( + polygons: List[np.ndarray], box: np.ndarray, mask_size: int +) -> torch.Tensor: + """ + Rasterize the polygons into a mask image and + crop the mask content in the given box. + The cropped mask is resized to (mask_size, mask_size). + + This function is used when generating training targets for mask head in Mask R-CNN. + Given original ground-truth masks for an image, new ground-truth mask + training targets in the size of `mask_size x mask_size` + must be provided for each predicted box. This function will be called to + produce such targets. + + Args: + polygons (list[ndarray[float]]): a list of polygons, which represents an instance. + box: 4-element numpy array + mask_size (int): + + Returns: + Tensor: BoolTensor of shape (mask_size, mask_size) + """ + # 1. Shift the polygons w.r.t the boxes + w, h = box[2] - box[0], box[3] - box[1] + + polygons = copy.deepcopy(polygons) + for p in polygons: + p[0::2] = p[0::2] - box[0] + p[1::2] = p[1::2] - box[1] + + # 2. Rescale the polygons to the new box size + # max() to avoid division by small number + ratio_h = mask_size / max(h, 0.1) + ratio_w = mask_size / max(w, 0.1) + + if ratio_h == ratio_w: + for p in polygons: + p *= ratio_h + else: + for p in polygons: + p[0::2] *= ratio_w + p[1::2] *= ratio_h + + # 3. Rasterize the polygons with coco api + mask = polygons_to_bitmask(polygons, mask_size, mask_size) + mask = torch.from_numpy(mask) + return mask + + +class BitMasks: + """ + This class stores the segmentation masks for all objects in one image, in + the form of bitmaps. + + Attributes: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + + def __init__(self, tensor: Union[torch.Tensor, np.ndarray]): + """ + Args: + tensor: bool Tensor of N,H,W, representing N instances in the image. + """ + if isinstance(tensor, torch.Tensor): + tensor = tensor.to(torch.bool) + else: + tensor = torch.as_tensor(tensor, dtype=torch.bool, device=torch.device("cpu")) + assert tensor.dim() == 3, tensor.size() + self.image_size = tensor.shape[1:] + self.tensor = tensor + + @torch.jit.unused + def to(self, *args: Any, **kwargs: Any) -> "BitMasks": + return BitMasks(self.tensor.to(*args, **kwargs)) + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __getitem__(self, item: Union[int, slice, torch.BoolTensor]) -> "BitMasks": + """ + Returns: + BitMasks: Create a new :class:`BitMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[3]`: return a `BitMasks` which contains only one mask. + 2. `new_masks = masks[2:10]`: return a slice of masks. + 3. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return BitMasks(self.tensor[item].unsqueeze(0)) + m = self.tensor[item] + assert m.dim() == 3, "Indexing on BitMasks with {} returns a tensor with shape {}!".format( + item, m.shape + ) + return BitMasks(m) + + @torch.jit.unused + def __iter__(self) -> torch.Tensor: + yield from self.tensor + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + def __len__(self) -> int: + return self.tensor.shape[0] + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: a BoolTensor which represents + whether each mask is empty (False) or non-empty (True). + """ + return self.tensor.flatten(1).any(dim=1) + + @staticmethod + def from_polygon_masks( + polygon_masks: Union["PolygonMasks", List[List[np.ndarray]]], height: int, width: int + ) -> "BitMasks": + """ + Args: + polygon_masks (list[list[ndarray]] or PolygonMasks) + height, width (int) + """ + if isinstance(polygon_masks, PolygonMasks): + polygon_masks = polygon_masks.polygons + masks = [polygons_to_bitmask(p, height, width) for p in polygon_masks] + if len(masks): + return BitMasks(torch.stack([torch.from_numpy(x) for x in masks])) + else: + return BitMasks(torch.empty(0, height, width, dtype=torch.bool)) + + @staticmethod + def from_roi_masks(roi_masks: "ROIMasks", height: int, width: int) -> "BitMasks": + """ + Args: + roi_masks: + height, width (int): + """ + return roi_masks.to_bitmasks(height, width) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each bitmask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + It has less reconstruction error compared to rasterization with polygons. + However we observe no difference in accuracy, + but BitMasks requires more memory to store all the masks. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: + A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + device = self.tensor.device + + batch_inds = torch.arange(len(boxes), device=device).to(dtype=boxes.dtype)[:, None] + rois = torch.cat([batch_inds, boxes], dim=1) # Nx5 + + bit_masks = self.tensor.to(dtype=torch.float32) + rois = rois.to(device=device) + output = ( + ROIAlign((mask_size, mask_size), 1.0, 0, aligned=True) + .forward(bit_masks[:, None, :, :], rois) + .squeeze(1) + ) + output = output >= 0.5 + return output + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around bitmasks. + If a mask is empty, it's bounding box will be all zero. + """ + boxes = torch.zeros(self.tensor.shape[0], 4, dtype=torch.float32) + x_any = torch.any(self.tensor, dim=1) + y_any = torch.any(self.tensor, dim=2) + for idx in range(self.tensor.shape[0]): + x = torch.where(x_any[idx, :])[0] + y = torch.where(y_any[idx, :])[0] + if len(x) > 0 and len(y) > 0: + boxes[idx, :] = torch.as_tensor( + [x[0], y[0], x[-1] + 1, y[-1] + 1], dtype=torch.float32 + ) + return Boxes(boxes) + + @staticmethod + def cat(bitmasks_list: List["BitMasks"]) -> "BitMasks": + """ + Concatenates a list of BitMasks into a single BitMasks + + Arguments: + bitmasks_list (list[BitMasks]) + + Returns: + BitMasks: the concatenated BitMasks + """ + assert isinstance(bitmasks_list, (list, tuple)) + assert len(bitmasks_list) > 0 + assert all(isinstance(bitmask, BitMasks) for bitmask in bitmasks_list) + + cat_bitmasks = type(bitmasks_list[0])(torch.cat([bm.tensor for bm in bitmasks_list], dim=0)) + return cat_bitmasks + + +class PolygonMasks: + """ + This class stores the segmentation masks for all objects in one image, in the form of polygons. + + Attributes: + polygons: list[list[ndarray]]. Each ndarray is a float64 vector representing a polygon. + """ + + def __init__(self, polygons: List[List[Union[torch.Tensor, np.ndarray]]]): + """ + Arguments: + polygons (list[list[np.ndarray]]): The first + level of the list correspond to individual instances, + the second level to all the polygons that compose the + instance, and the third level to the polygon coordinates. + The third level array should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + """ + if not isinstance(polygons, list): + raise ValueError( + "Cannot create PolygonMasks: Expect a list of list of polygons per image. " + "Got '{}' instead.".format(type(polygons)) + ) + + def _make_array(t: Union[torch.Tensor, np.ndarray]) -> np.ndarray: + # Use float64 for higher precision, because why not? + # Always put polygons on CPU (self.to is a no-op) since they + # are supposed to be small tensors. + # May need to change this assumption if GPU placement becomes useful + if isinstance(t, torch.Tensor): + t = t.cpu().numpy() + return np.asarray(t).astype("float64") + + def process_polygons( + polygons_per_instance: List[Union[torch.Tensor, np.ndarray]] + ) -> List[np.ndarray]: + if not isinstance(polygons_per_instance, list): + raise ValueError( + "Cannot create polygons: Expect a list of polygons per instance. " + "Got '{}' instead.".format(type(polygons_per_instance)) + ) + # transform each polygon to a numpy array + polygons_per_instance = [_make_array(p) for p in polygons_per_instance] + for polygon in polygons_per_instance: + if len(polygon) % 2 != 0 or len(polygon) < 6: + raise ValueError(f"Cannot create a polygon from {len(polygon)} coordinates.") + return polygons_per_instance + + self.polygons: List[List[np.ndarray]] = [ + process_polygons(polygons_per_instance) for polygons_per_instance in polygons + ] + + def to(self, *args: Any, **kwargs: Any) -> "PolygonMasks": + return self + + @property + def device(self) -> torch.device: + return torch.device("cpu") + + def get_bounding_boxes(self) -> Boxes: + """ + Returns: + Boxes: tight bounding boxes around polygon masks. + """ + boxes = torch.zeros(len(self.polygons), 4, dtype=torch.float32) + for idx, polygons_per_instance in enumerate(self.polygons): + minxy = torch.as_tensor([float("inf"), float("inf")], dtype=torch.float32) + maxxy = torch.zeros(2, dtype=torch.float32) + for polygon in polygons_per_instance: + coords = torch.from_numpy(polygon).view(-1, 2).to(dtype=torch.float32) + minxy = torch.min(minxy, torch.min(coords, dim=0).values) + maxxy = torch.max(maxxy, torch.max(coords, dim=0).values) + boxes[idx, :2] = minxy + boxes[idx, 2:] = maxxy + return Boxes(boxes) + + def nonempty(self) -> torch.Tensor: + """ + Find masks that are non-empty. + + Returns: + Tensor: + a BoolTensor which represents whether each mask is empty (False) or not (True). + """ + keep = [1 if len(polygon) > 0 else 0 for polygon in self.polygons] + return torch.from_numpy(np.asarray(keep, dtype=bool)) + + def __getitem__(self, item: Union[int, slice, List[int], torch.BoolTensor]) -> "PolygonMasks": + """ + Support indexing over the instances and return a `PolygonMasks` object. + `item` can be: + + 1. An integer. It will return an object with only one instance. + 2. A slice. It will return an object with the selected instances. + 3. A list[int]. It will return an object with the selected instances, + correpsonding to the indices in the list. + 4. A vector mask of type BoolTensor, whose length is num_instances. + It will return an object with the instances whose mask is nonzero. + """ + if isinstance(item, int): + selected_polygons = [self.polygons[item]] + elif isinstance(item, slice): + selected_polygons = self.polygons[item] + elif isinstance(item, list): + selected_polygons = [self.polygons[i] for i in item] + elif isinstance(item, torch.Tensor): + # Polygons is a list, so we have to move the indices back to CPU. + if item.dtype == torch.bool: + assert item.dim() == 1, item.shape + item = item.nonzero().squeeze(1).cpu().numpy().tolist() + elif item.dtype in [torch.int32, torch.int64]: + item = item.cpu().numpy().tolist() + else: + raise ValueError("Unsupported tensor dtype={} for indexing!".format(item.dtype)) + selected_polygons = [self.polygons[i] for i in item] + return PolygonMasks(selected_polygons) + + def __iter__(self) -> Iterator[List[np.ndarray]]: + """ + Yields: + list[ndarray]: the polygons for one instance. + Each Tensor is a float64 vector representing a polygon. + """ + return iter(self.polygons) + + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.polygons)) + return s + + def __len__(self) -> int: + return len(self.polygons) + + def crop_and_resize(self, boxes: torch.Tensor, mask_size: int) -> torch.Tensor: + """ + Crop each mask by the given box, and resize results to (mask_size, mask_size). + This can be used to prepare training targets for Mask R-CNN. + + Args: + boxes (Tensor): Nx4 tensor storing the boxes for each mask + mask_size (int): the size of the rasterized mask. + + Returns: + Tensor: A bool tensor of shape (N, mask_size, mask_size), where + N is the number of predicted boxes for this image. + """ + assert len(boxes) == len(self), "{} != {}".format(len(boxes), len(self)) + + device = boxes.device + # Put boxes on the CPU, as the polygon representation is not efficient GPU-wise + # (several small tensors for representing a single instance mask) + boxes = boxes.to(torch.device("cpu")) + + results = [ + rasterize_polygons_within_box(poly, box.numpy(), mask_size) + for poly, box in zip(self.polygons, boxes) + ] + """ + poly: list[list[float]], the polygons for one instance + box: a tensor of shape (4,) + """ + if len(results) == 0: + return torch.empty(0, mask_size, mask_size, dtype=torch.bool, device=device) + return torch.stack(results, dim=0).to(device=device) + + def area(self): + """ + Computes area of the mask. + Only works with Polygons, using the shoelace formula: + https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates + + Returns: + Tensor: a vector, area for each instance + """ + + area = [] + for polygons_per_instance in self.polygons: + area_per_instance = 0 + for p in polygons_per_instance: + area_per_instance += polygon_area(p[0::2], p[1::2]) + area.append(area_per_instance) + + return torch.tensor(area) + + @staticmethod + def cat(polymasks_list: List["PolygonMasks"]) -> "PolygonMasks": + """ + Concatenates a list of PolygonMasks into a single PolygonMasks + + Arguments: + polymasks_list (list[PolygonMasks]) + + Returns: + PolygonMasks: the concatenated PolygonMasks + """ + assert isinstance(polymasks_list, (list, tuple)) + assert len(polymasks_list) > 0 + assert all(isinstance(polymask, PolygonMasks) for polymask in polymasks_list) + + cat_polymasks = type(polymasks_list[0])( + list(itertools.chain.from_iterable(pm.polygons for pm in polymasks_list)) + ) + return cat_polymasks + + +class ROIMasks: + """ + Represent masks by N smaller masks defined in some ROIs. Once ROI boxes are given, + full-image bitmask can be obtained by "pasting" the mask on the region defined + by the corresponding ROI box. + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor: (N, M, M) mask tensor that defines the mask within each ROI. + """ + if tensor.dim() != 3: + raise ValueError("ROIMasks must take a masks of 3 dimension.") + self.tensor = tensor + + def to(self, device: torch.device) -> "ROIMasks": + return ROIMasks(self.tensor.to(device)) + + @property + def device(self) -> device: + return self.tensor.device + + def __len__(self): + return self.tensor.shape[0] + + def __getitem__(self, item) -> "ROIMasks": + """ + Returns: + ROIMasks: Create a new :class:`ROIMasks` by indexing. + + The following usage are allowed: + + 1. `new_masks = masks[2:10]`: return a slice of masks. + 2. `new_masks = masks[vector]`, where vector is a torch.BoolTensor + with `length = len(masks)`. Nonzero elements in the vector will be selected. + + Note that the returned object might share storage with this object, + subject to Pytorch's indexing semantics. + """ + t = self.tensor[item] + if t.dim() != 3: + raise ValueError( + f"Indexing on ROIMasks with {item} returns a tensor with shape {t.shape}!" + ) + return ROIMasks(t) + + @torch.jit.unused + def __repr__(self) -> str: + s = self.__class__.__name__ + "(" + s += "num_instances={})".format(len(self.tensor)) + return s + + @torch.jit.unused + def to_bitmasks(self, boxes: torch.Tensor, height, width, threshold=0.5): + """ + Args: see documentation of :func:`paste_masks_in_image`. + """ + from custom_detectron2.layers.mask_ops import paste_masks_in_image, _paste_masks_tensor_shape + + if torch.jit.is_tracing(): + if isinstance(height, torch.Tensor): + paste_func = _paste_masks_tensor_shape + else: + paste_func = paste_masks_in_image + else: + paste_func = retry_if_cuda_oom(paste_masks_in_image) + bitmasks = paste_func(self.tensor, boxes.tensor, (height, width), threshold=threshold) + return BitMasks(bitmasks) diff --git a/custom_detectron2/structures/rotated_boxes.py b/custom_detectron2/structures/rotated_boxes.py new file mode 100644 index 0000000000000000000000000000000000000000..7fc0ae02935b69fbb7a603bafccb0d4e728eda31 --- /dev/null +++ b/custom_detectron2/structures/rotated_boxes.py @@ -0,0 +1,505 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import math +from typing import List, Tuple +import torch + +from custom_detectron2.layers.rotated_boxes import pairwise_iou_rotated + +from .boxes import Boxes + + +class RotatedBoxes(Boxes): + """ + This structure stores a list of rotated boxes as a Nx5 torch.Tensor. + It supports some common methods about boxes + (`area`, `clip`, `nonempty`, etc), + and also behaves like a Tensor + (support indexing, `to(device)`, `.device`, and iteration over all boxes) + """ + + def __init__(self, tensor: torch.Tensor): + """ + Args: + tensor (Tensor[float]): a Nx5 matrix. Each row is + (x_center, y_center, width, height, angle), + in which angle is represented in degrees. + While there's no strict range restriction for it, + the recommended principal range is between [-180, 180) degrees. + + Assume we have a horizontal box B = (x_center, y_center, width, height), + where width is along the x-axis and height is along the y-axis. + The rotated box B_rot (x_center, y_center, width, height, angle) + can be seen as: + + 1. When angle == 0: + B_rot == B + 2. When angle > 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CCW; + 3. When angle < 0: + B_rot is obtained by rotating B w.r.t its center by :math:`|angle|` degrees CW. + + Mathematically, since the right-handed coordinate system for image space + is (y, x), where y is top->down and x is left->right, the 4 vertices of the + rotated rectangle :math:`(yr_i, xr_i)` (i = 1, 2, 3, 4) can be obtained from + the vertices of the horizontal rectangle :math:`(y_i, x_i)` (i = 1, 2, 3, 4) + in the following way (:math:`\\theta = angle*\\pi/180` is the angle in radians, + :math:`(y_c, x_c)` is the center of the rectangle): + + .. math:: + + yr_i = \\cos(\\theta) (y_i - y_c) - \\sin(\\theta) (x_i - x_c) + y_c, + + xr_i = \\sin(\\theta) (y_i - y_c) + \\cos(\\theta) (x_i - x_c) + x_c, + + which is the standard rigid-body rotation transformation. + + Intuitively, the angle is + (1) the rotation angle from y-axis in image space + to the height vector (top->down in the box's local coordinate system) + of the box in CCW, and + (2) the rotation angle from x-axis in image space + to the width vector (left->right in the box's local coordinate system) + of the box in CCW. + + More intuitively, consider the following horizontal box ABCD represented + in (x1, y1, x2, y2): (3, 2, 7, 4), + covering the [3, 7] x [2, 4] region of the continuous coordinate system + which looks like this: + + .. code:: none + + O--------> x + | + | A---B + | | | + | D---C + | + v y + + Note that each capital letter represents one 0-dimensional geometric point + instead of a 'square pixel' here. + + In the example above, using (x, y) to represent a point we have: + + .. math:: + + O = (0, 0), A = (3, 2), B = (7, 2), C = (7, 4), D = (3, 4) + + We name vector AB = vector DC as the width vector in box's local coordinate system, and + vector AD = vector BC as the height vector in box's local coordinate system. Initially, + when angle = 0 degree, they're aligned with the positive directions of x-axis and y-axis + in the image space, respectively. + + For better illustration, we denote the center of the box as E, + + .. code:: none + + O--------> x + | + | A---B + | | E | + | D---C + | + v y + + where the center E = ((3+7)/2, (2+4)/2) = (5, 3). + + Also, + + .. math:: + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Therefore, the corresponding representation for the same shape in rotated box in + (x_center, y_center, width, height, angle) format is: + + (5, 3, 4, 2, 0), + + Now, let's consider (5, 3, 4, 2, 90), which is rotated by 90 degrees + CCW (counter-clockwise) by definition. It looks like this: + + .. code:: none + + O--------> x + | B-C + | | | + | |E| + | | | + | A-D + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CCW with regard to E: + A = (4, 5), B = (4, 1), C = (6, 1), D = (6, 5) + + Here, 90 degrees can be seen as the CCW angle to rotate from y-axis to + vector AD or vector BC (the top->down height vector in box's local coordinate system), + or the CCW angle to rotate from x-axis to vector AB or vector DC (the left->right + width vector in box's local coordinate system). + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + Next, how about (5, 3, 4, 2, -90), which is rotated by 90 degrees CW (clockwise) + by definition? It looks like this: + + .. code:: none + + O--------> x + | D-A + | | | + | |E| + | | | + | C-B + v y + + The center E is still located at the same point (5, 3), while the vertices + ABCD are rotated by 90 degrees CW with regard to E: + A = (6, 1), B = (6, 5), C = (4, 5), D = (4, 1) + + .. math:: + + width = |AB| = |CD| = 5 - 1 = 4, + height = |AD| = |BC| = 6 - 4 = 2. + + This covers exactly the same region as (5, 3, 4, 2, 90) does, and their IoU + will be 1. However, these two will generate different RoI Pooling results and + should not be treated as an identical box. + + On the other hand, it's easy to see that (X, Y, W, H, A) is identical to + (X, Y, W, H, A+360N), for any integer N. For example (5, 3, 4, 2, 270) would be + identical to (5, 3, 4, 2, -90), because rotating the shape 270 degrees CCW is + equivalent to rotating the same shape 90 degrees CW. + + We could rotate further to get (5, 3, 4, 2, 180), or (5, 3, 4, 2, -180): + + .. code:: none + + O--------> x + | + | C---D + | | E | + | B---A + | + v y + + .. math:: + + A = (7, 4), B = (3, 4), C = (3, 2), D = (7, 2), + + width = |AB| = |CD| = 7 - 3 = 4, + height = |AD| = |BC| = 4 - 2 = 2. + + Finally, this is a very inaccurate (heavily quantized) illustration of + how (5, 3, 4, 2, 60) looks like in case anyone wonders: + + .. code:: none + + O--------> x + | B\ + | / C + | /E / + | A / + | `D + v y + + It's still a rectangle with center of (5, 3), width of 4 and height of 2, + but its angle (and thus orientation) is somewhere between + (5, 3, 4, 2, 0) and (5, 3, 4, 2, 90). + """ + device = tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu") + tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device) + if tensor.numel() == 0: + # Use reshape, so we don't end up creating a new tensor that does not depend on + # the inputs (and consequently confuses jit) + tensor = tensor.reshape((0, 5)).to(dtype=torch.float32, device=device) + assert tensor.dim() == 2 and tensor.size(-1) == 5, tensor.size() + + self.tensor = tensor + + def clone(self) -> "RotatedBoxes": + """ + Clone the RotatedBoxes. + + Returns: + RotatedBoxes + """ + return RotatedBoxes(self.tensor.clone()) + + def to(self, device: torch.device): + # Boxes are assumed float32 and does not support to(dtype) + return RotatedBoxes(self.tensor.to(device=device)) + + def area(self) -> torch.Tensor: + """ + Computes the area of all the boxes. + + Returns: + torch.Tensor: a vector with areas of each box. + """ + box = self.tensor + area = box[:, 2] * box[:, 3] + return area + + # Avoid in-place operations so that we can torchscript; NOTE: this creates a new tensor + def normalize_angles(self) -> None: + """ + Restrict angles to the range of [-180, 180) degrees + """ + angle_tensor = (self.tensor[:, 4] + 180.0) % 360.0 - 180.0 + self.tensor = torch.cat((self.tensor[:, :4], angle_tensor[:, None]), dim=1) + + def clip(self, box_size: Tuple[int, int], clip_angle_threshold: float = 1.0) -> None: + """ + Clip (in place) the boxes by limiting x coordinates to the range [0, width] + and y coordinates to the range [0, height]. + + For RRPN: + Only clip boxes that are almost horizontal with a tolerance of + clip_angle_threshold to maintain backward compatibility. + + Rotated boxes beyond this threshold are not clipped for two reasons: + + 1. There are potentially multiple ways to clip a rotated box to make it + fit within the image. + 2. It's tricky to make the entire rectangular box fit within the image + and still be able to not leave out pixels of interest. + + Therefore we rely on ops like RoIAlignRotated to safely handle this. + + Args: + box_size (height, width): The clipping box's size. + clip_angle_threshold: + Iff. abs(normalized(angle)) <= clip_angle_threshold (in degrees), + we do the clipping as horizontal boxes. + """ + h, w = box_size + + # normalize angles to be within (-180, 180] degrees + self.normalize_angles() + + idx = torch.where(torch.abs(self.tensor[:, 4]) <= clip_angle_threshold)[0] + + # convert to (x1, y1, x2, y2) + x1 = self.tensor[idx, 0] - self.tensor[idx, 2] / 2.0 + y1 = self.tensor[idx, 1] - self.tensor[idx, 3] / 2.0 + x2 = self.tensor[idx, 0] + self.tensor[idx, 2] / 2.0 + y2 = self.tensor[idx, 1] + self.tensor[idx, 3] / 2.0 + + # clip + x1.clamp_(min=0, max=w) + y1.clamp_(min=0, max=h) + x2.clamp_(min=0, max=w) + y2.clamp_(min=0, max=h) + + # convert back to (xc, yc, w, h) + self.tensor[idx, 0] = (x1 + x2) / 2.0 + self.tensor[idx, 1] = (y1 + y2) / 2.0 + # make sure widths and heights do not increase due to numerical errors + self.tensor[idx, 2] = torch.min(self.tensor[idx, 2], x2 - x1) + self.tensor[idx, 3] = torch.min(self.tensor[idx, 3], y2 - y1) + + def nonempty(self, threshold: float = 0.0) -> torch.Tensor: + """ + Find boxes that are non-empty. + A box is considered empty, if either of its side is no larger than threshold. + + Returns: + Tensor: a binary vector which represents + whether each box is empty (False) or non-empty (True). + """ + box = self.tensor + widths = box[:, 2] + heights = box[:, 3] + keep = (widths > threshold) & (heights > threshold) + return keep + + def __getitem__(self, item) -> "RotatedBoxes": + """ + Returns: + RotatedBoxes: Create a new :class:`RotatedBoxes` by indexing. + + The following usage are allowed: + + 1. `new_boxes = boxes[3]`: return a `RotatedBoxes` which contains only one box. + 2. `new_boxes = boxes[2:10]`: return a slice of boxes. + 3. `new_boxes = boxes[vector]`, where vector is a torch.ByteTensor + with `length = len(boxes)`. Nonzero elements in the vector will be selected. + + Note that the returned RotatedBoxes might share storage with this RotatedBoxes, + subject to Pytorch's indexing semantics. + """ + if isinstance(item, int): + return RotatedBoxes(self.tensor[item].view(1, -1)) + b = self.tensor[item] + assert b.dim() == 2, "Indexing on RotatedBoxes with {} failed to return a matrix!".format( + item + ) + return RotatedBoxes(b) + + def __len__(self) -> int: + return self.tensor.shape[0] + + def __repr__(self) -> str: + return "RotatedBoxes(" + str(self.tensor) + ")" + + def inside_box(self, box_size: Tuple[int, int], boundary_threshold: int = 0) -> torch.Tensor: + """ + Args: + box_size (height, width): Size of the reference box covering + [0, width] x [0, height] + boundary_threshold (int): Boxes that extend beyond the reference box + boundary by more than boundary_threshold are considered "outside". + + For RRPN, it might not be necessary to call this function since it's common + for rotated box to extend to outside of the image boundaries + (the clip function only clips the near-horizontal boxes) + + Returns: + a binary vector, indicating whether each box is inside the reference box. + """ + height, width = box_size + + cnt_x = self.tensor[..., 0] + cnt_y = self.tensor[..., 1] + half_w = self.tensor[..., 2] / 2.0 + half_h = self.tensor[..., 3] / 2.0 + a = self.tensor[..., 4] + c = torch.abs(torch.cos(a * math.pi / 180.0)) + s = torch.abs(torch.sin(a * math.pi / 180.0)) + # This basically computes the horizontal bounding rectangle of the rotated box + max_rect_dx = c * half_w + s * half_h + max_rect_dy = c * half_h + s * half_w + + inds_inside = ( + (cnt_x - max_rect_dx >= -boundary_threshold) + & (cnt_y - max_rect_dy >= -boundary_threshold) + & (cnt_x + max_rect_dx < width + boundary_threshold) + & (cnt_y + max_rect_dy < height + boundary_threshold) + ) + + return inds_inside + + def get_centers(self) -> torch.Tensor: + """ + Returns: + The box centers in a Nx2 array of (x, y). + """ + return self.tensor[:, :2] + + def scale(self, scale_x: float, scale_y: float) -> None: + """ + Scale the rotated box with horizontal and vertical scaling factors + Note: when scale_factor_x != scale_factor_y, + the rotated box does not preserve the rectangular shape when the angle + is not a multiple of 90 degrees under resize transformation. + Instead, the shape is a parallelogram (that has skew) + Here we make an approximation by fitting a rotated rectangle to the parallelogram. + """ + self.tensor[:, 0] *= scale_x + self.tensor[:, 1] *= scale_y + theta = self.tensor[:, 4] * math.pi / 180.0 + c = torch.cos(theta) + s = torch.sin(theta) + + # In image space, y is top->down and x is left->right + # Consider the local coordintate system for the rotated box, + # where the box center is located at (0, 0), and the four vertices ABCD are + # A(-w / 2, -h / 2), B(w / 2, -h / 2), C(w / 2, h / 2), D(-w / 2, h / 2) + # the midpoint of the left edge AD of the rotated box E is: + # E = (A+D)/2 = (-w / 2, 0) + # the midpoint of the top edge AB of the rotated box F is: + # F(0, -h / 2) + # To get the old coordinates in the global system, apply the rotation transformation + # (Note: the right-handed coordinate system for image space is yOx): + # (old_x, old_y) = (s * y + c * x, c * y - s * x) + # E(old) = (s * 0 + c * (-w/2), c * 0 - s * (-w/2)) = (-c * w / 2, s * w / 2) + # F(old) = (s * (-h / 2) + c * 0, c * (-h / 2) - s * 0) = (-s * h / 2, -c * h / 2) + # After applying the scaling factor (sfx, sfy): + # E(new) = (-sfx * c * w / 2, sfy * s * w / 2) + # F(new) = (-sfx * s * h / 2, -sfy * c * h / 2) + # The new width after scaling tranformation becomes: + + # w(new) = |E(new) - O| * 2 + # = sqrt[(sfx * c * w / 2)^2 + (sfy * s * w / 2)^2] * 2 + # = sqrt[(sfx * c)^2 + (sfy * s)^2] * w + # i.e., scale_factor_w = sqrt[(sfx * c)^2 + (sfy * s)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_w == scale_factor_x; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_w == scale_factor_y + self.tensor[:, 2] *= torch.sqrt((scale_x * c) ** 2 + (scale_y * s) ** 2) + + # h(new) = |F(new) - O| * 2 + # = sqrt[(sfx * s * h / 2)^2 + (sfy * c * h / 2)^2] * 2 + # = sqrt[(sfx * s)^2 + (sfy * c)^2] * h + # i.e., scale_factor_h = sqrt[(sfx * s)^2 + (sfy * c)^2] + # + # For example, + # when angle = 0 or 180, |c| = 1, s = 0, scale_factor_h == scale_factor_y; + # when |angle| = 90, c = 0, |s| = 1, scale_factor_h == scale_factor_x + self.tensor[:, 3] *= torch.sqrt((scale_x * s) ** 2 + (scale_y * c) ** 2) + + # The angle is the rotation angle from y-axis in image space to the height + # vector (top->down in the box's local coordinate system) of the box in CCW. + # + # angle(new) = angle_yOx(O - F(new)) + # = angle_yOx( (sfx * s * h / 2, sfy * c * h / 2) ) + # = atan2(sfx * s * h / 2, sfy * c * h / 2) + # = atan2(sfx * s, sfy * c) + # + # For example, + # when sfx == sfy, angle(new) == atan2(s, c) == angle(old) + self.tensor[:, 4] = torch.atan2(scale_x * s, scale_y * c) * 180 / math.pi + + @classmethod + def cat(cls, boxes_list: List["RotatedBoxes"]) -> "RotatedBoxes": + """ + Concatenates a list of RotatedBoxes into a single RotatedBoxes + + Arguments: + boxes_list (list[RotatedBoxes]) + + Returns: + RotatedBoxes: the concatenated RotatedBoxes + """ + assert isinstance(boxes_list, (list, tuple)) + if len(boxes_list) == 0: + return cls(torch.empty(0)) + assert all([isinstance(box, RotatedBoxes) for box in boxes_list]) + + # use torch.cat (v.s. layers.cat) so the returned boxes never share storage with input + cat_boxes = cls(torch.cat([b.tensor for b in boxes_list], dim=0)) + return cat_boxes + + @property + def device(self) -> torch.device: + return self.tensor.device + + @torch.jit.unused + def __iter__(self): + """ + Yield a box as a Tensor of shape (5,) at a time. + """ + yield from self.tensor + + +def pairwise_iou(boxes1: RotatedBoxes, boxes2: RotatedBoxes) -> None: + """ + Given two lists of rotated boxes of size N and M, + compute the IoU (intersection over union) + between **all** N x M pairs of boxes. + The box order must be (x_center, y_center, width, height, angle). + + Args: + boxes1, boxes2 (RotatedBoxes): + two `RotatedBoxes`. Contains N & M rotated boxes, respectively. + + Returns: + Tensor: IoU, sized [N,M]. + """ + + return pairwise_iou_rotated(boxes1.tensor, boxes2.tensor) diff --git a/custom_detectron2/tracking/__init__.py b/custom_detectron2/tracking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..21078ae822b04b71dbd8b056b5993d173eaf6bff --- /dev/null +++ b/custom_detectron2/tracking/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from .base_tracker import ( # noqa + BaseTracker, + build_tracker_head, + TRACKER_HEADS_REGISTRY, +) +from .bbox_iou_tracker import BBoxIOUTracker # noqa +from .hungarian_tracker import BaseHungarianTracker # noqa +from .iou_weighted_hungarian_bbox_iou_tracker import ( # noqa + IOUWeightedHungarianBBoxIOUTracker, +) +from .utils import create_prediction_pairs # noqa +from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker # noqa + +__all__ = [k for k in globals().keys() if not k.startswith("_")] diff --git a/custom_detectron2/tracking/base_tracker.py b/custom_detectron2/tracking/base_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..66a85b9b9ee2dbc47bec17ccd66df412cc4371ec --- /dev/null +++ b/custom_detectron2/tracking/base_tracker.py @@ -0,0 +1,64 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +from custom_detectron2.config import configurable +from custom_detectron2.utils.registry import Registry + +from ..config.config import CfgNode as CfgNode_ +from ..structures import Instances + +TRACKER_HEADS_REGISTRY = Registry("TRACKER_HEADS") +TRACKER_HEADS_REGISTRY.__doc__ = """ +Registry for tracking classes. +""" + + +class BaseTracker(object): + """ + A parent class for all trackers + """ + + @configurable + def __init__(self, **kwargs): + self._prev_instances = None # (D2)instances for previous frame + self._matched_idx = set() # indices in prev_instances found matching + self._matched_ID = set() # idendities in prev_instances found matching + self._untracked_prev_idx = set() # indices in prev_instances not found matching + self._id_count = 0 # used to assign new id + + @classmethod + def from_config(cls, cfg: CfgNode_): + raise NotImplementedError("Calling BaseTracker::from_config") + + def update(self, predictions: Instances) -> Instances: + """ + Args: + predictions: D2 Instances for predictions of the current frame + Return: + D2 Instances for predictions of the current frame with ID assigned + + _prev_instances and instances will have the following fields: + .pred_boxes (shape=[N, 4]) + .scores (shape=[N,]) + .pred_classes (shape=[N,]) + .pred_keypoints (shape=[N, M, 3], Optional) + .pred_masks (shape=List[2D_MASK], Optional) 2D_MASK: shape=[H, W] + .ID (shape=[N,]) + + N: # of detected bboxes + H and W: height and width of 2D mask + """ + raise NotImplementedError("Calling BaseTracker::update") + + +def build_tracker_head(cfg: CfgNode_) -> BaseTracker: + """ + Build a tracker head from `cfg.TRACKER_HEADS.TRACKER_NAME`. + + Args: + cfg: D2 CfgNode, config file with tracker information + Return: + tracker object + """ + name = cfg.TRACKER_HEADS.TRACKER_NAME + tracker_class = TRACKER_HEADS_REGISTRY.get(name) + return tracker_class(cfg) diff --git a/custom_detectron2/tracking/bbox_iou_tracker.py b/custom_detectron2/tracking/bbox_iou_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..f306db6d45d55e0741066d8d1f750ca1a3687eca --- /dev/null +++ b/custom_detectron2/tracking/bbox_iou_tracker.py @@ -0,0 +1,276 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +import copy +import numpy as np +from typing import List +import torch + +from custom_detectron2.config import configurable +from custom_detectron2.structures import Boxes, Instances +from custom_detectron2.structures.boxes import pairwise_iou + +from ..config.config import CfgNode as CfgNode_ +from .base_tracker import TRACKER_HEADS_REGISTRY, BaseTracker + + +@TRACKER_HEADS_REGISTRY.register() +class BBoxIOUTracker(BaseTracker): + """ + A bounding box tracker to assign ID based on IoU between current and previous instances + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__(**kwargs) + self._video_height = video_height + self._video_width = video_width + self._max_num_instances = max_num_instances + self._max_lost_frame_count = max_lost_frame_count + self._min_box_rel_dim = min_box_rel_dim + self._min_instance_period = min_instance_period + self._track_iou_threshold = track_iou_threshold + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.bbox_iou_tracker.BBoxIOUTracker", + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def update(self, instances: Instances) -> Instances: + """ + See BaseTracker description + """ + instances = self._initialize_extra_fields(instances) + if self._prev_instances is not None: + # calculate IoU of all bbox pairs + iou_all = pairwise_iou( + boxes1=instances.pred_boxes, + boxes2=self._prev_instances.pred_boxes, + ) + # sort IoU in descending order + bbox_pairs = self._create_prediction_pairs(instances, iou_all) + # assign previous ID to current bbox if IoU > track_iou_threshold + self._reset_fields() + for bbox_pair in bbox_pairs: + idx = bbox_pair["idx"] + prev_id = bbox_pair["prev_id"] + if ( + idx in self._matched_idx + or prev_id in self._matched_ID + or bbox_pair["IoU"] < self._track_iou_threshold + ): + continue + instances.ID[idx] = prev_id + instances.ID_period[idx] = bbox_pair["prev_period"] + 1 + instances.lost_frame_count[idx] = 0 + self._matched_idx.add(idx) + self._matched_ID.add(prev_id) + self._untracked_prev_idx.remove(bbox_pair["prev_idx"]) + instances = self._assign_new_id(instances) + instances = self._merge_untracked_instances(instances) + self._prev_instances = copy.deepcopy(instances) + return instances + + def _create_prediction_pairs(self, instances: Instances, iou_all: np.ndarray) -> List: + """ + For all instances in previous and current frames, create pairs. For each + pair, store index of the instance in current frame predcitions, index in + previous predictions, ID in previous predictions, IoU of the bboxes in this + pair, period in previous predictions. + + Args: + instances: D2 Instances, for predictions of the current frame + iou_all: IoU for all bboxes pairs + Return: + A list of IoU for all pairs + """ + bbox_pairs = [] + for i in range(len(instances)): + for j in range(len(self._prev_instances)): + bbox_pairs.append( + { + "idx": i, + "prev_idx": j, + "prev_id": self._prev_instances.ID[j], + "IoU": iou_all[i, j], + "prev_period": self._prev_instances.ID_period[j], + } + ) + return bbox_pairs + + def _initialize_extra_fields(self, instances: Instances) -> Instances: + """ + If input instances don't have ID, ID_period, lost_frame_count fields, + this method is used to initialize these fields. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with extra fields added + """ + if not instances.has("ID"): + instances.set("ID", [None] * len(instances)) + if not instances.has("ID_period"): + instances.set("ID_period", [None] * len(instances)) + if not instances.has("lost_frame_count"): + instances.set("lost_frame_count", [None] * len(instances)) + if self._prev_instances is None: + instances.ID = list(range(len(instances))) + self._id_count += len(instances) + instances.ID_period = [1] * len(instances) + instances.lost_frame_count = [0] * len(instances) + return instances + + def _reset_fields(self): + """ + Before each uodate call, reset fields first + """ + self._matched_idx = set() + self._matched_ID = set() + self._untracked_prev_idx = set(range(len(self._prev_instances))) + + def _assign_new_id(self, instances: Instances) -> Instances: + """ + For each untracked instance, assign a new id + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with new ID assigned + """ + untracked_idx = set(range(len(instances))).difference(self._matched_idx) + for idx in untracked_idx: + instances.ID[idx] = self._id_count + self._id_count += 1 + instances.ID_period[idx] = 1 + instances.lost_frame_count[idx] = 0 + return instances + + def _merge_untracked_instances(self, instances: Instances) -> Instances: + """ + For untracked previous instances, under certain condition, still keep them + in tracking and merge with the current instances. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances merging current instances and instances from previous + frame decided to keep tracking + """ + untracked_instances = Instances( + image_size=instances.image_size, + pred_boxes=[], + pred_classes=[], + scores=[], + ID=[], + ID_period=[], + lost_frame_count=[], + ) + prev_bboxes = list(self._prev_instances.pred_boxes) + prev_classes = list(self._prev_instances.pred_classes) + prev_scores = list(self._prev_instances.scores) + prev_ID_period = self._prev_instances.ID_period + if instances.has("pred_masks"): + untracked_instances.set("pred_masks", []) + prev_masks = list(self._prev_instances.pred_masks) + if instances.has("pred_keypoints"): + untracked_instances.set("pred_keypoints", []) + prev_keypoints = list(self._prev_instances.pred_keypoints) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.set("pred_keypoint_heatmaps", []) + prev_keypoint_heatmaps = list(self._prev_instances.pred_keypoint_heatmaps) + for idx in self._untracked_prev_idx: + x_left, y_top, x_right, y_bot = prev_bboxes[idx] + if ( + (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) + or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) + or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count + or prev_ID_period[idx] <= self._min_instance_period + ): + continue + untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) + untracked_instances.pred_classes.append(int(prev_classes[idx])) + untracked_instances.scores.append(float(prev_scores[idx])) + untracked_instances.ID.append(self._prev_instances.ID[idx]) + untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) + untracked_instances.lost_frame_count.append( + self._prev_instances.lost_frame_count[idx] + 1 + ) + if instances.has("pred_masks"): + untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) + if instances.has("pred_keypoints"): + untracked_instances.pred_keypoints.append( + prev_keypoints[idx].numpy().astype(np.uint8) + ) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.pred_keypoint_heatmaps.append( + prev_keypoint_heatmaps[idx].numpy().astype(np.float32) + ) + untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) + untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) + untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) + if instances.has("pred_masks"): + untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) + if instances.has("pred_keypoints"): + untracked_instances.pred_keypoints = torch.IntTensor(untracked_instances.pred_keypoints) + if instances.has("pred_keypoint_heatmaps"): + untracked_instances.pred_keypoint_heatmaps = torch.FloatTensor( + untracked_instances.pred_keypoint_heatmaps + ) + + return Instances.cat( + [ + instances, + untracked_instances, + ] + ) diff --git a/custom_detectron2/tracking/hungarian_tracker.py b/custom_detectron2/tracking/hungarian_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..00b3126057f665aeeb1742ce77bf8da1b0a56577 --- /dev/null +++ b/custom_detectron2/tracking/hungarian_tracker.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. +import copy +import numpy as np +from typing import Dict +import torch +from scipy.optimize import linear_sum_assignment + +from custom_detectron2.config import configurable +from custom_detectron2.structures import Boxes, Instances + +from ..config.config import CfgNode as CfgNode_ +from .base_tracker import BaseTracker + + +class BaseHungarianTracker(BaseTracker): + """ + A base class for all Hungarian trackers + """ + + @configurable + def __init__( + self, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + **kwargs + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + """ + super().__init__(**kwargs) + self._video_height = video_height + self._video_width = video_width + self._max_num_instances = max_num_instances + self._max_lost_frame_count = max_lost_frame_count + self._min_box_rel_dim = min_box_rel_dim + self._min_instance_period = min_instance_period + + @classmethod + def from_config(cls, cfg: CfgNode_) -> Dict: + raise NotImplementedError("Calling HungarianTracker::from_config") + + def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: + raise NotImplementedError("Calling HungarianTracker::build_matrix") + + def update(self, instances: Instances) -> Instances: + if instances.has("pred_keypoints"): + raise NotImplementedError("Need to add support for keypoints") + instances = self._initialize_extra_fields(instances) + if self._prev_instances is not None: + self._untracked_prev_idx = set(range(len(self._prev_instances))) + cost_matrix = self.build_cost_matrix(instances, self._prev_instances) + matched_idx, matched_prev_idx = linear_sum_assignment(cost_matrix) + instances = self._process_matched_idx(instances, matched_idx, matched_prev_idx) + instances = self._process_unmatched_idx(instances, matched_idx) + instances = self._process_unmatched_prev_idx(instances, matched_prev_idx) + self._prev_instances = copy.deepcopy(instances) + return instances + + def _initialize_extra_fields(self, instances: Instances) -> Instances: + """ + If input instances don't have ID, ID_period, lost_frame_count fields, + this method is used to initialize these fields. + + Args: + instances: D2 Instances, for predictions of the current frame + Return: + D2 Instances with extra fields added + """ + if not instances.has("ID"): + instances.set("ID", [None] * len(instances)) + if not instances.has("ID_period"): + instances.set("ID_period", [None] * len(instances)) + if not instances.has("lost_frame_count"): + instances.set("lost_frame_count", [None] * len(instances)) + if self._prev_instances is None: + instances.ID = list(range(len(instances))) + self._id_count += len(instances) + instances.ID_period = [1] * len(instances) + instances.lost_frame_count = [0] * len(instances) + return instances + + def _process_matched_idx( + self, instances: Instances, matched_idx: np.ndarray, matched_prev_idx: np.ndarray + ) -> Instances: + assert matched_idx.size == matched_prev_idx.size + for i in range(matched_idx.size): + instances.ID[matched_idx[i]] = self._prev_instances.ID[matched_prev_idx[i]] + instances.ID_period[matched_idx[i]] = ( + self._prev_instances.ID_period[matched_prev_idx[i]] + 1 + ) + instances.lost_frame_count[matched_idx[i]] = 0 + return instances + + def _process_unmatched_idx(self, instances: Instances, matched_idx: np.ndarray) -> Instances: + untracked_idx = set(range(len(instances))).difference(set(matched_idx)) + for idx in untracked_idx: + instances.ID[idx] = self._id_count + self._id_count += 1 + instances.ID_period[idx] = 1 + instances.lost_frame_count[idx] = 0 + return instances + + def _process_unmatched_prev_idx( + self, instances: Instances, matched_prev_idx: np.ndarray + ) -> Instances: + untracked_instances = Instances( + image_size=instances.image_size, + pred_boxes=[], + pred_masks=[], + pred_classes=[], + scores=[], + ID=[], + ID_period=[], + lost_frame_count=[], + ) + prev_bboxes = list(self._prev_instances.pred_boxes) + prev_classes = list(self._prev_instances.pred_classes) + prev_scores = list(self._prev_instances.scores) + prev_ID_period = self._prev_instances.ID_period + if instances.has("pred_masks"): + prev_masks = list(self._prev_instances.pred_masks) + untracked_prev_idx = set(range(len(self._prev_instances))).difference(set(matched_prev_idx)) + for idx in untracked_prev_idx: + x_left, y_top, x_right, y_bot = prev_bboxes[idx] + if ( + (1.0 * (x_right - x_left) / self._video_width < self._min_box_rel_dim) + or (1.0 * (y_bot - y_top) / self._video_height < self._min_box_rel_dim) + or self._prev_instances.lost_frame_count[idx] >= self._max_lost_frame_count + or prev_ID_period[idx] <= self._min_instance_period + ): + continue + untracked_instances.pred_boxes.append(list(prev_bboxes[idx].numpy())) + untracked_instances.pred_classes.append(int(prev_classes[idx])) + untracked_instances.scores.append(float(prev_scores[idx])) + untracked_instances.ID.append(self._prev_instances.ID[idx]) + untracked_instances.ID_period.append(self._prev_instances.ID_period[idx]) + untracked_instances.lost_frame_count.append( + self._prev_instances.lost_frame_count[idx] + 1 + ) + if instances.has("pred_masks"): + untracked_instances.pred_masks.append(prev_masks[idx].numpy().astype(np.uint8)) + + untracked_instances.pred_boxes = Boxes(torch.FloatTensor(untracked_instances.pred_boxes)) + untracked_instances.pred_classes = torch.IntTensor(untracked_instances.pred_classes) + untracked_instances.scores = torch.FloatTensor(untracked_instances.scores) + if instances.has("pred_masks"): + untracked_instances.pred_masks = torch.IntTensor(untracked_instances.pred_masks) + else: + untracked_instances.remove("pred_masks") + + return Instances.cat( + [ + instances, + untracked_instances, + ] + ) diff --git a/custom_detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py b/custom_detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..07fa28c1b6af274827a45fee1fb7beabc1f130e5 --- /dev/null +++ b/custom_detectron2/tracking/iou_weighted_hungarian_bbox_iou_tracker.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import numpy as np +from typing import List + +from custom_detectron2.config import CfgNode as CfgNode_ +from custom_detectron2.config import configurable + +from .base_tracker import TRACKER_HEADS_REGISTRY +from .vanilla_hungarian_bbox_iou_tracker import VanillaHungarianBBoxIOUTracker + + +@TRACKER_HEADS_REGISTRY.register() +class IOUWeightedHungarianBBoxIOUTracker(VanillaHungarianBBoxIOUTracker): + """ + A tracker using IoU as weight in Hungarian algorithm, also known + as Munkres or Kuhn-Munkres algorithm + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__( + video_height=video_height, + video_width=video_width, + max_num_instances=max_num_instances, + max_lost_frame_count=max_lost_frame_count, + min_box_rel_dim=min_box_rel_dim, + min_instance_period=min_instance_period, + track_iou_threshold=track_iou_threshold, + ) + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.iou_weighted_hungarian_bbox_iou_tracker.IOUWeightedHungarianBBoxIOUTracker", # noqa + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray: + """ + Based on IoU for each pair of bbox, assign the associated value in cost matrix + + Args: + cost_matrix: np.ndarray, initialized 2D array with target dimensions + bbox_pairs: list of bbox pair, in each pair, iou value is stored + Return: + np.ndarray, cost_matrix with assigned values + """ + for pair in bbox_pairs: + # assign (-1 * IoU) for above threshold pairs, algorithms will minimize cost + cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 * pair["IoU"] + return cost_matrix diff --git a/custom_detectron2/tracking/utils.py b/custom_detectron2/tracking/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dd81e6a064312d60af6b3c506b61cd3aaec3aa81 --- /dev/null +++ b/custom_detectron2/tracking/utils.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +import numpy as np +from typing import List + +from custom_detectron2.structures import Instances + + +def create_prediction_pairs( + instances: Instances, + prev_instances: Instances, + iou_all: np.ndarray, + threshold: float = 0.5, +) -> List: + """ + Args: + instances: predictions from current frame + prev_instances: predictions from previous frame + iou_all: 2D numpy array containing iou for each bbox pair + threshold: below the threshold, doesn't consider the pair of bbox is valid + Return: + List of bbox pairs + """ + bbox_pairs = [] + for i in range(len(instances)): + for j in range(len(prev_instances)): + if iou_all[i, j] < threshold: + continue + bbox_pairs.append( + { + "idx": i, + "prev_idx": j, + "prev_id": prev_instances.ID[j], + "IoU": iou_all[i, j], + "prev_period": prev_instances.ID_period[j], + } + ) + return bbox_pairs + + +LARGE_COST_VALUE = 100000 diff --git a/custom_detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py b/custom_detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..42dbbebbbf1b0f9368686183765b6f9db61fea92 --- /dev/null +++ b/custom_detectron2/tracking/vanilla_hungarian_bbox_iou_tracker.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +# Copyright 2004-present Facebook. All Rights Reserved. + +import numpy as np +from typing import List + +from custom_detectron2.config import CfgNode as CfgNode_ +from custom_detectron2.config import configurable +from custom_detectron2.structures import Instances +from custom_detectron2.structures.boxes import pairwise_iou +from custom_detectron2.tracking.utils import LARGE_COST_VALUE, create_prediction_pairs + +from .base_tracker import TRACKER_HEADS_REGISTRY +from .hungarian_tracker import BaseHungarianTracker + + +@TRACKER_HEADS_REGISTRY.register() +class VanillaHungarianBBoxIOUTracker(BaseHungarianTracker): + """ + Hungarian algo based tracker using bbox iou as metric + """ + + @configurable + def __init__( + self, + *, + video_height: int, + video_width: int, + max_num_instances: int = 200, + max_lost_frame_count: int = 0, + min_box_rel_dim: float = 0.02, + min_instance_period: int = 1, + track_iou_threshold: float = 0.5, + **kwargs, + ): + """ + Args: + video_height: height the video frame + video_width: width of the video frame + max_num_instances: maximum number of id allowed to be tracked + max_lost_frame_count: maximum number of frame an id can lost tracking + exceed this number, an id is considered as lost + forever + min_box_rel_dim: a percentage, smaller than this dimension, a bbox is + removed from tracking + min_instance_period: an instance will be shown after this number of period + since its first showing up in the video + track_iou_threshold: iou threshold, below this number a bbox pair is removed + from tracking + """ + super().__init__( + video_height=video_height, + video_width=video_width, + max_num_instances=max_num_instances, + max_lost_frame_count=max_lost_frame_count, + min_box_rel_dim=min_box_rel_dim, + min_instance_period=min_instance_period, + ) + self._track_iou_threshold = track_iou_threshold + + @classmethod + def from_config(cls, cfg: CfgNode_): + """ + Old style initialization using CfgNode + + Args: + cfg: D2 CfgNode, config file + Return: + dictionary storing arguments for __init__ method + """ + assert "VIDEO_HEIGHT" in cfg.TRACKER_HEADS + assert "VIDEO_WIDTH" in cfg.TRACKER_HEADS + video_height = cfg.TRACKER_HEADS.get("VIDEO_HEIGHT") + video_width = cfg.TRACKER_HEADS.get("VIDEO_WIDTH") + max_num_instances = cfg.TRACKER_HEADS.get("MAX_NUM_INSTANCES", 200) + max_lost_frame_count = cfg.TRACKER_HEADS.get("MAX_LOST_FRAME_COUNT", 0) + min_box_rel_dim = cfg.TRACKER_HEADS.get("MIN_BOX_REL_DIM", 0.02) + min_instance_period = cfg.TRACKER_HEADS.get("MIN_INSTANCE_PERIOD", 1) + track_iou_threshold = cfg.TRACKER_HEADS.get("TRACK_IOU_THRESHOLD", 0.5) + return { + "_target_": "detectron2.tracking.vanilla_hungarian_bbox_iou_tracker.VanillaHungarianBBoxIOUTracker", # noqa + "video_height": video_height, + "video_width": video_width, + "max_num_instances": max_num_instances, + "max_lost_frame_count": max_lost_frame_count, + "min_box_rel_dim": min_box_rel_dim, + "min_instance_period": min_instance_period, + "track_iou_threshold": track_iou_threshold, + } + + def build_cost_matrix(self, instances: Instances, prev_instances: Instances) -> np.ndarray: + """ + Build the cost matrix for assignment problem + (https://en.wikipedia.org/wiki/Assignment_problem) + + Args: + instances: D2 Instances, for current frame predictions + prev_instances: D2 Instances, for previous frame predictions + + Return: + the cost matrix in numpy array + """ + assert instances is not None and prev_instances is not None + # calculate IoU of all bbox pairs + iou_all = pairwise_iou( + boxes1=instances.pred_boxes, + boxes2=self._prev_instances.pred_boxes, + ) + bbox_pairs = create_prediction_pairs( + instances, self._prev_instances, iou_all, threshold=self._track_iou_threshold + ) + # assign large cost value to make sure pair below IoU threshold won't be matched + cost_matrix = np.full((len(instances), len(prev_instances)), LARGE_COST_VALUE) + return self.assign_cost_matrix_values(cost_matrix, bbox_pairs) + + def assign_cost_matrix_values(self, cost_matrix: np.ndarray, bbox_pairs: List) -> np.ndarray: + """ + Based on IoU for each pair of bbox, assign the associated value in cost matrix + + Args: + cost_matrix: np.ndarray, initialized 2D array with target dimensions + bbox_pairs: list of bbox pair, in each pair, iou value is stored + Return: + np.ndarray, cost_matrix with assigned values + """ + for pair in bbox_pairs: + # assign -1 for IoU above threshold pairs, algorithms will minimize cost + cost_matrix[pair["idx"]][pair["prev_idx"]] = -1 + return cost_matrix diff --git a/custom_detectron2/utils/README.md b/custom_detectron2/utils/README.md new file mode 100644 index 0000000000000000000000000000000000000000..9765b24a730b77556104187ac3ef5439ab0859fd --- /dev/null +++ b/custom_detectron2/utils/README.md @@ -0,0 +1,5 @@ +# Utility functions + +This folder contain utility functions that are not used in the +core library, but are useful for building models or training +code using the config system. diff --git a/custom_detectron2/utils/__init__.py b/custom_detectron2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9020c2df23e2af280b7bb168b996ae9eaf312eb8 --- /dev/null +++ b/custom_detectron2/utils/__init__.py @@ -0,0 +1 @@ +# Copyright (c) Facebook, Inc. and its affiliates. diff --git a/custom_detectron2/utils/analysis.py b/custom_detectron2/utils/analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..35b7bdcca4c545213c4cc92f9f71dd75cf288936 --- /dev/null +++ b/custom_detectron2/utils/analysis.py @@ -0,0 +1,188 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# -*- coding: utf-8 -*- + +import typing +from typing import Any, List +import fvcore +from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table +from torch import nn + +from custom_detectron2.export import TracingAdapter + +__all__ = [ + "activation_count_operators", + "flop_count_operators", + "parameter_count_table", + "parameter_count", + "FlopCountAnalysis", +] + +FLOPS_MODE = "flops" +ACTIVATIONS_MODE = "activations" + + +# Some extra ops to ignore from counting, including elementwise and reduction ops +_IGNORED_OPS = { + "aten::add", + "aten::add_", + "aten::argmax", + "aten::argsort", + "aten::batch_norm", + "aten::constant_pad_nd", + "aten::div", + "aten::div_", + "aten::exp", + "aten::log2", + "aten::max_pool2d", + "aten::meshgrid", + "aten::mul", + "aten::mul_", + "aten::neg", + "aten::nonzero_numpy", + "aten::reciprocal", + "aten::repeat_interleave", + "aten::rsub", + "aten::sigmoid", + "aten::sigmoid_", + "aten::softmax", + "aten::sort", + "aten::sqrt", + "aten::sub", + "torchvision::nms", # TODO estimate flop for nms +} + + +class FlopCountAnalysis(fvcore.nn.FlopCountAnalysis): + """ + Same as :class:`fvcore.nn.FlopCountAnalysis`, but supports detectron2 models. + """ + + def __init__(self, model, inputs): + """ + Args: + model (nn.Module): + inputs (Any): inputs of the given model. Does not have to be tuple of tensors. + """ + wrapper = TracingAdapter(model, inputs, allow_non_tensor=True) + super().__init__(wrapper, wrapper.flattened_inputs) + self.set_op_handle(**{k: None for k in _IGNORED_OPS}) + + +def flop_count_operators(model: nn.Module, inputs: list) -> typing.DefaultDict[str, float]: + """ + Implement operator-level flops counting using jit. + This is a wrapper of :func:`fvcore.nn.flop_count` and adds supports for standard + detection models in detectron2. + Please use :class:`FlopCountAnalysis` for more advanced functionalities. + + Note: + The function runs the input through the model to compute flops. + The flops of a detection model is often input-dependent, for example, + the flops of box & mask head depends on the number of proposals & + the number of detected objects. + Therefore, the flops counting using a single input may not accurately + reflect the computation cost of a model. It's recommended to average + across a number of inputs. + + Args: + model: a detectron2 model that takes `list[dict]` as input. + inputs (list[dict]): inputs to model, in detectron2's standard format. + Only "image" key will be used. + supported_ops (dict[str, Handle]): see documentation of :func:`fvcore.nn.flop_count` + + Returns: + Counter: Gflop count per operator + """ + old_train = model.training + model.eval() + ret = FlopCountAnalysis(model, inputs).by_operator() + model.train(old_train) + return {k: v / 1e9 for k, v in ret.items()} + + +def activation_count_operators( + model: nn.Module, inputs: list, **kwargs +) -> typing.DefaultDict[str, float]: + """ + Implement operator-level activations counting using jit. + This is a wrapper of fvcore.nn.activation_count, that supports standard detection models + in detectron2. + + Note: + The function runs the input through the model to compute activations. + The activations of a detection model is often input-dependent, for example, + the activations of box & mask head depends on the number of proposals & + the number of detected objects. + + Args: + model: a detectron2 model that takes `list[dict]` as input. + inputs (list[dict]): inputs to model, in detectron2's standard format. + Only "image" key will be used. + + Returns: + Counter: activation count per operator + """ + return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs) + + +def _wrapper_count_operators( + model: nn.Module, inputs: list, mode: str, **kwargs +) -> typing.DefaultDict[str, float]: + # ignore some ops + supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS} + supported_ops.update(kwargs.pop("supported_ops", {})) + kwargs["supported_ops"] = supported_ops + + assert len(inputs) == 1, "Please use batch size=1" + tensor_input = inputs[0]["image"] + inputs = [{"image": tensor_input}] # remove other keys, in case there are any + + old_train = model.training + if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)): + model = model.module + wrapper = TracingAdapter(model, inputs) + wrapper.eval() + if mode == FLOPS_MODE: + ret = flop_count(wrapper, (tensor_input,), **kwargs) + elif mode == ACTIVATIONS_MODE: + ret = activation_count(wrapper, (tensor_input,), **kwargs) + else: + raise NotImplementedError("Count for mode {} is not supported yet.".format(mode)) + # compatible with change in fvcore + if isinstance(ret, tuple): + ret = ret[0] + model.train(old_train) + return ret + + +def find_unused_parameters(model: nn.Module, inputs: Any) -> List[str]: + """ + Given a model, find parameters that do not contribute + to the loss. + + Args: + model: a model in training mode that returns losses + inputs: argument or a tuple of arguments. Inputs of the model + + Returns: + list[str]: the name of unused parameters + """ + assert model.training + for _, prm in model.named_parameters(): + prm.grad = None + + if isinstance(inputs, tuple): + losses = model(*inputs) + else: + losses = model(inputs) + + if isinstance(losses, dict): + losses = sum(losses.values()) + losses.backward() + + unused: List[str] = [] + for name, prm in model.named_parameters(): + if prm.grad is None: + unused.append(name) + prm.grad = None + return unused diff --git a/custom_detectron2/utils/collect_env.py b/custom_detectron2/utils/collect_env.py new file mode 100644 index 0000000000000000000000000000000000000000..bbc55fd9ff10506fe4d49f8fa2346ecc8a4b4606 --- /dev/null +++ b/custom_detectron2/utils/collect_env.py @@ -0,0 +1,246 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import numpy as np +import os +import re +import subprocess +import sys +from collections import defaultdict +import PIL +import torch +import torchvision +from tabulate import tabulate + +__all__ = ["collect_env_info"] + + +def collect_torch_env(): + try: + import torch.__config__ + + return torch.__config__.show() + except ImportError: + # compatible with older versions of pytorch + from torch.utils.collect_env import get_pretty_env_info + + return get_pretty_env_info() + + +def get_env_module(): + var_name = "DETECTRON2_ENV_MODULE" + return var_name, os.environ.get(var_name, "") + + +def detect_compute_compatibility(CUDA_HOME, so_file): + try: + cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump") + if os.path.isfile(cuobjdump): + output = subprocess.check_output( + "'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True + ) + output = output.decode("utf-8").strip().split("\n") + arch = [] + for line in output: + line = re.findall(r"\.sm_([0-9]*)\.", line)[0] + arch.append(".".join(line)) + arch = sorted(set(arch)) + return ", ".join(arch) + else: + return so_file + "; cannot find cuobjdump" + except Exception: + # unhandled failure + return so_file + + +def collect_env_info(): + has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM + torch_version = torch.__version__ + + # NOTE that CUDA_HOME/ROCM_HOME could be None even when CUDA runtime libs are functional + from torch.utils.cpp_extension import CUDA_HOME, ROCM_HOME + + has_rocm = False + if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None): + has_rocm = True + has_cuda = has_gpu and (not has_rocm) + + data = [] + data.append(("sys.platform", sys.platform)) # check-template.yml depends on it + data.append(("Python", sys.version.replace("\n", ""))) + data.append(("numpy", np.__version__)) + + try: + import custom_detectron2 # noqa + + data.append( + ("detectron2", detectron2.__version__ + " @" + os.path.dirname(detectron2.__file__)) + ) + except ImportError: + data.append(("detectron2", "failed to import")) + except AttributeError: + data.append(("detectron2", "imported a wrong installation")) + + try: + import custom_detectron2._C as _C + except ImportError as e: + data.append(("detectron2._C", f"not built correctly: {e}")) + + # print system compilers when extension fails to build + if sys.platform != "win32": # don't know what to do for windows + try: + # this is how torch/utils/cpp_extensions.py choose compiler + cxx = os.environ.get("CXX", "c++") + cxx = subprocess.check_output("'{}' --version".format(cxx), shell=True) + cxx = cxx.decode("utf-8").strip().split("\n")[0] + except subprocess.SubprocessError: + cxx = "Not found" + data.append(("Compiler ($CXX)", cxx)) + + if has_cuda and CUDA_HOME is not None: + try: + nvcc = os.path.join(CUDA_HOME, "bin", "nvcc") + nvcc = subprocess.check_output("'{}' -V".format(nvcc), shell=True) + nvcc = nvcc.decode("utf-8").strip().split("\n")[-1] + except subprocess.SubprocessError: + nvcc = "Not found" + data.append(("CUDA compiler", nvcc)) + if has_cuda and sys.platform != "win32": + try: + so_file = importlib.util.find_spec("detectron2._C").origin + except (ImportError, AttributeError): + pass + else: + data.append( + ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, so_file)) + ) + else: + # print compilers that are used to build extension + data.append(("Compiler", _C.get_compiler_version())) + data.append(("CUDA compiler", _C.get_cuda_version())) # cuda or hip + if has_cuda and getattr(_C, "has_cuda", lambda: True)(): + data.append( + ("detectron2 arch flags", detect_compute_compatibility(CUDA_HOME, _C.__file__)) + ) + + data.append(get_env_module()) + data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__))) + data.append(("PyTorch debug build", torch.version.debug)) + try: + data.append(("torch._C._GLIBCXX_USE_CXX11_ABI", torch._C._GLIBCXX_USE_CXX11_ABI)) + except Exception: + pass + + if not has_gpu: + has_gpu_text = "No: torch.cuda.is_available() == False" + else: + has_gpu_text = "Yes" + data.append(("GPU available", has_gpu_text)) + if has_gpu: + devices = defaultdict(list) + for k in range(torch.cuda.device_count()): + cap = ".".join((str(x) for x in torch.cuda.get_device_capability(k))) + name = torch.cuda.get_device_name(k) + f" (arch={cap})" + devices[name].append(str(k)) + for name, devids in devices.items(): + data.append(("GPU " + ",".join(devids), name)) + + if has_rocm: + msg = " - invalid!" if not (ROCM_HOME and os.path.isdir(ROCM_HOME)) else "" + data.append(("ROCM_HOME", str(ROCM_HOME) + msg)) + else: + try: + from torch.utils.collect_env import get_nvidia_driver_version, run as _run + + data.append(("Driver version", get_nvidia_driver_version(_run))) + except Exception: + pass + msg = " - invalid!" if not (CUDA_HOME and os.path.isdir(CUDA_HOME)) else "" + data.append(("CUDA_HOME", str(CUDA_HOME) + msg)) + + cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None) + if cuda_arch_list: + data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list)) + data.append(("Pillow", PIL.__version__)) + + try: + data.append( + ( + "torchvision", + str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__), + ) + ) + if has_cuda: + try: + torchvision_C = importlib.util.find_spec("torchvision._C").origin + msg = detect_compute_compatibility(CUDA_HOME, torchvision_C) + data.append(("torchvision arch flags", msg)) + except (ImportError, AttributeError): + data.append(("torchvision._C", "Not found")) + except AttributeError: + data.append(("torchvision", "unknown")) + + try: + import fvcore + + data.append(("fvcore", fvcore.__version__)) + except (ImportError, AttributeError): + pass + + try: + import iopath + + data.append(("iopath", iopath.__version__)) + except (ImportError, AttributeError): + pass + + try: + import cv2 + + data.append(("cv2", cv2.__version__)) + except (ImportError, AttributeError): + data.append(("cv2", "Not found")) + env_str = tabulate(data) + "\n" + env_str += collect_torch_env() + return env_str + + +def test_nccl_ops(): + num_gpu = torch.cuda.device_count() + if os.access("/tmp", os.W_OK): + import torch.multiprocessing as mp + + dist_url = "file:///tmp/nccl_tmp_file" + print("Testing NCCL connectivity ... this should not hang.") + mp.spawn(_test_nccl_worker, nprocs=num_gpu, args=(num_gpu, dist_url), daemon=False) + print("NCCL succeeded.") + + +def _test_nccl_worker(rank, num_gpu, dist_url): + import torch.distributed as dist + + dist.init_process_group(backend="NCCL", init_method=dist_url, rank=rank, world_size=num_gpu) + dist.barrier(device_ids=[rank]) + + +if __name__ == "__main__": + try: + from custom_detectron2.utils.collect_env import collect_env_info as f + + print(f()) + except ImportError: + print(collect_env_info()) + + if torch.cuda.is_available(): + num_gpu = torch.cuda.device_count() + for k in range(num_gpu): + device = f"cuda:{k}" + try: + x = torch.tensor([1, 2.0], dtype=torch.float32) + x = x.to(device) + except Exception as e: + print( + f"Unable to copy tensor to device={device}: {e}. " + "Your CUDA environment is broken." + ) + if num_gpu > 1: + test_nccl_ops() diff --git a/custom_detectron2/utils/colormap.py b/custom_detectron2/utils/colormap.py new file mode 100644 index 0000000000000000000000000000000000000000..14ded1659b40b161358c4aaf9cc84ffe0ffafe64 --- /dev/null +++ b/custom_detectron2/utils/colormap.py @@ -0,0 +1,158 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +""" +An awesome colormap for really neat visualizations. +Copied from Detectron, and removed gray colors. +""" + +import numpy as np +import random + +__all__ = ["colormap", "random_color", "random_colors"] + +# fmt: off +# RGB: +_COLORS = np.array( + [ + 0.000, 0.447, 0.741, + 0.850, 0.325, 0.098, + 0.929, 0.694, 0.125, + 0.494, 0.184, 0.556, + 0.466, 0.674, 0.188, + 0.301, 0.745, 0.933, + 0.635, 0.078, 0.184, + 0.300, 0.300, 0.300, + 0.600, 0.600, 0.600, + 1.000, 0.000, 0.000, + 1.000, 0.500, 0.000, + 0.749, 0.749, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 1.000, + 0.667, 0.000, 1.000, + 0.333, 0.333, 0.000, + 0.333, 0.667, 0.000, + 0.333, 1.000, 0.000, + 0.667, 0.333, 0.000, + 0.667, 0.667, 0.000, + 0.667, 1.000, 0.000, + 1.000, 0.333, 0.000, + 1.000, 0.667, 0.000, + 1.000, 1.000, 0.000, + 0.000, 0.333, 0.500, + 0.000, 0.667, 0.500, + 0.000, 1.000, 0.500, + 0.333, 0.000, 0.500, + 0.333, 0.333, 0.500, + 0.333, 0.667, 0.500, + 0.333, 1.000, 0.500, + 0.667, 0.000, 0.500, + 0.667, 0.333, 0.500, + 0.667, 0.667, 0.500, + 0.667, 1.000, 0.500, + 1.000, 0.000, 0.500, + 1.000, 0.333, 0.500, + 1.000, 0.667, 0.500, + 1.000, 1.000, 0.500, + 0.000, 0.333, 1.000, + 0.000, 0.667, 1.000, + 0.000, 1.000, 1.000, + 0.333, 0.000, 1.000, + 0.333, 0.333, 1.000, + 0.333, 0.667, 1.000, + 0.333, 1.000, 1.000, + 0.667, 0.000, 1.000, + 0.667, 0.333, 1.000, + 0.667, 0.667, 1.000, + 0.667, 1.000, 1.000, + 1.000, 0.000, 1.000, + 1.000, 0.333, 1.000, + 1.000, 0.667, 1.000, + 0.333, 0.000, 0.000, + 0.500, 0.000, 0.000, + 0.667, 0.000, 0.000, + 0.833, 0.000, 0.000, + 1.000, 0.000, 0.000, + 0.000, 0.167, 0.000, + 0.000, 0.333, 0.000, + 0.000, 0.500, 0.000, + 0.000, 0.667, 0.000, + 0.000, 0.833, 0.000, + 0.000, 1.000, 0.000, + 0.000, 0.000, 0.167, + 0.000, 0.000, 0.333, + 0.000, 0.000, 0.500, + 0.000, 0.000, 0.667, + 0.000, 0.000, 0.833, + 0.000, 0.000, 1.000, + 0.000, 0.000, 0.000, + 0.143, 0.143, 0.143, + 0.857, 0.857, 0.857, + 1.000, 1.000, 1.000 + ] +).astype(np.float32).reshape(-1, 3) +# fmt: on + + +def colormap(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a float32 array of Nx3 colors, in range [0, 255] or [0, 1] + """ + assert maximum in [255, 1], maximum + c = _COLORS * maximum + if not rgb: + c = c[:, ::-1] + return c + + +def random_color(rgb=False, maximum=255): + """ + Args: + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a vector of 3 numbers + """ + idx = np.random.randint(0, len(_COLORS)) + ret = _COLORS[idx] * maximum + if not rgb: + ret = ret[::-1] + return ret + + +def random_colors(N, rgb=False, maximum=255): + """ + Args: + N (int): number of unique colors needed + rgb (bool): whether to return RGB colors or BGR colors. + maximum (int): either 255 or 1 + + Returns: + ndarray: a list of random_color + """ + indices = random.sample(range(len(_COLORS)), N) + ret = [_COLORS[i] * maximum for i in indices] + if not rgb: + ret = [x[::-1] for x in ret] + return ret + + +if __name__ == "__main__": + import cv2 + + size = 100 + H, W = 10, 10 + canvas = np.random.rand(H * size, W * size, 3).astype("float32") + for h in range(H): + for w in range(W): + idx = h * W + w + if idx >= len(_COLORS): + break + canvas[h * size : (h + 1) * size, w * size : (w + 1) * size] = _COLORS[idx] + cv2.imshow("a", canvas) + cv2.waitKey(0) diff --git a/custom_detectron2/utils/comm.py b/custom_detectron2/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ea9a9f578c5704d1e7ff563ef156e9133ab465 --- /dev/null +++ b/custom_detectron2/utils/comm.py @@ -0,0 +1,238 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import functools +import numpy as np +import torch +import torch.distributed as dist + +_LOCAL_PROCESS_GROUP = None +_MISSING_LOCAL_PG_ERROR = ( + "Local process group is not yet created! Please use detectron2's `launch()` " + "to start processes and initialize pytorch process group. If you need to start " + "processes in other ways, please call comm.create_local_process_group(" + "num_workers_per_machine) after calling torch.distributed.init_process_group()." +) + + +def get_world_size() -> int: + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank() -> int: + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +@functools.lru_cache() +def create_local_process_group(num_workers_per_machine: int) -> None: + """ + Create a process group that contains ranks within the same machine. + + Detectron2's launch() in engine/launch.py will call this function. If you start + workers without launch(), you'll have to also call this. Otherwise utilities + like `get_local_rank()` will not work. + + This function contains a barrier. All processes must call it together. + + Args: + num_workers_per_machine: the number of worker processes per machine. Typically + the number of GPUs. + """ + global _LOCAL_PROCESS_GROUP + assert _LOCAL_PROCESS_GROUP is None + assert get_world_size() % num_workers_per_machine == 0 + num_machines = get_world_size() // num_workers_per_machine + machine_rank = get_rank() // num_workers_per_machine + for i in range(num_machines): + ranks_on_i = list(range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + _LOCAL_PROCESS_GROUP = pg + + +def get_local_process_group(): + """ + Returns: + A torch process group which only includes processes that are on the same + machine as the current process. This group can be useful for communication + within a machine, e.g. a per-machine SyncBN. + """ + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return _LOCAL_PROCESS_GROUP + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + assert _LOCAL_PROCESS_GROUP is not None, _MISSING_LOCAL_PG_ERROR + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def is_main_process() -> bool: + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + if dist.get_backend() == dist.Backend.NCCL: + # This argument is needed to avoid warnings. + # It's valid only for NCCL backend. + dist.barrier(device_ids=[torch.cuda.current_device()]) + else: + dist.barrier() + + +@functools.lru_cache() +def _get_global_gloo_group(): + """ + Return a process group based on gloo backend, containing all the ranks + The result is cached. + """ + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: list of data gathered from each rank + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def gather(data, dst=0, group=None): + """ + Run gather on arbitrary picklable data (not necessarily tensors). + + Args: + data: any picklable object + dst (int): destination rank + group: a torch process group. By default, will use a group which + contains all ranks on gloo backend. + + Returns: + list[data]: on dst, a list of data gathered from each rank. Otherwise, + an empty list. + """ + if get_world_size() == 1: + return [data] + if group is None: + group = _get_global_gloo_group() + world_size = dist.get_world_size(group=group) + if world_size == 1: + return [data] + rank = dist.get_rank(group=group) + + if rank == dst: + output = [None for _ in range(world_size)] + dist.gather_object(data, output, dst=dst, group=group) + return output + else: + dist.gather_object(data, None, dst=dst, group=group) + return [] + + +def shared_random_seed(): + """ + Returns: + int: a random number that is the same across all workers. + If workers need a shared RNG, they can use this shared seed to + create one. + + All workers must call this function, otherwise it will deadlock. + """ + ints = np.random.randint(2**31) + all_ints = all_gather(ints) + return all_ints[0] + + +def reduce_dict(input_dict, average=True): + """ + Reduce the values in the dictionary from all processes so that process with rank + 0 has the reduced results. + + Args: + input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor. + average (bool): whether to do average or sum + + Returns: + a dict with the same keys as input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/custom_detectron2/utils/develop.py b/custom_detectron2/utils/develop.py new file mode 100644 index 0000000000000000000000000000000000000000..e8416984954f7b32fc269100620e3c0d0d0f9585 --- /dev/null +++ b/custom_detectron2/utils/develop.py @@ -0,0 +1,59 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +""" Utilities for developers only. +These are not visible to users (not automatically imported). And should not +appeared in docs.""" +# adapted from https://github.com/tensorpack/tensorpack/blob/master/tensorpack/utils/develop.py + + +def create_dummy_class(klass, dependency, message=""): + """ + When a dependency of a class is not available, create a dummy class which throws ImportError + when used. + + Args: + klass (str): name of the class. + dependency (str): name of the dependency. + message: extra message to print + Returns: + class: a class object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, klass) + if message: + err = err + " " + message + + class _DummyMetaClass(type): + # throw error on class attribute access + def __getattr__(_, __): # noqa: B902 + raise ImportError(err) + + class _Dummy(object, metaclass=_DummyMetaClass): + # throw error on constructor + def __init__(self, *args, **kwargs): + raise ImportError(err) + + return _Dummy + + +def create_dummy_func(func, dependency, message=""): + """ + When a dependency of a function is not available, create a dummy function which throws + ImportError when used. + + Args: + func (str): name of the function. + dependency (str or list[str]): name(s) of the dependency. + message: extra message to print + Returns: + function: a function object + """ + err = "Cannot import '{}', therefore '{}' is not available.".format(dependency, func) + if message: + err = err + " " + message + + if isinstance(dependency, (list, tuple)): + dependency = ",".join(dependency) + + def _dummy(*args, **kwargs): + raise ImportError(err) + + return _dummy diff --git a/custom_detectron2/utils/env.py b/custom_detectron2/utils/env.py new file mode 100644 index 0000000000000000000000000000000000000000..40634c17c73273ac8927632be164f466cfe7d1fa --- /dev/null +++ b/custom_detectron2/utils/env.py @@ -0,0 +1,170 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import importlib +import importlib.util +import logging +import numpy as np +import os +import random +import sys +from datetime import datetime +import torch + +__all__ = ["seed_all_rng"] + + +TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2]) +""" +PyTorch version as a tuple of 2 ints. Useful for comparison. +""" + + +DOC_BUILDING = os.getenv("_DOC_BUILDING", False) # set in docs/conf.py +""" +Whether we're building documentation. +""" + + +def seed_all_rng(seed=None): + """ + Set the random seed for the RNG in torch, numpy and python. + + Args: + seed (int): if None, will use a strong random seed. + """ + if seed is None: + seed = ( + os.getpid() + + int(datetime.now().strftime("%S%f")) + + int.from_bytes(os.urandom(2), "big") + ) + logger = logging.getLogger(__name__) + logger.info("Using a generated random seed {}".format(seed)) + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + +# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path +def _import_file(module_name, file_path, make_importable=False): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + if make_importable: + sys.modules[module_name] = module + return module + + +def _configure_libraries(): + """ + Configurations for some libraries. + """ + # An environment option to disable `import cv2` globally, + # in case it leads to negative performance impact + disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False)) + if disable_cv2: + sys.modules["cv2"] = None + else: + # Disable opencl in opencv since its interaction with cuda often has negative effects + # This envvar is supported after OpenCV 3.4.0 + os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled" + try: + import cv2 + + if int(cv2.__version__.split(".")[0]) >= 3: + cv2.ocl.setUseOpenCL(False) + except ModuleNotFoundError: + # Other types of ImportError, if happened, should not be ignored. + # Because a failed opencv import could mess up address space + # https://github.com/skvark/opencv-python/issues/381 + pass + + def get_version(module, digit=2): + return tuple(map(int, module.__version__.split(".")[:digit])) + + # fmt: off + assert get_version(torch) >= (1, 4), "Requires torch>=1.4" + import fvcore + assert get_version(fvcore, 3) >= (0, 1, 2), "Requires fvcore>=0.1.2" + import yaml + assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1" + # fmt: on + + +_ENV_SETUP_DONE = False + + +def setup_environment(): + """Perform environment setup work. The default setup is a no-op, but this + function allows the user to specify a Python source file or a module in + the $DETECTRON2_ENV_MODULE environment variable, that performs + custom setup work that may be necessary to their computing environment. + """ + global _ENV_SETUP_DONE + if _ENV_SETUP_DONE: + return + _ENV_SETUP_DONE = True + + _configure_libraries() + + custom_module_path = os.environ.get("DETECTRON2_ENV_MODULE") + + if custom_module_path: + setup_custom_environment(custom_module_path) + else: + # The default setup is a no-op + pass + + +def setup_custom_environment(custom_module): + """ + Load custom environment setup by importing a Python source file or a + module, and run the setup function. + """ + if custom_module.endswith(".py"): + module = _import_file("detectron2.utils.env.custom_module", custom_module) + else: + module = importlib.import_module(custom_module) + assert hasattr(module, "setup_environment") and callable(module.setup_environment), ( + "Custom environment module defined in {} does not have the " + "required callable attribute 'setup_environment'." + ).format(custom_module) + module.setup_environment() + + +def fixup_module_metadata(module_name, namespace, keys=None): + """ + Fix the __qualname__ of module members to be their exported api name, so + when they are referenced in docs, sphinx can find them. Reference: + https://github.com/python-trio/trio/blob/6754c74eacfad9cc5c92d5c24727a2f3b620624e/trio/_util.py#L216-L241 + """ + if not DOC_BUILDING: + return + seen_ids = set() + + def fix_one(qualname, name, obj): + # avoid infinite recursion (relevant when using + # typing.Generic, for example) + if id(obj) in seen_ids: + return + seen_ids.add(id(obj)) + + mod = getattr(obj, "__module__", None) + if mod is not None and (mod.startswith(module_name) or mod.startswith("fvcore.")): + obj.__module__ = module_name + # Modules, unlike everything else in Python, put fully-qualitied + # names into their __name__ attribute. We check for "." to avoid + # rewriting these. + if hasattr(obj, "__name__") and "." not in obj.__name__: + obj.__name__ = name + obj.__qualname__ = qualname + if isinstance(obj, type): + for attr_name, attr_value in obj.__dict__.items(): + fix_one(objname + "." + attr_name, attr_name, attr_value) + + if keys is None: + keys = namespace.keys() + for objname in keys: + if not objname.startswith("_"): + obj = namespace[objname] + fix_one(objname, objname, obj) diff --git a/custom_detectron2/utils/events.py b/custom_detectron2/utils/events.py new file mode 100644 index 0000000000000000000000000000000000000000..cd059b49742f943f50fa530bb4519cb7ee0a9bb9 --- /dev/null +++ b/custom_detectron2/utils/events.py @@ -0,0 +1,534 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import datetime +import json +import logging +import os +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Optional +import torch +from fvcore.common.history_buffer import HistoryBuffer + +from custom_detectron2.utils.file_io import PathManager + +__all__ = [ + "get_event_storage", + "JSONWriter", + "TensorboardXWriter", + "CommonMetricPrinter", + "EventStorage", +] + +_CURRENT_STORAGE_STACK = [] + + +def get_event_storage(): + """ + Returns: + The :class:`EventStorage` object that's currently being used. + Throws an error if no :class:`EventStorage` is currently enabled. + """ + assert len( + _CURRENT_STORAGE_STACK + ), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!" + return _CURRENT_STORAGE_STACK[-1] + + +class EventWriter: + """ + Base class for writers that obtain events from :class:`EventStorage` and process them. + """ + + def write(self): + raise NotImplementedError + + def close(self): + pass + + +class JSONWriter(EventWriter): + """ + Write scalars to a json file. + + It saves scalars as one json per line (instead of a big json) for easy parsing. + + Examples parsing such a json file: + :: + $ cat metrics.json | jq -s '.[0:2]' + [ + { + "data_time": 0.008433341979980469, + "iteration": 19, + "loss": 1.9228371381759644, + "loss_box_reg": 0.050025828182697296, + "loss_classifier": 0.5316952466964722, + "loss_mask": 0.7236229181289673, + "loss_rpn_box": 0.0856662318110466, + "loss_rpn_cls": 0.48198649287223816, + "lr": 0.007173333333333333, + "time": 0.25401854515075684 + }, + { + "data_time": 0.007216215133666992, + "iteration": 39, + "loss": 1.282649278640747, + "loss_box_reg": 0.06222952902317047, + "loss_classifier": 0.30682939291000366, + "loss_mask": 0.6970193982124329, + "loss_rpn_box": 0.038663312792778015, + "loss_rpn_cls": 0.1471673548221588, + "lr": 0.007706666666666667, + "time": 0.2490077018737793 + } + ] + + $ cat metrics.json | jq '.loss_mask' + 0.7126231789588928 + 0.689423680305481 + 0.6776131987571716 + ... + + """ + + def __init__(self, json_file, window_size=20): + """ + Args: + json_file (str): path to the json file. New data will be appended if the file exists. + window_size (int): the window size of median smoothing for the scalars whose + `smoothing_hint` are True. + """ + self._file_handle = PathManager.open(json_file, "a") + self._window_size = window_size + self._last_write = -1 + + def write(self): + storage = get_event_storage() + to_save = defaultdict(dict) + + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + # keep scalars that have not been written + if iter <= self._last_write: + continue + to_save[iter][k] = v + if len(to_save): + all_iters = sorted(to_save.keys()) + self._last_write = max(all_iters) + + for itr, scalars_per_iter in to_save.items(): + scalars_per_iter["iteration"] = itr + self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n") + self._file_handle.flush() + try: + os.fsync(self._file_handle.fileno()) + except AttributeError: + pass + + def close(self): + self._file_handle.close() + + +class TensorboardXWriter(EventWriter): + """ + Write all scalars to a tensorboard file. + """ + + def __init__(self, log_dir: str, window_size: int = 20, **kwargs): + """ + Args: + log_dir (str): the directory to save the output events + window_size (int): the scalars will be median-smoothed by this window size + + kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` + """ + self._window_size = window_size + from torch.utils.tensorboard import SummaryWriter + + self._writer = SummaryWriter(log_dir, **kwargs) + self._last_write = -1 + + def write(self): + storage = get_event_storage() + new_last_write = self._last_write + for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items(): + if iter > self._last_write: + self._writer.add_scalar(k, v, iter) + new_last_write = max(new_last_write, iter) + self._last_write = new_last_write + + # storage.put_{image,histogram} is only meant to be used by + # tensorboard writer. So we access its internal fields directly from here. + if len(storage._vis_data) >= 1: + for img_name, img, step_num in storage._vis_data: + self._writer.add_image(img_name, img, step_num) + # Storage stores all image data and rely on this writer to clear them. + # As a result it assumes only one writer will use its image data. + # An alternative design is to let storage store limited recent + # data (e.g. only the most recent image) that all writers can access. + # In that case a writer may not see all image data if its period is long. + storage.clear_images() + + if len(storage._histograms) >= 1: + for params in storage._histograms: + self._writer.add_histogram_raw(**params) + storage.clear_histograms() + + def close(self): + if hasattr(self, "_writer"): # doesn't exist when the code fails at import + self._writer.close() + + +class CommonMetricPrinter(EventWriter): + """ + Print **common** metrics to the terminal, including + iteration time, ETA, memory, all losses, and the learning rate. + It also applies smoothing using a window of 20 elements. + + It's meant to print common metrics in common ways. + To print something in more customized ways, please implement a similar printer by yourself. + """ + + def __init__(self, max_iter: Optional[int] = None, window_size: int = 20): + """ + Args: + max_iter: the maximum number of iterations to train. + Used to compute ETA. If not given, ETA will not be printed. + window_size (int): the losses will be median-smoothed by this window size + """ + self.logger = logging.getLogger(__name__) + self._max_iter = max_iter + self._window_size = window_size + self._last_write = None # (step, time) of last call to write(). Used to compute ETA + + def _get_eta(self, storage) -> Optional[str]: + if self._max_iter is None: + return "" + iteration = storage.iter + try: + eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1) + storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False) + return str(datetime.timedelta(seconds=int(eta_seconds))) + except KeyError: + # estimate eta on our own - more noisy + eta_string = None + if self._last_write is not None: + estimate_iter_time = (time.perf_counter() - self._last_write[1]) / ( + iteration - self._last_write[0] + ) + eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + self._last_write = (iteration, time.perf_counter()) + return eta_string + + def write(self): + storage = get_event_storage() + iteration = storage.iter + if iteration == self._max_iter: + # This hook only reports training progress (loss, ETA, etc) but not other data, + # therefore do not write anything after training succeeds, even if this method + # is called. + return + + try: + avg_data_time = storage.history("data_time").avg( + storage.count_samples("data_time", self._window_size) + ) + last_data_time = storage.history("data_time").latest() + except KeyError: + # they may not exist in the first few iterations (due to warmup) + # or when SimpleTrainer is not used + avg_data_time = None + last_data_time = None + try: + avg_iter_time = storage.history("time").global_avg() + last_iter_time = storage.history("time").latest() + except KeyError: + avg_iter_time = None + last_iter_time = None + try: + lr = "{:.5g}".format(storage.history("lr").latest()) + except KeyError: + lr = "N/A" + + eta_string = self._get_eta(storage) + + if torch.cuda.is_available(): + max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0 + else: + max_mem_mb = None + + # NOTE: max_mem is parsed by grep in "dev/parse_results.sh" + self.logger.info( + str.format( + " {eta}iter: {iter} {losses} {non_losses} {avg_time}{last_time}" + + "{avg_data_time}{last_data_time} lr: {lr} {memory}", + eta=f"eta: {eta_string} " if eta_string else "", + iter=iteration, + losses=" ".join( + [ + "{}: {:.4g}".format( + k, v.median(storage.count_samples(k, self._window_size)) + ) + for k, v in storage.histories().items() + if "loss" in k + ] + ), + non_losses=" ".join( + [ + "{}: {:.4g}".format( + k, v.median(storage.count_samples(k, self._window_size)) + ) + for k, v in storage.histories().items() + if "[metric]" in k + ] + ), + avg_time="time: {:.4f} ".format(avg_iter_time) + if avg_iter_time is not None + else "", + last_time="last_time: {:.4f} ".format(last_iter_time) + if last_iter_time is not None + else "", + avg_data_time="data_time: {:.4f} ".format(avg_data_time) + if avg_data_time is not None + else "", + last_data_time="last_data_time: {:.4f} ".format(last_data_time) + if last_data_time is not None + else "", + lr=lr, + memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "", + ) + ) + + +class EventStorage: + """ + The user-facing class that provides metric storage functionalities. + + In the future we may add support for storing / logging other types of data if needed. + """ + + def __init__(self, start_iter=0): + """ + Args: + start_iter (int): the iteration number to start with + """ + self._history = defaultdict(HistoryBuffer) + self._smoothing_hints = {} + self._latest_scalars = {} + self._iter = start_iter + self._current_prefix = "" + self._vis_data = [] + self._histograms = [] + + def put_image(self, img_name, img_tensor): + """ + Add an `img_tensor` associated with `img_name`, to be shown on + tensorboard. + + Args: + img_name (str): The name of the image to put into tensorboard. + img_tensor (torch.Tensor or numpy.array): An `uint8` or `float` + Tensor of shape `[channel, height, width]` where `channel` is + 3. The image format should be RGB. The elements in img_tensor + can either have values in [0, 1] (float32) or [0, 255] (uint8). + The `img_tensor` will be visualized in tensorboard. + """ + self._vis_data.append((img_name, img_tensor, self._iter)) + + def put_scalar(self, name, value, smoothing_hint=True): + """ + Add a scalar `value` to the `HistoryBuffer` associated with `name`. + + Args: + smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be + smoothed when logged. The hint will be accessible through + :meth:`EventStorage.smoothing_hints`. A writer may ignore the hint + and apply custom smoothing rule. + + It defaults to True because most scalars we save need to be smoothed to + provide any useful signal. + """ + name = self._current_prefix + name + history = self._history[name] + value = float(value) + history.update(value, self._iter) + self._latest_scalars[name] = (value, self._iter) + + existing_hint = self._smoothing_hints.get(name) + if existing_hint is not None: + assert ( + existing_hint == smoothing_hint + ), "Scalar {} was put with a different smoothing_hint!".format(name) + else: + self._smoothing_hints[name] = smoothing_hint + + def put_scalars(self, *, smoothing_hint=True, **kwargs): + """ + Put multiple scalars from keyword arguments. + + Examples: + + storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True) + """ + for k, v in kwargs.items(): + self.put_scalar(k, v, smoothing_hint=smoothing_hint) + + def put_histogram(self, hist_name, hist_tensor, bins=1000): + """ + Create a histogram from a tensor. + + Args: + hist_name (str): The name of the histogram to put into tensorboard. + hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted + into a histogram. + bins (int): Number of histogram bins. + """ + ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item() + + # Create a histogram with PyTorch + hist_counts = torch.histc(hist_tensor, bins=bins) + hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32) + + # Parameter for the add_histogram_raw function of SummaryWriter + hist_params = dict( + tag=hist_name, + min=ht_min, + max=ht_max, + num=len(hist_tensor), + sum=float(hist_tensor.sum()), + sum_squares=float(torch.sum(hist_tensor**2)), + bucket_limits=hist_edges[1:].tolist(), + bucket_counts=hist_counts.tolist(), + global_step=self._iter, + ) + self._histograms.append(hist_params) + + def history(self, name): + """ + Returns: + HistoryBuffer: the scalar history for name + """ + ret = self._history.get(name, None) + if ret is None: + raise KeyError("No history metric available for {}!".format(name)) + return ret + + def histories(self): + """ + Returns: + dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars + """ + return self._history + + def latest(self): + """ + Returns: + dict[str -> (float, int)]: mapping from the name of each scalar to the most + recent value and the iteration number its added. + """ + return self._latest_scalars + + def latest_with_smoothing_hint(self, window_size=20): + """ + Similar to :meth:`latest`, but the returned values + are either the un-smoothed original latest value, + or a median of the given window_size, + depend on whether the smoothing_hint is True. + + This provides a default behavior that other writers can use. + + Note: All scalars saved in the past `window_size` iterations are used for smoothing. + This is different from the `window_size` definition in HistoryBuffer. + Use :meth:`get_history_window_size` to get the `window_size` used in HistoryBuffer. + """ + result = {} + for k, (v, itr) in self._latest_scalars.items(): + result[k] = ( + self._history[k].median(self.count_samples(k, window_size)) + if self._smoothing_hints[k] + else v, + itr, + ) + return result + + def count_samples(self, name, window_size=20): + """ + Return the number of samples logged in the past `window_size` iterations. + """ + samples = 0 + data = self._history[name].values() + for _, iter_ in reversed(data): + if iter_ > data[-1][1] - window_size: + samples += 1 + else: + break + return samples + + def smoothing_hints(self): + """ + Returns: + dict[name -> bool]: the user-provided hint on whether the scalar + is noisy and needs smoothing. + """ + return self._smoothing_hints + + def step(self): + """ + User should either: (1) Call this function to increment storage.iter when needed. Or + (2) Set `storage.iter` to the correct iteration number before each iteration. + + The storage will then be able to associate the new data with an iteration number. + """ + self._iter += 1 + + @property + def iter(self): + """ + Returns: + int: The current iteration number. When used together with a trainer, + this is ensured to be the same as trainer.iter. + """ + return self._iter + + @iter.setter + def iter(self, val): + self._iter = int(val) + + @property + def iteration(self): + # for backward compatibility + return self._iter + + def __enter__(self): + _CURRENT_STORAGE_STACK.append(self) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + assert _CURRENT_STORAGE_STACK[-1] == self + _CURRENT_STORAGE_STACK.pop() + + @contextmanager + def name_scope(self, name): + """ + Yields: + A context within which all the events added to this storage + will be prefixed by the name scope. + """ + old_prefix = self._current_prefix + self._current_prefix = name.rstrip("/") + "/" + yield + self._current_prefix = old_prefix + + def clear_images(self): + """ + Delete all the stored images for visualization. This should be called + after images are written to tensorboard. + """ + self._vis_data = [] + + def clear_histograms(self): + """ + Delete all the stored histograms for visualization. + This should be called after histograms are written to tensorboard. + """ + self._histograms = [] diff --git a/custom_detectron2/utils/file_io.py b/custom_detectron2/utils/file_io.py new file mode 100644 index 0000000000000000000000000000000000000000..09f7dffdb36199350bba57bd3b4e9e8babb40594 --- /dev/null +++ b/custom_detectron2/utils/file_io.py @@ -0,0 +1,39 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +from iopath.common.file_io import HTTPURLHandler, OneDrivePathHandler, PathHandler +from iopath.common.file_io import PathManager as PathManagerBase + +__all__ = ["PathManager", "PathHandler"] + + +PathManager = PathManagerBase() +""" +This is a detectron2 project-specific PathManager. +We try to stay away from global PathManager in fvcore as it +introduces potential conflicts among other libraries. +""" + + +class Detectron2Handler(PathHandler): + """ + Resolve anything that's hosted under detectron2's namespace. + """ + + PREFIX = "detectron2://" + S3_DETECTRON2_PREFIX = "https://dl.fbaipublicfiles.com/detectron2/" + + def _get_supported_prefixes(self): + return [self.PREFIX] + + def _get_local_path(self, path, **kwargs): + name = path[len(self.PREFIX) :] + return PathManager.get_local_path(self.S3_DETECTRON2_PREFIX + name, **kwargs) + + def _open(self, path, mode="r", **kwargs): + return PathManager.open( + self.S3_DETECTRON2_PREFIX + path[len(self.PREFIX) :], mode, **kwargs + ) + + +PathManager.register_handler(HTTPURLHandler()) +PathManager.register_handler(OneDrivePathHandler()) +PathManager.register_handler(Detectron2Handler()) diff --git a/custom_detectron2/utils/logger.py b/custom_detectron2/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..1bce0439512808c3cb31147edcc664485ae9e8fe --- /dev/null +++ b/custom_detectron2/utils/logger.py @@ -0,0 +1,237 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import atexit +import functools +import logging +import os +import sys +import time +from collections import Counter +import torch +from tabulate import tabulate +from termcolor import colored + +from custom_detectron2.utils.file_io import PathManager + +__all__ = ["setup_logger", "log_first_n", "log_every_n", "log_every_n_seconds"] + + +class _ColorfulFormatter(logging.Formatter): + def __init__(self, *args, **kwargs): + self._root_name = kwargs.pop("root_name") + "." + self._abbrev_name = kwargs.pop("abbrev_name", "") + if len(self._abbrev_name): + self._abbrev_name = self._abbrev_name + "." + super(_ColorfulFormatter, self).__init__(*args, **kwargs) + + def formatMessage(self, record): + record.name = record.name.replace(self._root_name, self._abbrev_name) + log = super(_ColorfulFormatter, self).formatMessage(record) + if record.levelno == logging.WARNING: + prefix = colored("WARNING", "red", attrs=["blink"]) + elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL: + prefix = colored("ERROR", "red", attrs=["blink", "underline"]) + else: + return log + return prefix + " " + log + + +@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers +def setup_logger( + output=None, distributed_rank=0, *, color=True, name="detectron2", abbrev_name=None +): + """ + Initialize the detectron2 logger and set its verbosity level to "DEBUG". + + Args: + output (str): a file name or a directory to save log. If None, will not save log file. + If ends with ".txt" or ".log", assumed to be a file name. + Otherwise, logs will be saved to `output/log.txt`. + name (str): the root module name of this logger + abbrev_name (str): an abbreviation of the module, to avoid long names in logs. + Set to "" to not log the root module in logs. + By default, will abbreviate "detectron2" to "d2" and leave other + modules unchanged. + + Returns: + logging.Logger: a logger + """ + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + if abbrev_name is None: + abbrev_name = "d2" if name == "detectron2" else name + + plain_formatter = logging.Formatter( + "[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S" + ) + # stdout logging: master only + if distributed_rank == 0: + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + if color: + formatter = _ColorfulFormatter( + colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s", + datefmt="%m/%d %H:%M:%S", + root_name=name, + abbrev_name=str(abbrev_name), + ) + else: + formatter = plain_formatter + ch.setFormatter(formatter) + logger.addHandler(ch) + + # file logging: all workers + if output is not None: + if output.endswith(".txt") or output.endswith(".log"): + filename = output + else: + filename = os.path.join(output, "log.txt") + if distributed_rank > 0: + filename = filename + ".rank{}".format(distributed_rank) + PathManager.mkdirs(os.path.dirname(filename)) + + fh = logging.StreamHandler(_cached_log_stream(filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(plain_formatter) + logger.addHandler(fh) + + return logger + + +# cache the opened file object, so that different calls to `setup_logger` +# with the same file name can safely write to the same file. +@functools.lru_cache(maxsize=None) +def _cached_log_stream(filename): + # use 1K buffer if writing to cloud storage + io = PathManager.open(filename, "a", buffering=1024 if "://" in filename else -1) + atexit.register(io.close) + return io + + +""" +Below are some other convenient logging methods. +They are mainly adopted from +https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py +""" + + +def _find_caller(): + """ + Returns: + str: module name of the caller + tuple: a hashable key to be used to identify different callers + """ + frame = sys._getframe(2) + while frame: + code = frame.f_code + if os.path.join("utils", "logger.") not in code.co_filename: + mod_name = frame.f_globals["__name__"] + if mod_name == "__main__": + mod_name = "detectron2" + return mod_name, (code.co_filename, frame.f_lineno, code.co_name) + frame = frame.f_back + + +_LOG_COUNTER = Counter() +_LOG_TIMER = {} + + +def log_first_n(lvl, msg, n=1, *, name=None, key="caller"): + """ + Log only for the first n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + key (str or tuple[str]): the string(s) can be one of "caller" or + "message", which defines how to identify duplicated logs. + For example, if called with `n=1, key="caller"`, this function + will only log the first call from the same caller, regardless of + the message content. + If called with `n=1, key="message"`, this function will log the + same content only once, even if they are called from different places. + If called with `n=1, key=("caller", "message")`, this function + will not log only if the same caller has logged the same message before. + """ + if isinstance(key, str): + key = (key,) + assert len(key) > 0 + + caller_module, caller_key = _find_caller() + hash_key = () + if "caller" in key: + hash_key = hash_key + caller_key + if "message" in key: + hash_key = hash_key + (msg,) + + _LOG_COUNTER[hash_key] += 1 + if _LOG_COUNTER[hash_key] <= n: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n(lvl, msg, n=1, *, name=None): + """ + Log once per n times. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + _LOG_COUNTER[key] += 1 + if n == 1 or _LOG_COUNTER[key] % n == 1: + logging.getLogger(name or caller_module).log(lvl, msg) + + +def log_every_n_seconds(lvl, msg, n=1, *, name=None): + """ + Log no more than once per n seconds. + + Args: + lvl (int): the logging level + msg (str): + n (int): + name (str): name of the logger to use. Will use the caller's module by default. + """ + caller_module, key = _find_caller() + last_logged = _LOG_TIMER.get(key, None) + current_time = time.time() + if last_logged is None or current_time - last_logged >= n: + logging.getLogger(name or caller_module).log(lvl, msg) + _LOG_TIMER[key] = current_time + + +def create_small_table(small_dict): + """ + Create a small table using the keys of small_dict as headers. This is only + suitable for small dictionaries. + + Args: + small_dict (dict): a result dictionary of only a few items. + + Returns: + str: the table as a string. + """ + keys, values = tuple(zip(*small_dict.items())) + table = tabulate( + [values], + headers=keys, + tablefmt="pipe", + floatfmt=".3f", + stralign="center", + numalign="center", + ) + return table + + +def _log_api_usage(identifier: str): + """ + Internal function used to log the usage of different detectron2 components + inside facebook's infra. + """ + torch._C._log_api_usage_once("detectron2." + identifier) diff --git a/custom_detectron2/utils/memory.py b/custom_detectron2/utils/memory.py new file mode 100644 index 0000000000000000000000000000000000000000..bd494780b9dbbd1571688cd270bb9b53d113c13e --- /dev/null +++ b/custom_detectron2/utils/memory.py @@ -0,0 +1,84 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import logging +from contextlib import contextmanager +from functools import wraps +import torch + +__all__ = ["retry_if_cuda_oom"] + + +@contextmanager +def _ignore_torch_cuda_oom(): + """ + A context which ignores CUDA OOM exception from pytorch. + """ + try: + yield + except RuntimeError as e: + # NOTE: the string may change? + if "CUDA out of memory. " in str(e): + pass + else: + raise + + +def retry_if_cuda_oom(func): + """ + Makes a function retry itself after encountering + pytorch's CUDA OOM error. + It will first retry after calling `torch.cuda.empty_cache()`. + + If that still fails, it will then retry by trying to convert inputs to CPUs. + In this case, it expects the function to dispatch to CPU implementation. + The return values may become CPU tensors as well and it's user's + responsibility to convert it back to CUDA tensor if needed. + + Args: + func: a stateless callable that takes tensor-like objects as arguments + + Returns: + a callable which retries `func` if OOM is encountered. + + Examples: + :: + output = retry_if_cuda_oom(some_torch_function)(input1, input2) + # output may be on CPU even if inputs are on GPU + + Note: + 1. When converting inputs to CPU, it will only look at each argument and check + if it has `.device` and `.to` for conversion. Nested structures of tensors + are not supported. + + 2. Since the function might be called more than once, it has to be + stateless. + """ + + def maybe_to_cpu(x): + try: + like_gpu_tensor = x.device.type == "cuda" and hasattr(x, "to") + except AttributeError: + like_gpu_tensor = False + if like_gpu_tensor: + return x.to(device="cpu") + else: + return x + + @wraps(func) + def wrapped(*args, **kwargs): + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Clear cache and retry + torch.cuda.empty_cache() + with _ignore_torch_cuda_oom(): + return func(*args, **kwargs) + + # Try on CPU. This slows down the code significantly, therefore print a notice. + logger = logging.getLogger(__name__) + logger.info("Attempting to copy inputs of {} to CPU due to CUDA OOM".format(str(func))) + new_args = (maybe_to_cpu(x) for x in args) + new_kwargs = {k: maybe_to_cpu(v) for k, v in kwargs.items()} + return func(*new_args, **new_kwargs) + + return wrapped diff --git a/custom_detectron2/utils/registry.py b/custom_detectron2/utils/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..4b01e9007c2578a7b5ae555c926cc06c8a3010f9 --- /dev/null +++ b/custom_detectron2/utils/registry.py @@ -0,0 +1,60 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +from typing import Any +import pydoc +from fvcore.common.registry import Registry # for backward compatibility. + +""" +``Registry`` and `locate` provide ways to map a string (typically found +in config files) to callable objects. +""" + +__all__ = ["Registry", "locate"] + + +def _convert_target_to_string(t: Any) -> str: + """ + Inverse of ``locate()``. + + Args: + t: any object with ``__module__`` and ``__qualname__`` + """ + module, qualname = t.__module__, t.__qualname__ + + # Compress the path to this object, e.g. ``module.submodule._impl.class`` + # may become ``module.submodule.class``, if the later also resolves to the same + # object. This simplifies the string, and also is less affected by moving the + # class implementation. + module_parts = module.split(".") + for k in range(1, len(module_parts)): + prefix = ".".join(module_parts[:k]) + candidate = f"{prefix}.{qualname}" + try: + if locate(candidate) is t: + return candidate + except ImportError: + pass + return f"{module}.{qualname}" + + +def locate(name: str) -> Any: + """ + Locate and return an object ``x`` using an input string ``{x.__module__}.{x.__qualname__}``, + such as "module.submodule.class_name". + + Raise Exception if it cannot be found. + """ + obj = pydoc.locate(name) + + # Some cases (e.g. torch.optim.sgd.SGD) not handled correctly + # by pydoc.locate. Try a private function from hydra. + if obj is None: + try: + # from hydra.utils import get_method - will print many errors + from hydra.utils import _locate + except ImportError as e: + raise ImportError(f"Cannot dynamically locate object {name}!") from e + else: + obj = _locate(name) # it raises if fails + + return obj diff --git a/custom_detectron2/utils/serialize.py b/custom_detectron2/utils/serialize.py new file mode 100644 index 0000000000000000000000000000000000000000..ed45065184f0512ef65c8f38d398de553ce576ca --- /dev/null +++ b/custom_detectron2/utils/serialize.py @@ -0,0 +1,32 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# import cloudpickle + + +class PicklableWrapper(object): + """ + Wrap an object to make it more picklable, note that it uses + heavy weight serialization libraries that are slower than pickle. + It's best to use it only on closures (which are usually not picklable). + + This is a simplified version of + https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py + """ + + def __init__(self, obj): + while isinstance(obj, PicklableWrapper): + # Wrapping an object twice is no-op + obj = obj._obj + self._obj = obj + + # def __reduce__(self): + # s = cloudpickle.dumps(self._obj) + # return cloudpickle.loads, (s,) + + def __call__(self, *args, **kwargs): + return self._obj(*args, **kwargs) + + def __getattr__(self, attr): + # Ensure that the wrapped object can be used seamlessly as the previous object. + if attr not in ["_obj"]: + return getattr(self._obj, attr) + return getattr(self, attr) diff --git a/custom_detectron2/utils/testing.py b/custom_detectron2/utils/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..b8f31836f6d702228e1661457c5647f1c908cf21 --- /dev/null +++ b/custom_detectron2/utils/testing.py @@ -0,0 +1,478 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import io +import numpy as np +import os +import re +import tempfile +import unittest +from typing import Callable +import torch +import torch.onnx.symbolic_helper as sym_help +from packaging import version +from torch._C import ListType +from torch.onnx import register_custom_op_symbolic + +from custom_detectron2 import model_zoo +from custom_detectron2.config import CfgNode, LazyConfig, instantiate +from custom_detectron2.data import DatasetCatalog +from custom_detectron2.data.detection_utils import read_image +from custom_detectron2.modeling import build_model +from custom_detectron2.structures import Boxes, Instances, ROIMasks +from custom_detectron2.utils.file_io import PathManager + + +""" +Internal utilities for tests. Don't use except for writing tests. +""" + + +def get_model_no_weights(config_path): + """ + Like model_zoo.get, but do not load any weights (even pretrained) + """ + cfg = model_zoo.get_config(config_path) + if isinstance(cfg, CfgNode): + if not torch.cuda.is_available(): + cfg.MODEL.DEVICE = "cpu" + return build_model(cfg) + else: + return instantiate(cfg.model) + + +def random_boxes(num_boxes, max_coord=100, device="cpu"): + """ + Create a random Nx4 boxes tensor, with coordinates < max_coord. + """ + boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5) + boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression + # Note: the implementation of this function in torchvision is: + # boxes[:, 2:] += torch.rand(N, 2) * 100 + # but it does not guarantee non-negative widths/heights constraints: + # boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]: + boxes[:, 2:] += boxes[:, :2] + return boxes + + +def get_sample_coco_image(tensor=True): + """ + Args: + tensor (bool): if True, returns 3xHxW tensor. + else, returns a HxWx3 numpy array. + + Returns: + an image, in BGR color. + """ + try: + file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"] + if not PathManager.exists(file_name): + raise FileNotFoundError() + except IOError: + # for public CI to run + file_name = PathManager.get_local_path( + "http://images.cocodataset.org/train2017/000000000009.jpg" + ) + ret = read_image(file_name, format="BGR") + if tensor: + ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1))) + return ret + + +def convert_scripted_instances(instances): + """ + Convert a scripted Instances object to a regular :class:`Instances` object + """ + assert hasattr( + instances, "image_size" + ), f"Expect an Instances object, but got {type(instances)}!" + ret = Instances(instances.image_size) + for name in instances._field_names: + val = getattr(instances, "_" + name, None) + if val is not None: + ret.set(name, val) + return ret + + +def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False): + """ + Args: + input, other (Instances): + size_as_tensor: compare image_size of the Instances as tensors (instead of tuples). + Useful for comparing outputs of tracing. + """ + if not isinstance(input, Instances): + input = convert_scripted_instances(input) + if not isinstance(other, Instances): + other = convert_scripted_instances(other) + + if not msg: + msg = "Two Instances are different! " + else: + msg = msg.rstrip() + " " + + size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!" + if size_as_tensor: + assert torch.equal( + torch.tensor(input.image_size), torch.tensor(other.image_size) + ), size_error_msg + else: + assert input.image_size == other.image_size, size_error_msg + fields = sorted(input.get_fields().keys()) + fields_other = sorted(other.get_fields().keys()) + assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!" + + for f in fields: + val1, val2 = input.get(f), other.get(f) + if isinstance(val1, (Boxes, ROIMasks)): + # boxes in the range of O(100) and can have a larger tolerance + assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), ( + msg + f"Field {f} differs too much!" + ) + elif isinstance(val1, torch.Tensor): + if val1.dtype.is_floating_point: + mag = torch.abs(val1).max().cpu().item() + assert torch.allclose(val1, val2, atol=mag * rtol), ( + msg + f"Field {f} differs too much!" + ) + else: + assert torch.equal(val1, val2), msg + f"Field {f} is different!" + else: + raise ValueError(f"Don't know how to compare type {type(val1)}") + + +def reload_script_model(module): + """ + Save a jit module and load it back. + Similar to the `getExportImportCopy` function in torch/testing/ + """ + buffer = io.BytesIO() + torch.jit.save(module, buffer) + buffer.seek(0) + return torch.jit.load(buffer) + + +def reload_lazy_config(cfg): + """ + Save an object by LazyConfig.save and load it back. + This is used to test that a config still works the same after + serialization/deserialization. + """ + with tempfile.TemporaryDirectory(prefix="detectron2") as d: + fname = os.path.join(d, "d2_cfg_test.yaml") + LazyConfig.save(cfg, fname) + return LazyConfig.load(fname) + + +def min_torch_version(min_version: str) -> bool: + """ + Returns True when torch's version is at least `min_version`. + """ + try: + import torch + except ImportError: + return False + + installed_version = version.parse(torch.__version__.split("+")[0]) + min_version = version.parse(min_version) + return installed_version >= min_version + + +def has_dynamic_axes(onnx_model): + """ + Return True when all ONNX input/output have only dynamic axes for all ranks + """ + return all( + not dim.dim_param.isnumeric() + for inp in onnx_model.graph.input + for dim in inp.type.tensor_type.shape.dim + ) and all( + not dim.dim_param.isnumeric() + for out in onnx_model.graph.output + for dim in out.type.tensor_type.shape.dim + ) + + +def register_custom_op_onnx_export( + opname: str, symbolic_fn: Callable, opset_version: int, min_version: str +) -> None: + """ + Register `symbolic_fn` as PyTorch's symbolic `opname`-`opset_version` for ONNX export. + The registration is performed only when current PyTorch's version is < `min_version.` + IMPORTANT: symbolic must be manually unregistered after the caller function returns + """ + if min_torch_version(min_version): + return + register_custom_op_symbolic(opname, symbolic_fn, opset_version) + print(f"_register_custom_op_onnx_export({opname}, {opset_version}) succeeded.") + + +def unregister_custom_op_onnx_export(opname: str, opset_version: int, min_version: str) -> None: + """ + Unregister PyTorch's symbolic `opname`-`opset_version` for ONNX export. + The un-registration is performed only when PyTorch's version is < `min_version` + IMPORTANT: The symbolic must have been manually registered by the caller, otherwise + the incorrect symbolic may be unregistered instead. + """ + + # TODO: _unregister_custom_op_symbolic is introduced PyTorch>=1.10 + # Remove after PyTorch 1.10+ is used by ALL detectron2's CI + try: + from torch.onnx import unregister_custom_op_symbolic as _unregister_custom_op_symbolic + except ImportError: + + def _unregister_custom_op_symbolic(symbolic_name, opset_version): + import torch.onnx.symbolic_registry as sym_registry + from torch.onnx.symbolic_helper import _onnx_main_opset, _onnx_stable_opsets + + def _get_ns_op_name_from_custom_op(symbolic_name): + try: + from torch.onnx.utils import get_ns_op_name_from_custom_op + + ns, op_name = get_ns_op_name_from_custom_op(symbolic_name) + except ImportError as import_error: + if not bool( + re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name) + ): + raise ValueError( + f"Invalid symbolic name {symbolic_name}. Must be `domain::name`" + ) from import_error + + ns, op_name = symbolic_name.split("::") + if ns == "onnx": + raise ValueError(f"{ns} domain cannot be modified.") from import_error + + if ns == "aten": + ns = "" + + return ns, op_name + + def _unregister_op(opname: str, domain: str, version: int): + try: + sym_registry.unregister_op(op_name, ns, ver) + except AttributeError as attribute_error: + if sym_registry.is_registered_op(opname, domain, version): + del sym_registry._registry[(domain, version)][opname] + if not sym_registry._registry[(domain, version)]: + del sym_registry._registry[(domain, version)] + else: + raise RuntimeError( + f"The opname {opname} is not registered." + ) from attribute_error + + ns, op_name = _get_ns_op_name_from_custom_op(symbolic_name) + for ver in _onnx_stable_opsets + [_onnx_main_opset]: + if ver >= opset_version: + _unregister_op(op_name, ns, ver) + + if min_torch_version(min_version): + return + _unregister_custom_op_symbolic(opname, opset_version) + print(f"_unregister_custom_op_onnx_export({opname}, {opset_version}) succeeded.") + + +skipIfOnCPUCI = unittest.skipIf( + os.environ.get("CI") and not torch.cuda.is_available(), + "The test is too slow on CPUs and will be executed on CircleCI's GPU jobs.", +) + + +def skipIfUnsupportedMinOpsetVersion(min_opset_version, current_opset_version=None): + """ + Skips tests for ONNX Opset versions older than min_opset_version. + """ + + def skip_dec(func): + def wrapper(self): + try: + opset_version = self.opset_version + except AttributeError: + opset_version = current_opset_version + if opset_version < min_opset_version: + raise unittest.SkipTest( + f"Unsupported opset_version {opset_version}" + f", required is {min_opset_version}" + ) + return func(self) + + return wrapper + + return skip_dec + + +def skipIfUnsupportedMinTorchVersion(min_version): + """ + Skips tests for PyTorch versions older than min_version. + """ + reason = f"module 'torch' has __version__ {torch.__version__}" f", required is: {min_version}" + return unittest.skipIf(not min_torch_version(min_version), reason) + + +# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI +def _pytorch1111_symbolic_opset9_to(g, self, *args): + """aten::to() symbolic that must be used for testing with PyTorch < 1.11.1.""" + + def is_aten_to_device_only(args): + if len(args) == 4: + # aten::to(Tensor, Device, bool, bool, memory_format) + return ( + args[0].node().kind() == "prim::device" + or args[0].type().isSubtypeOf(ListType.ofInts()) + or ( + sym_help._is_value(args[0]) + and args[0].node().kind() == "onnx::Constant" + and isinstance(args[0].node()["value"], str) + ) + ) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = sym_help._get_const(args[1], "i", "dtype") + return dtype is None + elif len(args) in (6, 7): + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) + # When dtype is None, this is a aten::to(device) call + dtype = sym_help._get_const(args[0], "i", "dtype") + return dtype is None + return False + + # ONNX doesn't have a concept of a device, so we ignore device-only casts + if is_aten_to_device_only(args): + return self + + if len(args) == 4: + # TestONNXRuntime::test_ones_bool shows args[0] of aten::to can be onnx::Constant[Tensor] + # In this case, the constant value is a tensor not int, + # so sym_help._maybe_get_const(args[0], 'i') would not work. + dtype = args[0] + if sym_help._is_value(args[0]) and args[0].node().kind() == "onnx::Constant": + tval = args[0].node()["value"] + if isinstance(tval, torch.Tensor): + if len(tval.shape) == 0: + tval = tval.item() + dtype = int(tval) + else: + dtype = tval + + if sym_help._is_value(dtype) or isinstance(dtype, torch.Tensor): + # aten::to(Tensor, Tensor, bool, bool, memory_format) + dtype = args[0].type().scalarType() + return g.op("Cast", self, to_i=sym_help.cast_pytorch_to_onnx[dtype]) + else: + # aten::to(Tensor, ScalarType, bool, bool, memory_format) + # memory_format is ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 5: + # aten::to(Tensor, Device, ScalarType, bool, bool, memory_format) + dtype = sym_help._get_const(args[1], "i", "dtype") + # memory_format is ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 6: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, memory_format) + dtype = sym_help._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + elif len(args) == 7: + # aten::to(Tensor, ScalarType, Layout, Device, bool, bool, bool, memory_format) + dtype = sym_help._get_const(args[0], "i", "dtype") + # Layout, device and memory_format are ignored + return g.op("Cast", self, to_i=sym_help.scalar_type_to_onnx[dtype]) + else: + return sym_help._onnx_unsupported("Unknown aten::to signature") + + +# TODO: Remove after PyTorch 1.11.1+ is used by detectron2's CI +def _pytorch1111_symbolic_opset9_repeat_interleave(g, self, repeats, dim=None, output_size=None): + + # from torch.onnx.symbolic_helper import ScalarType + from torch.onnx.symbolic_opset9 import expand, unsqueeze + + input = self + # if dim is None flatten + # By default, use the flattened input array, and return a flat output array + if sym_help._is_none(dim): + input = sym_help._reshape_helper(g, self, g.op("Constant", value_t=torch.tensor([-1]))) + dim = 0 + else: + dim = sym_help._maybe_get_scalar(dim) + + repeats_dim = sym_help._get_tensor_rank(repeats) + repeats_sizes = sym_help._get_tensor_sizes(repeats) + input_sizes = sym_help._get_tensor_sizes(input) + if repeats_dim is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "repeats rank." + ) + if repeats_sizes is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "repeats size." + ) + if input_sizes is None: + raise RuntimeError( + "Unsupported: ONNX export of repeat_interleave for unknown " "input size." + ) + + input_sizes_temp = input_sizes.copy() + for idx, input_size in enumerate(input_sizes): + if input_size is None: + input_sizes[idx], input_sizes_temp[idx] = 0, -1 + + # Cases where repeats is an int or single value tensor + if repeats_dim == 0 or (repeats_dim == 1 and repeats_sizes[0] == 1): + if not sym_help._is_tensor(repeats): + repeats = g.op("Constant", value_t=torch.LongTensor(repeats)) + if input_sizes[dim] == 0: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + ) + else: + reps = input_sizes[dim] + repeats = expand(g, repeats, g.op("Constant", value_t=torch.tensor([reps])), None) + + # Cases where repeats is a 1 dim Tensor + elif repeats_dim == 1: + if input_sizes[dim] == 0: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", + 9, + 13, + "Unsupported along dimension with unknown input size", + ) + if repeats_sizes[0] is None: + return sym_help._onnx_opset_unsupported_detailed( + "repeat_interleave", 9, 13, "Unsupported for cases with dynamic repeats" + ) + assert ( + repeats_sizes[0] == input_sizes[dim] + ), "repeats must have the same size as input along dim" + reps = repeats_sizes[0] + else: + raise RuntimeError("repeats must be 0-dim or 1-dim tensor") + + final_splits = list() + r_splits = sym_help._repeat_interleave_split_helper(g, repeats, reps, 0) + if isinstance(r_splits, torch._C.Value): + r_splits = [r_splits] + i_splits = sym_help._repeat_interleave_split_helper(g, input, reps, dim) + if isinstance(i_splits, torch._C.Value): + i_splits = [i_splits] + input_sizes[dim], input_sizes_temp[dim] = -1, 1 + for idx, r_split in enumerate(r_splits): + i_split = unsqueeze(g, i_splits[idx], dim + 1) + r_concat = [ + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[: dim + 1])), + r_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes_temp[dim + 1 :])), + ] + r_concat = g.op("Concat", *r_concat, axis_i=0) + i_split = expand(g, i_split, r_concat, None) + i_split = sym_help._reshape_helper( + g, + i_split, + g.op("Constant", value_t=torch.LongTensor(input_sizes)), + allowzero=0, + ) + final_splits.append(i_split) + return g.op("Concat", *final_splits, axis_i=dim) diff --git a/custom_detectron2/utils/tracing.py b/custom_detectron2/utils/tracing.py new file mode 100644 index 0000000000000000000000000000000000000000..3ffc44d23cb6aa43e940bf8562a9130f2a4f27a8 --- /dev/null +++ b/custom_detectron2/utils/tracing.py @@ -0,0 +1,71 @@ +import inspect +import torch + +from custom_detectron2.utils.env import TORCH_VERSION + +try: + from torch.fx._symbolic_trace import is_fx_tracing as is_fx_tracing_current + + tracing_current_exists = True +except ImportError: + tracing_current_exists = False + +try: + from torch.fx._symbolic_trace import _orig_module_call + + tracing_legacy_exists = True +except ImportError: + tracing_legacy_exists = False + + +@torch.jit.ignore +def is_fx_tracing_legacy() -> bool: + """ + Returns a bool indicating whether torch.fx is currently symbolically tracing a module. + Can be useful for gating module logic that is incompatible with symbolic tracing. + """ + return torch.nn.Module.__call__ is not _orig_module_call + + +@torch.jit.ignore +def is_fx_tracing() -> bool: + """Returns whether execution is currently in + Torch FX tracing mode""" + if TORCH_VERSION >= (1, 10) and tracing_current_exists: + return is_fx_tracing_current() + elif tracing_legacy_exists: + return is_fx_tracing_legacy() + else: + # Can't find either current or legacy tracing indication code. + # Enabling this assert_fx_safe() call regardless of tracing status. + return False + + +@torch.jit.ignore +def assert_fx_safe(condition: bool, message: str) -> torch.Tensor: + """An FX-tracing safe version of assert. + Avoids erroneous type assertion triggering when types are masked inside + an fx.proxy.Proxy object during tracing. + Args: condition - either a boolean expression or a string representing + the condition to test. If this assert triggers an exception when tracing + due to dynamic control flow, try encasing the expression in quotation + marks and supplying it as a string.""" + # Must return a concrete tensor for compatibility with PyTorch <=1.8. + # If <=1.8 compatibility is not needed, return type can be converted to None + if not is_fx_tracing(): + try: + if isinstance(condition, str): + caller_frame = inspect.currentframe().f_back + torch._assert( + eval(condition, caller_frame.f_globals, caller_frame.f_locals), message + ) + return torch.ones(1) + else: + torch._assert(condition, message) + return torch.ones(1) + except torch.fx.proxy.TraceError as e: + print( + "Found a non-FX compatible assertion. Skipping the check. Failure is shown below" + + str(e) + ) + return torch.zeros(1) diff --git a/custom_detectron2/utils/video_visualizer.py b/custom_detectron2/utils/video_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..a22836d17eb5895d66fd5d8f68cc109d132ea9fe --- /dev/null +++ b/custom_detectron2/utils/video_visualizer.py @@ -0,0 +1,287 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import numpy as np +from typing import List +import custom_pycocotools.mask as mask_util + +from custom_detectron2.structures import Instances +from custom_detectron2.utils.visualizer import ( + ColorMode, + Visualizer, + _create_text_labels, + _PanopticPrediction, +) + +from .colormap import random_color, random_colors + + +class _DetectedInstance: + """ + Used to store data about detected objects in video frame, + in order to transfer color to objects in the future frames. + + Attributes: + label (int): + bbox (tuple[float]): + mask_rle (dict): + color (tuple[float]): RGB colors in range (0, 1) + ttl (int): time-to-live for the instance. For example, if ttl=2, + the instance color can be transferred to objects in the next two frames. + """ + + __slots__ = ["label", "bbox", "mask_rle", "color", "ttl"] + + def __init__(self, label, bbox, mask_rle, color, ttl): + self.label = label + self.bbox = bbox + self.mask_rle = mask_rle + self.color = color + self.ttl = ttl + + +class VideoVisualizer: + def __init__(self, metadata, instance_mode=ColorMode.IMAGE): + """ + Args: + metadata (MetadataCatalog): image metadata. + """ + self.metadata = metadata + self._old_instances = [] + assert instance_mode in [ + ColorMode.IMAGE, + ColorMode.IMAGE_BW, + ], "Other mode not supported yet." + self._instance_mode = instance_mode + self._max_num_instances = self.metadata.get("max_num_instances", 74) + self._assigned_colors = {} + self._color_pool = random_colors(self._max_num_instances, rgb=True, maximum=1) + self._color_idx_set = set(range(len(self._color_pool))) + + def draw_instance_predictions(self, frame, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + frame (ndarray): an RGB image of shape (H, W, C), in the range [0, 255]. + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + frame_visualizer = Visualizer(frame, self.metadata) + num_instances = len(predictions) + if num_instances == 0: + return frame_visualizer.output + + boxes = predictions.pred_boxes.tensor.numpy() if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.numpy() if predictions.has("pred_classes") else None + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + colors = predictions.COLOR if predictions.has("COLOR") else [None] * len(predictions) + periods = predictions.ID_period if predictions.has("ID_period") else None + period_threshold = self.metadata.get("period_threshold", 0) + visibilities = ( + [True] * len(predictions) + if periods is None + else [x > period_threshold for x in periods] + ) + + if predictions.has("pred_masks"): + masks = predictions.pred_masks + # mask IOU is not yet enabled + # masks_rles = mask_util.encode(np.asarray(masks.permute(1, 2, 0), order="F")) + # assert len(masks_rles) == num_instances + else: + masks = None + + if not predictions.has("COLOR"): + if predictions.has("ID"): + colors = self._assign_colors_by_id(predictions) + else: + # ToDo: clean old assign color method and use a default tracker to assign id + detected = [ + _DetectedInstance(classes[i], boxes[i], mask_rle=None, color=colors[i], ttl=8) + for i in range(num_instances) + ] + colors = self._assign_colors(detected) + + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + + if self._instance_mode == ColorMode.IMAGE_BW: + # any() returns uint8 tensor + frame_visualizer.output.reset_image( + frame_visualizer._create_grayscale_image( + (masks.any(dim=0) > 0).numpy() if masks is not None else None + ) + ) + alpha = 0.3 + else: + alpha = 0.5 + + labels = ( + None + if labels is None + else [y[0] for y in filter(lambda x: x[1], zip(labels, visibilities))] + ) # noqa + assigned_colors = ( + None + if colors is None + else [y[0] for y in filter(lambda x: x[1], zip(colors, visibilities))] + ) # noqa + frame_visualizer.overlay_instances( + boxes=None if masks is not None else boxes[visibilities], # boxes are a bit distracting + masks=None if masks is None else masks[visibilities], + labels=labels, + keypoints=None if keypoints is None else keypoints[visibilities], + assigned_colors=assigned_colors, + alpha=alpha, + ) + + return frame_visualizer.output + + def draw_sem_seg(self, frame, sem_seg, area_threshold=None): + """ + Args: + sem_seg (ndarray or Tensor): semantic segmentation of shape (H, W), + each value is the integer label. + area_threshold (Optional[int]): only draw segmentations larger than the threshold + """ + # don't need to do anything special + frame_visualizer = Visualizer(frame, self.metadata) + frame_visualizer.draw_sem_seg(sem_seg, area_threshold=None) + return frame_visualizer.output + + def draw_panoptic_seg_predictions( + self, frame, panoptic_seg, segments_info, area_threshold=None, alpha=0.5 + ): + frame_visualizer = Visualizer(frame, self.metadata) + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + frame_visualizer.output.reset_image( + frame_visualizer._create_grayscale_image(pred.non_empty_mask()) + ) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + frame_visualizer.draw_binary_mask( + mask, + color=mask_color, + text=self.metadata.stuff_classes[category_idx], + alpha=alpha, + area_threshold=area_threshold, + ) + + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return frame_visualizer.output + # draw mask for all instances second + masks, sinfo = list(zip(*all_instances)) + num_instances = len(masks) + masks_rles = mask_util.encode( + np.asarray(np.asarray(masks).transpose(1, 2, 0), dtype=np.uint8, order="F") + ) + assert len(masks_rles) == num_instances + + category_ids = [x["category_id"] for x in sinfo] + detected = [ + _DetectedInstance(category_ids[i], bbox=None, mask_rle=masks_rles[i], color=None, ttl=8) + for i in range(num_instances) + ] + colors = self._assign_colors(detected) + labels = [self.metadata.thing_classes[k] for k in category_ids] + + frame_visualizer.overlay_instances( + boxes=None, + masks=masks, + labels=labels, + keypoints=None, + assigned_colors=colors, + alpha=alpha, + ) + return frame_visualizer.output + + def _assign_colors(self, instances): + """ + Naive tracking heuristics to assign same color to the same instance, + will update the internal state of tracked instances. + + Returns: + list[tuple[float]]: list of colors. + """ + + # Compute iou with either boxes or masks: + is_crowd = np.zeros((len(instances),), dtype=bool) + if instances[0].bbox is None: + assert instances[0].mask_rle is not None + # use mask iou only when box iou is None + # because box seems good enough + rles_old = [x.mask_rle for x in self._old_instances] + rles_new = [x.mask_rle for x in instances] + ious = mask_util.iou(rles_old, rles_new, is_crowd) + threshold = 0.5 + else: + boxes_old = [x.bbox for x in self._old_instances] + boxes_new = [x.bbox for x in instances] + ious = mask_util.iou(boxes_old, boxes_new, is_crowd) + threshold = 0.6 + if len(ious) == 0: + ious = np.zeros((len(self._old_instances), len(instances)), dtype="float32") + + # Only allow matching instances of the same label: + for old_idx, old in enumerate(self._old_instances): + for new_idx, new in enumerate(instances): + if old.label != new.label: + ious[old_idx, new_idx] = 0 + + matched_new_per_old = np.asarray(ious).argmax(axis=1) + max_iou_per_old = np.asarray(ious).max(axis=1) + + # Try to find match for each old instance: + extra_instances = [] + for idx, inst in enumerate(self._old_instances): + if max_iou_per_old[idx] > threshold: + newidx = matched_new_per_old[idx] + if instances[newidx].color is None: + instances[newidx].color = inst.color + continue + # If an old instance does not match any new instances, + # keep it for the next frame in case it is just missed by the detector + inst.ttl -= 1 + if inst.ttl > 0: + extra_instances.append(inst) + + # Assign random color to newly-detected instances: + for inst in instances: + if inst.color is None: + inst.color = random_color(rgb=True, maximum=1) + self._old_instances = instances[:] + extra_instances + return [d.color for d in instances] + + def _assign_colors_by_id(self, instances: Instances) -> List: + colors = [] + untracked_ids = set(self._assigned_colors.keys()) + for id in instances.ID: + if id in self._assigned_colors: + colors.append(self._color_pool[self._assigned_colors[id]]) + untracked_ids.remove(id) + else: + assert ( + len(self._color_idx_set) >= 1 + ), f"Number of id exceeded maximum, \ + max = {self._max_num_instances}" + idx = self._color_idx_set.pop() + color = self._color_pool[idx] + self._assigned_colors[id] = idx + colors.append(color) + for id in untracked_ids: + self._color_idx_set.add(self._assigned_colors[id]) + del self._assigned_colors[id] + return colors diff --git a/custom_detectron2/utils/visualizer.py b/custom_detectron2/utils/visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a6fd59fb67867a9a420f03fc42d36779f989ed --- /dev/null +++ b/custom_detectron2/utils/visualizer.py @@ -0,0 +1,1267 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +import colorsys +import logging +import math +import numpy as np +from enum import Enum, unique +import cv2 +import matplotlib as mpl +import matplotlib.colors as mplc +import matplotlib.figure as mplfigure +import custom_pycocotools.mask as mask_util +import torch +from matplotlib.backends.backend_agg import FigureCanvasAgg +from PIL import Image + +from custom_detectron2.data import MetadataCatalog +from custom_detectron2.structures import BitMasks, Boxes, BoxMode, Keypoints, PolygonMasks, RotatedBoxes +from custom_detectron2.utils.file_io import PathManager + +from .colormap import random_color + +logger = logging.getLogger(__name__) + +__all__ = ["ColorMode", "VisImage", "Visualizer"] + + +_SMALL_OBJECT_AREA_THRESH = 1000 +_LARGE_MASK_AREA_THRESH = 120000 +_OFF_WHITE = (1.0, 1.0, 240.0 / 255) +_BLACK = (0, 0, 0) +_RED = (1.0, 0, 0) + +_KEYPOINT_THRESHOLD = 0.05 + + +@unique +class ColorMode(Enum): + """ + Enum of different color modes to use for instance visualizations. + """ + + IMAGE = 0 + """ + Picks a random color for every instance and overlay segmentations with low opacity. + """ + SEGMENTATION = 1 + """ + Let instances of the same category have similar colors + (from metadata.thing_colors), and overlay them with + high opacity. This provides more attention on the quality of segmentation. + """ + IMAGE_BW = 2 + """ + Same as IMAGE, but convert all areas without masks to gray-scale. + Only available for drawing per-instance mask predictions. + """ + + +class GenericMask: + """ + Attribute: + polygons (list[ndarray]): list[ndarray]: polygons for this mask. + Each ndarray has format [x, y, x, y, ...] + mask (ndarray): a binary mask + """ + + def __init__(self, mask_or_polygons, height, width): + self._mask = self._polygons = self._has_holes = None + self.height = height + self.width = width + + m = mask_or_polygons + if isinstance(m, dict): + # RLEs + assert "counts" in m and "size" in m + if isinstance(m["counts"], list): # uncompressed RLEs + h, w = m["size"] + assert h == height and w == width + m = mask_util.frPyObjects(m, h, w) + self._mask = mask_util.decode(m)[:, :] + return + + if isinstance(m, list): # list[ndarray] + self._polygons = [np.asarray(x).reshape(-1) for x in m] + return + + if isinstance(m, np.ndarray): # assumed to be a binary mask + assert m.shape[1] != 2, m.shape + assert m.shape == ( + height, + width, + ), f"mask shape: {m.shape}, target dims: {height}, {width}" + self._mask = m.astype("uint8") + return + + raise ValueError("GenericMask cannot handle object {} of type '{}'".format(m, type(m))) + + @property + def mask(self): + if self._mask is None: + self._mask = self.polygons_to_mask(self._polygons) + return self._mask + + @property + def polygons(self): + if self._polygons is None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + return self._polygons + + @property + def has_holes(self): + if self._has_holes is None: + if self._mask is not None: + self._polygons, self._has_holes = self.mask_to_polygons(self._mask) + else: + self._has_holes = False # if original format is polygon, does not have holes + return self._has_holes + + def mask_to_polygons(self, mask): + # cv2.RETR_CCOMP flag retrieves all the contours and arranges them to a 2-level + # hierarchy. External contours (boundary) of the object are placed in hierarchy-1. + # Internal contours (holes) are placed in hierarchy-2. + # cv2.CHAIN_APPROX_NONE flag gets vertices of polygons from contours. + mask = np.ascontiguousarray(mask) # some versions of cv2 does not support incontiguous arr + res = cv2.findContours(mask.astype("uint8"), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE) + hierarchy = res[-1] + if hierarchy is None: # empty mask + return [], False + has_holes = (hierarchy.reshape(-1, 4)[:, 3] >= 0).sum() > 0 + res = res[-2] + res = [x.flatten() for x in res] + # These coordinates from OpenCV are integers in range [0, W-1 or H-1]. + # We add 0.5 to turn them into real-value coordinate space. A better solution + # would be to first +0.5 and then dilate the returned polygon by 0.5. + res = [x + 0.5 for x in res if len(x) >= 6] + return res, has_holes + + def polygons_to_mask(self, polygons): + rle = mask_util.frPyObjects(polygons, self.height, self.width) + rle = mask_util.merge(rle) + return mask_util.decode(rle)[:, :] + + def area(self): + return self.mask.sum() + + def bbox(self): + p = mask_util.frPyObjects(self.polygons, self.height, self.width) + p = mask_util.merge(p) + bbox = mask_util.toBbox(p) + bbox[2] += bbox[0] + bbox[3] += bbox[1] + return bbox + + +class _PanopticPrediction: + """ + Unify different panoptic annotation/prediction formats + """ + + def __init__(self, panoptic_seg, segments_info, metadata=None): + if segments_info is None: + assert metadata is not None + # If "segments_info" is None, we assume "panoptic_img" is a + # H*W int32 image storing the panoptic_id in the format of + # category_id * label_divisor + instance_id. We reserve -1 for + # VOID label. + label_divisor = metadata.label_divisor + segments_info = [] + for panoptic_label in np.unique(panoptic_seg.numpy()): + if panoptic_label == -1: + # VOID region. + continue + pred_class = panoptic_label // label_divisor + isthing = pred_class in metadata.thing_dataset_id_to_contiguous_id.values() + segments_info.append( + { + "id": int(panoptic_label), + "category_id": int(pred_class), + "isthing": bool(isthing), + } + ) + del metadata + + self._seg = panoptic_seg + + self._sinfo = {s["id"]: s for s in segments_info} # seg id -> seg info + segment_ids, areas = torch.unique(panoptic_seg, sorted=True, return_counts=True) + areas = areas.numpy() + sorted_idxs = np.argsort(-areas) + self._seg_ids, self._seg_areas = segment_ids[sorted_idxs], areas[sorted_idxs] + self._seg_ids = self._seg_ids.tolist() + for sid, area in zip(self._seg_ids, self._seg_areas): + if sid in self._sinfo: + self._sinfo[sid]["area"] = float(area) + + def non_empty_mask(self): + """ + Returns: + (H, W) array, a mask for all pixels that have a prediction + """ + empty_ids = [] + for id in self._seg_ids: + if id not in self._sinfo: + empty_ids.append(id) + if len(empty_ids) == 0: + return np.zeros(self._seg.shape, dtype=np.uint8) + assert ( + len(empty_ids) == 1 + ), ">1 ids corresponds to no labels. This is currently not supported" + return (self._seg != empty_ids[0]).numpy().astype(bool) + + def semantic_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or sinfo["isthing"]: + # Some pixels (e.g. id 0 in PanopticFPN) have no instance or semantic predictions. + continue + yield (self._seg == sid).numpy().astype(bool), sinfo + + def instance_masks(self): + for sid in self._seg_ids: + sinfo = self._sinfo.get(sid) + if sinfo is None or not sinfo["isthing"]: + continue + mask = (self._seg == sid).numpy().astype(bool) + if mask.sum() > 0: + yield mask, sinfo + + +def _create_text_labels(classes, scores, class_names, is_crowd=None): + """ + Args: + classes (list[int] or None): + scores (list[float] or None): + class_names (list[str] or None): + is_crowd (list[bool] or None): + + Returns: + list[str] or None + """ + labels = None + if classes is not None: + if class_names is not None and len(class_names) > 0: + labels = [class_names[i] for i in classes] + else: + labels = [str(i) for i in classes] + if scores is not None: + if labels is None: + labels = ["{:.0f}%".format(s * 100) for s in scores] + else: + labels = ["{} {:.0f}%".format(l, s * 100) for l, s in zip(labels, scores)] + if labels is not None and is_crowd is not None: + labels = [l + ("|crowd" if crowd else "") for l, crowd in zip(labels, is_crowd)] + return labels + + +class VisImage: + def __init__(self, img, scale=1.0): + """ + Args: + img (ndarray): an RGB image of shape (H, W, 3) in range [0, 255]. + scale (float): scale the input image + """ + self.img = img + self.scale = scale + self.width, self.height = img.shape[1], img.shape[0] + self._setup_figure(img) + + def _setup_figure(self, img): + """ + Args: + Same as in :meth:`__init__()`. + + Returns: + fig (matplotlib.pyplot.figure): top level container for all the image plot elements. + ax (matplotlib.pyplot.Axes): contains figure elements and sets the coordinate system. + """ + fig = mplfigure.Figure(frameon=False) + self.dpi = fig.get_dpi() + # add a small 1e-2 to avoid precision lost due to matplotlib's truncation + # (https://github.com/matplotlib/matplotlib/issues/15363) + fig.set_size_inches( + (self.width * self.scale + 1e-2) / self.dpi, + (self.height * self.scale + 1e-2) / self.dpi, + ) + self.canvas = FigureCanvasAgg(fig) + # self.canvas = mpl.backends.backend_cairo.FigureCanvasCairo(fig) + ax = fig.add_axes([0.0, 0.0, 1.0, 1.0]) + ax.axis("off") + self.fig = fig + self.ax = ax + self.reset_image(img) + + def reset_image(self, img): + """ + Args: + img: same as in __init__ + """ + img = img.astype("uint8") + self.ax.imshow(img, extent=(0, self.width, self.height, 0), interpolation="nearest") + + def save(self, filepath): + """ + Args: + filepath (str): a string that contains the absolute path, including the file name, where + the visualized image will be saved. + """ + self.fig.savefig(filepath) + + def get_image(self): + """ + Returns: + ndarray: + the visualized image of shape (H, W, 3) (RGB) in uint8 type. + The shape is scaled w.r.t the input image using the given `scale` argument. + """ + canvas = self.canvas + s, (width, height) = canvas.print_to_buffer() + # buf = io.BytesIO() # works for cairo backend + # canvas.print_rgba(buf) + # width, height = self.width, self.height + # s = buf.getvalue() + + buffer = np.frombuffer(s, dtype="uint8") + + img_rgba = buffer.reshape(height, width, 4) + rgb, alpha = np.split(img_rgba, [3], axis=2) + return rgb.astype("uint8") + + +class Visualizer: + """ + Visualizer that draws data about detection/segmentation on images. + + It contains methods like `draw_{text,box,circle,line,binary_mask,polygon}` + that draw primitive objects to images, as well as high-level wrappers like + `draw_{instance_predictions,sem_seg,panoptic_seg_predictions,dataset_dict}` + that draw composite data in some pre-defined style. + + Note that the exact visualization style for the high-level wrappers are subject to change. + Style such as color, opacity, label contents, visibility of labels, or even the visibility + of objects themselves (e.g. when the object is too small) may change according + to different heuristics, as long as the results still look visually reasonable. + + To obtain a consistent style, you can implement custom drawing functions with the + abovementioned primitive methods instead. If you need more customized visualization + styles, you can process the data yourself following their format documented in + tutorials (:doc:`/tutorials/models`, :doc:`/tutorials/datasets`). This class does not + intend to satisfy everyone's preference on drawing styles. + + This visualizer focuses on high rendering quality rather than performance. It is not + designed to be used for real-time applications. + """ + + # TODO implement a fast, rasterized version using OpenCV + + def __init__(self, img_rgb, metadata=None, scale=1.0, instance_mode=ColorMode.IMAGE): + """ + Args: + img_rgb: a numpy array of shape (H, W, C), where H and W correspond to + the height and width of the image respectively. C is the number of + color channels. The image is required to be in RGB format since that + is a requirement of the Matplotlib library. The image is also expected + to be in the range [0, 255]. + metadata (Metadata): dataset metadata (e.g. class names and colors) + instance_mode (ColorMode): defines one of the pre-defined style for drawing + instances on an image. + """ + self.img = np.asarray(img_rgb).clip(0, 255).astype(np.uint8) + if metadata is None: + metadata = MetadataCatalog.get("__nonexist__") + self.metadata = metadata + self.output = VisImage(self.img, scale=scale) + self.cpu_device = torch.device("cpu") + + # too small texts are useless, therefore clamp to 9 + self._default_font_size = max( + np.sqrt(self.output.height * self.output.width) // 90, 10 // scale + ) + self._instance_mode = instance_mode + self.keypoint_threshold = _KEYPOINT_THRESHOLD + + def draw_instance_predictions(self, predictions): + """ + Draw instance-level prediction results on an image. + + Args: + predictions (Instances): the output of an instance detection/segmentation + model. Following fields will be used to draw: + "pred_boxes", "pred_classes", "scores", "pred_masks" (or "pred_masks_rle"). + + Returns: + output (VisImage): image object with visualizations. + """ + boxes = predictions.pred_boxes if predictions.has("pred_boxes") else None + scores = predictions.scores if predictions.has("scores") else None + classes = predictions.pred_classes.tolist() if predictions.has("pred_classes") else None + labels = _create_text_labels(classes, scores, self.metadata.get("thing_classes", None)) + keypoints = predictions.pred_keypoints if predictions.has("pred_keypoints") else None + + if predictions.has("pred_masks"): + masks = np.asarray(predictions.pred_masks) + masks = [GenericMask(x, self.output.height, self.output.width) for x in masks] + else: + masks = None + + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in classes + ] + alpha = 0.8 + else: + colors = None + alpha = 0.5 + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image( + self._create_grayscale_image( + (predictions.pred_masks.any(dim=0) > 0).numpy() + if predictions.has("pred_masks") + else None + ) + ) + alpha = 0.3 + + self.overlay_instances( + masks=masks, + boxes=boxes, + labels=labels, + keypoints=keypoints, + assigned_colors=colors, + alpha=alpha, + ) + return self.output + + def draw_sem_seg(self, sem_seg, area_threshold=None, alpha=0.8): + """ + Draw semantic segmentation predictions/labels. + + Args: + sem_seg (Tensor or ndarray): the segmentation of shape (H, W). + Each value is the integer label of the pixel. + area_threshold (int): segments with less than `area_threshold` are not drawn. + alpha (float): the larger it is, the more opaque the segmentations are. + + Returns: + output (VisImage): image object with visualizations. + """ + if isinstance(sem_seg, torch.Tensor): + sem_seg = sem_seg.numpy() + labels, areas = np.unique(sem_seg, return_counts=True) + sorted_idxs = np.argsort(-areas).tolist() + labels = labels[sorted_idxs] + for label in filter(lambda l: l < len(self.metadata.stuff_classes), labels): + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[label]] + except (AttributeError, IndexError): + mask_color = None + + binary_mask = (sem_seg == label).astype(np.uint8) + text = self.metadata.stuff_classes[label] + self.draw_binary_mask( + binary_mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + return self.output + + def draw_panoptic_seg(self, panoptic_seg, segments_info, area_threshold=None, alpha=0.7): + """ + Draw panoptic prediction annotations or results. + + Args: + panoptic_seg (Tensor): of shape (height, width) where the values are ids for each + segment. + segments_info (list[dict] or None): Describe each segment in `panoptic_seg`. + If it is a ``list[dict]``, each dict contains keys "id", "category_id". + If None, category id of each pixel is computed by + ``pixel // metadata.label_divisor``. + area_threshold (int): stuff segments with less than `area_threshold` are not drawn. + + Returns: + output (VisImage): image object with visualizations. + """ + pred = _PanopticPrediction(panoptic_seg, segments_info, self.metadata) + + if self._instance_mode == ColorMode.IMAGE_BW: + self.output.reset_image(self._create_grayscale_image(pred.non_empty_mask())) + + # draw mask for all semantic segments first i.e. "stuff" + for mask, sinfo in pred.semantic_masks(): + category_idx = sinfo["category_id"] + try: + mask_color = [x / 255 for x in self.metadata.stuff_colors[category_idx]] + except AttributeError: + mask_color = None + + text = self.metadata.stuff_classes[category_idx] + self.draw_binary_mask( + mask, + color=mask_color, + edge_color=_OFF_WHITE, + text=text, + alpha=alpha, + area_threshold=area_threshold, + ) + + # draw mask for all instances second + all_instances = list(pred.instance_masks()) + if len(all_instances) == 0: + return self.output + masks, sinfo = list(zip(*all_instances)) + category_ids = [x["category_id"] for x in sinfo] + + try: + scores = [x["score"] for x in sinfo] + except KeyError: + scores = None + labels = _create_text_labels( + category_ids, scores, self.metadata.thing_classes, [x.get("iscrowd", 0) for x in sinfo] + ) + + try: + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) for c in category_ids + ] + except AttributeError: + colors = None + self.overlay_instances(masks=masks, labels=labels, assigned_colors=colors, alpha=alpha) + + return self.output + + draw_panoptic_seg_predictions = draw_panoptic_seg # backward compatibility + + def draw_dataset_dict(self, dic): + """ + Draw annotations/segmentations in Detectron2 Dataset format. + + Args: + dic (dict): annotation/segmentation data of one image, in Detectron2 Dataset format. + + Returns: + output (VisImage): image object with visualizations. + """ + annos = dic.get("annotations", None) + if annos: + if "segmentation" in annos[0]: + masks = [x["segmentation"] for x in annos] + else: + masks = None + if "keypoints" in annos[0]: + keypts = [x["keypoints"] for x in annos] + keypts = np.array(keypts).reshape(len(annos), -1, 3) + else: + keypts = None + + boxes = [ + BoxMode.convert(x["bbox"], x["bbox_mode"], BoxMode.XYXY_ABS) + if len(x["bbox"]) == 4 + else x["bbox"] + for x in annos + ] + + colors = None + category_ids = [x["category_id"] for x in annos] + if self._instance_mode == ColorMode.SEGMENTATION and self.metadata.get("thing_colors"): + colors = [ + self._jitter([x / 255 for x in self.metadata.thing_colors[c]]) + for c in category_ids + ] + names = self.metadata.get("thing_classes", None) + labels = _create_text_labels( + category_ids, + scores=None, + class_names=names, + is_crowd=[x.get("iscrowd", 0) for x in annos], + ) + self.overlay_instances( + labels=labels, boxes=boxes, masks=masks, keypoints=keypts, assigned_colors=colors + ) + + sem_seg = dic.get("sem_seg", None) + if sem_seg is None and "sem_seg_file_name" in dic: + with PathManager.open(dic["sem_seg_file_name"], "rb") as f: + sem_seg = Image.open(f) + sem_seg = np.asarray(sem_seg, dtype="uint8") + if sem_seg is not None: + self.draw_sem_seg(sem_seg, area_threshold=0, alpha=0.5) + + pan_seg = dic.get("pan_seg", None) + if pan_seg is None and "pan_seg_file_name" in dic: + with PathManager.open(dic["pan_seg_file_name"], "rb") as f: + pan_seg = Image.open(f) + pan_seg = np.asarray(pan_seg) + from panopticapi.utils import rgb2id + + pan_seg = rgb2id(pan_seg) + if pan_seg is not None: + segments_info = dic["segments_info"] + pan_seg = torch.tensor(pan_seg) + self.draw_panoptic_seg(pan_seg, segments_info, area_threshold=0, alpha=0.5) + return self.output + + def overlay_instances( + self, + *, + boxes=None, + labels=None, + masks=None, + keypoints=None, + assigned_colors=None, + alpha=0.5, + ): + """ + Args: + boxes (Boxes, RotatedBoxes or ndarray): either a :class:`Boxes`, + or an Nx4 numpy array of XYXY_ABS format for the N objects in a single image, + or a :class:`RotatedBoxes`, + or an Nx5 numpy array of (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image, + labels (list[str]): the text to be displayed for each instance. + masks (masks-like object): Supported types are: + + * :class:`detectron2.structures.PolygonMasks`, + :class:`detectron2.structures.BitMasks`. + * list[list[ndarray]]: contains the segmentation masks for all objects in one image. + The first level of the list corresponds to individual instances. The second + level to all the polygon that compose the instance, and the third level + to the polygon coordinates. The third level should have the format of + [x0, y0, x1, y1, ..., xn, yn] (n >= 3). + * list[ndarray]: each ndarray is a binary mask of shape (H, W). + * list[dict]: each dict is a COCO-style RLE. + keypoints (Keypoint or array like): an array-like object of shape (N, K, 3), + where the N is the number of instances and K is the number of keypoints. + The last dimension corresponds to (x, y, visibility or score). + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = 0 + if boxes is not None: + boxes = self._convert_boxes(boxes) + num_instances = len(boxes) + if masks is not None: + masks = self._convert_masks(masks) + if num_instances: + assert len(masks) == num_instances + else: + num_instances = len(masks) + if keypoints is not None: + if num_instances: + assert len(keypoints) == num_instances + else: + num_instances = len(keypoints) + keypoints = self._convert_keypoints(keypoints) + if labels is not None: + assert len(labels) == num_instances + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + if boxes is not None and boxes.shape[1] == 5: + return self.overlay_rotated_instances( + boxes=boxes, labels=labels, assigned_colors=assigned_colors + ) + + # Display in largest to smallest order to reduce occlusion. + areas = None + if boxes is not None: + areas = np.prod(boxes[:, 2:] - boxes[:, :2], axis=1) + elif masks is not None: + areas = np.asarray([x.area() for x in masks]) + + if areas is not None: + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] if boxes is not None else None + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + masks = [masks[idx] for idx in sorted_idxs] if masks is not None else None + assigned_colors = [assigned_colors[idx] for idx in sorted_idxs] + keypoints = keypoints[sorted_idxs] if keypoints is not None else None + + for i in range(num_instances): + color = assigned_colors[i] + if boxes is not None: + self.draw_box(boxes[i], edge_color=color) + + if masks is not None: + for segment in masks[i].polygons: + self.draw_polygon(segment.reshape(-1, 2), color, alpha=alpha) + + if labels is not None: + # first get a box + if boxes is not None: + x0, y0, x1, y1 = boxes[i] + text_pos = (x0, y0) # if drawing boxes, put text on the box corner. + horiz_align = "left" + elif masks is not None: + # skip small mask without polygon + if len(masks[i].polygons) == 0: + continue + + x0, y0, x1, y1 = masks[i].bbox() + + # draw text in the center (defined by median) when box is not drawn + # median is less sensitive to outliers. + text_pos = np.median(masks[i].mask.nonzero(), axis=1)[::-1] + horiz_align = "center" + else: + continue # drawing the box confidence for keypoints isn't very useful. + # for small objects, draw text at the side to avoid occlusion + instance_area = (y1 - y0) * (x1 - x0) + if ( + instance_area < _SMALL_OBJECT_AREA_THRESH * self.output.scale + or y1 - y0 < 40 * self.output.scale + ): + if y1 >= self.output.height - 5: + text_pos = (x1, y0) + else: + text_pos = (x0, y1) + + height_ratio = (y1 - y0) / np.sqrt(self.output.height * self.output.width) + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) + * 0.5 + * self._default_font_size + ) + self.draw_text( + labels[i], + text_pos, + color=lighter_color, + horizontal_alignment=horiz_align, + font_size=font_size, + ) + + # draw keypoints + if keypoints is not None: + for keypoints_per_instance in keypoints: + self.draw_and_connect_keypoints(keypoints_per_instance) + + return self.output + + def overlay_rotated_instances(self, boxes=None, labels=None, assigned_colors=None): + """ + Args: + boxes (ndarray): an Nx5 numpy array of + (x_center, y_center, width, height, angle_degrees) format + for the N objects in a single image. + labels (list[str]): the text to be displayed for each instance. + assigned_colors (list[matplotlib.colors]): a list of colors, where each color + corresponds to each mask or box in the image. Refer to 'matplotlib.colors' + for full list of formats that the colors are accepted in. + + Returns: + output (VisImage): image object with visualizations. + """ + num_instances = len(boxes) + + if assigned_colors is None: + assigned_colors = [random_color(rgb=True, maximum=1) for _ in range(num_instances)] + if num_instances == 0: + return self.output + + # Display in largest to smallest order to reduce occlusion. + if boxes is not None: + areas = boxes[:, 2] * boxes[:, 3] + + sorted_idxs = np.argsort(-areas).tolist() + # Re-order overlapped instances in descending order. + boxes = boxes[sorted_idxs] + labels = [labels[k] for k in sorted_idxs] if labels is not None else None + colors = [assigned_colors[idx] for idx in sorted_idxs] + + for i in range(num_instances): + self.draw_rotated_box_with_label( + boxes[i], edge_color=colors[i], label=labels[i] if labels is not None else None + ) + + return self.output + + def draw_and_connect_keypoints(self, keypoints): + """ + Draws keypoints of an instance and follows the rules for keypoint connections + to draw lines between appropriate keypoints. This follows color heuristics for + line color. + + Args: + keypoints (Tensor): a tensor of shape (K, 3), where K is the number of keypoints + and the last dimension corresponds to (x, y, probability). + + Returns: + output (VisImage): image object with visualizations. + """ + visible = {} + keypoint_names = self.metadata.get("keypoint_names") + for idx, keypoint in enumerate(keypoints): + + # draw keypoint + x, y, prob = keypoint + if prob > self.keypoint_threshold: + self.draw_circle((x, y), color=_RED) + if keypoint_names: + keypoint_name = keypoint_names[idx] + visible[keypoint_name] = (x, y) + + if self.metadata.get("keypoint_connection_rules"): + for kp0, kp1, color in self.metadata.keypoint_connection_rules: + if kp0 in visible and kp1 in visible: + x0, y0 = visible[kp0] + x1, y1 = visible[kp1] + color = tuple(x / 255.0 for x in color) + self.draw_line([x0, x1], [y0, y1], color=color) + + # draw lines from nose to mid-shoulder and mid-shoulder to mid-hip + # Note that this strategy is specific to person keypoints. + # For other keypoints, it should just do nothing + try: + ls_x, ls_y = visible["left_shoulder"] + rs_x, rs_y = visible["right_shoulder"] + mid_shoulder_x, mid_shoulder_y = (ls_x + rs_x) / 2, (ls_y + rs_y) / 2 + except KeyError: + pass + else: + # draw line from nose to mid-shoulder + nose_x, nose_y = visible.get("nose", (None, None)) + if nose_x is not None: + self.draw_line([nose_x, mid_shoulder_x], [nose_y, mid_shoulder_y], color=_RED) + + try: + # draw line from mid-shoulder to mid-hip + lh_x, lh_y = visible["left_hip"] + rh_x, rh_y = visible["right_hip"] + except KeyError: + pass + else: + mid_hip_x, mid_hip_y = (lh_x + rh_x) / 2, (lh_y + rh_y) / 2 + self.draw_line([mid_hip_x, mid_shoulder_x], [mid_hip_y, mid_shoulder_y], color=_RED) + return self.output + + """ + Primitive drawing functions: + """ + + def draw_text( + self, + text, + position, + *, + font_size=None, + color="g", + horizontal_alignment="center", + rotation=0, + ): + """ + Args: + text (str): class label + position (tuple): a tuple of the x and y coordinates to place text on image. + font_size (int, optional): font of the text. If not provided, a font size + proportional to the image width is calculated and used. + color: color of the text. Refer to `matplotlib.colors` for full list + of formats that are accepted. + horizontal_alignment (str): see `matplotlib.text.Text` + rotation: rotation angle in degrees CCW + + Returns: + output (VisImage): image object with text drawn. + """ + if not font_size: + font_size = self._default_font_size + + # since the text background is dark, we don't want the text to be dark + color = np.maximum(list(mplc.to_rgb(color)), 0.2) + color[np.argmax(color)] = max(0.8, np.max(color)) + + x, y = position + self.output.ax.text( + x, + y, + text, + size=font_size * self.output.scale, + family="sans-serif", + bbox={"facecolor": "black", "alpha": 0.8, "pad": 0.7, "edgecolor": "none"}, + verticalalignment="top", + horizontalalignment=horizontal_alignment, + color=color, + zorder=10, + rotation=rotation, + ) + return self.output + + def draw_box(self, box_coord, alpha=0.5, edge_color="g", line_style="-"): + """ + Args: + box_coord (tuple): a tuple containing x0, y0, x1, y1 coordinates, where x0 and y0 + are the coordinates of the image's top left corner. x1 and y1 are the + coordinates of the image's bottom right corner. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + + Returns: + output (VisImage): image object with box drawn. + """ + x0, y0, x1, y1 = box_coord + width = x1 - x0 + height = y1 - y0 + + linewidth = max(self._default_font_size / 4, 1) + + self.output.ax.add_patch( + mpl.patches.Rectangle( + (x0, y0), + width, + height, + fill=False, + edgecolor=edge_color, + linewidth=linewidth * self.output.scale, + alpha=alpha, + linestyle=line_style, + ) + ) + return self.output + + def draw_rotated_box_with_label( + self, rotated_box, alpha=0.5, edge_color="g", line_style="-", label=None + ): + """ + Draw a rotated box with label on its top-left corner. + + Args: + rotated_box (tuple): a tuple containing (cnt_x, cnt_y, w, h, angle), + where cnt_x and cnt_y are the center coordinates of the box. + w and h are the width and height of the box. angle represents how + many degrees the box is rotated CCW with regard to the 0-degree box. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + edge_color: color of the outline of the box. Refer to `matplotlib.colors` + for full list of formats that are accepted. + line_style (string): the string to use to create the outline of the boxes. + label (string): label for rotated box. It will not be rendered when set to None. + + Returns: + output (VisImage): image object with box drawn. + """ + cnt_x, cnt_y, w, h, angle = rotated_box + area = w * h + # use thinner lines when the box is small + linewidth = self._default_font_size / ( + 6 if area < _SMALL_OBJECT_AREA_THRESH * self.output.scale else 3 + ) + + theta = angle * math.pi / 180.0 + c = math.cos(theta) + s = math.sin(theta) + rect = [(-w / 2, h / 2), (-w / 2, -h / 2), (w / 2, -h / 2), (w / 2, h / 2)] + # x: left->right ; y: top->down + rotated_rect = [(s * yy + c * xx + cnt_x, c * yy - s * xx + cnt_y) for (xx, yy) in rect] + for k in range(4): + j = (k + 1) % 4 + self.draw_line( + [rotated_rect[k][0], rotated_rect[j][0]], + [rotated_rect[k][1], rotated_rect[j][1]], + color=edge_color, + linestyle="--" if k == 1 else line_style, + linewidth=linewidth, + ) + + if label is not None: + text_pos = rotated_rect[1] # topleft corner + + height_ratio = h / np.sqrt(self.output.height * self.output.width) + label_color = self._change_color_brightness(edge_color, brightness_factor=0.7) + font_size = ( + np.clip((height_ratio - 0.02) / 0.08 + 1, 1.2, 2) * 0.5 * self._default_font_size + ) + self.draw_text(label, text_pos, color=label_color, font_size=font_size, rotation=angle) + + return self.output + + def draw_circle(self, circle_coord, color, radius=3): + """ + Args: + circle_coord (list(int) or tuple(int)): contains the x and y coordinates + of the center of the circle. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + radius (int): radius of the circle. + + Returns: + output (VisImage): image object with box drawn. + """ + x, y = circle_coord + self.output.ax.add_patch( + mpl.patches.Circle(circle_coord, radius=radius, fill=True, color=color) + ) + return self.output + + def draw_line(self, x_data, y_data, color, linestyle="-", linewidth=None): + """ + Args: + x_data (list[int]): a list containing x values of all the points being drawn. + Length of list should match the length of y_data. + y_data (list[int]): a list containing y values of all the points being drawn. + Length of list should match the length of x_data. + color: color of the line. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + linestyle: style of the line. Refer to `matplotlib.lines.Line2D` + for a full list of formats that are accepted. + linewidth (float or None): width of the line. When it's None, + a default value will be computed and used. + + Returns: + output (VisImage): image object with line drawn. + """ + if linewidth is None: + linewidth = self._default_font_size / 3 + linewidth = max(linewidth, 1) + self.output.ax.add_line( + mpl.lines.Line2D( + x_data, + y_data, + linewidth=linewidth * self.output.scale, + color=color, + linestyle=linestyle, + ) + ) + return self.output + + def draw_binary_mask( + self, binary_mask, color=None, *, edge_color=None, text=None, alpha=0.5, area_threshold=10 + ): + """ + Args: + binary_mask (ndarray): numpy array of shape (H, W), where H is the image height and + W is the image width. Each value in the array is either a 0 or 1 value of uint8 + type. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + area_threshold (float): a connected component smaller than this area will not be shown. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + has_valid_segment = False + binary_mask = binary_mask.astype("uint8") # opencv needs uint8 + mask = GenericMask(binary_mask, self.output.height, self.output.width) + shape2d = (binary_mask.shape[0], binary_mask.shape[1]) + + if not mask.has_holes: + # draw polygons for regular masks + for segment in mask.polygons: + area = mask_util.area(mask_util.frPyObjects([segment], shape2d[0], shape2d[1])) + if area < (area_threshold or 0): + continue + has_valid_segment = True + segment = segment.reshape(-1, 2) + self.draw_polygon(segment, color=color, edge_color=edge_color, alpha=alpha) + else: + # TODO: Use Path/PathPatch to draw vector graphics: + # https://stackoverflow.com/questions/8919719/how-to-plot-a-complex-polygon + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = (mask.mask == 1).astype("float32") * alpha + has_valid_segment = True + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None and has_valid_segment: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_soft_mask(self, soft_mask, color=None, *, text=None, alpha=0.5): + """ + Args: + soft_mask (ndarray): float array of shape (H, W), each value in [0, 1]. + color: color of the mask. Refer to `matplotlib.colors` for a full list of + formats that are accepted. If None, will pick a random color. + text (str): if None, will be drawn on the object + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with mask drawn. + """ + if color is None: + color = random_color(rgb=True, maximum=1) + color = mplc.to_rgb(color) + + shape2d = (soft_mask.shape[0], soft_mask.shape[1]) + rgba = np.zeros(shape2d + (4,), dtype="float32") + rgba[:, :, :3] = color + rgba[:, :, 3] = soft_mask * alpha + self.output.ax.imshow(rgba, extent=(0, self.output.width, self.output.height, 0)) + + if text is not None: + lighter_color = self._change_color_brightness(color, brightness_factor=0.7) + binary_mask = (soft_mask > 0.5).astype("uint8") + self._draw_text_in_mask(binary_mask, text, lighter_color) + return self.output + + def draw_polygon(self, segment, color, edge_color=None, alpha=0.5): + """ + Args: + segment: numpy array of shape Nx2, containing all the points in the polygon. + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + edge_color: color of the polygon edges. Refer to `matplotlib.colors` for a + full list of formats that are accepted. If not provided, a darker shade + of the polygon color will be used instead. + alpha (float): blending efficient. Smaller values lead to more transparent masks. + + Returns: + output (VisImage): image object with polygon drawn. + """ + if edge_color is None: + # make edge color darker than the polygon color + if alpha > 0.8: + edge_color = self._change_color_brightness(color, brightness_factor=-0.7) + else: + edge_color = color + edge_color = mplc.to_rgb(edge_color) + (1,) + + polygon = mpl.patches.Polygon( + segment, + fill=True, + facecolor=mplc.to_rgb(color) + (alpha,), + edgecolor=edge_color, + linewidth=max(self._default_font_size // 15 * self.output.scale, 1), + ) + self.output.ax.add_patch(polygon) + return self.output + + """ + Internal methods: + """ + + def _jitter(self, color): + """ + Randomly modifies given color to produce a slightly different color than the color given. + + Args: + color (tuple[double]): a tuple of 3 elements, containing the RGB values of the color + picked. The values in the list are in the [0.0, 1.0] range. + + Returns: + jittered_color (tuple[double]): a tuple of 3 elements, containing the RGB values of the + color after being jittered. The values in the list are in the [0.0, 1.0] range. + """ + color = mplc.to_rgb(color) + vec = np.random.rand(3) + # better to do it in another color space + vec = vec / np.linalg.norm(vec) * 0.5 + res = np.clip(vec + color, 0, 1) + return tuple(res) + + def _create_grayscale_image(self, mask=None): + """ + Create a grayscale version of the original image. + The colors in masked area, if given, will be kept. + """ + img_bw = self.img.astype("f4").mean(axis=2) + img_bw = np.stack([img_bw] * 3, axis=2) + if mask is not None: + img_bw[mask] = self.img[mask] + return img_bw + + def _change_color_brightness(self, color, brightness_factor): + """ + Depending on the brightness_factor, gives a lighter or darker color i.e. a color with + less or more saturation than the original color. + + Args: + color: color of the polygon. Refer to `matplotlib.colors` for a full list of + formats that are accepted. + brightness_factor (float): a value in [-1.0, 1.0] range. A lightness factor of + 0 will correspond to no change, a factor in [-1.0, 0) range will result in + a darker color and a factor in (0, 1.0] range will result in a lighter color. + + Returns: + modified_color (tuple[double]): a tuple containing the RGB values of the + modified color. Each value in the tuple is in the [0.0, 1.0] range. + """ + assert brightness_factor >= -1.0 and brightness_factor <= 1.0 + color = mplc.to_rgb(color) + polygon_color = colorsys.rgb_to_hls(*mplc.to_rgb(color)) + modified_lightness = polygon_color[1] + (brightness_factor * polygon_color[1]) + modified_lightness = 0.0 if modified_lightness < 0.0 else modified_lightness + modified_lightness = 1.0 if modified_lightness > 1.0 else modified_lightness + modified_color = colorsys.hls_to_rgb(polygon_color[0], modified_lightness, polygon_color[2]) + return tuple(np.clip(modified_color, 0.0, 1.0)) + + def _convert_boxes(self, boxes): + """ + Convert different format of boxes to an NxB array, where B = 4 or 5 is the box dimension. + """ + if isinstance(boxes, Boxes) or isinstance(boxes, RotatedBoxes): + return boxes.tensor.detach().numpy() + else: + return np.asarray(boxes) + + def _convert_masks(self, masks_or_polygons): + """ + Convert different format of masks or polygons to a tuple of masks and polygons. + + Returns: + list[GenericMask]: + """ + + m = masks_or_polygons + if isinstance(m, PolygonMasks): + m = m.polygons + if isinstance(m, BitMasks): + m = m.tensor.numpy() + if isinstance(m, torch.Tensor): + m = m.numpy() + ret = [] + for x in m: + if isinstance(x, GenericMask): + ret.append(x) + else: + ret.append(GenericMask(x, self.output.height, self.output.width)) + return ret + + def _draw_text_in_mask(self, binary_mask, text, color): + """ + Find proper places to draw text given a binary mask. + """ + # TODO sometimes drawn on wrong objects. the heuristics here can improve. + _num_cc, cc_labels, stats, centroids = cv2.connectedComponentsWithStats(binary_mask, 8) + if stats[1:, -1].size == 0: + return + largest_component_id = np.argmax(stats[1:, -1]) + 1 + + # draw text on the largest component, as well as other very large components. + for cid in range(1, _num_cc): + if cid == largest_component_id or stats[cid, -1] > _LARGE_MASK_AREA_THRESH: + # median is more stable than centroid + # center = centroids[largest_component_id] + center = np.median((cc_labels == cid).nonzero(), axis=1)[::-1] + self.draw_text(text, center, color=color) + + def _convert_keypoints(self, keypoints): + if isinstance(keypoints, Keypoints): + keypoints = keypoints.tensor + keypoints = np.asarray(keypoints) + return keypoints + + def get_output(self): + """ + Returns: + output (VisImage): the image output containing the visualizations added + to the image. + """ + return self.output diff --git a/custom_manopth/CHANGES.md b/custom_manopth/CHANGES.md new file mode 100644 index 0000000000000000000000000000000000000000..27e7d74595a4b048f0e0aff1f77a7488870a821e --- /dev/null +++ b/custom_manopth/CHANGES.md @@ -0,0 +1 @@ +* Chumpy is removed \ No newline at end of file diff --git a/custom_manopth/LICENSE b/custom_manopth/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..e72bfddabc15be5718a7cc061ac10e47741d8219 --- /dev/null +++ b/custom_manopth/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. \ No newline at end of file diff --git a/custom_manopth/__init__.py b/custom_manopth/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e27cf8699e86b13d9d4cb28da10d54c405effe96 --- /dev/null +++ b/custom_manopth/__init__.py @@ -0,0 +1 @@ +name = 'manopth' diff --git a/custom_manopth/argutils.py b/custom_manopth/argutils.py new file mode 100644 index 0000000000000000000000000000000000000000..7e86eb025ad0618e63d730b4f59ee3615118d197 --- /dev/null +++ b/custom_manopth/argutils.py @@ -0,0 +1,51 @@ +import datetime +import os +import pickle +import subprocess +import sys + + +def print_args(args): + opts = vars(args) + print('======= Options ========') + for k, v in sorted(opts.items()): + print('{}: {}'.format(k, v)) + print('========================') + + +def save_args(args, save_folder, opt_prefix='opt', verbose=True): + opts = vars(args) + # Create checkpoint folder + if not os.path.exists(save_folder): + os.makedirs(save_folder, exist_ok=True) + + # Save options + opt_filename = '{}.txt'.format(opt_prefix) + opt_path = os.path.join(save_folder, opt_filename) + with open(opt_path, 'a') as opt_file: + opt_file.write('====== Options ======\n') + for k, v in sorted(opts.items()): + opt_file.write( + '{option}: {value}\n'.format(option=str(k), value=str(v))) + opt_file.write('=====================\n') + opt_file.write('launched {} at {}\n'.format( + str(sys.argv[0]), str(datetime.datetime.now()))) + + # Add git info + label = subprocess.check_output(["git", "describe", + "--always"]).strip() + if subprocess.call( + ["git", "branch"], + stderr=subprocess.STDOUT, + stdout=open(os.devnull, 'w')) == 0: + opt_file.write('=== Git info ====\n') + opt_file.write('{}\n'.format(label)) + commit = subprocess.check_output(['git', 'rev-parse', 'HEAD']) + opt_file.write('commit : {}\n'.format(commit.strip())) + + opt_picklename = '{}.pkl'.format(opt_prefix) + opt_picklepath = os.path.join(save_folder, opt_picklename) + with open(opt_picklepath, 'wb') as opt_file: + pickle.dump(opts, opt_file) + if verbose: + print('Saved options to {}'.format(opt_path)) diff --git a/custom_manopth/demo.py b/custom_manopth/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..0a250b7c4d4e2fd2face4611d87839faf795eecd --- /dev/null +++ b/custom_manopth/demo.py @@ -0,0 +1,59 @@ +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D +from mpl_toolkits.mplot3d.art3d import Poly3DCollection +import numpy as np +import torch + +from custom_manopth.manolayer import ManoLayer + + +def generate_random_hand(batch_size=1, ncomps=6, mano_root='mano/models'): + nfull_comps = ncomps + 3 # Add global orientation dims to PCA + random_pcapose = torch.rand(batch_size, nfull_comps) + mano_layer = ManoLayer(mano_root=mano_root) + verts, joints = mano_layer(random_pcapose) + return {'verts': verts, 'joints': joints, 'faces': mano_layer.th_faces} + + +def display_hand(hand_info, mano_faces=None, ax=None, alpha=0.2, batch_idx=0, show=True): + """ + Displays hand batch_idx in batch of hand_info, hand_info as returned by + generate_random_hand + """ + if ax is None: + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + verts, joints = hand_info['verts'][batch_idx], hand_info['joints'][ + batch_idx] + if mano_faces is None: + ax.scatter(verts[:, 0], verts[:, 1], verts[:, 2], alpha=0.1) + else: + mesh = Poly3DCollection(verts[mano_faces], alpha=alpha) + face_color = (141 / 255, 184 / 255, 226 / 255) + edge_color = (50 / 255, 50 / 255, 50 / 255) + mesh.set_edgecolor(edge_color) + mesh.set_facecolor(face_color) + ax.add_collection3d(mesh) + ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], color='r') + cam_equal_aspect_3d(ax, verts.numpy()) + if show: + plt.show() + + +def cam_equal_aspect_3d(ax, verts, flip_x=False): + """ + Centers view on cuboid containing hand and flips y and z axis + and fixes azimuth + """ + extents = np.stack([verts.min(0), verts.max(0)], axis=1) + sz = extents[:, 1] - extents[:, 0] + centers = np.mean(extents, axis=1) + maxsize = max(abs(sz)) + r = maxsize / 2 + if flip_x: + ax.set_xlim(centers[0] + r, centers[0] - r) + else: + ax.set_xlim(centers[0] - r, centers[0] + r) + # Invert y and z axis + ax.set_ylim(centers[1] + r, centers[1] - r) + ax.set_zlim(centers[2] + r, centers[2] - r) diff --git a/custom_manopth/manolayer.py b/custom_manopth/manolayer.py new file mode 100644 index 0000000000000000000000000000000000000000..29440d96eb0e41ca7945add89e1df12e274f2b01 --- /dev/null +++ b/custom_manopth/manolayer.py @@ -0,0 +1,274 @@ +import os + +import numpy as np +import torch +from torch.nn import Module + +from custom_manopth.smpl_handpca_wrapper_HAND_only import ready_arguments +from custom_manopth import rodrigues_layer, rotproj, rot6d +from custom_manopth.tensutils import (th_posemap_axisang, th_with_zeros, th_pack, + subtract_flat_id, make_list) + + +class ManoLayer(Module): + __constants__ = [ + 'use_pca', 'rot', 'ncomps', 'ncomps', 'kintree_parents', 'check', + 'side', 'center_idx', 'joint_rot_mode' + ] + + def __init__(self, + center_idx=None, + flat_hand_mean=True, + ncomps=6, + side='right', + mano_root='mano/models', + use_pca=True, + root_rot_mode='axisang', + joint_rot_mode='axisang', + robust_rot=False): + """ + Args: + center_idx: index of center joint in our computations, + if -1 centers on estimate of palm as middle of base + of middle finger and wrist + flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match + flat hand, else match average hand pose + mano_root: path to MANO pkl files for left and right hand + ncomps: number of PCA components form pose space (<45) + side: 'right' or 'left' + use_pca: Use PCA decomposition for pose space. + joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca + """ + super().__init__() + + self.center_idx = center_idx + self.robust_rot = robust_rot + if root_rot_mode == 'axisang': + self.rot = 3 + else: + self.rot = 6 + self.flat_hand_mean = flat_hand_mean + self.side = side + self.use_pca = use_pca + self.joint_rot_mode = joint_rot_mode + self.root_rot_mode = root_rot_mode + if use_pca: + self.ncomps = ncomps + else: + self.ncomps = 45 + + if side == 'right': + self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl') + elif side == 'left': + self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl') + + smpl_data = ready_arguments(self.mano_path) + + hands_components = smpl_data['hands_components'] + + self.smpl_data = smpl_data + + self.register_buffer('th_betas', + torch.Tensor(smpl_data['betas']).unsqueeze(0)) + self.register_buffer('th_shapedirs', + torch.Tensor(smpl_data['shapedirs'])) + self.register_buffer('th_posedirs', + torch.Tensor(smpl_data['posedirs'])) + self.register_buffer( + 'th_v_template', + torch.Tensor(smpl_data['v_template']).unsqueeze(0)) + self.register_buffer( + 'th_J_regressor', + torch.Tensor(np.array(smpl_data['J_regressor'].toarray()))) + self.register_buffer('th_weights', + torch.Tensor(smpl_data['weights'])) + self.register_buffer('th_faces', + torch.Tensor(smpl_data['f'].astype(np.int32)).long()) + + # Get hand mean + hands_mean = np.zeros(hands_components.shape[1] + ) if flat_hand_mean else smpl_data['hands_mean'] + hands_mean = hands_mean.copy() + th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0) + if self.use_pca or self.joint_rot_mode == 'axisang': + # Save as axis-angle + self.register_buffer('th_hands_mean', th_hands_mean) + selected_components = hands_components[:ncomps] + self.register_buffer('th_comps', torch.Tensor(hands_components)) + self.register_buffer('th_selected_comps', + torch.Tensor(selected_components)) + else: + th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues( + th_hands_mean.view(15, 3)).reshape(15, 3, 3) + self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat) + + # Kinematic chain params + self.kintree_table = smpl_data['kintree_table'] + parents = list(self.kintree_table[0].tolist()) + self.kintree_parents = parents + + def forward(self, + th_pose_coeffs, + th_betas=torch.zeros(1), + th_trans=torch.zeros(1), + root_palm=torch.Tensor([0]), + share_betas=torch.Tensor([0]), + ): + """ + Args: + th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices + th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape + else centers on root joint (9th joint) + root_palm: return palm as hand root instead of wrist + """ + # if len(th_pose_coeffs) == 0: + # return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0) + + batch_size = th_pose_coeffs.shape[0] + # Get axis angle from PCA components and coefficients + if self.use_pca or self.joint_rot_mode == 'axisang': + # Remove global rot coeffs + th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot + + self.ncomps] + if self.use_pca: + # PCA components --> axis angles + th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps) + else: + th_full_hand_pose = th_hand_pose_coeffs + + # Concatenate back global rot + th_full_pose = torch.cat([ + th_pose_coeffs[:, :self.rot], + self.th_hands_mean + th_full_hand_pose + ], 1) + if self.root_rot_mode == 'axisang': + # compute rotation matrixes from axis-angle while skipping global rotation + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose) + root_rot = th_rot_map[:, :9].view(batch_size, 3, 3) + th_rot_map = th_rot_map[:, 9:] + th_pose_map = th_pose_map[:, 9:] + else: + # th_posemap offsets by 3, so add offset or 3 to get to self.rot=6 + th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose[:, 6:]) + if self.robust_rot: + root_rot = rot6d.robust_compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + root_rot = rot6d.compute_rotation_matrix_from_ortho6d(th_full_pose[:, :6]) + else: + assert th_pose_coeffs.dim() == 4, ( + 'When not self.use_pca, ' + 'th_pose_coeffs should have 4 dims, got {}'.format( + th_pose_coeffs.dim())) + assert th_pose_coeffs.shape[2:4] == (3, 3), ( + 'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two' + 'last dims, got {}'.format(th_pose_coeffs.shape[2:4])) + th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs) + th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1) + th_pose_map = subtract_flat_id(th_rot_map) + root_rot = th_pose_rots[:, 0] + + # Full axis angle representation with root joint + if th_betas is None or th_betas.numel() == 1: + th_v_shaped = torch.matmul(self.th_shapedirs, + self.th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped).repeat( + batch_size, 1, 1) + + else: + if share_betas: + th_betas = th_betas.mean(0, keepdim=True).expand(th_betas.shape[0], 10) + th_v_shaped = torch.matmul(self.th_shapedirs, + th_betas.transpose(1, 0)).permute( + 2, 0, 1) + self.th_v_template + th_j = torch.matmul(self.th_J_regressor, th_v_shaped) + # th_pose_map should have shape 20x135 + + th_v_posed = th_v_shaped + torch.matmul( + self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1) + # Final T pose with transformation done ! + + # Global rigid transformation + + root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1) + root_trans = th_with_zeros(torch.cat([root_rot, root_j], 2)) + + all_rots = th_rot_map.view(th_rot_map.shape[0], 15, 3, 3) + lev1_idxs = [1, 4, 7, 10, 13] + lev2_idxs = [2, 5, 8, 11, 14] + lev3_idxs = [3, 6, 9, 12, 15] + lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]] + lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]] + lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]] + lev1_j = th_j[:, lev1_idxs] + lev2_j = th_j[:, lev2_idxs] + lev3_j = th_j[:, lev3_idxs] + + # From base to tips + # Get lev1 results + all_transforms = [root_trans.unsqueeze(1)] + lev1_j_rel = lev1_j - root_j.transpose(1, 2) + lev1_rel_transform_flt = th_with_zeros(torch.cat([lev1_rots, lev1_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + root_trans_flt = root_trans.unsqueeze(1).repeat(1, 5, 1, 1).view(root_trans.shape[0] * 5, 4, 4) + lev1_flt = torch.matmul(root_trans_flt, lev1_rel_transform_flt) + all_transforms.append(lev1_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev2 results + lev2_j_rel = lev2_j - lev1_j + lev2_rel_transform_flt = th_with_zeros(torch.cat([lev2_rots, lev2_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev2_flt = torch.matmul(lev1_flt, lev2_rel_transform_flt) + all_transforms.append(lev2_flt.view(all_rots.shape[0], 5, 4, 4)) + + # Get lev3 results + lev3_j_rel = lev3_j - lev2_j + lev3_rel_transform_flt = th_with_zeros(torch.cat([lev3_rots, lev3_j_rel.unsqueeze(3)], 3).view(-1, 3, 4)) + lev3_flt = torch.matmul(lev2_flt, lev3_rel_transform_flt) + all_transforms.append(lev3_flt.view(all_rots.shape[0], 5, 4, 4)) + + reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15] + th_results = torch.cat(all_transforms, 1)[:, reorder_idxs] + th_results_global = th_results + + joint_js = torch.cat([th_j, th_j.new_zeros(th_j.shape[0], 16, 1)], 2) + tmp2 = torch.matmul(th_results, joint_js.unsqueeze(3)) + th_results2 = (th_results - torch.cat([tmp2.new_zeros(*tmp2.shape[:2], 4, 3), tmp2], 3)).permute(0, 2, 3, 1) + + th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1)) + + th_rest_shape_h = torch.cat([ + th_v_posed.transpose(2, 1), + torch.ones((batch_size, 1, th_v_posed.shape[1]), + dtype=th_T.dtype, + device=th_T.device), + ], 1) + + th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1) + th_verts = th_verts[:, :, :3] + th_jtr = th_results_global[:, :, :3, 3] + # In addition to MANO reference joints we sample vertices on each finger + # to serve as finger tips + if self.side == 'right': + tips = th_verts[:, [745, 317, 444, 556, 673]] + else: + tips = th_verts[:, [745, 317, 445, 556, 673]] + if bool(root_palm): + palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2 + th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1) + th_jtr = torch.cat([th_jtr, tips], 1) + + # Reorder joints to match visualization utilities + th_jtr = th_jtr[:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]] + + if th_trans is None or bool(torch.norm(th_trans) == 0): + if self.center_idx is not None: + center_joint = th_jtr[:, self.center_idx].unsqueeze(1) + th_jtr = th_jtr - center_joint + th_verts = th_verts - center_joint + else: + th_jtr = th_jtr + th_trans.unsqueeze(1) + th_verts = th_verts + th_trans.unsqueeze(1) + + # Scale to milimeters + th_verts = th_verts * 1000 + th_jtr = th_jtr * 1000 + return th_verts, th_jtr diff --git a/custom_manopth/posemapper.py b/custom_manopth/posemapper.py new file mode 100644 index 0000000000000000000000000000000000000000..9b86ea0a5bfbcf5f2f59a6874fb1bd0d8aa99e6d --- /dev/null +++ b/custom_manopth/posemapper.py @@ -0,0 +1,37 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import numpy as np +import cv2 + +def lrotmin(p): + if isinstance(p, np.ndarray): + p = p.ravel()[3:] + return np.concatenate( + [(cv2.Rodrigues(np.array(pp))[0] - np.eye(3)).ravel() + for pp in p.reshape((-1, 3))]).ravel() + + +def posemap(s): + if s == 'lrotmin': + return lrotmin + else: + raise Exception('Unknown posemapping: %s' % (str(s), )) \ No newline at end of file diff --git a/custom_manopth/rodrigues_layer.py b/custom_manopth/rodrigues_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..eb7a2565a6c7ccf185446c157bd2599943d1832c --- /dev/null +++ b/custom_manopth/rodrigues_layer.py @@ -0,0 +1,89 @@ +""" +This part reuses code from https://github.com/MandyMo/pytorch_HMR/blob/master/src/util.py +which is part of a PyTorch port of SMPL. +Thanks to Zhang Xiong (MandyMo) for making this great code available on github ! +""" + +import argparse +from torch.autograd import gradcheck +import torch +from torch.autograd import Variable + +from custom_manopth import argutils + + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [batch_size, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [batch_size, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat / norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:, 0], norm_quat[:, 1], norm_quat[:, + 2], norm_quat[:, + 3] + + batch_size = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w * x, w * y, w * z + xy, xz, yz = x * y, x * z, y * z + + rotMat = torch.stack([ + w2 + x2 - y2 - z2, 2 * xy - 2 * wz, 2 * wy + 2 * xz, 2 * wz + 2 * xy, + w2 - x2 + y2 - z2, 2 * yz - 2 * wx, 2 * xz - 2 * wy, 2 * wx + 2 * yz, + w2 - x2 - y2 + z2 + ], + dim=1).view(batch_size, 3, 3) + return rotMat + + +def batch_rodrigues(axisang): + #axisang N x 3 + axisang_norm = torch.norm(axisang + 1e-8, p=2, dim=1) + angle = torch.unsqueeze(axisang_norm, -1) + axisang_normalized = torch.div(axisang, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * axisang_normalized], dim=1) + rot_mat = quat2mat(quat) + rot_mat = rot_mat.view(rot_mat.shape[0], 9) + return rot_mat + + +def th_get_axis_angle(vector): + angle = torch.norm(vector, 2, 1) + axes = vector / angle.unsqueeze(1) + return axes, angle + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--cuda', action='store_true') + args = parser.parse_args() + + argutils.print_args(args) + + n_components = 6 + rot = 3 + inputs = torch.rand(args.batch_size, rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + if args.cuda: + inputs = inputs.cuda() + # outputs = batch_rodrigues(inputs) + test_function = gradcheck(batch_rodrigues, (inputs_var, )) + print('batch test passed !') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_function = gradcheck(th_cv2_rod_sub_id.apply, (inputs_var, )) + print('th_cv2_rod test passed') + + inputs = torch.rand(rot) + inputs_var = Variable(inputs.double(), requires_grad=True) + test_th = gradcheck(th_cv2_rod.apply, (inputs_var, )) + print('th_cv2_rod_id test passed !') diff --git a/custom_manopth/rot6d.py b/custom_manopth/rot6d.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d60efbcfadda5f216c0eb9a60b348e248435a0 --- /dev/null +++ b/custom_manopth/rot6d.py @@ -0,0 +1,71 @@ +import torch + + +def compute_rotation_matrix_from_ortho6d(poses): + """ + Code from + https://github.com/papagina/RotationContinuity + On the Continuity of Rotation Representations in Neural Networks + Zhou et al. CVPR19 + https://zhouyisjtu.github.io/project_rotation/rotation.html + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + z = cross_product(x, y_raw) # batch*3 + z = normalize_vector(z) # batch*3 + y = cross_product(z, x) # batch*3 + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + return matrix + +def robust_compute_rotation_matrix_from_ortho6d(poses): + """ + Instead of making 2nd vector orthogonal to first + create a base that takes into account the two predicted + directions equally + """ + x_raw = poses[:, 0:3] # batch*3 + y_raw = poses[:, 3:6] # batch*3 + + x = normalize_vector(x_raw) # batch*3 + y = normalize_vector(y_raw) # batch*3 + middle = normalize_vector(x + y) + orthmid = normalize_vector(x - y) + x = normalize_vector(middle + orthmid) + y = normalize_vector(middle - orthmid) + # Their scalar product should be small ! + # assert torch.einsum("ij,ij->i", [x, y]).abs().max() < 0.00001 + z = normalize_vector(cross_product(x, y)) + + x = x.view(-1, 3, 1) + y = y.view(-1, 3, 1) + z = z.view(-1, 3, 1) + matrix = torch.cat((x, y, z), 2) # batch*3*3 + # Check for reflection in matrix ! If found, flip last vector TODO + assert (torch.stack([torch.det(mat) for mat in matrix ])< 0).sum() == 0 + return matrix + + +def normalize_vector(v): + batch = v.shape[0] + v_mag = torch.sqrt(v.pow(2).sum(1)) # batch + v_mag = torch.max(v_mag, v.new([1e-8])) + v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1]) + v = v/v_mag + return v + + +def cross_product(u, v): + batch = u.shape[0] + i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1] + j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2] + k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0] + + out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) + + return out diff --git a/custom_manopth/rotproj.py b/custom_manopth/rotproj.py new file mode 100644 index 0000000000000000000000000000000000000000..91a601d5de8117ed9fe1708c2d11e8713dc18011 --- /dev/null +++ b/custom_manopth/rotproj.py @@ -0,0 +1,21 @@ +import torch + + +def batch_rotprojs(batches_rotmats): + proj_rotmats = [] + for batch_idx, batch_rotmats in enumerate(batches_rotmats): + proj_batch_rotmats = [] + for rot_idx, rotmat in enumerate(batch_rotmats): + # GPU implementation of svd is VERY slow + # ~ 2 10^-3 per hit vs 5 10^-5 on cpu + U, S, V = rotmat.cpu().svd() + rotmat = torch.matmul(U, V.transpose(0, 1)) + orth_det = rotmat.det() + # Remove reflection + if orth_det < 0: + rotmat[:, 2] = -1 * rotmat[:, 2] + + rotmat = rotmat.cuda() + proj_batch_rotmats.append(rotmat) + proj_rotmats.append(torch.stack(proj_batch_rotmats)) + return torch.stack(proj_rotmats) diff --git a/custom_manopth/smpl_handpca_wrapper_HAND_only.py b/custom_manopth/smpl_handpca_wrapper_HAND_only.py new file mode 100644 index 0000000000000000000000000000000000000000..e2d157625974deb10f9f77f95de1f1d369ecbf8f --- /dev/null +++ b/custom_manopth/smpl_handpca_wrapper_HAND_only.py @@ -0,0 +1,155 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + +def col(A): + return A.reshape((-1, 1)) + +def MatVecMult(mtx, vec): + result = mtx.dot(col(vec.ravel())).ravel() + if len(vec.shape) > 1 and vec.shape[1] > 1: + result = result.reshape((-1, vec.shape[1])) + return result + +def ready_arguments(fname_or_dict, posekey4vposed='pose'): + import numpy as np + import pickle + from custom_manopth.posemapper import posemap + + if not isinstance(fname_or_dict, dict): + dd = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # dd = pickle.load(open(fname_or_dict, 'rb')) + else: + dd = fname_or_dict + + want_shapemodel = 'shapedirs' in dd + nposeparms = dd['kintree_table'].shape[1] * 3 + + if 'trans' not in dd: + dd['trans'] = np.zeros(3) + if 'pose' not in dd: + dd['pose'] = np.zeros(nposeparms) + if 'shapedirs' in dd and 'betas' not in dd: + dd['betas'] = np.zeros(dd['shapedirs'].shape[-1]) + + for s in [ + 'v_template', 'weights', 'posedirs', 'pose', 'trans', 'shapedirs', + 'betas', 'J' + ]: + if (s in dd) and not hasattr(dd[s], 'dterms'): + dd[s] = np.array(dd[s]) + + assert (posekey4vposed in dd) + if want_shapemodel: + dd['v_shaped'] = dd['shapedirs'].dot(dd['betas']) + dd['v_template'] + v_shaped = dd['v_shaped'] + J_tmpx = MatVecMult(dd['J_regressor'], v_shaped[:, 0]) + J_tmpy = MatVecMult(dd['J_regressor'], v_shaped[:, 1]) + J_tmpz = MatVecMult(dd['J_regressor'], v_shaped[:, 2]) + dd['J'] = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd['v_posed'] = v_shaped + dd['posedirs'].dot(pose_map_res) + else: + pose_map_res = posemap(dd['bs_type'])(dd[posekey4vposed]) + dd_add = dd['posedirs'].dot(pose_map_res) + dd['v_posed'] = dd['v_template'] + dd_add + + return dd + + +def load_model(fname_or_dict, ncomps=6, flat_hand_mean=False, v_template=None): + ''' This model loads the fully articulable HAND SMPL model, + and replaces the pose DOFS by ncomps from PCA''' + + from custom_manopth.verts import verts_core + import numpy as np + import pickle + import scipy.sparse as sp + np.random.seed(1) + + if not isinstance(fname_or_dict, dict): + smpl_data = pickle.load(open(fname_or_dict, 'rb'), encoding='latin1') + # smpl_data = pickle.load(open(fname_or_dict, 'rb')) + else: + smpl_data = fname_or_dict + + rot = 3 # for global orientation!!! + + hands_components = smpl_data['hands_components'] + hands_mean = np.zeros(hands_components.shape[ + 1]) if flat_hand_mean else smpl_data['hands_mean'] + hands_coeffs = smpl_data['hands_coeffs'][:, :ncomps] + + selected_components = np.vstack((hands_components[:ncomps])) + hands_mean = hands_mean.copy() + + pose_coeffs = np.zeros(rot + selected_components.shape[0]) + full_hand_pose = pose_coeffs[rot:(rot + ncomps)].dot(selected_components) + + smpl_data['fullpose'] = np.concatenate((pose_coeffs[:rot], + hands_mean + full_hand_pose)) + smpl_data['pose'] = pose_coeffs + + Jreg = smpl_data['J_regressor'] + if not sp.issparse(Jreg): + smpl_data['J_regressor'] = (sp.csc_matrix( + (Jreg.data, (Jreg.row, Jreg.col)), shape=Jreg.shape)) + + # slightly modify ready_arguments to make sure that it uses the fullpose + # (which will NOT be pose) for the computation of posedirs + dd = ready_arguments(smpl_data, posekey4vposed='fullpose') + + # create the smpl formula with the fullpose, + # but expose the PCA coefficients as smpl.pose for compatibility + args = { + 'pose': dd['fullpose'], + 'v': dd['v_posed'], + 'J': dd['J'], + 'weights': dd['weights'], + 'kintree_table': dd['kintree_table'], + 'xp': np, + 'want_Jtr': True, + 'bs_style': dd['bs_style'], + } + + result_previous, meta = verts_core(**args) + + result = result_previous + dd['trans'].reshape((1, 3)) + result.no_translation = result_previous + + if meta is not None: + for field in ['Jtr', 'A', 'A_global', 'A_weighted']: + if (hasattr(meta, field)): + setattr(result, field, getattr(meta, field)) + + setattr(result, 'Jtr', meta) + if hasattr(result, 'Jtr'): + result.J_transformed = result.Jtr + dd['trans'].reshape((1, 3)) + + for k, v in dd.items(): + setattr(result, k, v) + + if v_template is not None: + result.v_template[:] = v_template + + return result + + +if __name__ == '__main__': + load_model() \ No newline at end of file diff --git a/custom_manopth/tensutils.py b/custom_manopth/tensutils.py new file mode 100644 index 0000000000000000000000000000000000000000..ce627beb5b4bd607a0e7ecb2d3a46814199e7b89 --- /dev/null +++ b/custom_manopth/tensutils.py @@ -0,0 +1,47 @@ +import torch + +from custom_manopth import rodrigues_layer + + +def th_posemap_axisang(pose_vectors): + rot_nb = int(pose_vectors.shape[1] / 3) + pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3) + rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped) + rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9) + pose_maps = subtract_flat_id(rot_mats) + return pose_maps, rot_mats + + +def th_with_zeros(tensor): + batch_size = tensor.shape[0] + padding = torch.tensor([0.0, 0.0, 0.0, 1.0], device = tensor.device, dtype = tensor.dtype) + padding.requires_grad = False + + concat_list = [tensor, padding.view(1, 1, 4).repeat(batch_size, 1, 1)] + cat_res = torch.cat(concat_list, 1) + return cat_res + + +def th_pack(tensor): + batch_size = tensor.shape[0] + padding = tensor.new_zeros((batch_size, 4, 3)) + padding.requires_grad = False + pack_list = [padding, tensor] + pack_res = torch.cat(pack_list, 2) + return pack_res + + +def subtract_flat_id(rot_mats): + # Subtracts identity as a flattened tensor + rot_nb = int(rot_mats.shape[1] / 9) + id_flat = torch.eye( + 3, dtype=rot_mats.dtype, device=rot_mats.device).view(1, 9).repeat( + rot_mats.shape[0], rot_nb) + # id_flat.requires_grad = False + results = rot_mats - id_flat + return results + + +def make_list(tensor): + # type: (List[int]) -> List[int] + return tensor diff --git a/custom_manopth/verts.py b/custom_manopth/verts.py new file mode 100644 index 0000000000000000000000000000000000000000..7a9e5c320544a50b1c3c2d35dce0cb9967ed0ee9 --- /dev/null +++ b/custom_manopth/verts.py @@ -0,0 +1,117 @@ +''' +Copyright 2017 Javier Romero, Dimitrios Tzionas, Michael J Black and the Max Planck Gesellschaft. All rights reserved. +This software is provided for research purposes only. +By using this software you agree to the terms of the MANO/SMPL+H Model license here http://mano.is.tue.mpg.de/license + +More information about MANO/SMPL+H is available at http://mano.is.tue.mpg.de. +For comments or questions, please email us at: mano@tue.mpg.de + + +About this file: +================ +This file defines a wrapper for the loading functions of the MANO model. + +Modules included: +- load_model: + loads the MANO model from a given file location (i.e. a .pkl file location), + or a dictionary object. + +''' + + +import numpy as np +import mano.webuser.lbs as lbs +from mano.webuser.posemapper import posemap +import scipy.sparse as sp + + +def ischumpy(x): + return hasattr(x, 'dterms') + + +def verts_decorated(trans, + pose, + v_template, + J_regressor, + weights, + kintree_table, + bs_style, + f, + bs_type=None, + posedirs=None, + betas=None, + shapedirs=None, + want_Jtr=False): + + for which in [ + trans, pose, v_template, weights, posedirs, betas, shapedirs + ]: + if which is not None: + assert ischumpy(which) + + v = v_template + + if shapedirs is not None: + if betas is None: + betas = np.zeros(shapedirs.shape[-1]) + v_shaped = v + shapedirs.dot(betas) + else: + v_shaped = v + + if posedirs is not None: + v_posed = v_shaped + posedirs.dot(posemap(bs_type)(pose)) + else: + v_posed = v_shaped + + v = v_posed + + if sp.issparse(J_regressor): + J_tmpx = np.matmul(J_regressor, v_shaped[:, 0]) + J_tmpy = np.matmul(J_regressor, v_shaped[:, 1]) + J_tmpz = np.matmul(J_regressor, v_shaped[:, 2]) + J = np.vstack((J_tmpx, J_tmpy, J_tmpz)).T + else: + assert (ischumpy(J)) + + assert (bs_style == 'lbs') + result, Jtr = lbs.verts_core( + pose, v, J, weights, kintree_table, want_Jtr=True, xp=np) + + tr = trans.reshape((1, 3)) + result = result + tr + Jtr = Jtr + tr + + result.trans = trans + result.f = f + result.pose = pose + result.v_template = v_template + result.J = J + result.J_regressor = J_regressor + result.weights = weights + result.kintree_table = kintree_table + result.bs_style = bs_style + result.bs_type = bs_type + if posedirs is not None: + result.posedirs = posedirs + result.v_posed = v_posed + if shapedirs is not None: + result.shapedirs = shapedirs + result.betas = betas + result.v_shaped = v_shaped + if want_Jtr: + result.J_transformed = Jtr + return result + + +def verts_core(pose, + v, + J, + weights, + kintree_table, + bs_style, + want_Jtr=False, + xp=np): + + assert (bs_style == 'lbs') + result = lbs.verts_core(pose, v, J, weights, kintree_table, want_Jtr, xp) + return result \ No newline at end of file diff --git a/custom_mesh_graphormer/__init__.py b/custom_mesh_graphormer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3dc1f76bc69e3f559bee6253b24fc93acee9e1f9 --- /dev/null +++ b/custom_mesh_graphormer/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/custom_mesh_graphormer/datasets/__init__.py b/custom_mesh_graphormer/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/custom_mesh_graphormer/datasets/__init__.py @@ -0,0 +1 @@ + diff --git a/custom_mesh_graphormer/datasets/build.py b/custom_mesh_graphormer/datasets/build.py new file mode 100644 index 0000000000000000000000000000000000000000..e91c283074eb264c8780568108a7d58435a4b19f --- /dev/null +++ b/custom_mesh_graphormer/datasets/build.py @@ -0,0 +1,147 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import os.path as op +import torch +import logging +import code +from custom_mesh_graphormer.utils.comm import get_world_size +from custom_mesh_graphormer.datasets.human_mesh_tsv import (MeshTSVDataset, MeshTSVYamlDataset) +from custom_mesh_graphormer.datasets.hand_mesh_tsv import (HandMeshTSVDataset, HandMeshTSVYamlDataset) + + +def build_dataset(yaml_file, args, is_train=True, scale_factor=1): + print(yaml_file) + if not op.isfile(yaml_file): + yaml_file = op.join(args.data_dir, yaml_file) + # code.interact(local=locals()) + assert op.isfile(yaml_file) + return MeshTSVYamlDataset(yaml_file, is_train, False, scale_factor) + + +class IterationBasedBatchSampler(torch.utils.data.sampler.BatchSampler): + """ + Wraps a BatchSampler, resampling from it until + a specified number of iterations have been sampled + """ + + def __init__(self, batch_sampler, num_iterations, start_iter=0): + self.batch_sampler = batch_sampler + self.num_iterations = num_iterations + self.start_iter = start_iter + + def __iter__(self): + iteration = self.start_iter + while iteration <= self.num_iterations: + # if the underlying sampler has a set_epoch method, like + # DistributedSampler, used for making each process see + # a different split of the dataset, then set it + if hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(iteration) + for batch in self.batch_sampler: + iteration += 1 + if iteration > self.num_iterations: + break + yield batch + + def __len__(self): + return self.num_iterations + + +def make_batch_data_sampler(sampler, images_per_gpu, num_iters=None, start_iter=0): + batch_sampler = torch.utils.data.sampler.BatchSampler( + sampler, images_per_gpu, drop_last=False + ) + if num_iters is not None and num_iters >= 0: + batch_sampler = IterationBasedBatchSampler( + batch_sampler, num_iters, start_iter + ) + return batch_sampler + + +def make_data_sampler(dataset, shuffle, distributed): + if distributed: + return torch.utils.data.distributed.DistributedSampler(dataset, shuffle=shuffle) + if shuffle: + sampler = torch.utils.data.sampler.RandomSampler(dataset) + else: + sampler = torch.utils.data.sampler.SequentialSampler(dataset) + return sampler + + +def make_data_loader(args, yaml_file, is_distributed=True, + is_train=True, start_iter=0, scale_factor=1): + + dataset = build_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) + logger = logging.getLogger(__name__) + if is_train==True: + shuffle = True + images_per_gpu = args.per_gpu_train_batch_size + images_per_batch = images_per_gpu * get_world_size() + iters_per_batch = len(dataset) // images_per_batch + num_iters = iters_per_batch * args.num_train_epochs + logger.info("Train with {} images per GPU.".format(images_per_gpu)) + logger.info("Total batch size {}".format(images_per_batch)) + logger.info("Total training steps {}".format(num_iters)) + else: + shuffle = False + images_per_gpu = args.per_gpu_eval_batch_size + num_iters = None + start_iter = 0 + + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + sampler, images_per_gpu, num_iters, start_iter + ) + data_loader = torch.utils.data.DataLoader( + dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, + pin_memory=True, + ) + return data_loader + + +#============================================================================================== + +def build_hand_dataset(yaml_file, args, is_train=True, scale_factor=1): + print(yaml_file) + if not op.isfile(yaml_file): + yaml_file = op.join(args.data_dir, yaml_file) + # code.interact(local=locals()) + assert op.isfile(yaml_file) + return HandMeshTSVYamlDataset(args, yaml_file, is_train, False, scale_factor) + + +def make_hand_data_loader(args, yaml_file, is_distributed=True, + is_train=True, start_iter=0, scale_factor=1): + + dataset = build_hand_dataset(yaml_file, args, is_train=is_train, scale_factor=scale_factor) + logger = logging.getLogger(__name__) + if is_train==True: + shuffle = True + images_per_gpu = args.per_gpu_train_batch_size + images_per_batch = images_per_gpu * get_world_size() + iters_per_batch = len(dataset) // images_per_batch + num_iters = iters_per_batch * args.num_train_epochs + logger.info("Train with {} images per GPU.".format(images_per_gpu)) + logger.info("Total batch size {}".format(images_per_batch)) + logger.info("Total training steps {}".format(num_iters)) + else: + shuffle = False + images_per_gpu = args.per_gpu_eval_batch_size + num_iters = None + start_iter = 0 + + sampler = make_data_sampler(dataset, shuffle, is_distributed) + batch_sampler = make_batch_data_sampler( + sampler, images_per_gpu, num_iters, start_iter + ) + data_loader = torch.utils.data.DataLoader( + dataset, num_workers=args.num_workers, batch_sampler=batch_sampler, + pin_memory=True, + ) + return data_loader + diff --git a/custom_mesh_graphormer/datasets/hand_mesh_tsv.py b/custom_mesh_graphormer/datasets/hand_mesh_tsv.py new file mode 100644 index 0000000000000000000000000000000000000000..d4a839200a50aa353689a022e215e2013890a66c --- /dev/null +++ b/custom_mesh_graphormer/datasets/hand_mesh_tsv.py @@ -0,0 +1,334 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import cv2 +import math +import json +from PIL import Image +import os.path as op +import numpy as np +import code + +from custom_mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile +from custom_mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml +from custom_mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa +import torch +import torchvision.transforms as transforms + + +class HandMeshTSVDataset(object): + def __init__(self, args, img_file, label_file=None, hw_file=None, + linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): + + self.args = args + self.img_file = img_file + self.label_file = label_file + self.hw_file = hw_file + self.linelist_file = linelist_file + self.img_tsv = self.get_tsv_file(img_file) + self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) + self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) + + if self.is_composite: + assert op.isfile(self.linelist_file) + self.line_list = [i for i in range(self.hw_tsv.num_rows())] + else: + self.line_list = load_linelist_file(linelist_file) + + self.cv2_output = cv2_output + self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.is_train = is_train + self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] + self.noise_factor = 0.4 + self.rot_factor = 90 # Random rotation in the range [-rot_factor, rot_factor] + self.img_res = 224 + self.image_keys = self.prepare_image_keys() + self.joints_definition = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', + 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') + self.root_index = self.joints_definition.index('Wrist') + + def get_tsv_file(self, tsv_file): + if tsv_file: + if self.is_composite: + return CompositeTSVFile(tsv_file, self.linelist_file, + root=self.root) + tsv_path = find_file_path_in_yaml(tsv_file, self.root) + return TSVFile(tsv_path) + + def get_valid_tsv(self): + # sorted by file size + if self.hw_tsv: + return self.hw_tsv + if self.label_tsv: + return self.label_tsv + + def prepare_image_keys(self): + tsv = self.get_valid_tsv() + return [tsv.get_key(i) for i in range(tsv.num_rows())] + + def prepare_image_key_to_index(self): + tsv = self.get_valid_tsv() + return {tsv.get_key(i) : i for i in range(tsv.num_rows())} + + + def augm_params(self): + """Get augmentation parameters.""" + flip = 0 # flipping + pn = np.ones(3) # per channel pixel-noise + + if self.args.multiscale_inference == False: + rot = 0 # rotation + sc = 1.0 # scaling + elif self.args.multiscale_inference == True: + rot = self.args.rot + sc = self.args.sc + + if self.is_train: + sc = 1.0 + # Each channel is multiplied with a number + # in the area [1-opt.noiseFactor,1+opt.noiseFactor] + pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) + + # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] + rot = min(2*self.rot_factor, + max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) + + # The scale is multiplied with a number + # in the area [1-scaleFactor,1+scaleFactor] + sc = min(1+self.scale_factor, + max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) + # but it is zero with probability 3/5 + if np.random.uniform() <= 0.6: + rot = 0 + + return flip, pn, rot, sc + + def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): + """Process rgb image and do augmentation.""" + rgb_img = crop(rgb_img, center, scale, + [self.img_res, self.img_res], rot=rot) + # flip the image + if flip: + rgb_img = flip_img(rgb_img) + # in the rgb image we add pixel noise in a channel-wise manner + rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) + rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) + rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) + # (3,224,224),float,[0,1] + rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 + return rgb_img + + def j2d_processing(self, kp, center, scale, r, f): + """Process gt 2D keypoints and apply all augmentation transforms.""" + nparts = kp.shape[0] + for i in range(nparts): + kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, + [self.img_res, self.img_res], rot=r) + # convert to normalized coordinates + kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. + # flip the x coordinates + if f: + kp = flip_kp(kp) + kp = kp.astype('float32') + return kp + + + def j3d_processing(self, S, r, f): + """Process gt 3D keypoints and apply all augmentation transforms.""" + # in-plane rotation + rot_mat = np.eye(3) + if not r == 0: + rot_rad = -r * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) + # flip the x coordinates + if f: + S = flip_kp(S) + S = S.astype('float32') + return S + + def pose_processing(self, pose, r, f): + """Process SMPL theta parameters and apply all augmentation transforms.""" + # rotation or the pose parameters + pose = pose.astype('float32') + pose[:3] = rot_aa(pose[:3], r) + # flip the pose parameters + if f: + pose = flip_pose(pose) + # (72),float + pose = pose.astype('float32') + return pose + + def get_line_no(self, idx): + return idx if self.line_list is None else self.line_list[idx] + + def get_image(self, idx): + line_no = self.get_line_no(idx) + row = self.img_tsv[line_no] + # use -1 to support old format with multiple columns. + cv2_im = img_from_base64(row[-1]) + if self.cv2_output: + return cv2_im.astype(np.float32, copy=True) + cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) + return cv2_im + + def get_annotations(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv[line_no] + annotations = json.loads(row[1]) + return annotations + else: + return [] + + def get_target_from_annotations(self, annotations, img_size, idx): + # This function will be overwritten by each dataset to + # decode the labels to specific formats for each task. + return annotations + + def get_img_info(self, idx): + if self.hw_tsv is not None: + line_no = self.get_line_no(idx) + row = self.hw_tsv[line_no] + try: + # json string format with "height" and "width" being the keys + return json.loads(row[1])[0] + except ValueError: + # list of strings representing height and width in order + hw_str = row[1].split(' ') + hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} + return hw_dict + + def get_img_key(self, idx): + line_no = self.get_line_no(idx) + # based on the overhead of reading each row. + if self.hw_tsv: + return self.hw_tsv[line_no][0] + elif self.label_tsv: + return self.label_tsv[line_no][0] + else: + return self.img_tsv[line_no][0] + + def __len__(self): + if self.line_list is None: + return self.img_tsv.num_rows() + else: + return len(self.line_list) + + def __getitem__(self, idx): + + img = self.get_image(idx) + img_key = self.get_img_key(idx) + annotations = self.get_annotations(idx) + + annotations = annotations[0] + center = annotations['center'] + scale = annotations['scale'] + has_2d_joints = annotations['has_2d_joints'] + has_3d_joints = annotations['has_3d_joints'] + joints_2d = np.asarray(annotations['2d_joints']) + joints_3d = np.asarray(annotations['3d_joints']) + + if joints_2d.ndim==3: + joints_2d = joints_2d[0] + if joints_3d.ndim==3: + joints_3d = joints_3d[0] + + # Get SMPL parameters, if available + has_smpl = np.asarray(annotations['has_smpl']) + pose = np.asarray(annotations['pose']) + betas = np.asarray(annotations['betas']) + + # Get augmentation parameters + flip,pn,rot,sc = self.augm_params() + + # Process image + img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) + img = torch.from_numpy(img).float() + # Store image before normalization to use it in visualization + transfromed_img = self.normalize_img(img) + + # normalize 3d pose by aligning the wrist as the root (at origin) + root_coord = joints_3d[self.root_index,:-1] + joints_3d[:,:-1] = joints_3d[:,:-1] - root_coord[None,:] + # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) + joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) + # 2d pose augmentation + joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) + + ################################### + # Masking percantage + # We observe that 0% or 5% works better for 3D hand mesh + # We think this is probably becasue 3D vertices are quite sparse in the down-sampled hand mesh + mvm_percent = 0.0 # or 0.05 + ################################### + + mjm_mask = np.ones((21,1)) + if self.is_train: + num_joints = 21 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked + indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) + mjm_mask[indices,:] = 0.0 + mjm_mask = torch.from_numpy(mjm_mask).float() + + mvm_mask = np.ones((195,1)) + if self.is_train: + num_vertices = 195 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked + indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) + mvm_mask[indices,:] = 0.0 + mvm_mask = torch.from_numpy(mvm_mask).float() + + meta_data = {} + meta_data['ori_img'] = img + meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() + meta_data['betas'] = torch.from_numpy(betas).float() + meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() + meta_data['has_3d_joints'] = has_3d_joints + meta_data['has_smpl'] = has_smpl + meta_data['mjm_mask'] = mjm_mask + meta_data['mvm_mask'] = mvm_mask + + # Get 2D keypoints and apply augmentation transforms + meta_data['has_2d_joints'] = has_2d_joints + meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() + + meta_data['scale'] = float(sc * scale) + meta_data['center'] = np.asarray(center).astype(np.float32) + + return img_key, transfromed_img, meta_data + + +class HandMeshTSVYamlDataset(HandMeshTSVDataset): + """ TSVDataset taking a Yaml file for easy function call + """ + def __init__(self, args, yaml_file, is_train=True, cv2_output=False, scale_factor=1): + self.cfg = load_from_yaml_file(yaml_file) + self.is_composite = self.cfg.get('composite', False) + self.root = op.dirname(yaml_file) + + if self.is_composite==False: + img_file = find_file_path_in_yaml(self.cfg['img'], self.root) + label_file = find_file_path_in_yaml(self.cfg.get('label', None), + self.root) + hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + else: + img_file = self.cfg['img'] + hw_file = self.cfg['hw'] + label_file = self.cfg.get('label', None) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + + super(HandMeshTSVYamlDataset, self).__init__( + args, img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) diff --git a/custom_mesh_graphormer/datasets/human_mesh_tsv.py b/custom_mesh_graphormer/datasets/human_mesh_tsv.py new file mode 100644 index 0000000000000000000000000000000000000000..b60e083079f82899c3ecaffc3b8c785ec53e42ae --- /dev/null +++ b/custom_mesh_graphormer/datasets/human_mesh_tsv.py @@ -0,0 +1,337 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import cv2 +import math +import json +from PIL import Image +import os.path as op +import numpy as np +import code + +from custom_mesh_graphormer.utils.tsv_file import TSVFile, CompositeTSVFile +from custom_mesh_graphormer.utils.tsv_file_ops import load_linelist_file, load_from_yaml_file, find_file_path_in_yaml +from custom_mesh_graphormer.utils.image_ops import img_from_base64, crop, flip_img, flip_pose, flip_kp, transform, rot_aa +import torch +import torchvision.transforms as transforms + + +class MeshTSVDataset(object): + def __init__(self, img_file, label_file=None, hw_file=None, + linelist_file=None, is_train=True, cv2_output=False, scale_factor=1): + + self.img_file = img_file + self.label_file = label_file + self.hw_file = hw_file + self.linelist_file = linelist_file + self.img_tsv = self.get_tsv_file(img_file) + self.label_tsv = None if label_file is None else self.get_tsv_file(label_file) + self.hw_tsv = None if hw_file is None else self.get_tsv_file(hw_file) + + if self.is_composite: + assert op.isfile(self.linelist_file) + self.line_list = [i for i in range(self.hw_tsv.num_rows())] + else: + self.line_list = load_linelist_file(linelist_file) + + self.cv2_output = cv2_output + self.normalize_img = transforms.Normalize(mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225]) + self.is_train = is_train + self.scale_factor = 0.25 # rescale bounding boxes by a factor of [1-options.scale_factor,1+options.scale_factor] + self.noise_factor = 0.4 + self.rot_factor = 30 # Random rotation in the range [-rot_factor, rot_factor] + self.img_res = 224 + + self.image_keys = self.prepare_image_keys() + + self.joints_definition = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', + 'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') + self.pelvis_index = self.joints_definition.index('Pelvis') + + def get_tsv_file(self, tsv_file): + if tsv_file: + if self.is_composite: + return CompositeTSVFile(tsv_file, self.linelist_file, + root=self.root) + tsv_path = find_file_path_in_yaml(tsv_file, self.root) + return TSVFile(tsv_path) + + def get_valid_tsv(self): + # sorted by file size + if self.hw_tsv: + return self.hw_tsv + if self.label_tsv: + return self.label_tsv + + def prepare_image_keys(self): + tsv = self.get_valid_tsv() + return [tsv.get_key(i) for i in range(tsv.num_rows())] + + def prepare_image_key_to_index(self): + tsv = self.get_valid_tsv() + return {tsv.get_key(i) : i for i in range(tsv.num_rows())} + + + def augm_params(self): + """Get augmentation parameters.""" + flip = 0 # flipping + pn = np.ones(3) # per channel pixel-noise + rot = 0 # rotation + sc = 1 # scaling + if self.is_train: + # We flip with probability 1/2 + if np.random.uniform() <= 0.5: + flip = 1 + + # Each channel is multiplied with a number + # in the area [1-opt.noiseFactor,1+opt.noiseFactor] + pn = np.random.uniform(1-self.noise_factor, 1+self.noise_factor, 3) + + # The rotation is a number in the area [-2*rotFactor, 2*rotFactor] + rot = min(2*self.rot_factor, + max(-2*self.rot_factor, np.random.randn()*self.rot_factor)) + + # The scale is multiplied with a number + # in the area [1-scaleFactor,1+scaleFactor] + sc = min(1+self.scale_factor, + max(1-self.scale_factor, np.random.randn()*self.scale_factor+1)) + # but it is zero with probability 3/5 + if np.random.uniform() <= 0.6: + rot = 0 + + return flip, pn, rot, sc + + def rgb_processing(self, rgb_img, center, scale, rot, flip, pn): + """Process rgb image and do augmentation.""" + rgb_img = crop(rgb_img, center, scale, + [self.img_res, self.img_res], rot=rot) + # flip the image + if flip: + rgb_img = flip_img(rgb_img) + # in the rgb image we add pixel noise in a channel-wise manner + rgb_img[:,:,0] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,0]*pn[0])) + rgb_img[:,:,1] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,1]*pn[1])) + rgb_img[:,:,2] = np.minimum(255.0, np.maximum(0.0, rgb_img[:,:,2]*pn[2])) + # (3,224,224),float,[0,1] + rgb_img = np.transpose(rgb_img.astype('float32'),(2,0,1))/255.0 + return rgb_img + + def j2d_processing(self, kp, center, scale, r, f): + """Process gt 2D keypoints and apply all augmentation transforms.""" + nparts = kp.shape[0] + for i in range(nparts): + kp[i,0:2] = transform(kp[i,0:2]+1, center, scale, + [self.img_res, self.img_res], rot=r) + # convert to normalized coordinates + kp[:,:-1] = 2.*kp[:,:-1]/self.img_res - 1. + # flip the x coordinates + if f: + kp = flip_kp(kp) + kp = kp.astype('float32') + return kp + + def j3d_processing(self, S, r, f): + """Process gt 3D keypoints and apply all augmentation transforms.""" + # in-plane rotation + rot_mat = np.eye(3) + if not r == 0: + rot_rad = -r * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + S[:, :-1] = np.einsum('ij,kj->ki', rot_mat, S[:, :-1]) + # flip the x coordinates + if f: + S = flip_kp(S) + S = S.astype('float32') + return S + + def pose_processing(self, pose, r, f): + """Process SMPL theta parameters and apply all augmentation transforms.""" + # rotation or the pose parameters + pose = pose.astype('float32') + pose[:3] = rot_aa(pose[:3], r) + # flip the pose parameters + if f: + pose = flip_pose(pose) + # (72),float + pose = pose.astype('float32') + return pose + + def get_line_no(self, idx): + return idx if self.line_list is None else self.line_list[idx] + + def get_image(self, idx): + line_no = self.get_line_no(idx) + row = self.img_tsv[line_no] + # use -1 to support old format with multiple columns. + cv2_im = img_from_base64(row[-1]) + if self.cv2_output: + return cv2_im.astype(np.float32, copy=True) + cv2_im = cv2.cvtColor(cv2_im, cv2.COLOR_BGR2RGB) + + return cv2_im + + def get_annotations(self, idx): + line_no = self.get_line_no(idx) + if self.label_tsv is not None: + row = self.label_tsv[line_no] + annotations = json.loads(row[1]) + return annotations + else: + return [] + + def get_target_from_annotations(self, annotations, img_size, idx): + # This function will be overwritten by each dataset to + # decode the labels to specific formats for each task. + return annotations + + + def get_img_info(self, idx): + if self.hw_tsv is not None: + line_no = self.get_line_no(idx) + row = self.hw_tsv[line_no] + try: + # json string format with "height" and "width" being the keys + return json.loads(row[1])[0] + except ValueError: + # list of strings representing height and width in order + hw_str = row[1].split(' ') + hw_dict = {"height": int(hw_str[0]), "width": int(hw_str[1])} + return hw_dict + + def get_img_key(self, idx): + line_no = self.get_line_no(idx) + # based on the overhead of reading each row. + if self.hw_tsv: + return self.hw_tsv[line_no][0] + elif self.label_tsv: + return self.label_tsv[line_no][0] + else: + return self.img_tsv[line_no][0] + + def __len__(self): + if self.line_list is None: + return self.img_tsv.num_rows() + else: + return len(self.line_list) + + def __getitem__(self, idx): + + img = self.get_image(idx) + img_key = self.get_img_key(idx) + annotations = self.get_annotations(idx) + + annotations = annotations[0] + center = annotations['center'] + scale = annotations['scale'] + has_2d_joints = annotations['has_2d_joints'] + has_3d_joints = annotations['has_3d_joints'] + joints_2d = np.asarray(annotations['2d_joints']) + joints_3d = np.asarray(annotations['3d_joints']) + + if joints_2d.ndim==3: + joints_2d = joints_2d[0] + if joints_3d.ndim==3: + joints_3d = joints_3d[0] + + # Get SMPL parameters, if available + has_smpl = np.asarray(annotations['has_smpl']) + pose = np.asarray(annotations['pose']) + betas = np.asarray(annotations['betas']) + + try: + gender = annotations['gender'] + except KeyError: + gender = 'none' + + # Get augmentation parameters + flip,pn,rot,sc = self.augm_params() + + # Process image + img = self.rgb_processing(img, center, sc*scale, rot, flip, pn) + img = torch.from_numpy(img).float() + # Store image before normalization to use it in visualization + transfromed_img = self.normalize_img(img) + + # normalize 3d pose by aligning the pelvis as the root (at origin) + root_pelvis = joints_3d[self.pelvis_index,:-1] + joints_3d[:,:-1] = joints_3d[:,:-1] - root_pelvis[None,:] + # 3d pose augmentation (random flip + rotation, consistent to image and SMPL) + joints_3d_transformed = self.j3d_processing(joints_3d.copy(), rot, flip) + # 2d pose augmentation + joints_2d_transformed = self.j2d_processing(joints_2d.copy(), center, sc*scale, rot, flip) + + ################################### + # Masking percantage + # We observe that 30% works better for human body mesh. Further details are reported in the paper. + mvm_percent = 0.3 + ################################### + + mjm_mask = np.ones((14,1)) + if self.is_train: + num_joints = 14 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_joints) # at most x% of the joints could be masked + indices = np.random.choice(np.arange(num_joints),replace=False,size=masked_num) + mjm_mask[indices,:] = 0.0 + mjm_mask = torch.from_numpy(mjm_mask).float() + + mvm_mask = np.ones((431,1)) + if self.is_train: + num_vertices = 431 + pb = np.random.random_sample() + masked_num = int(pb * mvm_percent * num_vertices) # at most x% of the vertices could be masked + indices = np.random.choice(np.arange(num_vertices),replace=False,size=masked_num) + mvm_mask[indices,:] = 0.0 + mvm_mask = torch.from_numpy(mvm_mask).float() + + meta_data = {} + meta_data['ori_img'] = img + meta_data['pose'] = torch.from_numpy(self.pose_processing(pose, rot, flip)).float() + meta_data['betas'] = torch.from_numpy(betas).float() + meta_data['joints_3d'] = torch.from_numpy(joints_3d_transformed).float() + meta_data['has_3d_joints'] = has_3d_joints + meta_data['has_smpl'] = has_smpl + + meta_data['mjm_mask'] = mjm_mask + meta_data['mvm_mask'] = mvm_mask + + # Get 2D keypoints and apply augmentation transforms + meta_data['has_2d_joints'] = has_2d_joints + meta_data['joints_2d'] = torch.from_numpy(joints_2d_transformed).float() + meta_data['scale'] = float(sc * scale) + meta_data['center'] = np.asarray(center).astype(np.float32) + meta_data['gender'] = gender + return img_key, transfromed_img, meta_data + + + +class MeshTSVYamlDataset(MeshTSVDataset): + """ TSVDataset taking a Yaml file for easy function call + """ + def __init__(self, yaml_file, is_train=True, cv2_output=False, scale_factor=1): + self.cfg = load_from_yaml_file(yaml_file) + self.is_composite = self.cfg.get('composite', False) + self.root = op.dirname(yaml_file) + + if self.is_composite==False: + img_file = find_file_path_in_yaml(self.cfg['img'], self.root) + label_file = find_file_path_in_yaml(self.cfg.get('label', None), + self.root) + hw_file = find_file_path_in_yaml(self.cfg.get('hw', None), self.root) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + else: + img_file = self.cfg['img'] + hw_file = self.cfg['hw'] + label_file = self.cfg.get('label', None) + linelist_file = find_file_path_in_yaml(self.cfg.get('linelist', None), + self.root) + + super(MeshTSVYamlDataset, self).__init__( + img_file, label_file, hw_file, linelist_file, is_train, cv2_output=cv2_output, scale_factor=scale_factor) diff --git a/custom_mesh_graphormer/modeling/__init__.py b/custom_mesh_graphormer/modeling/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_mesh_graphormer/modeling/_gcnn.py b/custom_mesh_graphormer/modeling/_gcnn.py new file mode 100644 index 0000000000000000000000000000000000000000..f2baaa97add7357d16d0cb5360e37be35eccf2dd --- /dev/null +++ b/custom_mesh_graphormer/modeling/_gcnn.py @@ -0,0 +1,184 @@ +from __future__ import division +import torch +import torch.nn.functional as F +import numpy as np +import scipy.sparse +import math +from pathlib import Path +data_path = Path(__file__).parent / "data" + +from comfy.model_management import get_torch_device +from wrapper_for_mps import sparse_to_dense +device = get_torch_device() + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def gelu(x): + """Implementation of the gelu activation function. + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): + 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + Also see https://arxiv.org/abs/1606.08415 + """ + return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0))) + +class BertLayerNorm(torch.nn.Module): + def __init__(self, hidden_size, eps=1e-12): + """Construct a layernorm module in the TF style (epsilon inside the square root). + """ + super(BertLayerNorm, self).__init__() + self.weight = torch.nn.Parameter(torch.ones(hidden_size)) + self.bias = torch.nn.Parameter(torch.zeros(hidden_size)) + self.variance_epsilon = eps + + def forward(self, x): + u = x.mean(-1, keepdim=True) + s = (x - u).pow(2).mean(-1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.variance_epsilon) + return self.weight * x + self.bias + + +class GraphResBlock(torch.nn.Module): + """ + Graph Residual Block similar to the Bottleneck Residual Block in ResNet + """ + def __init__(self, in_channels, out_channels, mesh_type='body'): + super(GraphResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lin1 = GraphLinear(in_channels, out_channels // 2) + self.conv = GraphConvolution(out_channels // 2, out_channels // 2, mesh_type) + self.lin2 = GraphLinear(out_channels // 2, out_channels) + self.skip_conv = GraphLinear(in_channels, out_channels) + # print('Use BertLayerNorm in GraphResBlock') + self.pre_norm = BertLayerNorm(in_channels) + self.norm1 = BertLayerNorm(out_channels // 2) + self.norm2 = BertLayerNorm(out_channels // 2) + + def forward(self, x): + trans_y = F.relu(self.pre_norm(x)).transpose(1,2) + y = self.lin1(trans_y).transpose(1,2) + + y = F.relu(self.norm1(y)) + y = self.conv(y) + + trans_y = F.relu(self.norm2(y)).transpose(1,2) + y = self.lin2(trans_y).transpose(1,2) + + z = x+y + + return z + +# class GraphResBlock(torch.nn.Module): +# """ +# Graph Residual Block similar to the Bottleneck Residual Block in ResNet +# """ +# def __init__(self, in_channels, out_channels, mesh_type='body'): +# super(GraphResBlock, self).__init__() +# self.in_channels = in_channels +# self.out_channels = out_channels +# self.conv = GraphConvolution(self.in_channels, self.out_channels, mesh_type) +# print('Use BertLayerNorm and GeLU in GraphResBlock') +# self.norm = BertLayerNorm(self.out_channels) +# def forward(self, x): +# y = self.conv(x) +# y = self.norm(y) +# y = gelu(y) +# z = x+y +# return z + +class GraphLinear(torch.nn.Module): + """ + Generalization of 1x1 convolutions on Graphs + """ + def __init__(self, in_channels, out_channels): + super(GraphLinear, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.W = torch.nn.Parameter(torch.FloatTensor(out_channels, in_channels)) + self.b = torch.nn.Parameter(torch.FloatTensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + w_stdv = 1 / (self.in_channels * self.out_channels) + self.W.data.uniform_(-w_stdv, w_stdv) + self.b.data.uniform_(-w_stdv, w_stdv) + + def forward(self, x): + return torch.matmul(self.W[None, :], x) + self.b[None, :, None] + +class GraphConvolution(torch.nn.Module): + """Simple GCN layer, similar to https://arxiv.org/abs/1609.02907.""" + def __init__(self, in_features, out_features, mesh='body', bias=True): + super(GraphConvolution, self).__init__() + self.in_features = in_features + self.out_features = out_features + + if mesh=='body': + adj_indices = torch.load(data_path / 'smpl_431_adjmat_indices.pt') + adj_mat_value = torch.load(data_path / 'smpl_431_adjmat_values.pt') + adj_mat_size = torch.load(data_path / 'smpl_431_adjmat_size.pt') + elif mesh=='hand': + adj_indices = torch.load(data_path / 'mano_195_adjmat_indices.pt') + adj_mat_value = torch.load(data_path / 'mano_195_adjmat_values.pt') + adj_mat_size = torch.load(data_path / 'mano_195_adjmat_size.pt') + + self.adjmat = sparse_to_dense(torch.sparse_coo_tensor(adj_indices, adj_mat_value, size=adj_mat_size)).to(device) + + self.weight = torch.nn.Parameter(torch.FloatTensor(in_features, out_features)) + if bias: + self.bias = torch.nn.Parameter(torch.FloatTensor(out_features)) + else: + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self): + # stdv = 1. / math.sqrt(self.weight.size(1)) + stdv = 6. / math.sqrt(self.weight.size(0) + self.weight.size(1)) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.uniform_(-stdv, stdv) + + def forward(self, x): + if x.ndimension() == 2: + support = torch.matmul(x, self.weight) + output = torch.matmul(self.adjmat, support) + if self.bias is not None: + output = output + self.bias + return output + else: + output = [] + for i in range(x.shape[0]): + support = torch.matmul(x[i], self.weight) + # output.append(torch.matmul(self.adjmat, support)) + output.append(spmm(self.adjmat, support)) + output = torch.stack(output, dim=0) + if self.bias is not None: + output = output + self.bias + return output + + def __repr__(self): + return self.__class__.__name__ + ' (' \ + + str(self.in_features) + ' -> ' \ + + str(self.out_features) + ')' \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/_mano.py b/custom_mesh_graphormer/modeling/_mano.py new file mode 100644 index 0000000000000000000000000000000000000000..8374b6300ce671ebe1ba6070c6002cdf07ad396c --- /dev/null +++ b/custom_mesh_graphormer/modeling/_mano.py @@ -0,0 +1,184 @@ +""" +This file contains the MANO defination and mesh sampling operations for MANO mesh + +Adapted from opensource projects +MANOPTH (https://github.com/hassony2/manopth) +Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) +GraphCMR (https://github.com/nkolot/GraphCMR/) +""" + +from __future__ import division +import numpy as np +import torch +import torch.nn as nn +import os.path as osp +import json +import code +from custom_manopth.manolayer import ManoLayer +import scipy.sparse +import custom_mesh_graphormer.modeling.data.config as cfg +from pathlib import Path + +from comfy.model_management import get_torch_device +from wrapper_for_mps import sparse_to_dense +device = get_torch_device() + +class MANO(nn.Module): + def __init__(self): + super(MANO, self).__init__() + + self.mano_dir = str(Path(__file__).parent / "data") + self.layer = self.get_layer() + self.vertex_num = 778 + self.face = self.layer.th_faces.numpy() + self.joint_regressor = self.layer.th_J_regressor.numpy() + + self.joint_num = 21 + self.joints_name = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', 'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') + self.skeleton = ( (0,1), (0,5), (0,9), (0,13), (0,17), (1,2), (2,3), (3,4), (5,6), (6,7), (7,8), (9,10), (10,11), (11,12), (13,14), (14,15), (15,16), (17,18), (18,19), (19,20) ) + self.root_joint_idx = self.joints_name.index('Wrist') + + # add fingertips to joint_regressor + self.fingertip_vertex_idx = [745, 317, 444, 556, 673] # mesh vertex idx (right hand) + thumbtip_onehot = np.array([1 if i == 745 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + indextip_onehot = np.array([1 if i == 317 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + middletip_onehot = np.array([1 if i == 445 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + ringtip_onehot = np.array([1 if i == 556 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + pinkytip_onehot = np.array([1 if i == 673 else 0 for i in range(self.joint_regressor.shape[1])], dtype=np.float32).reshape(1,-1) + self.joint_regressor = np.concatenate((self.joint_regressor, thumbtip_onehot, indextip_onehot, middletip_onehot, ringtip_onehot, pinkytip_onehot)) + self.joint_regressor = self.joint_regressor[[0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20],:] + joint_regressor_torch = torch.from_numpy(self.joint_regressor).float() + self.register_buffer('joint_regressor_torch', joint_regressor_torch) + + def get_layer(self): + return ManoLayer(mano_root=osp.join(self.mano_dir), flat_hand_mean=False, use_pca=False) # load right hand MANO model + + def get_3d_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 778, 3) + Output: + 3D joints: size = (B, 21, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.joint_regressor_torch]) + return joints + + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename=cfg.MANO_sampling_matrix, + num_downsampling=1, nsize=1, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + # self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x diff --git a/custom_mesh_graphormer/modeling/_smpl.py b/custom_mesh_graphormer/modeling/_smpl.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b46413005e564b0191ea8e8bb9b0641ea97766 --- /dev/null +++ b/custom_mesh_graphormer/modeling/_smpl.py @@ -0,0 +1,283 @@ +""" +This file contains the definition of the SMPL model + +It is adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) +""" +from __future__ import division + +import torch +import torch.nn as nn +import numpy as np +import scipy.sparse +try: + import cPickle as pickle +except ImportError: + import pickle + +from custom_mesh_graphormer.utils.geometric_layers import rodrigues +import custom_mesh_graphormer.modeling.data.config as cfg + +from comfy.model_management import get_torch_device +from wrapper_for_mps import sparse_to_dense +device = get_torch_device() + +class SMPL(nn.Module): + + def __init__(self, gender='neutral'): + super(SMPL, self).__init__() + + if gender=='m': + model_file=cfg.SMPL_Male + elif gender=='f': + model_file=cfg.SMPL_Female + else: + model_file=cfg.SMPL_FILE + + smpl_model = pickle.load(open(model_file, 'rb'), encoding='latin1') + J_regressor = smpl_model['J_regressor'].tocoo() + row = J_regressor.row + col = J_regressor.col + data = J_regressor.data + i = torch.LongTensor([row, col]) + v = torch.FloatTensor(data) + J_regressor_shape = [24, 6890] + self.register_buffer('J_regressor', torch.sparse_coo_tensor(i, v, J_regressor_shape).to_dense()) + self.register_buffer('weights', torch.FloatTensor(smpl_model['weights'])) + self.register_buffer('posedirs', torch.FloatTensor(smpl_model['posedirs'])) + self.register_buffer('v_template', torch.FloatTensor(smpl_model['v_template'])) + self.register_buffer('shapedirs', torch.FloatTensor(np.array(smpl_model['shapedirs']))) + self.register_buffer('faces', torch.from_numpy(smpl_model['f'].astype(np.int64))) + self.register_buffer('kintree_table', torch.from_numpy(smpl_model['kintree_table'].astype(np.int64))) + id_to_col = {self.kintree_table[1, i].item(): i for i in range(self.kintree_table.shape[1])} + self.register_buffer('parent', torch.LongTensor([id_to_col[self.kintree_table[0, it].item()] for it in range(1, self.kintree_table.shape[1])])) + + self.pose_shape = [24, 3] + self.beta_shape = [10] + self.translation_shape = [3] + + self.pose = torch.zeros(self.pose_shape) + self.beta = torch.zeros(self.beta_shape) + self.translation = torch.zeros(self.translation_shape) + + self.verts = None + self.J = None + self.R = None + + J_regressor_extra = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_TRAIN_EXTRA)).float() + self.register_buffer('J_regressor_extra', J_regressor_extra) + self.joints_idx = cfg.JOINTS_IDX + + J_regressor_h36m_correct = torch.from_numpy(np.load(cfg.JOINT_REGRESSOR_H36M_correct)).float() + self.register_buffer('J_regressor_h36m_correct', J_regressor_h36m_correct) + + + def forward(self, pose, beta): + device = pose.device + batch_size = pose.shape[0] + v_template = self.v_template[None, :] + shapedirs = self.shapedirs.view(-1,10)[None, :].expand(batch_size, -1, -1) + beta = beta[:, :, None] + v_shaped = torch.matmul(shapedirs, beta).view(-1, 6890, 3) + v_template + # batched sparse matmul not supported in pytorch + J = [] + for i in range(batch_size): + J.append(torch.matmul(self.J_regressor, v_shaped[i])) + J = torch.stack(J, dim=0) + # input it rotmat: (bs,24,3,3) + if pose.ndimension() == 4: + R = pose + # input it rotmat: (bs,72) + elif pose.ndimension() == 2: + pose_cube = pose.view(-1, 3) # (batch_size * 24, 1, 3) + R = rodrigues(pose_cube).view(batch_size, 24, 3, 3) + R = R.view(batch_size, 24, 3, 3) + I_cube = torch.eye(3)[None, None, :].to(device) + # I_cube = torch.eye(3)[None, None, :].expand(theta.shape[0], R.shape[1]-1, -1, -1) + lrotmin = (R[:,1:,:] - I_cube).view(batch_size, -1) + posedirs = self.posedirs.view(-1,207)[None, :].expand(batch_size, -1, -1) + v_posed = v_shaped + torch.matmul(posedirs, lrotmin[:, :, None]).view(-1, 6890, 3) + J_ = J.clone() + J_[:, 1:, :] = J[:, 1:, :] - J[:, self.parent, :] + G_ = torch.cat([R, J_[:, :, :, None]], dim=-1) + pad_row = torch.FloatTensor([0,0,0,1]).to(device).view(1,1,1,4).expand(batch_size, 24, -1, -1) + G_ = torch.cat([G_, pad_row], dim=2) + G = [G_[:, 0].clone()] + for i in range(1, 24): + G.append(torch.matmul(G[self.parent[i-1]], G_[:, i, :, :])) + G = torch.stack(G, dim=1) + + rest = torch.cat([J, torch.zeros(batch_size, 24, 1).to(device)], dim=2).view(batch_size, 24, 4, 1) + zeros = torch.zeros(batch_size, 24, 4, 3).to(device) + rest = torch.cat([zeros, rest], dim=-1) + rest = torch.matmul(G, rest) + G = G - rest + T = torch.matmul(self.weights, G.permute(1,0,2,3).contiguous().view(24,-1)).view(6890, batch_size, 4, 4).transpose(0,1) + rest_shape_h = torch.cat([v_posed, torch.ones_like(v_posed)[:, :, [0]]], dim=-1) + v = torch.matmul(T, rest_shape_h[:, :, :, None])[:, :, :3, 0] + return v + + def get_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 38, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor]) + joints_extra = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_extra]) + joints = torch.cat((joints, joints_extra), dim=1) + joints = joints[:, cfg.JOINTS_IDX] + return joints + + def get_h36m_joints(self, vertices): + """ + This method is used to get the joint locations from the SMPL mesh + Input: + vertices: size = (B, 6890, 3) + Output: + 3D joints: size = (B, 24, 3) + """ + joints = torch.einsum('bik,ji->bjk', [vertices, self.J_regressor_h36m_correct]) + return joints + +class SparseMM(torch.autograd.Function): + """Redefine sparse @ dense matrix multiplication to enable backpropagation. + The builtin matrix multiplication operation does not support backpropagation in some cases. + """ + @staticmethod + def forward(ctx, sparse, dense): + ctx.req_grad = dense.requires_grad + ctx.save_for_backward(sparse) + return torch.matmul(sparse, dense) + + @staticmethod + def backward(ctx, grad_output): + grad_input = None + sparse, = ctx.saved_tensors + if ctx.req_grad: + grad_input = torch.matmul(sparse.t(), grad_output) + return None, grad_input + +def spmm(sparse, dense): + sparse = sparse.to(device) + dense = dense.to(device) + return SparseMM.apply(sparse, dense) + + +def scipy_to_pytorch(A, U, D): + """Convert scipy sparse matrices to pytorch sparse matrix.""" + ptU = [] + ptD = [] + + for i in range(len(U)): + u = scipy.sparse.coo_matrix(U[i]) + i = torch.LongTensor(np.array([u.row, u.col])) + v = torch.FloatTensor(u.data) + ptU.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, u.shape))) + + for i in range(len(D)): + d = scipy.sparse.coo_matrix(D[i]) + i = torch.LongTensor(np.array([d.row, d.col])) + v = torch.FloatTensor(d.data) + ptD.append(sparse_to_dense(torch.sparse_coo_tensor(i, v, d.shape))) + + return ptU, ptD + + +def adjmat_sparse(adjmat, nsize=1): + """Create row-normalized sparse graph adjacency matrix.""" + adjmat = scipy.sparse.csr_matrix(adjmat) + if nsize > 1: + orig_adjmat = adjmat.copy() + for _ in range(1, nsize): + adjmat = adjmat * orig_adjmat + adjmat.data = np.ones_like(adjmat.data) + for i in range(adjmat.shape[0]): + adjmat[i,i] = 1 + num_neighbors = np.array(1 / adjmat.sum(axis=-1)) + adjmat = adjmat.multiply(num_neighbors) + adjmat = scipy.sparse.coo_matrix(adjmat) + row = adjmat.row + col = adjmat.col + data = adjmat.data + i = torch.LongTensor(np.array([row, col])) + v = torch.from_numpy(data).float() + adjmat = sparse_to_dense(torch.sparse_coo_tensor(i, v, adjmat.shape)) + return adjmat + +def get_graph_params(filename, nsize=1): + """Load and process graph adjacency matrix and upsampling/downsampling matrices.""" + data = np.load(filename, encoding='latin1', allow_pickle=True) + A = data['A'] + U = data['U'] + D = data['D'] + U, D = scipy_to_pytorch(A, U, D) + A = [adjmat_sparse(a, nsize=nsize) for a in A] + return A, U, D + + +class Mesh(object): + """Mesh object that is used for handling certain graph operations.""" + def __init__(self, filename=cfg.SMPL_sampling_matrix, + num_downsampling=1, nsize=1, device=torch.device('cuda')): + self._A, self._U, self._D = get_graph_params(filename=filename, nsize=nsize) + # self._A = [a.to(device) for a in self._A] + self._U = [u.to(device) for u in self._U] + self._D = [d.to(device) for d in self._D] + self.num_downsampling = num_downsampling + + # load template vertices from SMPL and normalize them + smpl = SMPL() + ref_vertices = smpl.v_template + center = 0.5*(ref_vertices.max(dim=0)[0] + ref_vertices.min(dim=0)[0])[None] + ref_vertices -= center + ref_vertices /= ref_vertices.abs().max().item() + + self._ref_vertices = ref_vertices.to(device) + self.faces = smpl.faces.int().to(device) + + # @property + # def adjmat(self): + # """Return the graph adjacency matrix at the specified subsampling level.""" + # return self._A[self.num_downsampling].float() + + @property + def ref_vertices(self): + """Return the template vertices at the specified subsampling level.""" + ref_vertices = self._ref_vertices + for i in range(self.num_downsampling): + ref_vertices = torch.spmm(self._D[i], ref_vertices) + return ref_vertices + + def downsample(self, x, n1=0, n2=None): + """Downsample mesh.""" + if n2 is None: + n2 = self.num_downsampling + if x.ndimension() < 3: + for i in range(n1, n2): + x = spmm(self._D[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in range(n1, n2): + y = spmm(self._D[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x + + def upsample(self, x, n1=1, n2=0): + """Upsample mesh.""" + if x.ndimension() < 3: + for i in reversed(range(n2, n1)): + x = spmm(self._U[i], x) + elif x.ndimension() == 3: + out = [] + for i in range(x.shape[0]): + y = x[i] + for j in reversed(range(n2, n1)): + y = spmm(self._U[j], y) + out.append(y) + x = torch.stack(out, dim=0) + return x diff --git a/custom_mesh_graphormer/modeling/bert/__init__.py b/custom_mesh_graphormer/modeling/bert/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..197c5b92786fcb15d754555c2cd4d3531f98a0e0 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/__init__.py @@ -0,0 +1,17 @@ +__version__ = "1.0.0" + +from .modeling_bert import (BertConfig, BertModel, + load_tf_weights_in_bert) + +from .modeling_graphormer import Graphormer + +from .e2e_body_network import Graphormer_Body_Network + +from .e2e_hand_network import Graphormer_Hand_Network + +CONFIG_NAME = "config.json" + +from .modeling_utils import (WEIGHTS_NAME, TF_WEIGHTS_NAME, + PretrainedConfig, PreTrainedModel, prune_layer, Conv1D) + +from .file_utils import (PYTORCH_PRETRAINED_BERT_CACHE) diff --git a/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json b/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json new file mode 100644 index 0000000000000000000000000000000000000000..79276673252f15cea400800731e0d4e3d3cba64f --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/bert-base-uncased/config.json @@ -0,0 +1,16 @@ +{ + "architectures": [ + "BertForMaskedLM" + ], + "attention_probs_dropout_prob": 0.1, + "hidden_act": "gelu", + "hidden_dropout_prob": 0.1, + "hidden_size": 768, + "initializer_range": 0.02, + "intermediate_size": 3072, + "max_position_embeddings": 512, + "num_attention_heads": 12, + "num_hidden_layers": 12, + "type_vocab_size": 2, + "vocab_size": 30522 +} diff --git a/custom_mesh_graphormer/modeling/bert/e2e_body_network.py b/custom_mesh_graphormer/modeling/bert/e2e_body_network.py new file mode 100644 index 0000000000000000000000000000000000000000..f570ec58f66c20802898103555949e74ff56930c --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/e2e_body_network.py @@ -0,0 +1,103 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import custom_mesh_graphormer.modeling.data.config as cfg +from comfy.model_management import get_torch_device +device = get_torch_device() + +class Graphormer_Body_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for human pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder, mesh_sampler): + super(Graphormer_Body_Network, self).__init__() + self.config = config + self.config.device = device + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(431, 1723) + self.upsampling2 = torch.nn.Linear(1723, 6890) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(431, 250) + self.cam_param_fc3 = torch.nn.Linear(250, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + + def forward(self, images, smpl, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1,72)) + template_pose[:,0] = 3.1416 # Rectify "upside down" reference mesh in global coord + template_pose = template_pose.to(device) + template_betas = torch.zeros((1,10)).to(device) + template_vertices = smpl(template_pose, template_betas) + + # template mesh simplification + template_vertices_sub = mesh_sampler.downsample(template_vertices) + template_vertices_sub2 = mesh_sampler.downsample(template_vertices_sub, n1=1, n2=2) + + # template mesh-to-joint regression + template_3d_joints = smpl.get_h36m_joints(template_vertices) + template_pelvis = template_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + template_3d_joints = template_3d_joints[:,cfg.H36M_J17_TO_J14,:] + num_joints = template_3d_joints.shape[1] + + # normalize + template_3d_joints = template_3d_joints - template_pelvis[:, None, :] + template_vertices_sub2 = template_vertices_sub2 - template_pelvis[:, None, :] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub2],dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + # concatinate image feat and 3d mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1,2) + grid_feat = self.grid_feat_dim(grid_feat) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat],dim=1) + + if is_train==True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01 + features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + + # forward pass + if self.config.output_attentions==True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:,:num_joints,:] + pred_vertices_sub2 = features[:,num_joints:-49,:] + + # learn camera parameters + x = self.cam_param_fc(pred_vertices_sub2) + x = x.transpose(1,2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1,2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub2.transpose(1,2) + pred_vertices_sub = self.upsampling(temp_transpose) + pred_vertices_full = self.upsampling2(pred_vertices_sub) + pred_vertices_sub = pred_vertices_sub.transpose(1,2) + pred_vertices_full = pred_vertices_full.transpose(1,2) + + if self.config.output_attentions==True: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices_full \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py b/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py new file mode 100644 index 0000000000000000000000000000000000000000..57cba0a830c2ca975048fd1a3de9f8341196ba43 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/e2e_hand_network.py @@ -0,0 +1,94 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +import torch +import custom_mesh_graphormer.modeling.data.config as cfg +from comfy.model_management import get_torch_device +device = get_torch_device() + +class Graphormer_Hand_Network(torch.nn.Module): + ''' + End-to-end Graphormer network for hand pose and mesh reconstruction from a single image. + ''' + def __init__(self, args, config, backbone, trans_encoder): + super(Graphormer_Hand_Network, self).__init__() + self.config = config + self.backbone = backbone + self.trans_encoder = trans_encoder + self.upsampling = torch.nn.Linear(195, 778) + self.cam_param_fc = torch.nn.Linear(3, 1) + self.cam_param_fc2 = torch.nn.Linear(195+21, 150) + self.cam_param_fc3 = torch.nn.Linear(150, 3) + self.grid_feat_dim = torch.nn.Linear(1024, 2051) + + def forward(self, images, mesh_model, mesh_sampler, meta_masks=None, is_train=False): + batch_size = images.size(0) + # Generate T-pose template mesh + template_pose = torch.zeros((1,48)) + template_pose = template_pose.to(device) + template_betas = torch.zeros((1,10)).to(device) + template_vertices, template_3d_joints = mesh_model.layer(template_pose, template_betas) + template_vertices = template_vertices/1000.0 + template_3d_joints = template_3d_joints/1000.0 + + template_vertices_sub = mesh_sampler.downsample(template_vertices) + + # normalize + template_root = template_3d_joints[:,cfg.J_NAME.index('Wrist'),:] + template_3d_joints = template_3d_joints - template_root[:, None, :] + template_vertices = template_vertices - template_root[:, None, :] + template_vertices_sub = template_vertices_sub - template_root[:, None, :] + num_joints = template_3d_joints.shape[1] + + # concatinate template joints and template vertices, and then duplicate to batch size + ref_vertices = torch.cat([template_3d_joints, template_vertices_sub],dim=1) + ref_vertices = ref_vertices.expand(batch_size, -1, -1) + + # extract grid features and global image features using a CNN backbone + image_feat, grid_feat = self.backbone(images) + # concatinate image feat and mesh template + image_feat = image_feat.view(batch_size, 1, 2048).expand(-1, ref_vertices.shape[-2], -1) + # process grid features + grid_feat = torch.flatten(grid_feat, start_dim=2) + grid_feat = grid_feat.transpose(1,2) + grid_feat = self.grid_feat_dim(grid_feat) + # concatinate image feat and template mesh to form the joint/vertex queries + features = torch.cat([ref_vertices, image_feat], dim=2) + # prepare input tokens including joint/vertex queries and grid features + features = torch.cat([features, grid_feat],dim=1) + + if is_train==True: + # apply mask vertex/joint modeling + # meta_masks is a tensor of all the masks, randomly generated in dataloader + # we pre-define a [MASK] token, which is a floating-value vector with 0.01s + special_token = torch.ones_like(features[:,:-49,:]).to(device)*0.01 + features[:,:-49,:] = features[:,:-49,:]*meta_masks + special_token*(1-meta_masks) + + # forward pass + if self.config.output_attentions==True: + features, hidden_states, att = self.trans_encoder(features) + else: + features = self.trans_encoder(features) + + pred_3d_joints = features[:,:num_joints,:] + pred_vertices_sub = features[:,num_joints:-49,:] + + # learn camera parameters + x = self.cam_param_fc(features[:,:-49,:]) + x = x.transpose(1,2) + x = self.cam_param_fc2(x) + x = self.cam_param_fc3(x) + cam_param = x.transpose(1,2) + cam_param = cam_param.squeeze() + + temp_transpose = pred_vertices_sub.transpose(1,2) + pred_vertices = self.upsampling(temp_transpose) + pred_vertices = pred_vertices.transpose(1,2) + + if self.config.output_attentions==True: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att + else: + return cam_param, pred_3d_joints, pred_vertices_sub, pred_vertices \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/bert/file_utils.py b/custom_mesh_graphormer/modeling/bert/file_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0b26c2ba2e0a5371ee599dc5253b160fa04ba510 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/file_utils.py @@ -0,0 +1 @@ +from transformers.file_utils import * \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/bert/modeling_bert.py b/custom_mesh_graphormer/modeling/bert/modeling_bert.py new file mode 100644 index 0000000000000000000000000000000000000000..820b69ae270b9bc723f000ba78f467269de16684 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/modeling_bert.py @@ -0,0 +1 @@ +from transformers.models.bert.modeling_bert import * \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py b/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py new file mode 100644 index 0000000000000000000000000000000000000000..e7c6e2625afaee69f4f9350374443d4f21e881a1 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/modeling_graphormer.py @@ -0,0 +1,328 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import logging +import math +import os +import code +import torch +from torch import nn +from .modeling_bert import BertPreTrainedModel, BertEmbeddings, BertPooler, BertIntermediate, BertOutput, BertSelfOutput +import custom_mesh_graphormer.modeling.data.config as cfg +from custom_mesh_graphormer.modeling._gcnn import GraphConvolution, GraphResBlock +from .modeling_utils import prune_linear_layer +LayerNormClass = torch.nn.LayerNorm +BertLayerNorm = torch.nn.LayerNorm +from comfy.model_management import get_torch_device +device = get_torch_device() + + +class BertSelfAttention(nn.Module): + def __init__(self, config): + super(BertSelfAttention, self).__init__() + if config.hidden_size % config.num_attention_heads != 0: + raise ValueError( + "The hidden size (%d) is not a multiple of the number of attention " + "heads (%d)" % (config.hidden_size, config.num_attention_heads)) + self.output_attentions = config.output_attentions + + self.num_attention_heads = config.num_attention_heads + self.attention_head_size = int(config.hidden_size / config.num_attention_heads) + self.all_head_size = self.num_attention_heads * self.attention_head_size + + self.query = nn.Linear(config.hidden_size, self.all_head_size) + self.key = nn.Linear(config.hidden_size, self.all_head_size) + self.value = nn.Linear(config.hidden_size, self.all_head_size) + + self.dropout = nn.Dropout(config.attention_probs_dropout_prob) + + def transpose_for_scores(self, x): + new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) + x = x.view(*new_x_shape) + return x.permute(0, 2, 1, 3) + + def forward(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + if history_state is not None: + x_states = torch.cat([history_state, hidden_states], dim=1) + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(x_states) + mixed_value_layer = self.value(x_states) + else: + mixed_query_layer = self.query(hidden_states) + mixed_key_layer = self.key(hidden_states) + mixed_value_layer = self.value(hidden_states) + + query_layer = self.transpose_for_scores(mixed_query_layer) + key_layer = self.transpose_for_scores(mixed_key_layer) + value_layer = self.transpose_for_scores(mixed_value_layer) + + # Take the dot product between "query" and "key" to get the raw attention scores. + attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) + attention_scores = attention_scores / math.sqrt(self.attention_head_size) + # Apply the attention mask is (precomputed for all layers in BertModel forward() function) + attention_scores = attention_scores + attention_mask + + # Normalize the attention scores to probabilities. + attention_probs = nn.Softmax(dim=-1)(attention_scores) + + # This is actually dropping out entire tokens to attend to, which might + # seem a bit unusual, but is taken from the original Transformer paper. + attention_probs = self.dropout(attention_probs) + + # Mask heads if we want to + if head_mask is not None: + attention_probs = attention_probs * head_mask + + context_layer = torch.matmul(attention_probs, value_layer) + + context_layer = context_layer.permute(0, 2, 1, 3).contiguous() + new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) + context_layer = context_layer.view(*new_context_layer_shape) + + outputs = (context_layer, attention_probs) if self.output_attentions else (context_layer,) + return outputs + +class BertAttention(nn.Module): + def __init__(self, config): + super(BertAttention, self).__init__() + self.self = BertSelfAttention(config) + self.output = BertSelfOutput(config) + + def prune_heads(self, heads): + if len(heads) == 0: + return + mask = torch.ones(self.self.num_attention_heads, self.self.attention_head_size) + for head in heads: + mask[head] = 0 + mask = mask.view(-1).contiguous().eq(1) + index = torch.arange(len(mask))[mask].long() + # Prune linear layers + self.self.query = prune_linear_layer(self.self.query, index) + self.self.key = prune_linear_layer(self.self.key, index) + self.self.value = prune_linear_layer(self.self.value, index) + self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) + # Update hyper params + self.self.num_attention_heads = self.self.num_attention_heads - len(heads) + self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads + + def forward(self, input_tensor, attention_mask, head_mask=None, + history_state=None): + self_outputs = self.self(input_tensor, attention_mask, head_mask, + history_state) + attention_output = self.output(self_outputs[0], input_tensor) + outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them + return outputs + + +class GraphormerLayer(nn.Module): + def __init__(self, config): + super(GraphormerLayer, self).__init__() + self.attention = BertAttention(config) + self.has_graph_conv = config.graph_conv + self.mesh_type = config.mesh_type + + if self.has_graph_conv == True: + self.graph_conv = GraphResBlock(config.hidden_size, config.hidden_size, mesh_type=self.mesh_type) + + self.intermediate = BertIntermediate(config) + self.output = BertOutput(config) + + def MHA_GCN(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + attention_outputs = self.attention(hidden_states, attention_mask, + head_mask, history_state) + attention_output = attention_outputs[0] + + if self.has_graph_conv==True: + if self.mesh_type == 'body': + joints = attention_output[:,0:14,:] + vertices = attention_output[:,14:-49,:] + img_tokens = attention_output[:,-49:,:] + + elif self.mesh_type == 'hand': + joints = attention_output[:,0:21,:] + vertices = attention_output[:,21:-49,:] + img_tokens = attention_output[:,-49:,:] + + vertices = self.graph_conv(vertices) + joints_vertices = torch.cat([joints,vertices,img_tokens],dim=1) + else: + joints_vertices = attention_output + + intermediate_output = self.intermediate(joints_vertices) + layer_output = self.output(intermediate_output, joints_vertices) + outputs = (layer_output,) + attention_outputs[1:] # add attentions if we output them + return outputs + + def forward(self, hidden_states, attention_mask, head_mask=None, + history_state=None): + return self.MHA_GCN(hidden_states, attention_mask, head_mask,history_state) + + +class GraphormerEncoder(nn.Module): + def __init__(self, config): + super(GraphormerEncoder, self).__init__() + self.output_attentions = config.output_attentions + self.output_hidden_states = config.output_hidden_states + self.layer = nn.ModuleList([GraphormerLayer(config) for _ in range(config.num_hidden_layers)]) + + def forward(self, hidden_states, attention_mask, head_mask=None, + encoder_history_states=None): + all_hidden_states = () + all_attentions = () + for i, layer_module in enumerate(self.layer): + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + history_state = None if encoder_history_states is None else encoder_history_states[i] + layer_outputs = layer_module( + hidden_states, attention_mask, head_mask[i], + history_state) + hidden_states = layer_outputs[0] + + if self.output_attentions: + all_attentions = all_attentions + (layer_outputs[1],) + + # Add last layer + if self.output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + outputs = (hidden_states,) + if self.output_hidden_states: + outputs = outputs + (all_hidden_states,) + if self.output_attentions: + outputs = outputs + (all_attentions,) + + return outputs # outputs, (hidden states), (attentions) + +class EncoderBlock(BertPreTrainedModel): + def __init__(self, config): + super(EncoderBlock, self).__init__(config) + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = GraphormerEncoder(config) + self.pooler = BertPooler(config) + self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) + self.img_dim = config.img_feature_dim + + try: + self.use_img_layernorm = config.use_img_layernorm + except: + self.use_img_layernorm = None + + self.img_embedding = nn.Linear(self.img_dim, self.config.hidden_size, bias=True) + self.dropout = nn.Dropout(config.hidden_dropout_prob) + if self.use_img_layernorm: + self.LayerNorm = LayerNormClass(config.hidden_size, eps=config.img_layer_norm_eps) + + + def _prune_heads(self, heads_to_prune): + """ Prunes heads of the model. + heads_to_prune: dict of {layer_num: list of heads to prune in this layer} + See base class PreTrainedModel + """ + for layer, heads in heads_to_prune.items(): + self.encoder.layer[layer].attention.prune_heads(heads) + + def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, + position_ids=None, head_mask=None): + + batch_size = len(img_feats) + seq_length = len(img_feats[0]) + input_ids = torch.zeros([batch_size, seq_length],dtype=torch.long).to(device) + + if position_ids is None: + position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) + position_ids = position_ids.unsqueeze(0).expand_as(input_ids) + + position_embeddings = self.position_embeddings(position_ids) + + if attention_mask is None: + attention_mask = torch.ones_like(input_ids) + + if token_type_ids is None: + token_type_ids = torch.zeros_like(input_ids) + + if attention_mask.dim() == 2: + extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + extended_attention_mask = attention_mask.unsqueeze(1) + else: + raise NotImplementedError + + extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility + extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0 + + if head_mask is not None: + if head_mask.dim() == 1: + head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) + head_mask = head_mask.expand(self.config.num_hidden_layers, -1, -1, -1, -1) + elif head_mask.dim() == 2: + head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer + head_mask = head_mask.to(dtype=next(self.parameters()).dtype) # switch to fload if need + fp16 compatibility + else: + head_mask = [None] * self.config.num_hidden_layers + + # Project input token features to have spcified hidden size + img_embedding_output = self.img_embedding(img_feats) + + # We empirically observe that adding an additional learnable position embedding leads to more stable training + embeddings = position_embeddings + img_embedding_output + + if self.use_img_layernorm: + embeddings = self.LayerNorm(embeddings) + embeddings = self.dropout(embeddings) + + encoder_outputs = self.encoder(embeddings, + extended_attention_mask, head_mask=head_mask) + sequence_output = encoder_outputs[0] + + outputs = (sequence_output,) + if self.config.output_hidden_states: + all_hidden_states = encoder_outputs[1] + outputs = outputs + (all_hidden_states,) + if self.config.output_attentions: + all_attentions = encoder_outputs[-1] + outputs = outputs + (all_attentions,) + + return outputs + +class Graphormer(BertPreTrainedModel): + ''' + The archtecture of a transformer encoder block we used in Graphormer + ''' + def __init__(self, config): + super(Graphormer, self).__init__(config) + self.config = config + self.bert = EncoderBlock(config) + self.cls_head = nn.Linear(config.hidden_size, self.config.output_feature_dim) + self.residual = nn.Linear(config.img_feature_dim, self.config.output_feature_dim) + + def forward(self, img_feats, input_ids=None, token_type_ids=None, attention_mask=None, masked_lm_labels=None, + next_sentence_label=None, position_ids=None, head_mask=None): + ''' + # self.bert has three outputs + # predictions[0]: output tokens + # predictions[1]: all_hidden_states, if enable "self.config.output_hidden_states" + # predictions[2]: attentions, if enable "self.config.output_attentions" + ''' + predictions = self.bert(img_feats=img_feats, input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, + attention_mask=attention_mask, head_mask=head_mask) + + # We use "self.cls_head" to perform dimensionality reduction. We don't use it for classification. + pred_score = self.cls_head(predictions[0]) + res_img_feats = self.residual(img_feats) + pred_score = pred_score + res_img_feats + + if self.config.output_attentions and self.config.output_hidden_states: + return pred_score, predictions[1], predictions[-1] + else: + return pred_score + + \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/bert/modeling_utils.py b/custom_mesh_graphormer/modeling/bert/modeling_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fcfab434a6ce1f7d77f34c47e4756724029663c9 --- /dev/null +++ b/custom_mesh_graphormer/modeling/bert/modeling_utils.py @@ -0,0 +1 @@ +from transformers.modeling_utils import * \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy b/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy new file mode 100644 index 0000000000000000000000000000000000000000..c15c7c4294d859ee037404876073a969c0da5524 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/J_regressor_extra.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:40dfaa71fcc7eed6966a6ed046311b7e8ea0eb9a5172b298e3df6fc4b6ec0eb0 +size 771808 diff --git a/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy b/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy new file mode 100644 index 0000000000000000000000000000000000000000..dff7bedc5d08289a308299a6c82df39484e4b62b --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/J_regressor_h36m_correct.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1835d64133d5f66bd80a814ab1c1dc0900ef01950f568320acf5f9390c1f2c8c +size 937168 diff --git a/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl b/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl new file mode 100644 index 0000000000000000000000000000000000000000..94bb3a9fbdbbc001b985f6fe36cb290773cc55b2 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/MANO_LEFT.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b61cdb94a33582d626456752515624d7c558b5adcc997d13fb422963b5f791ed +size 3447713 diff --git a/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl b/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl new file mode 100644 index 0000000000000000000000000000000000000000..bc0e78214e1408bb7c8b9aa058fd0910fa08d248 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/MANO_RIGHT.pkl @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9e3fb9ac790637539011258e211415d3bcae8daa2759c86b9046f4b371f0c423 +size 3447679 diff --git a/custom_mesh_graphormer/modeling/data/README.md b/custom_mesh_graphormer/modeling/data/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e7cfc083291ab8ec837f5f63f2d04643a33659b8 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/README.md @@ -0,0 +1,30 @@ + +# Extra data +Adapted from open source project [GraphCMR](https://github.com/nkolot/GraphCMR/) and [Pose2Mesh](https://github.com/hongsukchoi/Pose2Mesh_RELEASE) + +Our code requires additional data to run smoothly. + +### J_regressor_extra.npy +Joints regressor for joints or landmarks that are not included in the standard set of SMPL joints. + +### J_regressor_h36m_correct.npy +Joints regressor reflecting the Human3.6M joints. + +### mesh_downsampling.npz +Extra file with precomputed downsampling for the SMPL body mesh. + +### mano_downsampling.npz +Extra file with precomputed downsampling for the MANO hand mesh. + +### basicModel_neutral_lbs_10_207_0_v1.0.0.pkl +SMPL neutral model. Please visit the official website to download the file [http://smplify.is.tue.mpg.de/](http://smplify.is.tue.mpg.de/) + +### basicModel_m_lbs_10_207_0_v1.0.0.pkl +SMPL male model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) + +### basicModel_f_lbs_10_207_0_v1.0.0.pkl +SMPL female model. Please visit the official website to download the file [https://smpl.is.tue.mpg.de/](https://smpl.is.tue.mpg.de/) + +### MANO_RIGHT.pkl +MANO hand model. Please visit the official website to download the file [https://mano.is.tue.mpg.de/](https://mano.is.tue.mpg.de/) + diff --git a/custom_mesh_graphormer/modeling/data/config.py b/custom_mesh_graphormer/modeling/data/config.py new file mode 100644 index 0000000000000000000000000000000000000000..b8eb40f7d5130616173489cbf8697d76da504c29 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/config.py @@ -0,0 +1,47 @@ +""" +This file contains definitions of useful data stuctures and the paths +for the datasets and data files necessary to run the code. + +Adapted from opensource project GraphCMR (https://github.com/nkolot/GraphCMR/) and Pose2Mesh (https://github.com/hongsukchoi/Pose2Mesh_RELEASE) + +""" + +from pathlib import Path +folder_path = Path(__file__).parent.parent +JOINT_REGRESSOR_TRAIN_EXTRA = folder_path / 'data/J_regressor_extra.npy' +JOINT_REGRESSOR_H36M_correct = folder_path / 'data/J_regressor_h36m_correct.npy' +SMPL_FILE = folder_path / 'data/basicModel_neutral_lbs_10_207_0_v1.0.0.pkl' +SMPL_Male = folder_path / 'data/basicModel_m_lbs_10_207_0_v1.0.0.pkl' +SMPL_Female = folder_path / 'data/basicModel_f_lbs_10_207_0_v1.0.0.pkl' +SMPL_sampling_matrix = folder_path / 'data/mesh_downsampling.npz' +MANO_FILE = folder_path / 'data/MANO_RIGHT.pkl' +MANO_sampling_matrix = folder_path / 'data/mano_downsampling.npz' + +JOINTS_IDX = [8, 5, 29, 30, 4, 7, 21, 19, 17, 16, 18, 20, 31, 32, 33, 34, 35, 36, 37, 24, 26, 25, 28, 27] + + +""" +We follow the body joint definition, loss functions, and evaluation metrics from +open source project GraphCMR (https://github.com/nkolot/GraphCMR/) + +Each dataset uses different sets of joints. +We use a superset of 24 joints such that we include all joints from every dataset. +If a dataset doesn't provide annotations for a specific joint, we simply ignore it. +The joints used here are: +""" +J24_NAME = ('R_Ankle', 'R_Knee', 'R_Hip', 'L_Hip', 'L_Knee', 'L_Ankle', 'R_Wrist', 'R_Elbow', 'R_Shoulder', 'L_Shoulder', +'L_Elbow','L_Wrist','Neck','Top_of_Head','Pelvis','Thorax','Spine','Jaw','Head','Nose','L_Eye','R_Eye','L_Ear','R_Ear') +H36M_J17_NAME = ( 'Pelvis', 'R_Hip', 'R_Knee', 'R_Ankle', 'L_Hip', 'L_Knee', 'L_Ankle', 'Torso', 'Neck', 'Nose', 'Head', + 'L_Shoulder', 'L_Elbow', 'L_Wrist', 'R_Shoulder', 'R_Elbow', 'R_Wrist') +J24_TO_J14 = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 18] +H36M_J17_TO_J14 = [3, 2, 1, 4, 5, 6, 16, 15, 14, 11, 12, 13, 8, 10] + +""" +We follow the hand joint definition and mesh topology from +open source project Manopth (https://github.com/hassony2/manopth) + +The hand joints used here are: +""" +J_NAME = ('Wrist', 'Thumb_1', 'Thumb_2', 'Thumb_3', 'Thumb_4', 'Index_1', 'Index_2', 'Index_3', 'Index_4', 'Middle_1', +'Middle_2', 'Middle_3', 'Middle_4', 'Ring_1', 'Ring_2', 'Ring_3', 'Ring_4', 'Pinky_1', 'Pinky_2', 'Pinky_3', 'Pinky_4') +ROOT_INDEX = 0 \ No newline at end of file diff --git a/custom_mesh_graphormer/modeling/data/mano_downsampling.npz b/custom_mesh_graphormer/modeling/data/mano_downsampling.npz new file mode 100644 index 0000000000000000000000000000000000000000..1e2197737db3ffaab05713e49784c85c7b83afe9 --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/mano_downsampling.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:db2b23b0ede7c34039f1d8960e2e839552e0f58d4b34a8612b4065d8a47f9c80 +size 176509 diff --git a/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz b/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz new file mode 100644 index 0000000000000000000000000000000000000000..ee14dddb06717c1312b81f3fa74a7ef2ea4357ac --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/mesh_downsampling.npz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5683b656d9acd7f1558db832527beb0cc6b3b45388cf41a7979dab52e2c57477 +size 1720359 diff --git a/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy b/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy new file mode 100644 index 0000000000000000000000000000000000000000..18d0ded99e79da8a8a364e476f527658a8b3ee3a --- /dev/null +++ b/custom_mesh_graphormer/modeling/data/smpl_431_faces.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9dc2cfedd4901e31a9ac3c9f60bcaf7647ef41b067e68ceb15ac194b7b6714ae +size 21128 diff --git a/custom_mesh_graphormer/modeling/hrnet/config/__init__.py b/custom_mesh_graphormer/modeling/hrnet/config/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..59be7491faa1cb598b023802f68e0e15732f962b --- /dev/null +++ b/custom_mesh_graphormer/modeling/hrnet/config/__init__.py @@ -0,0 +1,9 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# ------------------------------------------------------------------------------ + +from .default import _C as config +from .default import update_config +from .models import MODEL_EXTRAS diff --git a/custom_mesh_graphormer/modeling/hrnet/config/default.py b/custom_mesh_graphormer/modeling/hrnet/config/default.py new file mode 100644 index 0000000000000000000000000000000000000000..59b9843467c25b9512497f9704e1873206eebfa4 --- /dev/null +++ b/custom_mesh_graphormer/modeling/hrnet/config/default.py @@ -0,0 +1,138 @@ + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from yacs.config import CfgNode as CN + + +_C = CN() + +_C.OUTPUT_DIR = '' +_C.LOG_DIR = '' +_C.DATA_DIR = '' +_C.GPUS = (0,) +_C.WORKERS = 4 +_C.PRINT_FREQ = 20 +_C.AUTO_RESUME = False +_C.PIN_MEMORY = True +_C.RANK = 0 + +# Cudnn related params +_C.CUDNN = CN() +_C.CUDNN.BENCHMARK = True +_C.CUDNN.DETERMINISTIC = False +_C.CUDNN.ENABLED = True + +# common params for NETWORK +_C.MODEL = CN() +_C.MODEL.NAME = 'cls_hrnet' +_C.MODEL.INIT_WEIGHTS = True +_C.MODEL.PRETRAINED = '' +_C.MODEL.NUM_JOINTS = 17 +_C.MODEL.NUM_CLASSES = 1000 +_C.MODEL.TAG_PER_JOINT = True +_C.MODEL.TARGET_TYPE = 'gaussian' +_C.MODEL.IMAGE_SIZE = [256, 256] # width * height, ex: 192 * 256 +_C.MODEL.HEATMAP_SIZE = [64, 64] # width * height, ex: 24 * 32 +_C.MODEL.SIGMA = 2 +_C.MODEL.EXTRA = CN(new_allowed=True) + +_C.LOSS = CN() +_C.LOSS.USE_OHKM = False +_C.LOSS.TOPK = 8 +_C.LOSS.USE_TARGET_WEIGHT = True +_C.LOSS.USE_DIFFERENT_JOINTS_WEIGHT = False + +# DATASET related params +_C.DATASET = CN() +_C.DATASET.ROOT = '' +_C.DATASET.DATASET = 'mpii' +_C.DATASET.TRAIN_SET = 'train' +_C.DATASET.TEST_SET = 'valid' +_C.DATASET.DATA_FORMAT = 'jpg' +_C.DATASET.HYBRID_JOINTS_TYPE = '' +_C.DATASET.SELECT_DATA = False + +# training data augmentation +_C.DATASET.FLIP = True +_C.DATASET.SCALE_FACTOR = 0.25 +_C.DATASET.ROT_FACTOR = 30 +_C.DATASET.PROB_HALF_BODY = 0.0 +_C.DATASET.NUM_JOINTS_HALF_BODY = 8 +_C.DATASET.COLOR_RGB = False + +# train +_C.TRAIN = CN() + +_C.TRAIN.LR_FACTOR = 0.1 +_C.TRAIN.LR_STEP = [90, 110] +_C.TRAIN.LR = 0.001 + +_C.TRAIN.OPTIMIZER = 'adam' +_C.TRAIN.MOMENTUM = 0.9 +_C.TRAIN.WD = 0.0001 +_C.TRAIN.NESTEROV = False +_C.TRAIN.GAMMA1 = 0.99 +_C.TRAIN.GAMMA2 = 0.0 + +_C.TRAIN.BEGIN_EPOCH = 0 +_C.TRAIN.END_EPOCH = 140 + +_C.TRAIN.RESUME = False +_C.TRAIN.CHECKPOINT = '' + +_C.TRAIN.BATCH_SIZE_PER_GPU = 32 +_C.TRAIN.SHUFFLE = True + +# testing +_C.TEST = CN() + +# size of images for each device +_C.TEST.BATCH_SIZE_PER_GPU = 32 +# Test Model Epoch +_C.TEST.FLIP_TEST = False +_C.TEST.POST_PROCESS = False +_C.TEST.SHIFT_HEATMAP = False + +_C.TEST.USE_GT_BBOX = False + +# nms +_C.TEST.IMAGE_THRE = 0.1 +_C.TEST.NMS_THRE = 0.6 +_C.TEST.SOFT_NMS = False +_C.TEST.OKS_THRE = 0.5 +_C.TEST.IN_VIS_THRE = 0.0 +_C.TEST.COCO_BBOX_FILE = '' +_C.TEST.BBOX_THRE = 1.0 +_C.TEST.MODEL_FILE = '' + +# debug +_C.DEBUG = CN() +_C.DEBUG.DEBUG = False +_C.DEBUG.SAVE_BATCH_IMAGES_GT = False +_C.DEBUG.SAVE_BATCH_IMAGES_PRED = False +_C.DEBUG.SAVE_HEATMAPS_GT = False +_C.DEBUG.SAVE_HEATMAPS_PRED = False + + +def update_config(cfg, config_file): + cfg.defrost() + cfg.merge_from_file(config_file) + cfg.freeze() + + +if __name__ == '__main__': + import sys + with open(sys.argv[1], 'w') as f: + print(_C, file=f) + diff --git a/custom_mesh_graphormer/modeling/hrnet/config/models.py b/custom_mesh_graphormer/modeling/hrnet/config/models.py new file mode 100644 index 0000000000000000000000000000000000000000..7a73bc9ffd00bfe081496174c808edb93f240cfb --- /dev/null +++ b/custom_mesh_graphormer/modeling/hrnet/config/models.py @@ -0,0 +1,47 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Create by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + +# high_resoluton_net related params for classification +POSE_HIGH_RESOLUTION_NET = CN() +POSE_HIGH_RESOLUTION_NET.PRETRAINED_LAYERS = ['*'] +POSE_HIGH_RESOLUTION_NET.STEM_INPLANES = 64 +POSE_HIGH_RESOLUTION_NET.FINAL_CONV_KERNEL = 1 +POSE_HIGH_RESOLUTION_NET.WITH_HEAD = True + +POSE_HIGH_RESOLUTION_NET.STAGE2 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BRANCHES = 2 +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_BLOCKS = [4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE2.NUM_CHANNELS = [32, 64] +POSE_HIGH_RESOLUTION_NET.STAGE2.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE2.FUSE_METHOD = 'SUM' + +POSE_HIGH_RESOLUTION_NET.STAGE3 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BRANCHES = 3 +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_BLOCKS = [4, 4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE3.NUM_CHANNELS = [32, 64, 128] +POSE_HIGH_RESOLUTION_NET.STAGE3.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE3.FUSE_METHOD = 'SUM' + +POSE_HIGH_RESOLUTION_NET.STAGE4 = CN() +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_MODULES = 1 +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BRANCHES = 4 +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +POSE_HIGH_RESOLUTION_NET.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] +POSE_HIGH_RESOLUTION_NET.STAGE4.BLOCK = 'BASIC' +POSE_HIGH_RESOLUTION_NET.STAGE4.FUSE_METHOD = 'SUM' + +MODEL_EXTRAS = { + 'cls_hrnet': POSE_HIGH_RESOLUTION_NET, +} diff --git a/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py b/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py new file mode 100644 index 0000000000000000000000000000000000000000..3388c230c981234b454b14224713720f7db04ce1 --- /dev/null +++ b/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net.py @@ -0,0 +1,523 @@ + + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# Modified by Kevin Lin (keli@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +import code +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i], + momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.classifier = nn.Linear(2048, 1000) + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, + channels, + head_channels[i], + 1, + stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels)-1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i+1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](y_list[i+1]) + \ + self.downsamp_modules[i](y) + + y = self.final_layer(y) + + if torch._C._get_tracing_state(): + y = y.flatten(start_dim=2).mean(dim=2) + else: + y = F.avg_pool2d(y, kernel_size=y.size() + [2:]).view(y.size(0), -1) + + # y = self.classifier(y) + + return y + + def init_weights(self, pretrained='',): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # logger.info( + # '=> loading {} pretrained model {}'.format(k, pretrained)) + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # code.interact(local=locals()) + +def get_cls_net(config, pretrained, **kwargs): + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained=pretrained) + return model diff --git a/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py b/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..6c44c948335aa0de8781814beaa2404ea34e6574 --- /dev/null +++ b/custom_mesh_graphormer/modeling/hrnet/hrnet_cls_net_gridfeat.py @@ -0,0 +1,524 @@ + + +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn) +# Modified by Kevin Lin (keli@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import logging +import functools + +import numpy as np + +import torch +import torch.nn as nn +import torch._utils +import torch.nn.functional as F +import code +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=1, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, + padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, + bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion, + momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(False) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + nn.BatchNorm2d(num_inchannels[i], + momentum=BN_MOMENTUM), + nn.Upsample(scale_factor=2**(j-i), mode='nearest'))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + nn.BatchNorm2d(num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(False))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, cfg, **kwargs): + super(HighResolutionNet, self).__init__() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = nn.BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = cfg['MODEL']['EXTRA']['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + self.stage2_cfg = cfg['MODEL']['EXTRA']['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = cfg['MODEL']['EXTRA']['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = cfg['MODEL']['EXTRA']['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + # Classification Head + self.incre_modules, self.downsamp_modules, \ + self.final_layer = self._make_head(pre_stage_channels) + + self.classifier = nn.Linear(2048, 1000) + + def _make_head(self, pre_stage_channels): + head_block = Bottleneck + head_channels = [32, 64, 128, 256] + + # Increasing the #channels on each resolution + # from C, 2C, 4C, 8C to 128, 256, 512, 1024 + incre_modules = [] + for i, channels in enumerate(pre_stage_channels): + incre_module = self._make_layer(head_block, + channels, + head_channels[i], + 1, + stride=1) + incre_modules.append(incre_module) + incre_modules = nn.ModuleList(incre_modules) + + # downsampling modules + downsamp_modules = [] + for i in range(len(pre_stage_channels)-1): + in_channels = head_channels[i] * head_block.expansion + out_channels = head_channels[i+1] * head_block.expansion + + downsamp_module = nn.Sequential( + nn.Conv2d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=2, + padding=1), + nn.BatchNorm2d(out_channels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + downsamp_modules.append(downsamp_module) + downsamp_modules = nn.ModuleList(downsamp_modules) + + final_layer = nn.Sequential( + nn.Conv2d( + in_channels=head_channels[3] * head_block.expansion, + out_channels=2048, + kernel_size=1, + stride=1, + padding=0 + ), + nn.BatchNorm2d(2048, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True) + ) + + return incre_modules, downsamp_modules, final_layer + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + nn.BatchNorm2d( + num_channels_cur_layer[i], momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + nn.BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + nn.BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage4(x_list) + + # Classification Head + y = self.incre_modules[0](y_list[0]) + for i in range(len(self.downsamp_modules)): + y = self.incre_modules[i+1](y_list[i+1]) + \ + self.downsamp_modules[i](y) + + yy = self.final_layer(y) + + if torch._C._get_tracing_state(): + yy = yy.flatten(start_dim=2).mean(dim=2) + else: + yy = F.avg_pool2d(yy, kernel_size=yy.size() + [2:]).view(yy.size(0), -1) + + # y = self.classifier(y) + return yy, y + + + + def init_weights(self, pretrained='',): + logger.info('=> init weights from normal distribution') + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + if os.path.isfile(pretrained): + pretrained_dict = torch.load(pretrained) + logger.info('=> loading pretrained model {}'.format(pretrained)) + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = self.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # logger.info( + # '=> loading {} pretrained model {}'.format(k, pretrained)) + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + self.load_state_dict(model_dict) + # code.interact(local=locals()) + +def get_cls_net_gridfeat(config, pretrained, **kwargs): + model = HighResolutionNet(config, **kwargs) + model.init_weights(pretrained=pretrained) + return model diff --git a/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py b/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py new file mode 100644 index 0000000000000000000000000000000000000000..83685c4eb09f4ef31b64b0070b514113456e6f9a --- /dev/null +++ b/custom_mesh_graphormer/tools/run_gphmer_bodymesh.py @@ -0,0 +1,750 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Training and evaluation codes for +3D human body mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer +from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network +from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh +from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import custom_mesh_graphormer.modeling.data.config as cfg +from custom_mesh_graphormer.datasets.build import make_data_loader + +from custom_mesh_graphormer.utils.logger import setup_logger +from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger +from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test +from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection + +from comfy.model_management import get_torch_device +device = get_torch_device() + +from azureml.core.run import Run +aml_run = Run.get_context() + +def save_checkpoint(model, args, epoch, iteration, num_trial=10): + checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( + epoch, iteration)) + if not is_main_process(): + return checkpoint_dir + mkdir(checkpoint_dir) + model_to_save = model.module if hasattr(model, 'module') else model + for i in range(num_trial): + try: + torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) + torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) + torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) + logger.info("Save checkpoint to {}".format(checkpoint_dir)) + break + except: + pass + else: + logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) + return checkpoint_dir + +def save_scores(args, split, mpjpe, pampjpe, mpve): + eval_log = [] + res = {} + res['mPJPE'] = mpjpe + res['PAmPJPE'] = pampjpe + res['mPVE'] = mpve + eval_log.append(res) + with open(op.join(args.output_dir, split+'_eval_logs.json'), 'w') as f: + json.dump(eval_log, f) + logger.info("Save eval scores to {}".format(args.output_dir)) + return + +def adjust_learning_rate(optimizer, epoch, args): + """ + Sets the learning rate to the initial LR decayed by x every y epochs + x = 0.1, y = args.num_train_epochs/2.0 = 100 + """ + lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def mean_per_joint_position_error(pred, gt, has_3d_joints): + """ + Compute mPJPE + """ + gt = gt[has_3d_joints == 1] + gt = gt[:, :, :-1] + pred = pred[has_3d_joints == 1] + + with torch.no_grad(): + gt_pelvis = (gt[:, 2,:] + gt[:, 3,:]) / 2 + gt = gt - gt_pelvis[:, None, :] + pred_pelvis = (pred[:, 2,:] + pred[:, 3,:]) / 2 + pred = pred - pred_pelvis[:, None, :] + error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + return error + +def mean_per_vertex_error(pred, gt, has_smpl): + """ + Compute mPVE + """ + pred = pred[has_smpl == 1] + gt = gt[has_smpl == 1] + with torch.no_grad(): + error = torch.sqrt( ((pred - gt) ** 2).sum(dim=-1)).mean(dim=-1).cpu().numpy() + return error + +def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): + """ + Compute 2D reprojection loss if 2D keypoint annotations are available. + The confidence (conf) is binary and indicates whether the keypoints exist or not. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() + return loss + +def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d, device): + """ + Compute 3D keypoint loss if 3D keypoint annotations are available. + """ + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() + gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] + conf = conf[has_pose_3d == 1] + pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] + if len(gt_keypoints_3d) > 0: + gt_pelvis = (gt_keypoints_3d[:, 2,:] + gt_keypoints_3d[:, 3,:]) / 2 + gt_keypoints_3d = gt_keypoints_3d - gt_pelvis[:, None, :] + pred_pelvis = (pred_keypoints_3d[:, 2,:] + pred_keypoints_3d[:, 3,:]) / 2 + pred_keypoints_3d = pred_keypoints_3d - pred_pelvis[:, None, :] + return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, device): + """ + Compute per-vertex loss if vertex annotations are available. + """ + pred_vertices_with_shape = pred_vertices[has_smpl == 1] + gt_vertices_with_shape = gt_vertices[has_smpl == 1] + if len(gt_vertices_with_shape) > 0: + return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def rectify_pose(pose): + pose = pose.copy() + R_mod = cv2.Rodrigues(np.array([np.pi, 0, 0]))[0] + R_root = cv2.Rodrigues(pose[:3])[0] + new_root = R_root.dot(R_mod) + pose[:3] = cv2.Rodrigues(new_root)[0].reshape(3) + return pose + +def run(args, train_dataloader, val_dataloader, Graphormer_model, smpl, mesh_sampler, renderer): + smpl.eval() + max_iter = len(train_dataloader) + iters_per_epoch = max_iter // args.num_train_epochs + if iters_per_epoch<1000: + args.logging_steps = 500 + + optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=0) + + # define loss function (criterion) and optimizer + criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + + logger.info( + ' '.join( + ['Local rank: {o}', 'Max iteration: {a}', 'iters_per_epoch: {b}','num_train_epochs: {c}',] + ).format(o=args.local_rank, a=max_iter, b=iters_per_epoch, c=args.num_train_epochs) + ) + + start_training_time = time.time() + end = time.time() + Graphormer_model.train() + batch_time = AverageMeter() + data_time = AverageMeter() + log_losses = AverageMeter() + log_loss_2djoints = AverageMeter() + log_loss_3djoints = AverageMeter() + log_loss_vertices = AverageMeter() + log_eval_metrics = EvalMetricsLogger() + + for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): + # gc.collect() + # torch.cuda.empty_cache() + Graphormer_model.train() + iteration += 1 + epoch = iteration // iters_per_epoch + batch_size = images.size(0) + adjust_learning_rate(optimizer, epoch, args) + data_time.update(time.time() - end) + + images = images.to(device) + gt_2d_joints = annotations['joints_2d'].to(device) + gt_2d_joints = gt_2d_joints[:,cfg.J24_TO_J14,:] + has_2d_joints = annotations['has_2d_joints'].to(device) + + gt_3d_joints = annotations['joints_3d'].to(device) + gt_3d_pelvis = gt_3d_joints[:,cfg.J24_NAME.index('Pelvis'),:3] + gt_3d_joints = gt_3d_joints[:,cfg.J24_TO_J14,:] + gt_3d_joints[:,:,:3] = gt_3d_joints[:,:,:3] - gt_3d_pelvis[:, None, :] + has_3d_joints = annotations['has_3d_joints'].to(device) + + gt_pose = annotations['pose'].to(device) + gt_betas = annotations['betas'].to(device) + has_smpl = annotations['has_smpl'].to(device) + mjm_mask = annotations['mjm_mask'].to(device) + mvm_mask = annotations['mvm_mask'].to(device) + + # generate simplified mesh + gt_vertices = smpl(gt_pose, gt_betas) + gt_vertices_sub2 = mesh_sampler.downsample(gt_vertices, n1=0, n2=2) + gt_vertices_sub = mesh_sampler.downsample(gt_vertices) + + # normalize gt based on smpl's pelvis + gt_smpl_3d_joints = smpl.get_h36m_joints(gt_vertices) + gt_smpl_3d_pelvis = gt_smpl_3d_joints[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + gt_vertices_sub2 = gt_vertices_sub2 - gt_smpl_3d_pelvis[:, None, :] + + # prepare masks for mask vertex/joint modeling + mjm_mask_ = mjm_mask.expand(-1,-1,2051) + mvm_mask_ = mvm_mask.expand(-1,-1,2051) + meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices = Graphormer_model(images, smpl, mesh_sampler, meta_masks=meta_masks, is_train=True) + + # normalize gt based on smpl's pelvis + gt_vertices_sub = gt_vertices_sub - gt_smpl_3d_pelvis[:, None, :] + gt_vertices = gt_vertices - gt_smpl_3d_pelvis[:, None, :] + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) + pred_2d_joints = orthographic_projection(pred_3d_joints, pred_camera) + + # compute 3d joint loss (where the joints are directly output from transformer) + loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints, has_3d_joints, args.device) + # compute 3d vertex loss + loss_vertices = ( args.vloss_w_sub2 * vertices_loss(criterion_vertices, pred_vertices_sub2, gt_vertices_sub2, has_smpl, args.device) + \ + args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_smpl, args.device) + \ + args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl, args.device) ) + # compute 3d joint loss (where the joints are regressed from full mesh) + loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_smpl, gt_3d_joints, has_3d_joints, args.device) + # compute 2d joint loss + loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ + keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_smpl, gt_2d_joints, has_2d_joints) + + loss_3d_joints = loss_3d_joints + loss_reg_3d_joints + + # we empirically use hyperparameters to balance difference losses + loss = args.joints_loss_weight*loss_3d_joints + \ + args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints + + # update logs + log_loss_2djoints.update(loss_2d_joints.item(), batch_size) + log_loss_3djoints.update(loss_3d_joints.item(), batch_size) + log_loss_vertices.update(loss_vertices.item(), batch_size) + log_losses.update(loss.item(), batch_size) + + # back prop + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if iteration % args.logging_steps == 0 or iteration == max_iter: + eta_seconds = batch_time.avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + logger.info( + ' '.join( + ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] + ).format(eta=eta_string, ep=epoch, iter=iteration, + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) + + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( + log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, + optimizer.param_groups[0]['lr']) + ) + + aml_run.log(name='Loss', value=float(log_losses.avg)) + aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) + aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) + aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) + + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_smpl.detach()) + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + if is_main_process()==True: + stamp = str(epoch) + '_' + str(iteration) + temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + aml_run.log_image(name='visual results', path=temp_fname) + + if iteration % iters_per_epoch == 0: + val_mPVE, val_mPJPE, val_PAmPJPE, val_count = run_validate(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + epoch, + smpl, + mesh_sampler) + aml_run.log(name='mPVE', value=float(1000*val_mPVE)) + aml_run.log(name='mPJPE', value=float(1000*val_mPJPE)) + aml_run.log(name='PAmPJPE', value=float(1000*val_PAmPJPE)) + logger.info( + ' '.join(['Validation', 'epoch: {ep}',]).format(ep=epoch) + + ' mPVE: {:6.2f}, mPJPE: {:6.2f}, PAmPJPE: {:6.2f}, Data Count: {:6.2f}'.format(1000*val_mPVE, 1000*val_mPJPE, 1000*val_PAmPJPE, val_count) + ) + + if val_PAmPJPE0: + mPVE.update(np.mean(error_vertices), int(torch.sum(has_smpl)) ) + if len(error_joints)>0: + mPJPE.update(np.mean(error_joints), int(torch.sum(has_3d_joints)) ) + if len(error_joints_pa)>0: + PAmPJPE.update(np.mean(error_joints_pa), int(torch.sum(has_3d_joints)) ) + + val_mPVE = all_gather(float(mPVE.avg)) + val_mPVE = sum(val_mPVE)/len(val_mPVE) + val_mPJPE = all_gather(float(mPJPE.avg)) + val_mPJPE = sum(val_mPJPE)/len(val_mPJPE) + + val_PAmPJPE = all_gather(float(PAmPJPE.avg)) + val_PAmPJPE = sum(val_PAmPJPE)/len(val_PAmPJPE) + + val_count = all_gather(float(mPVE.count)) + val_count = sum(val_count) + + return val_mPVE, val_mPJPE, val_PAmPJPE, val_count + + +def visualize_mesh( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(14)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_test( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d, + PAmPJPE_h36m_j14): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(14)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + score = PAmPJPE_h36m_j14[i] + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--data_dir", default='datasets', type=str, required=False, + help="Directory with all datasets, each in one subfolder") + parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, + help="Yaml file with all data for training.") + parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, + help="Yaml file with all data for validation.") + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + ######################################################### + # Training parameters + ######################################################### + parser.add_argument("--per_gpu_train_batch_size", default=30, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=30, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, + help="The initial lr.") + parser.add_argument("--num_train_epochs", default=200, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--vertices_loss_weight", default=100.0, type=float) + parser.add_argument("--joints_loss_weight", default=1000.0, type=float) + parser.add_argument("--vloss_w_full", default=0.33, type=float) + parser.add_argument("--vloss_w_sub", default=0.33, type=float) + parser.add_argument("--vloss_w_sub2", default=0.33, type=float) + parser.add_argument("--drop_out", default=0.1, type=float, + help="Drop out ratio in BERT.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") + parser.add_argument("--interm_size_scale", default=2, type=int) + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=False, action='store_true',) + parser.add_argument('--logging_steps', type=int, default=1000, + help="Log every X steps.") + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + parser.add_argument("--local_rank", type=int, default=0, + help="For distributed training.") + + + args = parser.parse_args() + return args + + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + if args.distributed: + print("Init distributed training on local rank {} ({}), rank {}, world size {}".format(args.local_rank, int(os.environ["LOCAL_RANK"]), int(os.environ["NODE_RANK"]), args.num_gpus)) + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend='nccl', init_method='env://' + ) + local_rank = int(os.environ["LOCAL_RANK"]) + args.device = torch.device("cuda", local_rank) + synchronize() + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + smpl = SMPL().to(args.device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=smpl.faces.cpu().numpy()) + + # Load model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.hidden_dropout_prob = args.drop_out + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*args.interm_size_scale) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + + # init ImageNet pre-trained backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) + + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + states = torch.load(args.resume_checkpoint) + # states = checkpoint_loaded.state_dict() + for k, v in states.items(): + states[k] = v.cpu() + # del checkpoint_loaded + _model.load_state_dict(states, strict=False) + del states + gc.collect() + torch.cuda.empty_cache() + + + _model.to(args.device) + logger.info("Training parameters %s", args) + + if args.run_eval_only==True: + val_dataloader = make_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run_eval_general(args, val_dataloader, _model, smpl, mesh_sampler) + + else: + train_dataloader = make_data_loader(args, args.train_yaml, + args.distributed, is_train=True, scale_factor=args.img_scale_factor) + val_dataloader = make_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run(args, train_dataloader, val_dataloader, _model, smpl, mesh_sampler, renderer) + + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py b/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..7d67fc92740d2a34312d5a91f4facfe3ab1dd914 --- /dev/null +++ b/custom_mesh_graphormer/tools/run_gphmer_bodymesh_inference.py @@ -0,0 +1,351 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +End-to-end inference codes for +3D human body mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer +from custom_mesh_graphormer.modeling.bert import Graphormer_Body_Network as Graphormer_Network +from custom_mesh_graphormer.modeling._smpl import SMPL, Mesh +from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import custom_mesh_graphormer.modeling.data.config as cfg +from custom_mesh_graphormer.datasets.build import make_data_loader + +from custom_mesh_graphormer.utils.logger import setup_logger +from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from custom_mesh_graphormer.utils.metric_logger import AverageMeter, EvalMetricsLogger +from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text +from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection + +from PIL import Image +from torchvision import transforms + +from comfy.model_management import get_torch_device +device = get_torch_device() + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + +transform_visualize = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor()]) + +def run_inference(args, image_list, Graphormer_model, smpl, renderer, mesh_sampler): + # switch to evaluate mode + Graphormer_model.eval() + smpl.eval() + with torch.no_grad(): + for image_file in image_list: + if 'pred' not in image_file: + att_all = [] + img = Image.open(image_file) + img_tensor = transform(img) + img_visual = transform_visualize(img) + + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub2, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, smpl, mesh_sampler) + + # obtain 3d joints from full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + + pred_3d_pelvis = pred_3d_joints_from_smpl[:,cfg.H36M_J17_NAME.index('Pelvis'),:] + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + # save attantion + att_max_value = att[-1] + att_cpu = np.asarray(att_max_value.cpu().detach()) + att_all.append(att_cpu) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_smpl = smpl.get_h36m_joints(pred_vertices) + pred_3d_joints_from_smpl = pred_3d_joints_from_smpl[:,cfg.H36M_J17_TO_J14,:] + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_smpl = orthographic_projection(pred_3d_joints_from_smpl, pred_camera) + pred_2d_431_vertices_from_smpl = orthographic_projection(pred_vertices_sub2, pred_camera) + visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], + pred_vertices[0].detach(), + pred_camera.detach()) + # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], + # pred_vertices[0].detach(), + # pred_vertices_sub2[0].detach(), + # pred_2d_431_vertices_from_smpl[0].detach(), + # pred_2d_joints_from_smpl[0].detach(), + # pred_camera.detach(), + # att[-1][0].detach()) + + visual_imgs = visual_imgs_output.transpose(1,2,0) + visual_imgs = np.asarray(visual_imgs) + + temp_fname = image_file[:-4] + '_graphormer_pred.jpg' + print('save to ', temp_fname) + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + + return + +def visualize_mesh( renderer, images, + pred_vertices_full, + pred_camera): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + cam = pred_camera.cpu().numpy() + # Visualize only mesh reconstruction + rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def visualize_mesh_and_attention( renderer, images, + pred_vertices_full, + pred_vertices, + pred_2d_vertices, + pred_2d_joints, + pred_camera, + attention): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + vertices = pred_vertices.cpu().numpy() + vertices_2d = pred_2d_vertices.cpu().numpy() + joints_2d = pred_2d_joints.cpu().numpy() + cam = pred_camera.cpu().numpy() + att = attention.cpu().numpy() + # Visualize reconstruction and attention + rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + parser.add_argument("--image_file_or_path", default='./samples/human-body', type=str, + help="test data") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='body', type=str, help="body or hand") + parser.add_argument("--interm_size_scale", default=2, type=int) + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=True, action='store_true',) + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + + args = parser.parse_args() + return args + + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + smpl = SMPL().to(args.device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=smpl.faces.cpu().numpy()) + + # Load model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*args.interm_size_scale) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # init ImageNet pre-trained backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-2]) + + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder, mesh_sampler) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + states = torch.load(args.resume_checkpoint) + # states = checkpoint_loaded.state_dict() + for k, v in states.items(): + states[k] = v.cpu() + # del checkpoint_loaded + _model.load_state_dict(states, strict=False) + del states + gc.collect() + torch.cuda.empty_cache() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + logger.info("Run inference") + + image_list = [] + if not args.image_file_or_path: + raise ValueError("image_file_or_path not specified") + if op.isfile(args.image_file_or_path): + image_list = [args.image_file_or_path] + elif op.isdir(args.image_file_or_path): + # should be a path with images only + for filename in os.listdir(args.image_file_or_path): + if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: + image_list.append(args.image_file_or_path+'/'+filename) + else: + raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) + + run_inference(args, image_list, _model, smpl, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/custom_mesh_graphormer/tools/run_gphmer_handmesh.py b/custom_mesh_graphormer/tools/run_gphmer_handmesh.py new file mode 100644 index 0000000000000000000000000000000000000000..5971f7bb8f457f7f59943b1b98e3c4645f639398 --- /dev/null +++ b/custom_mesh_graphormer/tools/run_gphmer_handmesh.py @@ -0,0 +1,713 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Training and evaluation codes for +3D hand mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer +from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from custom_mesh_graphormer.modeling._mano import MANO, Mesh +from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import custom_mesh_graphormer.modeling.data.config as cfg +from custom_mesh_graphormer.datasets.build import make_hand_data_loader + +from custom_mesh_graphormer.utils.logger import setup_logger +from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from custom_mesh_graphormer.utils.metric_logger import AverageMeter +from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction, visualize_reconstruction_test, visualize_reconstruction_no_text +from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection + +from comfy.model_management import get_torch_device +device = get_torch_device() + +from azureml.core.run import Run +aml_run = Run.get_context() + +def save_checkpoint(model, args, epoch, iteration, num_trial=10): + checkpoint_dir = op.join(args.output_dir, 'checkpoint-{}-{}'.format( + epoch, iteration)) + if not is_main_process(): + return checkpoint_dir + mkdir(checkpoint_dir) + model_to_save = model.module if hasattr(model, 'module') else model + for i in range(num_trial): + try: + torch.save(model_to_save, op.join(checkpoint_dir, 'model.bin')) + torch.save(model_to_save.state_dict(), op.join(checkpoint_dir, 'state_dict.bin')) + torch.save(args, op.join(checkpoint_dir, 'training_args.bin')) + logger.info("Save checkpoint to {}".format(checkpoint_dir)) + break + except: + pass + else: + logger.info("Failed to save checkpoint after {} trails.".format(num_trial)) + return checkpoint_dir + +def adjust_learning_rate(optimizer, epoch, args): + """ + Sets the learning rate to the initial LR decayed by x every y epochs + x = 0.1, y = args.num_train_epochs/2.0 = 100 + """ + lr = args.lr * (0.1 ** (epoch // (args.num_train_epochs/2.0) )) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def keypoint_2d_loss(criterion_keypoints, pred_keypoints_2d, gt_keypoints_2d, has_pose_2d): + """ + Compute 2D reprojection loss if 2D keypoint annotations are available. + The confidence is binary and indicates whether the keypoints exist or not. + """ + conf = gt_keypoints_2d[:, :, -1].unsqueeze(-1).clone() + loss = (conf * criterion_keypoints(pred_keypoints_2d, gt_keypoints_2d[:, :, :-1])).mean() + return loss + +def keypoint_3d_loss(criterion_keypoints, pred_keypoints_3d, gt_keypoints_3d, has_pose_3d): + """ + Compute 3D keypoint loss if 3D keypoint annotations are available. + """ + conf = gt_keypoints_3d[:, :, -1].unsqueeze(-1).clone() + gt_keypoints_3d = gt_keypoints_3d[:, :, :-1].clone() + gt_keypoints_3d = gt_keypoints_3d[has_pose_3d == 1] + conf = conf[has_pose_3d == 1] + pred_keypoints_3d = pred_keypoints_3d[has_pose_3d == 1] + if len(gt_keypoints_3d) > 0: + gt_root = gt_keypoints_3d[:, 0,:] + gt_keypoints_3d = gt_keypoints_3d - gt_root[:, None, :] + pred_root = pred_keypoints_3d[:, 0,:] + pred_keypoints_3d = pred_keypoints_3d - pred_root[:, None, :] + return (conf * criterion_keypoints(pred_keypoints_3d, gt_keypoints_3d)).mean() + else: + return torch.FloatTensor(1).fill_(0.).to(device) + +def vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_smpl): + """ + Compute per-vertex loss if vertex annotations are available. + """ + pred_vertices_with_shape = pred_vertices[has_smpl == 1] + gt_vertices_with_shape = gt_vertices[has_smpl == 1] + if len(gt_vertices_with_shape) > 0: + return criterion_vertices(pred_vertices_with_shape, gt_vertices_with_shape) + else: + return torch.FloatTensor(1).fill_(0.).to(device) + + +def run(args, train_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): + + max_iter = len(train_dataloader) + iters_per_epoch = max_iter // args.num_train_epochs + + optimizer = torch.optim.Adam(params=list(Graphormer_model.parameters()), + lr=args.lr, + betas=(0.9, 0.999), + weight_decay=0) + + # define loss function (criterion) and optimizer + criterion_2d_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + + start_training_time = time.time() + end = time.time() + Graphormer_model.train() + batch_time = AverageMeter() + data_time = AverageMeter() + log_losses = AverageMeter() + log_loss_2djoints = AverageMeter() + log_loss_3djoints = AverageMeter() + log_loss_vertices = AverageMeter() + + for iteration, (img_keys, images, annotations) in enumerate(train_dataloader): + + Graphormer_model.train() + iteration += 1 + epoch = iteration // iters_per_epoch + batch_size = images.size(0) + adjust_learning_rate(optimizer, epoch, args) + data_time.update(time.time() - end) + + images = images.to(device) + gt_2d_joints = annotations['joints_2d'].to(device) + gt_pose = annotations['pose'].to(device) + gt_betas = annotations['betas'].to(device) + has_mesh = annotations['has_smpl'].to(device) + has_3d_joints = has_mesh + has_2d_joints = has_mesh + mjm_mask = annotations['mjm_mask'].to(device) + mvm_mask = annotations['mvm_mask'].to(device) + + # generate mesh + gt_vertices, gt_3d_joints = mano_model.layer(gt_pose, gt_betas) + gt_vertices = gt_vertices/1000.0 + gt_3d_joints = gt_3d_joints/1000.0 + + gt_vertices_sub = mesh_sampler.downsample(gt_vertices) + # normalize gt based on hand's wrist + gt_3d_root = gt_3d_joints[:,cfg.J_NAME.index('Wrist'),:] + gt_vertices = gt_vertices - gt_3d_root[:, None, :] + gt_vertices_sub = gt_vertices_sub - gt_3d_root[:, None, :] + gt_3d_joints = gt_3d_joints - gt_3d_root[:, None, :] + gt_3d_joints_with_tag = torch.ones((batch_size,gt_3d_joints.shape[1],4)).to(device) + gt_3d_joints_with_tag[:,:,:3] = gt_3d_joints + + # prepare masks for mask vertex/joint modeling + mjm_mask_ = mjm_mask.expand(-1,-1,2051) + mvm_mask_ = mvm_mask.expand(-1,-1,2051) + meta_masks = torch.cat([mjm_mask_, mvm_mask_], dim=1) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler, meta_masks=meta_masks, is_train=True) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + + # obtain 2d joints, which are projected from 3d joints of smpl mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + pred_2d_joints = orthographic_projection(pred_3d_joints.contiguous(), pred_camera.contiguous()) + + # compute 3d joint loss (where the joints are directly output from transformer) + loss_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints, gt_3d_joints_with_tag, has_3d_joints) + + # compute 3d vertex loss + loss_vertices = ( args.vloss_w_sub * vertices_loss(criterion_vertices, pred_vertices_sub, gt_vertices_sub, has_mesh) + \ + args.vloss_w_full * vertices_loss(criterion_vertices, pred_vertices, gt_vertices, has_mesh) ) + + # compute 3d joint loss (where the joints are regressed from full mesh) + loss_reg_3d_joints = keypoint_3d_loss(criterion_keypoints, pred_3d_joints_from_mesh, gt_3d_joints_with_tag, has_3d_joints) + # compute 2d joint loss + loss_2d_joints = keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints, gt_2d_joints, has_2d_joints) + \ + keypoint_2d_loss(criterion_2d_keypoints, pred_2d_joints_from_mesh, gt_2d_joints, has_2d_joints) + + loss_3d_joints = loss_3d_joints + loss_reg_3d_joints + + # we empirically use hyperparameters to balance difference losses + loss = args.joints_loss_weight*loss_3d_joints + \ + args.vertices_loss_weight*loss_vertices + args.vertices_loss_weight*loss_2d_joints + + # update logs + log_loss_2djoints.update(loss_2d_joints.item(), batch_size) + log_loss_3djoints.update(loss_3d_joints.item(), batch_size) + log_loss_vertices.update(loss_vertices.item(), batch_size) + log_losses.update(loss.item(), batch_size) + + # back prop + optimizer.zero_grad() + loss.backward() + optimizer.step() + + batch_time.update(time.time() - end) + end = time.time() + + if iteration % args.logging_steps == 0 or iteration == max_iter: + eta_seconds = batch_time.avg * (max_iter - iteration) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + logger.info( + ' '.join( + ['eta: {eta}', 'epoch: {ep}', 'iter: {iter}', 'max mem : {memory:.0f}',] + ).format(eta=eta_string, ep=epoch, iter=iteration, + memory=torch.cuda.max_memory_allocated() / 1024.0 / 1024.0) + + ' loss: {:.4f}, 2d joint loss: {:.4f}, 3d joint loss: {:.4f}, vertex loss: {:.4f}, compute: {:.4f}, data: {:.4f}, lr: {:.6f}'.format( + log_losses.avg, log_loss_2djoints.avg, log_loss_3djoints.avg, log_loss_vertices.avg, batch_time.avg, data_time.avg, + optimizer.param_groups[0]['lr']) + ) + + aml_run.log(name='Loss', value=float(log_losses.avg)) + aml_run.log(name='3d joint Loss', value=float(log_loss_3djoints.avg)) + aml_run.log(name='2d joint Loss', value=float(log_loss_2djoints.avg)) + aml_run.log(name='vertex Loss', value=float(log_loss_vertices.avg)) + + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_mesh.detach()) + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + if is_main_process()==True: + stamp = str(epoch) + '_' + str(iteration) + temp_fname = args.output_dir + 'visual_' + stamp + '.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + aml_run.log_image(name='visual results', path=temp_fname) + + if iteration % iters_per_epoch == 0: + if epoch%10==0: + checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) + + total_training_time = time.time() - start_training_time + total_time_str = str(datetime.timedelta(seconds=total_training_time)) + logger.info('Total training time: {} ({:.4f} s / iter)'.format( + total_time_str, total_training_time / max_iter) + ) + checkpoint_dir = save_checkpoint(Graphormer_model, args, epoch, iteration) + +def run_eval_and_save(args, split, val_dataloader, Graphormer_model, mano_model, renderer, mesh_sampler): + + criterion_keypoints = torch.nn.MSELoss(reduction='none').to(device) + criterion_vertices = torch.nn.L1Loss().to(device) + + if args.distributed: + Graphormer_model = torch.nn.parallel.DistributedDataParallel( + Graphormer_model, device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=True, + ) + Graphormer_model.eval() + + if args.aml_eval==True: + run_aml_inference_hand_mesh(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + 0, + mano_model, mesh_sampler, + renderer, split) + else: + run_inference_hand_mesh(args, val_dataloader, + Graphormer_model, + criterion_keypoints, + criterion_vertices, + 0, + mano_model, mesh_sampler, + renderer, split) + checkpoint_dir = save_checkpoint(Graphormer_model, args, 0, 0) + return + +def run_aml_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): + # switch to evaluate mode + Graphormer_model.eval() + fname_output_save = [] + mesh_output_save = [] + joint_output_save = [] + world_size = get_world_size() + with torch.no_grad(): + for i, (img_keys, images, annotations) in enumerate(val_loader): + batch_size = images.size(0) + # compute output + images = images.to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + + for j in range(batch_size): + fname_output_save.append(img_keys[j]) + pred_vertices_list = pred_vertices[j].tolist() + mesh_output_save.append(pred_vertices_list) + pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() + joint_output_save.append(pred_3d_joints_from_mesh_list) + + if world_size > 1: + torch.distributed.barrier() + print('save results to pred.json') + output_json_file = 'pred.json' + print('save results to ', output_json_file) + with open(output_json_file, 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + azure_ckpt_name = '200' # args.resume_checkpoint.split('/')[-2].split('-')[1] + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + output_zip_file = args.output_dir + 'ckpt' + azure_ckpt_name + '-' + inference_setting +'-pred.zip' + + resolved_submit_cmd = 'zip ' + output_zip_file + ' ' + output_json_file + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm %s'%(output_json_file) + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + if world_size > 1: + torch.distributed.barrier() + + return + +def run_inference_hand_mesh(args, val_loader, Graphormer_model, criterion, criterion_vertices, epoch, mano_model, mesh_sampler, renderer, split): + # switch to evaluate mode + Graphormer_model.eval() + fname_output_save = [] + mesh_output_save = [] + joint_output_save = [] + with torch.no_grad(): + for i, (img_keys, images, annotations) in enumerate(val_loader): + batch_size = images.size(0) + # compute output + images = images.to(device) + + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices = Graphormer_model(images, mano_model, mesh_sampler) + + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] + pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + for j in range(batch_size): + fname_output_save.append(img_keys[j]) + pred_vertices_list = pred_vertices[j].tolist() + mesh_output_save.append(pred_vertices_list) + pred_3d_joints_from_mesh_list = pred_3d_joints_from_mesh[j].tolist() + joint_output_save.append(pred_3d_joints_from_mesh_list) + + if i%20==0: + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano_model.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + visual_imgs = visualize_mesh( renderer, + annotations['ori_img'].detach(), + annotations['joints_2d'].detach(), + pred_vertices.detach(), + pred_camera.detach(), + pred_2d_joints_from_mesh.detach()) + + visual_imgs = visual_imgs.transpose(0,1) + visual_imgs = visual_imgs.transpose(1,2) + visual_imgs = np.asarray(visual_imgs) + + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + temp_fname = args.output_dir + args.resume_checkpoint[0:-9] + 'freihand_results_'+inference_setting+'_batch'+str(i)+'.jpg' + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + + print('save results to pred.json') + with open('pred.json', 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + run_exp_name = args.resume_checkpoint.split('/')[-3] + run_ckpt_name = args.resume_checkpoint.split('/')[-2].split('-')[1] + inference_setting = 'sc%02d_rot%s'%(int(args.sc*10),str(int(args.rot))) + resolved_submit_cmd = 'zip ' + args.output_dir + run_exp_name + '-ckpt'+ run_ckpt_name + '-' + inference_setting +'-pred.zip ' + 'pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + return + +def visualize_mesh( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(21)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_test( renderer, + images, + gt_keypoints_2d, + pred_vertices, + pred_camera, + pred_keypoints_2d, + PAmPJPE): + """Tensorboard logging.""" + gt_keypoints_2d = gt_keypoints_2d.cpu().numpy() + to_lsp = list(range(21)) + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 10)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get LSP keypoints from the full list of keypoints + gt_keypoints_2d_ = gt_keypoints_2d[i, to_lsp] + pred_keypoints_2d_ = pred_keypoints_2d.cpu().numpy()[i, to_lsp] + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + score = PAmPJPE[i] + # Visualize reconstruction and detected pose + rend_img = visualize_reconstruction_test(img, 224, gt_keypoints_2d_, vertices, pred_keypoints_2d_, cam, renderer, score) + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def visualize_mesh_no_text( renderer, + images, + pred_vertices, + pred_camera): + """Tensorboard logging.""" + rend_imgs = [] + batch_size = pred_vertices.shape[0] + # Do visualization for the first 6 images of the batch + for i in range(min(batch_size, 1)): + img = images[i].cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices = pred_vertices[i].cpu().numpy() + cam = pred_camera[i].cpu().numpy() + # Visualize reconstruction only + rend_img = visualize_reconstruction_no_text(img, 224, vertices, cam, renderer, color='hand') + rend_img = rend_img.transpose(2,0,1) + rend_imgs.append(torch.from_numpy(rend_img)) + rend_imgs = make_grid(rend_imgs, nrow=1) + return rend_imgs + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--data_dir", default='datasets', type=str, required=False, + help="Directory with all datasets, each in one subfolder") + parser.add_argument("--train_yaml", default='imagenet2012/train.yaml', type=str, required=False, + help="Yaml file with all data for training.") + parser.add_argument("--val_yaml", default='imagenet2012/test.yaml', type=str, required=False, + help="Yaml file with all data for validation.") + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + ######################################################### + # Training parameters + ######################################################### + parser.add_argument("--per_gpu_train_batch_size", default=64, type=int, + help="Batch size per GPU/CPU for training.") + parser.add_argument("--per_gpu_eval_batch_size", default=64, type=int, + help="Batch size per GPU/CPU for evaluation.") + parser.add_argument('--lr', "--learning_rate", default=1e-4, type=float, + help="The initial lr.") + parser.add_argument("--num_train_epochs", default=200, type=int, + help="Total number of training epochs to perform.") + parser.add_argument("--vertices_loss_weight", default=1.0, type=float) + parser.add_argument("--joints_loss_weight", default=1.0, type=float) + parser.add_argument("--vloss_w_full", default=0.5, type=float) + parser.add_argument("--vloss_w_sub", default=0.5, type=float) + parser.add_argument("--drop_out", default=0.1, type=float, + help="Drop out ratio in BERT.") + ######################################################### + # Model architectures + ######################################################### + parser.add_argument("--num_hidden_layers", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=-1, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") + + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=False, action='store_true',) + parser.add_argument("--multiscale_inference", default=False, action='store_true',) + # if enable "multiscale_inference", dataloader will apply transformations to the test image based on + # the rotation "rot" and scale "sc" parameters below + parser.add_argument("--rot", default=0, type=float) + parser.add_argument("--sc", default=1.0, type=float) + parser.add_argument("--aml_eval", default=False, action='store_true',) + + parser.add_argument('--logging_steps', type=int, default=100, + help="Log every X steps.") + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + parser.add_argument("--local_rank", type=int, default=0, + help="For distributed training.") + args = parser.parse_args() + return args + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + args.distributed = args.num_gpus > 1 + args.device = torch.device(args.device) + if args.distributed: + print("Init distributed training on local rank {}".format(args.local_rank)) + torch.cuda.set_device(args.local_rank) + torch.distributed.init_process_group( + backend='nccl', init_method='env://' + ) + synchronize() + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and SMPL utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.hidden_dropout_prob = args.drop_out + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + torch.cuda.empty_cache() + + _model.to(args.device) + logger.info("Training parameters %s", args) + + if args.run_eval_only==True: + val_dataloader = make_hand_data_loader(args, args.val_yaml, + args.distributed, is_train=False, scale_factor=args.img_scale_factor) + run_eval_and_save(args, 'freihand', val_dataloader, _model, mano_model, renderer, mesh_sampler) + + else: + train_dataloader = make_hand_data_loader(args, args.train_yaml, + args.distributed, is_train=True, scale_factor=args.img_scale_factor) + run(args, train_dataloader, _model, mano_model, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py b/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7300a77dee4cfc11160abfbc19e2acf89d26ed --- /dev/null +++ b/custom_mesh_graphormer/tools/run_gphmer_handmesh_inference.py @@ -0,0 +1,338 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +End-to-end inference codes for +3D hand mesh reconstruction from an image +""" + +from __future__ import absolute_import, division, print_function +import argparse +import os +import os.path as op +import code +import json +import time +import datetime +import torch +import torchvision.models as models +from torchvision.utils import make_grid +import gc +import numpy as np +import cv2 +from custom_mesh_graphormer.modeling.bert import BertConfig, Graphormer +from custom_mesh_graphormer.modeling.bert import Graphormer_Hand_Network as Graphormer_Network +from custom_mesh_graphormer.modeling._mano import MANO, Mesh +from custom_mesh_graphormer.modeling.hrnet.hrnet_cls_net_gridfeat import get_cls_net_gridfeat +from custom_mesh_graphormer.modeling.hrnet.config import config as hrnet_config +from custom_mesh_graphormer.modeling.hrnet.config import update_config as hrnet_update_config +import custom_mesh_graphormer.modeling.data.config as cfg +from custom_mesh_graphormer.datasets.build import make_hand_data_loader + +from custom_mesh_graphormer.utils.logger import setup_logger +from custom_mesh_graphormer.utils.comm import synchronize, is_main_process, get_rank, get_world_size, all_gather +from custom_mesh_graphormer.utils.miscellaneous import mkdir, set_seed +from custom_mesh_graphormer.utils.metric_logger import AverageMeter +from custom_mesh_graphormer.utils.renderer import Renderer, visualize_reconstruction_and_att_local, visualize_reconstruction_no_text +from custom_mesh_graphormer.utils.metric_pampjpe import reconstruction_error +from custom_mesh_graphormer.utils.geometric_layers import orthographic_projection + +from PIL import Image +from torchvision import transforms + +from comfy.model_management import get_torch_device +device = get_torch_device() + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225])]) + +transform_visualize = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), + transforms.ToTensor()]) + +def run_inference(args, image_list, Graphormer_model, mano, renderer, mesh_sampler): +# switch to evaluate mode + Graphormer_model.eval() + mano.eval() + with torch.no_grad(): + for image_file in image_list: + if 'pred' not in image_file: + att_all = [] + print(image_file) + img = Image.open(image_file) + img_tensor = transform(img) + img_visual = transform_visualize(img) + + batch_imgs = torch.unsqueeze(img_tensor, 0).to(device) + batch_visual_imgs = torch.unsqueeze(img_visual, 0).to(device) + # forward-pass + pred_camera, pred_3d_joints, pred_vertices_sub, pred_vertices, hidden_states, att = Graphormer_model(batch_imgs, mano, mesh_sampler) + # obtain 3d joints from full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + pred_3d_pelvis = pred_3d_joints_from_mesh[:,cfg.J_NAME.index('Wrist'),:] + pred_3d_joints_from_mesh = pred_3d_joints_from_mesh - pred_3d_pelvis[:, None, :] + pred_vertices = pred_vertices - pred_3d_pelvis[:, None, :] + + # save attantion + att_max_value = att[-1] + att_cpu = np.asarray(att_max_value.cpu().detach()) + att_all.append(att_cpu) + + # obtain 3d joints, which are regressed from the full mesh + pred_3d_joints_from_mesh = mano.get_3d_joints(pred_vertices) + # obtain 2d joints, which are projected from 3d joints of mesh + pred_2d_joints_from_mesh = orthographic_projection(pred_3d_joints_from_mesh.contiguous(), pred_camera.contiguous()) + pred_2d_coarse_vertices_from_mesh = orthographic_projection(pred_vertices_sub.contiguous(), pred_camera.contiguous()) + + + visual_imgs_output = visualize_mesh( renderer, batch_visual_imgs[0], + pred_vertices[0].detach(), + pred_camera.detach()) + # visual_imgs_output = visualize_mesh_and_attention( renderer, batch_visual_imgs[0], + # pred_vertices[0].detach(), + # pred_vertices_sub[0].detach(), + # pred_2d_coarse_vertices_from_mesh[0].detach(), + # pred_2d_joints_from_mesh[0].detach(), + # pred_camera.detach(), + # att[-1][0].detach()) + visual_imgs = visual_imgs_output.transpose(1,2,0) + visual_imgs = np.asarray(visual_imgs) + + temp_fname = image_file[:-4] + '_graphormer_pred.jpg' + print('save to ', temp_fname) + cv2.imwrite(temp_fname, np.asarray(visual_imgs[:,:,::-1]*255)) + return + +def visualize_mesh( renderer, images, + pred_vertices_full, + pred_camera): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + cam = pred_camera.cpu().numpy() + # Visualize only mesh reconstruction + rend_img = visualize_reconstruction_no_text(img, 224, vertices_full, cam, renderer, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def visualize_mesh_and_attention( renderer, images, + pred_vertices_full, + pred_vertices, + pred_2d_vertices, + pred_2d_joints, + pred_camera, + attention): + img = images.cpu().numpy().transpose(1,2,0) + # Get predict vertices for the particular example + vertices_full = pred_vertices_full.cpu().numpy() + vertices = pred_vertices.cpu().numpy() + vertices_2d = pred_2d_vertices.cpu().numpy() + joints_2d = pred_2d_joints.cpu().numpy() + cam = pred_camera.cpu().numpy() + att = attention.cpu().numpy() + # Visualize reconstruction and attention + rend_img = visualize_reconstruction_and_att_local(img, 224, vertices_full, vertices, vertices_2d, cam, renderer, joints_2d, att, color='light_blue') + rend_img = rend_img.transpose(2,0,1) + return rend_img + +def parse_args(): + parser = argparse.ArgumentParser() + ######################################################### + # Data related arguments + ######################################################### + parser.add_argument("--num_workers", default=4, type=int, + help="Workers in dataloader.") + parser.add_argument("--img_scale_factor", default=1, type=int, + help="adjust image resolution.") + parser.add_argument("--image_file_or_path", default='./samples/hand', type=str, + help="test data") + ######################################################### + # Loading/saving checkpoints + ######################################################### + parser.add_argument("--model_name_or_path", default='src/modeling/bert/bert-base-uncased/', type=str, required=False, + help="Path to pre-trained transformer model or model type.") + parser.add_argument("--resume_checkpoint", default=None, type=str, required=False, + help="Path to specific checkpoint for resume training.") + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + parser.add_argument("--config_name", default="", type=str, + help="Pretrained config name or path if not the same as model_name.") + parser.add_argument('-a', '--arch', default='hrnet-w64', + help='CNN backbone architecture: hrnet-w64, hrnet, resnet50') + ######################################################### + # Model architectures + ######################################################### + parser.add_argument("--num_hidden_layers", default=4, type=int, required=False, + help="Update model config if given") + parser.add_argument("--hidden_size", default=-1, type=int, required=False, + help="Update model config if given") + parser.add_argument("--num_attention_heads", default=4, type=int, required=False, + help="Update model config if given. Note that the division of " + "hidden_size / num_attention_heads should be in integer.") + parser.add_argument("--intermediate_size", default=-1, type=int, required=False, + help="Update model config if given.") + parser.add_argument("--input_feat_dim", default='2051,512,128', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--hidden_feat_dim", default='1024,256,64', type=str, + help="The Image Feature Dimension.") + parser.add_argument("--which_gcn", default='0,0,1', type=str, + help="which encoder block to have graph conv. Encoder1, Encoder2, Encoder3. Default: only Encoder3 has graph conv") + parser.add_argument("--mesh_type", default='hand', type=str, help="body or hand") + + ######################################################### + # Others + ######################################################### + parser.add_argument("--run_eval_only", default=True, action='store_true',) + parser.add_argument("--device", type=str, default='cuda', + help="cuda or cpu") + parser.add_argument('--seed', type=int, default=88, + help="random seed for initialization.") + args = parser.parse_args() + return args + +def main(args): + global logger + # Setup CUDA, GPU & distributed training + args.num_gpus = int(os.environ['WORLD_SIZE']) if 'WORLD_SIZE' in os.environ else 1 + os.environ['OMP_NUM_THREADS'] = str(args.num_workers) + print('set os.environ[OMP_NUM_THREADS] to {}'.format(os.environ['OMP_NUM_THREADS'])) + + mkdir(args.output_dir) + logger = setup_logger("Graphormer", args.output_dir, get_rank()) + set_seed(args.seed, args.num_gpus) + logger.info("Using {} GPUs".format(args.num_gpus)) + + # Mesh and MANO utils + mano_model = MANO().to(args.device) + mano_model.layer = mano_model.layer.to(device) + mesh_sampler = Mesh() + + # Renderer for visualization + renderer = Renderer(faces=mano_model.face) + + # Load pretrained model + trans_encoder = [] + + input_feat_dim = [int(item) for item in args.input_feat_dim.split(',')] + hidden_feat_dim = [int(item) for item in args.hidden_feat_dim.split(',')] + output_feat_dim = input_feat_dim[1:] + [3] + + # which encoder block to have graph convs + which_blk_graph = [int(item) for item in args.which_gcn.split(',')] + + if args.run_eval_only==True and args.resume_checkpoint!=None and args.resume_checkpoint!='None' and 'state_dict' not in args.resume_checkpoint: + # if only run eval, load checkpoint + logger.info("Evaluation: Loading from checkpoint {}".format(args.resume_checkpoint)) + _model = torch.load(args.resume_checkpoint) + + else: + # init three transformer-encoder blocks in a loop + for i in range(len(output_feat_dim)): + config_class, model_class = BertConfig, Graphormer + config = config_class.from_pretrained(args.config_name if args.config_name \ + else args.model_name_or_path) + + config.output_attentions = False + config.img_feature_dim = input_feat_dim[i] + config.output_feature_dim = output_feat_dim[i] + args.hidden_size = hidden_feat_dim[i] + args.intermediate_size = int(args.hidden_size*2) + + if which_blk_graph[i]==1: + config.graph_conv = True + logger.info("Add Graph Conv") + else: + config.graph_conv = False + + config.mesh_type = args.mesh_type + + # update model structure if specified in arguments + update_params = ['num_hidden_layers', 'hidden_size', 'num_attention_heads', 'intermediate_size'] + for idx, param in enumerate(update_params): + arg_param = getattr(args, param) + config_param = getattr(config, param) + if arg_param > 0 and arg_param != config_param: + logger.info("Update config parameter {}: {} -> {}".format(param, config_param, arg_param)) + setattr(config, param, arg_param) + + # init a transformer encoder and append it to a list + assert config.hidden_size % config.num_attention_heads == 0 + model = model_class(config=config) + logger.info("Init model from scratch.") + trans_encoder.append(model) + + # create backbone model + if args.arch=='hrnet': + hrnet_yaml = 'models/hrnet/cls_hrnet_w40_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w40_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w40 model') + elif args.arch=='hrnet-w64': + hrnet_yaml = 'models/hrnet/cls_hrnet_w64_sgd_lr5e-2_wd1e-4_bs32_x100.yaml' + hrnet_checkpoint = 'models/hrnet/hrnetv2_w64_imagenet_pretrained.pth' + hrnet_update_config(hrnet_config, hrnet_yaml) + backbone = get_cls_net_gridfeat(hrnet_config, pretrained=hrnet_checkpoint) + logger.info('=> loading hrnet-v2-w64 model') + else: + print("=> using pre-trained model '{}'".format(args.arch)) + backbone = models.__dict__[args.arch](pretrained=True) + # remove the last fc layer + backbone = torch.nn.Sequential(*list(backbone.children())[:-1]) + + trans_encoder = torch.nn.Sequential(*trans_encoder) + total_params = sum(p.numel() for p in trans_encoder.parameters()) + logger.info('Graphormer encoders total parameters: {}'.format(total_params)) + backbone_total_params = sum(p.numel() for p in backbone.parameters()) + logger.info('Backbone total parameters: {}'.format(backbone_total_params)) + + # build end-to-end Graphormer network (CNN backbone + multi-layer Graphormer encoder) + _model = Graphormer_Network(args, config, backbone, trans_encoder) + + if args.resume_checkpoint!=None and args.resume_checkpoint!='None': + # for fine-tuning or resume training or inference, load weights from checkpoint + logger.info("Loading state dict from checkpoint {}".format(args.resume_checkpoint)) + # workaround approach to load sparse tensor in graph conv. + state_dict = torch.load(args.resume_checkpoint) + _model.load_state_dict(state_dict, strict=False) + del state_dict + gc.collect() + torch.cuda.empty_cache() + + # update configs to enable attention outputs + setattr(_model.trans_encoder[-1].config,'output_attentions', True) + setattr(_model.trans_encoder[-1].config,'output_hidden_states', True) + _model.trans_encoder[-1].bert.encoder.output_attentions = True + _model.trans_encoder[-1].bert.encoder.output_hidden_states = True + for iter_layer in range(4): + _model.trans_encoder[-1].bert.encoder.layer[iter_layer].attention.self.output_attentions = True + for inter_block in range(3): + setattr(_model.trans_encoder[-1].config,'device', args.device) + + _model.to(args.device) + logger.info("Run inference") + + image_list = [] + if not args.image_file_or_path: + raise ValueError("image_file_or_path not specified") + if op.isfile(args.image_file_or_path): + image_list = [args.image_file_or_path] + elif op.isdir(args.image_file_or_path): + # should be a path with images only + for filename in os.listdir(args.image_file_or_path): + if filename.endswith(".png") or filename.endswith(".jpg") and 'pred' not in filename: + image_list.append(args.image_file_or_path+'/'+filename) + else: + raise ValueError("Cannot find images at {}".format(args.image_file_or_path)) + + run_inference(args, image_list, _model, mano_model, renderer, mesh_sampler) + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/custom_mesh_graphormer/tools/run_hand_multiscale.py b/custom_mesh_graphormer/tools/run_hand_multiscale.py new file mode 100644 index 0000000000000000000000000000000000000000..d656756b7c347d0c55bedf8b359f89b5f60ee2fb --- /dev/null +++ b/custom_mesh_graphormer/tools/run_hand_multiscale.py @@ -0,0 +1,136 @@ +from __future__ import absolute_import, division, print_function + +import argparse +import os +import os.path as op +import code +import json +import zipfile +import torch +import numpy as np +from custom_mesh_graphormer.utils.metric_pampjpe import get_alignMesh + + +def load_pred_json(filepath): + archive = zipfile.ZipFile(filepath, 'r') + jsondata = archive.read('pred.json') + reference = json.loads(jsondata.decode("utf-8")) + return reference[0], reference[1] + + +def multiscale_fusion(output_dir): + s = '10' + filepath = output_dir+'ckpt200-sc10_rot0-pred.zip' + ref_joints, ref_vertices = load_pred_json(filepath) + ref_joints_array = np.asarray(ref_joints) + ref_vertices_array = np.asarray(ref_vertices) + + rotations = [0.0] + for i in range(1,10): + rotations.append(i*10) + rotations.append(i*-10) + + scale = [0.7,0.8,0.9,1.0,1.1] + multiscale_joints = [] + multiscale_vertices = [] + + counter = 0 + for s in scale: + for r in rotations: + setting = 'sc%02d_rot%s'%(int(s*10),str(int(r))) + filepath = output_dir+'ckpt200-'+setting+'-pred.zip' + joints, vertices = load_pred_json(filepath) + joints_array = np.asarray(joints) + vertices_array = np.asarray(vertices) + + pa_joint_error, pa_joint_array, _ = get_alignMesh(joints_array, ref_joints_array, reduction=None) + pa_vertices_error, pa_vertices_array, _ = get_alignMesh(vertices_array, ref_vertices_array, reduction=None) + print('--------------------------') + print('scale:', s, 'rotate', r) + print('PAMPJPE:', 1000*np.mean(pa_joint_error)) + print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) + multiscale_joints.append(pa_joint_array) + multiscale_vertices.append(pa_vertices_array) + counter = counter + 1 + + overall_joints_array = ref_joints_array.copy() + overall_vertices_array = ref_vertices_array.copy() + for i in range(counter): + overall_joints_array += multiscale_joints[i] + overall_vertices_array += multiscale_vertices[i] + + overall_joints_array /= (1+counter) + overall_vertices_array /= (1+counter) + pa_joint_error, pa_joint_array, _ = get_alignMesh(overall_joints_array, ref_joints_array, reduction=None) + pa_vertices_error, pa_vertices_array, _ = get_alignMesh(overall_vertices_array, ref_vertices_array, reduction=None) + print('--------------------------') + print('overall:') + print('PAMPJPE:', 1000*np.mean(pa_joint_error)) + print('PAMPVPE:', 1000*np.mean(pa_vertices_error)) + + joint_output_save = overall_joints_array.tolist() + mesh_output_save = overall_vertices_array.tolist() + + print('save results to pred.json') + with open('pred.json', 'w') as f: + json.dump([joint_output_save, mesh_output_save], f) + + + filepath = output_dir+'ckpt200-multisc-pred.zip' + resolved_submit_cmd = 'zip ' + filepath + ' ' + 'pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + resolved_submit_cmd = 'rm pred.json' + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + + +def run_multiscale_inference(model_path, mode, output_dir): + + if mode==True: + rotations = [0.0] + for i in range(1,10): + rotations.append(i*10) + rotations.append(i*-10) + scale = [0.7,0.8,0.9,1.0,1.1] + else: + rotations = [0.0] + scale = [1.0] + + job_cmd = "python ./src/tools/run_gphmer_handmesh.py " \ + "--val_yaml freihand_v3/test.yaml " \ + "--resume_checkpoint %s " \ + "--per_gpu_eval_batch_size 32 --run_eval_only --num_worker 2 " \ + "--multiscale_inference " \ + "--rot %f " \ + "--sc %s " \ + "--arch hrnet-w64 " \ + "--num_hidden_layers 4 " \ + "--num_attention_heads 4 " \ + "--input_feat_dim 2051,512,128 " \ + "--hidden_feat_dim 1024,256,64 " \ + "--output_dir %s" + + for s in scale: + for r in rotations: + resolved_submit_cmd = job_cmd%(model_path, r, s, output_dir) + print(resolved_submit_cmd) + os.system(resolved_submit_cmd) + +def main(args): + model_path = args.model_path + mode = args.multiscale_inference + output_dir = args.output_dir + run_multiscale_inference(model_path, mode, output_dir) + if mode==True: + multiscale_fusion(output_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluate a checkpoint in the folder") + parser.add_argument("--model_path") + parser.add_argument("--multiscale_inference", default=False, action='store_true',) + parser.add_argument("--output_dir", default='output/', type=str, required=False, + help="The output directory to save checkpoint and test results.") + args = parser.parse_args() + main(args) diff --git a/custom_mesh_graphormer/utils/__init__.py b/custom_mesh_graphormer/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/custom_mesh_graphormer/utils/comm.py b/custom_mesh_graphormer/utils/comm.py new file mode 100644 index 0000000000000000000000000000000000000000..ae05350639005b43ee8c96803d0d545442a24968 --- /dev/null +++ b/custom_mesh_graphormer/utils/comm.py @@ -0,0 +1,176 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +This file contains primitives for multi-gpu communication. +This is useful when doing distributed training. +""" + +import pickle +import time + +import torch +import torch.distributed as dist + +from comfy.model_management import get_torch_device +device = get_torch_device() + + +def get_world_size(): + if not dist.is_available(): + return 1 + if not dist.is_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not dist.is_available(): + return 0 + if not dist.is_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def synchronize(): + """ + Helper function to synchronize (barrier) among all processes when + using distributed training + """ + if not dist.is_available(): + return + if not dist.is_initialized(): + return + world_size = dist.get_world_size() + if world_size == 1: + return + dist.barrier() + + +def gather_on_master(data): + """Same as all_gather, but gathers data on master process only, using CPU. + Thus, this does not work with NCCL backend unless they add CPU support. + + The memory consumption of this function is ~ 3x of data size. While in + principal, it should be ~2x, it's not easy to force Python to release + memory immediately and thus, peak memory usage could be up to 3x. + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + # trying to optimize memory, but in fact, it's not guaranteed to be released + del data + storage = torch.ByteStorage.from_buffer(buffer) + del buffer + tensor = torch.ByteTensor(storage) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]) + size_list = [torch.LongTensor([0]) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)) + tensor = torch.cat((tensor, padding), dim=0) + del padding + + if is_main_process(): + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,))) + dist.gather(tensor, gather_list=tensor_list, dst=0) + del tensor + else: + dist.gather(tensor, gather_list=[], dst=0) + del tensor + return + + data_list = [] + for tensor in tensor_list: + buffer = tensor.cpu().numpy().tobytes() + del tensor + data_list.append(pickle.loads(buffer)) + del buffer + + return data_list + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to(device) + + # obtain Tensor size of each rank + local_size = torch.LongTensor([tensor.numel()]).to(device) + size_list = [torch.LongTensor([0]).to(device) for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.ByteTensor(size=(max_size,)).to(device)) + if local_size != max_size: + padding = torch.ByteTensor(size=(max_size - local_size,)).to(device) + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that process with rank + 0 has the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.reduce(values, dst=0) + if dist.get_rank() == 0 and average: + # only main process gets accumulated, so only divide by + # world_size in this case + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict diff --git a/custom_mesh_graphormer/utils/dataset_utils.py b/custom_mesh_graphormer/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cb66451e456e4fd8591ef83b34db1d9192cb23cc --- /dev/null +++ b/custom_mesh_graphormer/utils/dataset_utils.py @@ -0,0 +1,66 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +""" + + +import os +import os.path as op +import numpy as np +import base64 +import cv2 +import yaml +from collections import OrderedDict + + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except: + return None + + +def load_labelmap(labelmap_file): + label_dict = None + if labelmap_file is not None and op.isfile(labelmap_file): + label_dict = OrderedDict() + with open(labelmap_file, 'r') as fp: + for line in fp: + label = line.strip().split('\t')[0] + if label in label_dict: + raise ValueError("Duplicate label " + label + " in labelmap.") + else: + label_dict[label] = len(label_dict) + return label_dict + + +def load_shuffle_file(shuf_file): + shuf_list = None + if shuf_file is not None: + with open(shuf_file, 'r') as fp: + shuf_list = [] + for i in fp: + shuf_list.append(int(i.strip())) + return shuf_list + + +def load_box_shuffle_file(shuf_file): + if shuf_file is not None: + with open(shuf_file, 'r') as fp: + img_shuf_list = [] + box_shuf_list = [] + for i in fp: + idx = [int(_) for _ in i.strip().split('\t')] + img_shuf_list.append(idx[0]) + box_shuf_list.append(idx[1]) + return [img_shuf_list, box_shuf_list] + return None + + +def load_from_yaml_file(file_name): + with open(file_name, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) diff --git a/custom_mesh_graphormer/utils/geometric_layers.py b/custom_mesh_graphormer/utils/geometric_layers.py new file mode 100644 index 0000000000000000000000000000000000000000..ad4bf9d10eb089e0887a38501dc8b78f36fc2553 --- /dev/null +++ b/custom_mesh_graphormer/utils/geometric_layers.py @@ -0,0 +1,58 @@ +""" +Useful geometric operations, e.g. Orthographic projection and a differentiable Rodrigues formula + +Parts of the code are taken from https://github.com/MandyMo/pytorch_HMR +""" +import torch + +def rodrigues(theta): + """Convert axis-angle representation to rotation matrix. + Args: + theta: size = [B, 3] + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + l1norm = torch.norm(theta + 1e-8, p = 2, dim = 1) + angle = torch.unsqueeze(l1norm, -1) + normalized = torch.div(theta, angle) + angle = angle * 0.5 + v_cos = torch.cos(angle) + v_sin = torch.sin(angle) + quat = torch.cat([v_cos, v_sin * normalized], dim = 1) + return quat2mat(quat) + +def quat2mat(quat): + """Convert quaternion coefficients to rotation matrix. + Args: + quat: size = [B, 4] 4 <===>(w, x, y, z) + Returns: + Rotation matrix corresponding to the quaternion -- size = [B, 3, 3] + """ + norm_quat = quat + norm_quat = norm_quat/norm_quat.norm(p=2, dim=1, keepdim=True) + w, x, y, z = norm_quat[:,0], norm_quat[:,1], norm_quat[:,2], norm_quat[:,3] + + B = quat.size(0) + + w2, x2, y2, z2 = w.pow(2), x.pow(2), y.pow(2), z.pow(2) + wx, wy, wz = w*x, w*y, w*z + xy, xz, yz = x*y, x*z, y*z + + rotMat = torch.stack([w2 + x2 - y2 - z2, 2*xy - 2*wz, 2*wy + 2*xz, + 2*wz + 2*xy, w2 - x2 + y2 - z2, 2*yz - 2*wx, + 2*xz - 2*wy, 2*wx + 2*yz, w2 - x2 - y2 + z2], dim=1).view(B, 3, 3) + return rotMat + +def orthographic_projection(X, camera): + """Perform orthographic projection of 3D points X using the camera parameters + Args: + X: size = [B, N, 3] + camera: size = [B, 3] + Returns: + Projected 2D points -- size = [B, N, 2] + """ + camera = camera.view(-1, 1, 3) + X_trans = X[:, :, :2] + camera[:, :, 1:] + shape = X_trans.shape + X_2d = (camera[:, :, 0] * X_trans.view(shape[0], -1)).view(shape) + return X_2d diff --git a/custom_mesh_graphormer/utils/image_ops.py b/custom_mesh_graphormer/utils/image_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..eb62d8fa2503a24acaa61c7df8cc999f195dfa11 --- /dev/null +++ b/custom_mesh_graphormer/utils/image_ops.py @@ -0,0 +1,208 @@ +""" +Image processing tools + +Modified from open source projects: +(https://github.com/nkolot/GraphCMR/) +(https://github.com/open-mmlab/mmdetection) + +""" + +import numpy as np +import base64 +import cv2 +import torch +import scipy.misc + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except ValueError: + return None + +def myimrotate(img, angle, center=None, scale=1.0, border_value=0, auto_bound=False): + if center is not None and auto_bound: + raise ValueError('`auto_bound` conflicts with `center`') + h, w = img.shape[:2] + if center is None: + center = ((w - 1) * 0.5, (h - 1) * 0.5) + assert isinstance(center, tuple) + + matrix = cv2.getRotationMatrix2D(center, angle, scale) + if auto_bound: + cos = np.abs(matrix[0, 0]) + sin = np.abs(matrix[0, 1]) + new_w = h * sin + w * cos + new_h = h * cos + w * sin + matrix[0, 2] += (new_w - w) * 0.5 + matrix[1, 2] += (new_h - h) * 0.5 + w = int(np.round(new_w)) + h = int(np.round(new_h)) + rotated = cv2.warpAffine(img, matrix, (w, h), borderValue=border_value) + return rotated + +def myimresize(img, size, return_scale=False, interpolation='bilinear'): + + h, w = img.shape[:2] + resized_img = cv2.resize( + img, (size[0],size[1]), interpolation=cv2.INTER_LINEAR) + if not return_scale: + return resized_img + else: + w_scale = size[0] / w + h_scale = size[1] / h + return resized_img, w_scale, h_scale + + +def get_transform(center, scale, res, rot=0): + """Generate transformation matrix.""" + h = 200 * scale + t = np.zeros((3, 3)) + t[0, 0] = float(res[1]) / h + t[1, 1] = float(res[0]) / h + t[0, 2] = res[1] * (-float(center[0]) / h + .5) + t[1, 2] = res[0] * (-float(center[1]) / h + .5) + t[2, 2] = 1 + if not rot == 0: + rot = -rot # To match direction of rotation from cropping + rot_mat = np.zeros((3,3)) + rot_rad = rot * np.pi / 180 + sn,cs = np.sin(rot_rad), np.cos(rot_rad) + rot_mat[0,:2] = [cs, -sn] + rot_mat[1,:2] = [sn, cs] + rot_mat[2,2] = 1 + # Need to rotate around center + t_mat = np.eye(3) + t_mat[0,2] = -res[1]/2 + t_mat[1,2] = -res[0]/2 + t_inv = t_mat.copy() + t_inv[:2,2] *= -1 + t = np.dot(t_inv,np.dot(rot_mat,np.dot(t_mat,t))) + return t + +def transform(pt, center, scale, res, invert=0, rot=0): + """Transform pixel location to different reference.""" + t = get_transform(center, scale, res, rot=rot) + if invert: + # t = np.linalg.inv(t) + t_torch = torch.from_numpy(t) + t_torch = torch.inverse(t_torch) + t = t_torch.numpy() + new_pt = np.array([pt[0]-1, pt[1]-1, 1.]).T + new_pt = np.dot(t, new_pt) + return new_pt[:2].astype(int)+1 + +def crop(img, center, scale, res, rot=0): + """Crop image according to the supplied bounding box.""" + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1, + res[1]+1], center, scale, res, invert=1))-1 + # Padding so that when rotated proper amount of context is included + pad = int(np.linalg.norm(br - ul) / 2 - float(br[1] - ul[1]) / 2) + if not rot == 0: + ul -= pad + br += pad + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(new_shape) + + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], len(img[0])) - ul[0] + new_y = max(0, -ul[1]), min(br[1], len(img)) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(len(img[0]), br[0]) + old_y = max(0, ul[1]), min(len(img), br[1]) + + new_img[new_y[0]:new_y[1], new_x[0]:new_x[1]] = img[old_y[0]:old_y[1], + old_x[0]:old_x[1]] + if not rot == 0: + # Remove padding + # new_img = scipy.misc.imrotate(new_img, rot) + new_img = myimrotate(new_img, rot) + new_img = new_img[pad:-pad, pad:-pad] + + # new_img = scipy.misc.imresize(new_img, res) + new_img = myimresize(new_img, [res[0], res[1]]) + return new_img + +def uncrop(img, center, scale, orig_shape, rot=0, is_rgb=True): + """'Undo' the image cropping/resizing. + This function is used when evaluating mask/part segmentation. + """ + res = img.shape[:2] + # Upper left point + ul = np.array(transform([1, 1], center, scale, res, invert=1))-1 + # Bottom right point + br = np.array(transform([res[0]+1,res[1]+1], center, scale, res, invert=1))-1 + # size of cropped image + crop_shape = [br[1] - ul[1], br[0] - ul[0]] + + new_shape = [br[1] - ul[1], br[0] - ul[0]] + if len(img.shape) > 2: + new_shape += [img.shape[2]] + new_img = np.zeros(orig_shape, dtype=np.uint8) + # Range to fill new array + new_x = max(0, -ul[0]), min(br[0], orig_shape[1]) - ul[0] + new_y = max(0, -ul[1]), min(br[1], orig_shape[0]) - ul[1] + # Range to sample from original image + old_x = max(0, ul[0]), min(orig_shape[1], br[0]) + old_y = max(0, ul[1]), min(orig_shape[0], br[1]) + # img = scipy.misc.imresize(img, crop_shape, interp='nearest') + img = myimresize(img, [crop_shape[0],crop_shape[1]]) + new_img[old_y[0]:old_y[1], old_x[0]:old_x[1]] = img[new_y[0]:new_y[1], new_x[0]:new_x[1]] + return new_img + +def rot_aa(aa, rot): + """Rotate axis angle parameters.""" + # pose parameters + R = np.array([[np.cos(np.deg2rad(-rot)), -np.sin(np.deg2rad(-rot)), 0], + [np.sin(np.deg2rad(-rot)), np.cos(np.deg2rad(-rot)), 0], + [0, 0, 1]]) + # find the rotation of the body in camera frame + per_rdg, _ = cv2.Rodrigues(aa) + # apply the global rotation to the global orientation + resrot, _ = cv2.Rodrigues(np.dot(R,per_rdg)) + aa = (resrot.T)[0] + return aa + +def flip_img(img): + """Flip rgb images or masks. + channels come last, e.g. (256,256,3). + """ + img = np.fliplr(img) + return img + +def flip_kp(kp): + """Flip keypoints.""" + flipped_parts = [5, 4, 3, 2, 1, 0, 11, 10, 9, 8, 7, 6, 12, 13, 14, 15, 16, 17, 18, 19, 21, 20, 23, 22] + kp = kp[flipped_parts] + kp[:,0] = - kp[:,0] + return kp + +def flip_pose(pose): + """Flip pose. + The flipping is based on SMPL parameters. + """ + flippedParts = [0, 1, 2, 6, 7, 8, 3, 4, 5, 9, 10, 11, 15, 16, 17, 12, 13, + 14 ,18, 19, 20, 24, 25, 26, 21, 22, 23, 27, 28, 29, 33, + 34, 35, 30, 31, 32, 36, 37, 38, 42, 43, 44, 39, 40, 41, + 45, 46, 47, 51, 52, 53, 48, 49, 50, 57, 58, 59, 54, 55, + 56, 63, 64, 65, 60, 61, 62, 69, 70, 71, 66, 67, 68] + pose = pose[flippedParts] + # we also negate the second and the third dimension of the axis-angle + pose[1::3] = -pose[1::3] + pose[2::3] = -pose[2::3] + return pose + +def flip_aa(aa): + """Flip axis-angle representation. + We negate the second and the third dimension of the axis-angle. + """ + aa[1] = -aa[1] + aa[2] = -aa[2] + return aa \ No newline at end of file diff --git a/custom_mesh_graphormer/utils/logger.py b/custom_mesh_graphormer/utils/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..013139605fc60634265b269d5260d4075e2b6478 --- /dev/null +++ b/custom_mesh_graphormer/utils/logger.py @@ -0,0 +1,100 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import logging +import os +import sys +from logging import StreamHandler, Handler, getLevelName + + +# this class is a copy of logging.FileHandler except we end self.close() +# at the end of each emit. While closing file and reopening file after each +# write is not efficient, it allows us to see partial logs when writing to +# fused Azure blobs, which is very convenient +class FileHandler(StreamHandler): + """ + A handler class which writes formatted logging records to disk files. + """ + def __init__(self, filename, mode='a', encoding=None, delay=False): + """ + Open the specified file and use it as the stream for logging. + """ + # Issue #27493: add support for Path objects to be passed in + filename = os.fspath(filename) + #keep the absolute path, otherwise derived classes which use this + #may come a cropper when the current directory changes + self.baseFilename = os.path.abspath(filename) + self.mode = mode + self.encoding = encoding + self.delay = delay + if delay: + #We don't open the stream, but we still need to call the + #Handler constructor to set level, formatter, lock etc. + Handler.__init__(self) + self.stream = None + else: + StreamHandler.__init__(self, self._open()) + + def close(self): + """ + Closes the stream. + """ + self.acquire() + try: + try: + if self.stream: + try: + self.flush() + finally: + stream = self.stream + self.stream = None + if hasattr(stream, "close"): + stream.close() + finally: + # Issue #19523: call unconditionally to + # prevent a handler leak when delay is set + StreamHandler.close(self) + finally: + self.release() + + def _open(self): + """ + Open the current base file with the (original) mode and encoding. + Return the resulting stream. + """ + return open(self.baseFilename, self.mode, encoding=self.encoding) + + def emit(self, record): + """ + Emit a record. + + If the stream was not opened because 'delay' was specified in the + constructor, open it before calling the superclass's emit. + """ + if self.stream is None: + self.stream = self._open() + StreamHandler.emit(self, record) + self.close() + + def __repr__(self): + level = getLevelName(self.level) + return '<%s %s (%s)>' % (self.__class__.__name__, self.baseFilename, level) + + +def setup_logger(name, save_dir, distributed_rank, filename="log.txt"): + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + # don't log results for the non-master process + if distributed_rank > 0: + return logger + ch = logging.StreamHandler(stream=sys.stdout) + ch.setLevel(logging.DEBUG) + formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") + ch.setFormatter(formatter) + logger.addHandler(ch) + + if save_dir: + fh = FileHandler(os.path.join(save_dir, filename)) + fh.setLevel(logging.DEBUG) + fh.setFormatter(formatter) + logger.addHandler(fh) + + return logger diff --git a/custom_mesh_graphormer/utils/metric_logger.py b/custom_mesh_graphormer/utils/metric_logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ddaa0ab3ac95314cb94b5cb8c2a76a838f141a59 --- /dev/null +++ b/custom_mesh_graphormer/utils/metric_logger.py @@ -0,0 +1,45 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Basic logger. It Computes and stores the average and current value +""" + +class AverageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + + +class EvalMetricsLogger(object): + + def __init__(self): + self.reset() + + def reset(self): + # define a upper-bound performance (worst case) + # numbers are in unit millimeter + self.PAmPJPE = 100.0/1000.0 + self.mPJPE = 100.0/1000.0 + self.mPVE = 100.0/1000.0 + + self.epoch = 0 + + def update(self, mPVE, mPJPE, PAmPJPE, epoch): + self.PAmPJPE = PAmPJPE + self.mPJPE = mPJPE + self.mPVE = mPVE + self.epoch = epoch diff --git a/custom_mesh_graphormer/utils/metric_pampjpe.py b/custom_mesh_graphormer/utils/metric_pampjpe.py new file mode 100644 index 0000000000000000000000000000000000000000..89fe55b7f8d0973670f3cc141a666528dd19ebd1 --- /dev/null +++ b/custom_mesh_graphormer/utils/metric_pampjpe.py @@ -0,0 +1,99 @@ +""" +Functions for compuing Procrustes alignment and reconstruction error + +Parts of the code are adapted from https://github.com/akanazawa/hmr + +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np + +def compute_similarity_transform(S1, S2): + """Computes a similarity transform (sR, t) that takes + a set of 3D points S1 (3 x N) closest to a set of 3D points S2, + where R is an 3x3 rotation matrix, t 3x1 translation, s scale. + i.e. solves the orthogonal Procrutes problem. + """ + transposed = False + if S1.shape[0] != 3 and S1.shape[0] != 2: + S1 = S1.T + S2 = S2.T + transposed = True + assert(S2.shape[1] == S1.shape[1]) + + # 1. Remove mean. + mu1 = S1.mean(axis=1, keepdims=True) + mu2 = S2.mean(axis=1, keepdims=True) + X1 = S1 - mu1 + X2 = S2 - mu2 + + # 2. Compute variance of X1 used for scale. + var1 = np.sum(X1**2) + + # 3. The outer product of X1 and X2. + K = X1.dot(X2.T) + + # 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are + # singular vectors of K. + U, s, Vh = np.linalg.svd(K) + V = Vh.T + # Construct Z that fixes the orientation of R to get det(R)=1. + Z = np.eye(U.shape[0]) + Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T))) + # Construct R. + R = V.dot(Z.dot(U.T)) + + # 5. Recover scale. + scale = np.trace(R.dot(K)) / var1 + + # 6. Recover translation. + t = mu2 - scale*(R.dot(mu1)) + + # 7. Error: + S1_hat = scale*R.dot(S1) + t + + if transposed: + S1_hat = S1_hat.T + + return S1_hat + +def compute_similarity_transform_batch(S1, S2): + """Batched version of compute_similarity_transform.""" + S1_hat = np.zeros_like(S1) + for i in range(S1.shape[0]): + S1_hat[i] = compute_similarity_transform(S1[i], S2[i]) + return S1_hat + +def reconstruction_error(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re + + +def reconstruction_error_v2(S1, S2, J24_TO_J14, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + S1_hat = S1_hat[:,J24_TO_J14,:] + S2 = S2[:,J24_TO_J14,:] + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re + +def get_alignMesh(S1, S2, reduction='mean'): + """Do Procrustes alignment and compute reconstruction error.""" + S1_hat = compute_similarity_transform_batch(S1, S2) + re = np.sqrt( ((S1_hat - S2)** 2).sum(axis=-1)).mean(axis=-1) + if reduction == 'mean': + re = re.mean() + elif reduction == 'sum': + re = re.sum() + return re, S1_hat, S2 diff --git a/custom_mesh_graphormer/utils/miscellaneous.py b/custom_mesh_graphormer/utils/miscellaneous.py new file mode 100644 index 0000000000000000000000000000000000000000..3de72c69c8fcd4502dc5e8b58656b3ee08db7554 --- /dev/null +++ b/custom_mesh_graphormer/utils/miscellaneous.py @@ -0,0 +1,171 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +import errno +import os +import os.path as op +import re +import logging +import numpy as np +import torch +import random +import shutil +from .comm import is_main_process +import yaml + + +def mkdir(path): + # if it is the current folder, skip. + # otherwise the original code will raise FileNotFoundError + if path == '': + return + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def save_config(cfg, path): + if is_main_process(): + with open(path, 'w') as f: + f.write(cfg.dump()) + + +def config_iteration(output_dir, max_iter): + save_file = os.path.join(output_dir, 'last_checkpoint') + iteration = -1 + if os.path.exists(save_file): + with open(save_file, 'r') as f: + fname = f.read().strip() + model_name = os.path.basename(fname) + model_path = os.path.dirname(fname) + if model_name.startswith('model_') and len(model_name) == 17: + iteration = int(model_name[-11:-4]) + elif model_name == "model_final": + iteration = max_iter + elif model_path.startswith('checkpoint-') and len(model_path) == 18: + iteration = int(model_path.split('-')[-1]) + return iteration + + +def get_matching_parameters(model, regexp, none_on_empty=True): + """Returns parameters matching regular expression""" + if not regexp: + if none_on_empty: + return {} + else: + return dict(model.named_parameters()) + compiled_pattern = re.compile(regexp) + params = {} + for weight_name, weight in model.named_parameters(): + if compiled_pattern.match(weight_name): + params[weight_name] = weight + return params + + +def freeze_weights(model, regexp): + """Freeze weights based on regular expression.""" + logger = logging.getLogger("maskrcnn_benchmark.trainer") + for weight_name, weight in get_matching_parameters(model, regexp).items(): + weight.requires_grad = False + logger.info("Disabled training of {}".format(weight_name)) + + +def unfreeze_weights(model, regexp, backbone_freeze_at=-1, + is_distributed=False): + """Unfreeze weights based on regular expression. + This is helpful during training to unfreeze freezed weights after + other unfreezed weights have been trained for some iterations. + """ + logger = logging.getLogger("maskrcnn_benchmark.trainer") + for weight_name, weight in get_matching_parameters(model, regexp).items(): + weight.requires_grad = True + logger.info("Enabled training of {}".format(weight_name)) + if backbone_freeze_at >= 0: + logger.info("Freeze backbone at stage: {}".format(backbone_freeze_at)) + if is_distributed: + model.module.backbone.body._freeze_backbone(backbone_freeze_at) + else: + model.backbone.body._freeze_backbone(backbone_freeze_at) + + +def delete_tsv_files(tsvs): + for t in tsvs: + if op.isfile(t): + try_delete(t) + line = op.splitext(t)[0] + '.lineidx' + if op.isfile(line): + try_delete(line) + + +def concat_files(ins, out): + mkdir(op.dirname(out)) + out_tmp = out + '.tmp' + with open(out_tmp, 'wb') as fp_out: + for i, f in enumerate(ins): + logging.info('concating {}/{} - {}'.format(i, len(ins), f)) + with open(f, 'rb') as fp_in: + shutil.copyfileobj(fp_in, fp_out, 1024*1024*10) + os.rename(out_tmp, out) + + +def concat_tsv_files(tsvs, out_tsv): + concat_files(tsvs, out_tsv) + sizes = [os.stat(t).st_size for t in tsvs] + sizes = np.cumsum(sizes) + all_idx = [] + for i, t in enumerate(tsvs): + for idx in load_list_file(op.splitext(t)[0] + '.lineidx'): + if i == 0: + all_idx.append(idx) + else: + all_idx.append(str(int(idx) + sizes[i - 1])) + with open(op.splitext(out_tsv)[0] + '.lineidx', 'w') as f: + f.write('\n'.join(all_idx)) + + +def load_list_file(fname): + with open(fname, 'r') as fp: + lines = fp.readlines() + result = [line.strip() for line in lines] + if len(result) > 0 and result[-1] == '': + result = result[:-1] + return result + + +def try_once(func): + def func_wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logging.info('ignore error \n{}'.format(str(e))) + return func_wrapper + + +@try_once +def try_delete(f): + os.remove(f) + + +def set_seed(seed, n_gpu): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if n_gpu > 0: + torch.cuda.manual_seed_all(seed) + + +def print_and_run_cmd(cmd): + print(cmd) + os.system(cmd) + + +def write_to_yaml_file(context, file_name): + with open(file_name, 'w') as fp: + yaml.dump(context, fp, encoding='utf-8') + + +def load_from_yaml_file(yaml_file): + with open(yaml_file, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) + + diff --git a/custom_mesh_graphormer/utils/renderer.py b/custom_mesh_graphormer/utils/renderer.py new file mode 100644 index 0000000000000000000000000000000000000000..3b2f8a96e7c227466a07a19650e75228f5aa6860 --- /dev/null +++ b/custom_mesh_graphormer/utils/renderer.py @@ -0,0 +1,691 @@ +""" +Rendering tools for 3D mesh visualization on 2D image. + +Parts of the code are taken from https://github.com/akanazawa/hmr +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import cv2 +import code +from opendr.camera import ProjectPoints +from opendr.renderer import ColoredRenderer, TexturedRenderer +from opendr.lighting import LambertianPointLight +import random + + +# Rotate the points by a specified angle. +def rotateY(points, angle): + ry = np.array([ + [np.cos(angle), 0., np.sin(angle)], [0., 1., 0.], + [-np.sin(angle), 0., np.cos(angle)] + ]) + return np.dot(points, ry) + +def draw_skeleton(input_image, joints, draw_edges=True, vis=None, radius=None): + """ + joints is 3 x 19. but if not will transpose it. + 0: Right ankle + 1: Right knee + 2: Right hip + 3: Left hip + 4: Left knee + 5: Left ankle + 6: Right wrist + 7: Right elbow + 8: Right shoulder + 9: Left shoulder + 10: Left elbow + 11: Left wrist + 12: Neck + 13: Head top + 14: nose + 15: left_eye + 16: right_eye + 17: left_ear + 18: right_ear + """ + + if radius is None: + radius = max(4, (np.mean(input_image.shape[:2]) * 0.01).astype(int)) + + colors = { + 'pink': (197, 27, 125), # L lower leg + 'light_pink': (233, 163, 201), # L upper leg + 'light_green': (161, 215, 106), # L lower arm + 'green': (77, 146, 33), # L upper arm + 'red': (215, 48, 39), # head + 'light_red': (252, 146, 114), # head + 'light_orange': (252, 141, 89), # chest + 'purple': (118, 42, 131), # R lower leg + 'light_purple': (175, 141, 195), # R upper + 'light_blue': (145, 191, 219), # R lower arm + 'blue': (69, 117, 180), # R upper arm + 'gray': (130, 130, 130), # + 'white': (255, 255, 255), # + } + + image = input_image.copy() + input_is_float = False + + if np.issubdtype(image.dtype, np.float): + input_is_float = True + max_val = image.max() + if max_val <= 2.: # should be 1 but sometimes it's slightly above 1 + image = (image * 255).astype(np.uint8) + else: + image = (image).astype(np.uint8) + + if joints.shape[0] != 2: + joints = joints.T + joints = np.round(joints).astype(int) + + jcolors = [ + 'light_pink', 'light_pink', 'light_pink', 'pink', 'pink', 'pink', + 'light_blue', 'light_blue', 'light_blue', 'blue', 'blue', 'blue', + 'purple', 'purple', 'red', 'green', 'green', 'white', 'white', + 'purple', 'purple', 'red', 'green', 'green', 'white', 'white' + ] + + if joints.shape[1] == 19: + # parent indices -1 means no parents + parents = np.array([ + 1, 2, 8, 9, 3, 4, 7, 8, 12, 12, 9, 10, 14, -1, 13, -1, -1, 15, 16 + ]) + # Left is light and right is dark + ecolors = { + 0: 'light_pink', + 1: 'light_pink', + 2: 'light_pink', + 3: 'pink', + 4: 'pink', + 5: 'pink', + 6: 'light_blue', + 7: 'light_blue', + 8: 'light_blue', + 9: 'blue', + 10: 'blue', + 11: 'blue', + 12: 'purple', + 17: 'light_green', + 18: 'light_green', + 14: 'purple' + } + elif joints.shape[1] == 14: + parents = np.array([ + 1, + 2, + 8, + 9, + 3, + 4, + 7, + 8, + -1, + -1, + 9, + 10, + 13, + -1, + ]) + ecolors = { + 0: 'light_pink', + 1: 'light_pink', + 2: 'light_pink', + 3: 'pink', + 4: 'pink', + 5: 'pink', + 6: 'light_blue', + 7: 'light_blue', + 10: 'light_blue', + 11: 'blue', + 12: 'purple' + } + elif joints.shape[1] == 21: # hand + parents = np.array([ + -1, + 0, + 1, + 2, + 3, + 0, + 5, + 6, + 7, + 0, + 9, + 10, + 11, + 0, + 13, + 14, + 15, + 0, + 17, + 18, + 19, + ]) + ecolors = { + 0: 'light_purple', + 1: 'light_green', + 2: 'light_green', + 3: 'light_green', + 4: 'light_green', + 5: 'pink', + 6: 'pink', + 7: 'pink', + 8: 'pink', + 9: 'light_blue', + 10: 'light_blue', + 11: 'light_blue', + 12: 'light_blue', + 13: 'light_red', + 14: 'light_red', + 15: 'light_red', + 16: 'light_red', + 17: 'purple', + 18: 'purple', + 19: 'purple', + 20: 'purple', + } + else: + print('Unknown skeleton!!') + + for child in range(len(parents)): + point = joints[:, child] + # If invisible skip + if vis is not None and vis[child] == 0: + continue + if draw_edges: + cv2.circle(image, (point[0], point[1]), radius, colors['white'], + -1) + cv2.circle(image, (point[0], point[1]), radius - 1, + colors[jcolors[child]], -1) + else: + # cv2.circle(image, (point[0], point[1]), 5, colors['white'], 1) + cv2.circle(image, (point[0], point[1]), radius - 1, + colors[jcolors[child]], 1) + # cv2.circle(image, (point[0], point[1]), 5, colors['gray'], -1) + pa_id = parents[child] + if draw_edges and pa_id >= 0: + if vis is not None and vis[pa_id] == 0: + continue + point_pa = joints[:, pa_id] + cv2.circle(image, (point_pa[0], point_pa[1]), radius - 1, + colors[jcolors[pa_id]], -1) + if child not in ecolors.keys(): + print('bad') + import ipdb + ipdb.set_trace() + cv2.line(image, (point[0], point[1]), (point_pa[0], point_pa[1]), + colors[ecolors[child]], radius - 2) + + # Convert back in original dtype + if input_is_float: + if max_val <= 1.: + image = image.astype(np.float32) / 255. + else: + image = image.astype(np.float32) + + return image + +def draw_text(input_image, content): + """ + content is a dict. draws key: val on image + Assumes key is str, val is float + """ + image = input_image.copy() + input_is_float = False + if np.issubdtype(image.dtype, np.float): + input_is_float = True + image = (image * 255).astype(np.uint8) + + black = (255, 255, 0) + margin = 15 + start_x = 5 + start_y = margin + for key in sorted(content.keys()): + text = "%s: %.2g" % (key, content[key]) + cv2.putText(image, text, (start_x, start_y), 0, 0.45, black) + start_y += margin + + if input_is_float: + image = image.astype(np.float32) / 255. + return image + +def visualize_reconstruction(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + gt_vis = gt_kp[:, 2].astype(bool) + loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2) + debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss} + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + rend_img = draw_text(rend_img, debug_text) + + # Draw skeleton + gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size + pred_joint = ((pred_kp + 1) * 0.5) * img_size + img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis) + skel_img = draw_skeleton(img_with_gt, pred_joint) + + combined = np.hstack([skel_img, rend_img]) + + return combined + +def visualize_reconstruction_test(img, img_size, gt_kp, vertices, pred_kp, camera, renderer, score, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + gt_vis = gt_kp[:, 2].astype(bool) + loss = np.sum((gt_kp[gt_vis, :2] - pred_kp[gt_vis])**2) + debug_text = {"sc": camera[0], "tx": camera[1], "ty": camera[2], "kpl": loss, "pa-mpjpe": score*1000} + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + rend_img = draw_text(rend_img, debug_text) + + # Draw skeleton + gt_joint = ((gt_kp[:, :2] + 1) * 0.5) * img_size + pred_joint = ((pred_kp + 1) * 0.5) * img_size + img_with_gt = draw_skeleton( img, gt_joint, draw_edges=False, vis=gt_vis) + skel_img = draw_skeleton(img_with_gt, pred_joint) + + combined = np.hstack([skel_img, rend_img]) + + return combined + + + +def visualize_reconstruction_and_att(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices_full, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, body_color='light_blue') + + + heads_num, vertex_num, _ = attention.shape + + all_head = np.zeros((vertex_num,vertex_num)) + + ###### find max + # for i in range(vertex_num): + # for j in range(vertex_num): + # all_head[i,j] = np.max(attention[:,i,j]) + + ##### find avg + for h in range(4): + att_per_img = attention[h] + all_head = all_head + att_per_img + all_head = all_head/4 + + col_sums = all_head.sum(axis=0) + all_head = all_head / col_sums[np.newaxis, :] + + + # code.interact(local=locals()) + + combined = [] + if vertex_num>400: # body + selected_joints = [6,7,4,5,13] # [6,7,4,5,13,12] + else: # hand + selected_joints = [0, 4, 8, 12, 16, 20] + # Draw attention + for ii in range(len(selected_joints)): + reference_id = selected_joints[ii] + ref_point = ref_points[reference_id] + attention_to_show = all_head[reference_id][14::] + min_v = np.min(attention_to_show) + max_v = np.max(attention_to_show) + norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v) + + vertices_norm = ((vertices_2d + 1) * 0.5) * img_size + ref_norm = ((ref_point + 1) * 0.5) * img_size + image = np.zeros_like(rend_img) + + for jj in range(vertices_norm.shape[0]): + x = int(vertices_norm[jj,0]) + y = int(vertices_norm[jj,1]) + cv2.circle(image,(x,y), 1, (255,255,255), -1) + + total_to_draw = [] + for jj in range(vertices_norm.shape[0]): + thres = 0.0 + if norm_attention_to_show[jj]>thres: + things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]] + total_to_draw.append(things) + # plot_one_line(ref_norm, vertices_norm[jj], image, reference_id, alpha=0.4*(norm_attention_to_show[jj]-thres)/(1-thres) ) + total_to_draw.sort() + max_att_score = total_to_draw[-1][0] + for item in total_to_draw: + attention_score = item[0] + ref_point = item[1] + vertex = item[2] + plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) ) + # code.interact(local=locals()) + if len(combined)==0: + combined = image + else: + combined = np.hstack([combined, image]) + + final = np.hstack([img, combined, rend_img]) + + return final + + +def visualize_reconstruction_and_att_local(img, img_size, vertices_full, vertices, vertices_2d, camera, renderer, ref_points, attention, color='light_blue', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices_full, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, body_color=color) + heads_num, vertex_num, _ = attention.shape + all_head = np.zeros((vertex_num,vertex_num)) + + ##### compute avg attention for 4 attention heads + for h in range(4): + att_per_img = attention[h] + all_head = all_head + att_per_img + all_head = all_head/4 + + col_sums = all_head.sum(axis=0) + all_head = all_head / col_sums[np.newaxis, :] + + combined = [] + if vertex_num>400: # body + selected_joints = [7] # [6,7,4,5,13,12] + else: # hand + selected_joints = [0] # [0, 4, 8, 12, 16, 20] + # Draw attention + for ii in range(len(selected_joints)): + reference_id = selected_joints[ii] + ref_point = ref_points[reference_id] + attention_to_show = all_head[reference_id][14::] + min_v = np.min(attention_to_show) + max_v = np.max(attention_to_show) + norm_attention_to_show = (attention_to_show - min_v)/(max_v-min_v) + vertices_norm = ((vertices_2d + 1) * 0.5) * img_size + ref_norm = ((ref_point + 1) * 0.5) * img_size + image = rend_img*0.4 + + total_to_draw = [] + for jj in range(vertices_norm.shape[0]): + thres = 0.0 + if norm_attention_to_show[jj]>thres: + things = [norm_attention_to_show[jj], ref_norm, vertices_norm[jj]] + total_to_draw.append(things) + total_to_draw.sort() + max_att_score = total_to_draw[-1][0] + for item in total_to_draw: + attention_score = item[0] + ref_point = item[1] + vertex = item[2] + plot_one_line(ref_point, vertex, image, ii, alpha=(attention_score-thres)/(max_att_score-thres) ) + + for jj in range(vertices_norm.shape[0]): + x = int(vertices_norm[jj,0]) + y = int(vertices_norm[jj,1]) + cv2.circle(image,(x,y), 1, (255,255,255), -1) + + if len(combined)==0: + combined = image + else: + combined = np.hstack([combined, image]) + + final = np.hstack([img, combined, rend_img]) + + return final + + +def visualize_reconstruction_no_text(img, img_size, vertices, camera, renderer, color='pink', focal_length=1000): + """Overlays gt_kp and pred_kp on img. + Draws vert with text. + Renderer is an instance of SMPLRenderer. + """ + # Fix a flength so i can render this with persp correct scale + res = img.shape[1] + camera_t = np.array([camera[1], camera[2], 2*focal_length/(res * camera[0] +1e-9)]) + rend_img = renderer.render(vertices, camera_t=camera_t, + img=img, use_bg=True, + focal_length=focal_length, + body_color=color) + + + combined = np.hstack([img, rend_img]) + + return combined + + +def plot_one_line(ref, vertex, img, color_index, alpha=0.0, line_thickness=None): + # 13,6,7,8,3,4,5 + # att_colors = [(255, 221, 104), (255, 255, 0), (255, 215, 227), (210, 240, 119), \ + # (209, 238, 245), (244, 200, 243), (233, 242, 216)] + att_colors = [(255, 255, 0), (244, 200, 243), (210, 243, 119), (209, 238, 255), (200, 208, 255), (250, 238, 215)] + + + overlay = img.copy() + # output = img.copy() + # Plots one bounding box on image img + tl = line_thickness or round(0.002 * (img.shape[0] + img.shape[1]) / 2) + 1 # line/font thickness + + color = list(att_colors[color_index]) + c1, c2 = (int(ref[0]), int(ref[1])), (int(vertex[0]), int(vertex[1])) + cv2.line(overlay, c1, c2, (alpha*float(color[0])/255,alpha*float(color[1])/255,alpha*float(color[2])/255) , thickness=tl, lineType=cv2.LINE_AA) + cv2.addWeighted(overlay, alpha, img, 1 - alpha, 0, img) + + + +def cam2pixel(cam_coord, f, c): + x = cam_coord[:, 0] / (cam_coord[:, 2]) * f[0] + c[0] + y = cam_coord[:, 1] / (cam_coord[:, 2]) * f[1] + c[1] + z = cam_coord[:, 2] + img_coord = np.concatenate((x[:,None], y[:,None], z[:,None]),1) + return img_coord + + +class Renderer(object): + """ + Render mesh using OpenDR for visualization. + """ + + def __init__(self, width=800, height=600, near=0.5, far=1000, faces=None): + self.colors = {'hand': [.9, .9, .9], 'pink': [.9, .7, .7], 'light_blue': [0.65098039, 0.74117647, 0.85882353] } + self.width = width + self.height = height + self.faces = faces + self.renderer = ColoredRenderer() + + def render(self, vertices, faces=None, img=None, + camera_t=np.zeros([3], dtype=np.float32), + camera_rot=np.zeros([3], dtype=np.float32), + camera_center=None, + use_bg=False, + bg_color=(0.0, 0.0, 0.0), + body_color=None, + focal_length=5000, + disp_text=False, + gt_keyp=None, + pred_keyp=None, + **kwargs): + if img is not None: + height, width = img.shape[:2] + else: + height, width = self.height, self.width + + if faces is None: + faces = self.faces + + if camera_center is None: + camera_center = np.array([width * 0.5, + height * 0.5]) + + self.renderer.camera = ProjectPoints(rt=camera_rot, + t=camera_t, + f=focal_length * np.ones(2), + c=camera_center, + k=np.zeros(5)) + dist = np.abs(self.renderer.camera.t.r[2] - + np.mean(vertices, axis=0)[2]) + far = dist + 20 + + self.renderer.frustum = {'near': 1.0, 'far': far, + 'width': width, + 'height': height} + + if img is not None: + if use_bg: + self.renderer.background_image = img + else: + self.renderer.background_image = np.ones_like( + img) * np.array(bg_color) + + if body_color is None: + color = self.colors['light_blue'] + else: + color = self.colors[body_color] + + if isinstance(self.renderer, TexturedRenderer): + color = [1.,1.,1.] + + self.renderer.set(v=vertices, f=faces, + vc=color, bgcolor=np.ones(3)) + albedo = self.renderer.vc + # Construct Back Light (on back right corner) + yrot = np.radians(120) + + self.renderer.vc = LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-200, -100, -100]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Left Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([800, 10, 300]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Right Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-500, 500, 1000]), yrot), + vc=albedo, + light_color=np.array([.7, .7, .7])) + + return self.renderer.r + + + def render_vertex_color(self, vertices, faces=None, img=None, + camera_t=np.zeros([3], dtype=np.float32), + camera_rot=np.zeros([3], dtype=np.float32), + camera_center=None, + use_bg=False, + bg_color=(0.0, 0.0, 0.0), + vertex_color=None, + focal_length=5000, + disp_text=False, + gt_keyp=None, + pred_keyp=None, + **kwargs): + if img is not None: + height, width = img.shape[:2] + else: + height, width = self.height, self.width + + if faces is None: + faces = self.faces + + if camera_center is None: + camera_center = np.array([width * 0.5, + height * 0.5]) + + self.renderer.camera = ProjectPoints(rt=camera_rot, + t=camera_t, + f=focal_length * np.ones(2), + c=camera_center, + k=np.zeros(5)) + dist = np.abs(self.renderer.camera.t.r[2] - + np.mean(vertices, axis=0)[2]) + far = dist + 20 + + self.renderer.frustum = {'near': 1.0, 'far': far, + 'width': width, + 'height': height} + + if img is not None: + if use_bg: + self.renderer.background_image = img + else: + self.renderer.background_image = np.ones_like( + img) * np.array(bg_color) + + if vertex_color is None: + vertex_color = self.colors['light_blue'] + + + self.renderer.set(v=vertices, f=faces, + vc=vertex_color, bgcolor=np.ones(3)) + albedo = self.renderer.vc + # Construct Back Light (on back right corner) + yrot = np.radians(120) + + self.renderer.vc = LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-200, -100, -100]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Left Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([800, 10, 300]), yrot), + vc=albedo, + light_color=np.array([1, 1, 1])) + + # Construct Right Light + self.renderer.vc += LambertianPointLight( + f=self.renderer.f, + v=self.renderer.v, + num_verts=self.renderer.v.shape[0], + light_pos=rotateY(np.array([-500, 500, 1000]), yrot), + vc=albedo, + light_color=np.array([.7, .7, .7])) + + return self.renderer.r \ No newline at end of file diff --git a/custom_mesh_graphormer/utils/tsv_file.py b/custom_mesh_graphormer/utils/tsv_file.py new file mode 100644 index 0000000000000000000000000000000000000000..8e79a9ec55d7ee9a2ce622d7a4ae5af59f7d7c17 --- /dev/null +++ b/custom_mesh_graphormer/utils/tsv_file.py @@ -0,0 +1,162 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Definition of TSV class +""" + + +import logging +import os +import os.path as op + + +def generate_lineidx(filein, idxout): + idxout_tmp = idxout + '.tmp' + with open(filein, 'r') as tsvin, open(idxout_tmp,'w') as tsvout: + fsize = os.fstat(tsvin.fileno()).st_size + fpos = 0 + while fpos!=fsize: + tsvout.write(str(fpos)+"\n") + tsvin.readline() + fpos = tsvin.tell() + os.rename(idxout_tmp, idxout) + + +def read_to_character(fp, c): + result = [] + while True: + s = fp.read(32) + assert s != '' + if c in s: + result.append(s[: s.index(c)]) + break + else: + result.append(s) + return ''.join(result) + + +class TSVFile(object): + def __init__(self, tsv_file, generate_lineidx=False): + self.tsv_file = tsv_file + self.lineidx = op.splitext(tsv_file)[0] + '.lineidx' + self._fp = None + self._lineidx = None + # the process always keeps the process which opens the file. + # If the pid is not equal to the currrent pid, we will re-open the file. + self.pid = None + # generate lineidx if not exist + if not op.isfile(self.lineidx) and generate_lineidx: + generate_lineidx(self.tsv_file, self.lineidx) + + def __del__(self): + if self._fp: + self._fp.close() + + def __str__(self): + return "TSVFile(tsv_file='{}')".format(self.tsv_file) + + def __repr__(self): + return str(self) + + def num_rows(self): + self._ensure_lineidx_loaded() + return len(self._lineidx) + + def seek(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + try: + pos = self._lineidx[idx] + except: + logging.info('{}-{}'.format(self.tsv_file, idx)) + raise + self._fp.seek(pos) + return [s.strip() for s in self._fp.readline().split('\t')] + + def seek_first_column(self, idx): + self._ensure_tsv_opened() + self._ensure_lineidx_loaded() + pos = self._lineidx[idx] + self._fp.seek(pos) + return read_to_character(self._fp, '\t') + + def get_key(self, idx): + return self.seek_first_column(idx) + + def __getitem__(self, index): + return self.seek(index) + + def __len__(self): + return self.num_rows() + + def _ensure_lineidx_loaded(self): + if self._lineidx is None: + logging.info('loading lineidx: {}'.format(self.lineidx)) + with open(self.lineidx, 'r') as fp: + self._lineidx = [int(i.strip()) for i in fp.readlines()] + + def _ensure_tsv_opened(self): + if self._fp is None: + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + if self.pid != os.getpid(): + logging.info('re-open {} because the process id changed'.format(self.tsv_file)) + self._fp = open(self.tsv_file, 'r') + self.pid = os.getpid() + + +class CompositeTSVFile(): + def __init__(self, file_list, seq_file, root='.'): + if isinstance(file_list, str): + self.file_list = load_list_file(file_list) + else: + assert isinstance(file_list, list) + self.file_list = file_list + + self.seq_file = seq_file + self.root = root + self.initialized = False + self.initialize() + + def get_key(self, index): + idx_source, idx_row = self.seq[index] + k = self.tsvs[idx_source].get_key(idx_row) + return '_'.join([self.file_list[idx_source], k]) + + def num_rows(self): + return len(self.seq) + + def __getitem__(self, index): + idx_source, idx_row = self.seq[index] + return self.tsvs[idx_source].seek(idx_row) + + def __len__(self): + return len(self.seq) + + def initialize(self): + ''' + this function has to be called in init function if cache_policy is + enabled. Thus, let's always call it in init funciton to make it simple. + ''' + if self.initialized: + return + self.seq = [] + with open(self.seq_file, 'r') as fp: + for line in fp: + parts = line.strip().split('\t') + self.seq.append([int(parts[0]), int(parts[1])]) + self.tsvs = [TSVFile(op.join(self.root, f)) for f in self.file_list] + self.initialized = True + + +def load_list_file(fname): + with open(fname, 'r') as fp: + lines = fp.readlines() + result = [line.strip() for line in lines] + if len(result) > 0 and result[-1] == '': + result = result[:-1] + return result + + diff --git a/custom_mesh_graphormer/utils/tsv_file_ops.py b/custom_mesh_graphormer/utils/tsv_file_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..70f2798b931fd937d100b2093228013eb6176113 --- /dev/null +++ b/custom_mesh_graphormer/utils/tsv_file_ops.py @@ -0,0 +1,116 @@ +""" +Copyright (c) Microsoft Corporation. +Licensed under the MIT license. + +Basic operations for TSV files +""" + + +import os +import os.path as op +import json +import numpy as np +import base64 +import cv2 +from tqdm import tqdm +import yaml +from custom_mesh_graphormer.utils.miscellaneous import mkdir +from custom_mesh_graphormer.utils.tsv_file import TSVFile + + +def img_from_base64(imagestring): + try: + jpgbytestring = base64.b64decode(imagestring) + nparr = np.frombuffer(jpgbytestring, np.uint8) + r = cv2.imdecode(nparr, cv2.IMREAD_COLOR) + return r + except ValueError: + return None + +def load_linelist_file(linelist_file): + if linelist_file is not None: + line_list = [] + with open(linelist_file, 'r') as fp: + for i in fp: + line_list.append(int(i.strip())) + return line_list + +def tsv_writer(values, tsv_file, sep='\t'): + mkdir(op.dirname(tsv_file)) + lineidx_file = op.splitext(tsv_file)[0] + '.lineidx' + idx = 0 + tsv_file_tmp = tsv_file + '.tmp' + lineidx_file_tmp = lineidx_file + '.tmp' + with open(tsv_file_tmp, 'w') as fp, open(lineidx_file_tmp, 'w') as fpidx: + assert values is not None + for value in values: + assert value is not None + value = [v if type(v)!=bytes else v.decode('utf-8') for v in value] + v = '{0}\n'.format(sep.join(map(str, value))) + fp.write(v) + fpidx.write(str(idx) + '\n') + idx = idx + len(v) + os.rename(tsv_file_tmp, tsv_file) + os.rename(lineidx_file_tmp, lineidx_file) + +def tsv_reader(tsv_file, sep='\t'): + with open(tsv_file, 'r') as fp: + for i, line in enumerate(fp): + yield [x.strip() for x in line.split(sep)] + +def config_save_file(tsv_file, save_file=None, append_str='.new.tsv'): + if save_file is not None: + return save_file + return op.splitext(tsv_file)[0] + append_str + +def get_line_list(linelist_file=None, num_rows=None): + if linelist_file is not None: + return load_linelist_file(linelist_file) + + if num_rows is not None: + return [i for i in range(num_rows)] + +def generate_hw_file(img_file, save_file=None): + rows = tsv_reader(img_file) + def gen_rows(): + for i, row in tqdm(enumerate(rows)): + row1 = [row[0]] + img = img_from_base64(row[-1]) + height = img.shape[0] + width = img.shape[1] + row1.append(json.dumps([{"height":height, "width": width}])) + yield row1 + + save_file = config_save_file(img_file, save_file, '.hw.tsv') + tsv_writer(gen_rows(), save_file) + +def generate_linelist_file(label_file, save_file=None, ignore_attrs=()): + # generate a list of image that has labels + # images with only ignore labels are not selected. + line_list = [] + rows = tsv_reader(label_file) + for i, row in tqdm(enumerate(rows)): + labels = json.loads(row[1]) + if labels: + if ignore_attrs and all([any([lab[attr] for attr in ignore_attrs if attr in lab]) \ + for lab in labels]): + continue + line_list.append([i]) + + save_file = config_save_file(label_file, save_file, '.linelist.tsv') + tsv_writer(line_list, save_file) + +def load_from_yaml_file(yaml_file): + with open(yaml_file, 'r') as fp: + return yaml.load(fp, Loader=yaml.CLoader) + +def find_file_path_in_yaml(fname, root): + if fname is not None: + if op.isfile(fname): + return fname + elif op.isfile(op.join(root, fname)): + return op.join(root, fname) + else: + raise FileNotFoundError( + errno.ENOENT, os.strerror(errno.ENOENT), op.join(root, fname) + )