Spaces:
Configuration error
Configuration error
Upload 777 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +2 -0
- __init__.py +1 -214
- custom_albumentations/LICENSE +21 -0
- custom_albumentations/__init__.py +15 -0
- custom_albumentations/augmentations/__init__.py +21 -0
- custom_albumentations/augmentations/blur/__init__.py +2 -0
- custom_albumentations/augmentations/blur/functional.py +106 -0
- custom_albumentations/augmentations/blur/transforms.py +486 -0
- custom_albumentations/augmentations/crops/__init__.py +2 -0
- custom_albumentations/augmentations/crops/functional.py +317 -0
- custom_albumentations/augmentations/crops/transforms.py +943 -0
- custom_albumentations/augmentations/domain_adaptation.py +337 -0
- custom_albumentations/augmentations/dropout/__init__.py +5 -0
- custom_albumentations/augmentations/dropout/channel_dropout.py +72 -0
- custom_albumentations/augmentations/dropout/coarse_dropout.py +187 -0
- custom_albumentations/augmentations/dropout/cutout.py +79 -0
- custom_albumentations/augmentations/dropout/functional.py +29 -0
- custom_albumentations/augmentations/dropout/grid_dropout.py +155 -0
- custom_albumentations/augmentations/dropout/mask_dropout.py +99 -0
- custom_albumentations/augmentations/functional.py +1380 -0
- custom_albumentations/augmentations/geometric/__init__.py +4 -0
- custom_albumentations/augmentations/geometric/functional.py +1300 -0
- custom_albumentations/augmentations/geometric/resize.py +198 -0
- custom_albumentations/augmentations/geometric/rotate.py +294 -0
- custom_albumentations/augmentations/geometric/transforms.py +1499 -0
- custom_albumentations/augmentations/transforms.py +2667 -0
- custom_albumentations/augmentations/utils.py +211 -0
- custom_albumentations/core/__init__.py +0 -0
- custom_albumentations/core/bbox_utils.py +522 -0
- custom_albumentations/core/composition.py +552 -0
- custom_albumentations/core/keypoints_utils.py +286 -0
- custom_albumentations/core/serialization.py +247 -0
- custom_albumentations/core/transforms_interface.py +293 -0
- custom_albumentations/core/utils.py +137 -0
- custom_albumentations/imgaug/__init__.py +0 -0
- custom_albumentations/imgaug/stubs.py +77 -0
- custom_albumentations/imgaug/transforms.py +391 -0
- custom_albumentations/pytorch/__init__.py +3 -0
- custom_albumentations/pytorch/functional.py +31 -0
- custom_albumentations/pytorch/transforms.py +104 -0
- custom_albumentations/random_utils.py +96 -0
- custom_controlnet_aux/__init__.py +1 -0
- custom_controlnet_aux/anime_face_segment/__init__.py +66 -0
- custom_controlnet_aux/anime_face_segment/anime_segmentation.py +58 -0
- custom_controlnet_aux/anime_face_segment/isnet.py +619 -0
- custom_controlnet_aux/anime_face_segment/network.py +100 -0
- custom_controlnet_aux/anime_face_segment/util.py +40 -0
- custom_controlnet_aux/binary/__init__.py +38 -0
- custom_controlnet_aux/canny/__init__.py +17 -0
- custom_controlnet_aux/color/__init__.py +37 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
|
__init__.py
CHANGED
|
@@ -1,214 +1 @@
|
|
| 1 |
-
|
| 2 |
-
from .utils import here, define_preprocessor_inputs, INPUT
|
| 3 |
-
from pathlib import Path
|
| 4 |
-
import traceback
|
| 5 |
-
import importlib
|
| 6 |
-
from .log import log, blue_text, cyan_text, get_summary, get_label
|
| 7 |
-
from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS
|
| 8 |
-
from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS
|
| 9 |
-
#Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17
|
| 10 |
-
sys.path.insert(0, str(Path(here, "src").resolve()))
|
| 11 |
-
for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]:
|
| 12 |
-
sys.path.append(str(Path(here, "src", pkg_name).resolve()))
|
| 13 |
-
|
| 14 |
-
#Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out
|
| 15 |
-
#https://github.com/pytorch/pytorch/issues/77764
|
| 16 |
-
#https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485
|
| 17 |
-
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1')
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
def load_nodes():
|
| 21 |
-
shorted_errors = []
|
| 22 |
-
full_error_messages = []
|
| 23 |
-
node_class_mappings = {}
|
| 24 |
-
node_display_name_mappings = {}
|
| 25 |
-
|
| 26 |
-
for filename in (here / "node_wrappers").iterdir():
|
| 27 |
-
module_name = filename.stem
|
| 28 |
-
if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store))
|
| 29 |
-
try:
|
| 30 |
-
module = importlib.import_module(
|
| 31 |
-
f".node_wrappers.{module_name}", package=__package__
|
| 32 |
-
)
|
| 33 |
-
node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS"))
|
| 34 |
-
if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
|
| 35 |
-
node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS"))
|
| 36 |
-
|
| 37 |
-
log.debug(f"Imported {module_name} nodes")
|
| 38 |
-
|
| 39 |
-
except AttributeError:
|
| 40 |
-
pass # wip nodes
|
| 41 |
-
except Exception:
|
| 42 |
-
error_message = traceback.format_exc()
|
| 43 |
-
full_error_messages.append(error_message)
|
| 44 |
-
error_message = error_message.splitlines()[-1]
|
| 45 |
-
shorted_errors.append(
|
| 46 |
-
f"Failed to import module {module_name} because {error_message}"
|
| 47 |
-
)
|
| 48 |
-
|
| 49 |
-
if len(shorted_errors) > 0:
|
| 50 |
-
full_err_log = '\n\n'.join(full_error_messages)
|
| 51 |
-
print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n")
|
| 52 |
-
log.info(
|
| 53 |
-
f"Some nodes failed to load:\n\t"
|
| 54 |
-
+ "\n\t".join(shorted_errors)
|
| 55 |
-
+ "\n\n"
|
| 56 |
-
+ "Check that you properly installed the dependencies.\n"
|
| 57 |
-
+ "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)"
|
| 58 |
-
)
|
| 59 |
-
return node_class_mappings, node_display_name_mappings
|
| 60 |
-
|
| 61 |
-
AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes()
|
| 62 |
-
|
| 63 |
-
#For nodes not mapping image to image or has special requirements
|
| 64 |
-
AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"]
|
| 65 |
-
AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"]
|
| 66 |
-
AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"]
|
| 67 |
-
|
| 68 |
-
def preprocessor_options():
|
| 69 |
-
auxs = list(AUX_NODE_MAPPINGS.keys())
|
| 70 |
-
auxs.insert(0, "none")
|
| 71 |
-
for name in AIO_NOT_SUPPORTED:
|
| 72 |
-
if name in auxs:
|
| 73 |
-
auxs.remove(name)
|
| 74 |
-
return auxs
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
PREPROCESSOR_OPTIONS = preprocessor_options()
|
| 78 |
-
|
| 79 |
-
class AIO_Preprocessor:
|
| 80 |
-
@classmethod
|
| 81 |
-
def INPUT_TYPES(s):
|
| 82 |
-
return define_preprocessor_inputs(
|
| 83 |
-
preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"),
|
| 84 |
-
resolution=INPUT.RESOLUTION()
|
| 85 |
-
)
|
| 86 |
-
|
| 87 |
-
RETURN_TYPES = ("IMAGE",)
|
| 88 |
-
FUNCTION = "execute"
|
| 89 |
-
|
| 90 |
-
CATEGORY = "ControlNet Preprocessors"
|
| 91 |
-
|
| 92 |
-
def execute(self, preprocessor, image, resolution=512):
|
| 93 |
-
if preprocessor == "none":
|
| 94 |
-
return (image, )
|
| 95 |
-
else:
|
| 96 |
-
aux_class = AUX_NODE_MAPPINGS[preprocessor]
|
| 97 |
-
input_types = aux_class.INPUT_TYPES()
|
| 98 |
-
input_types = {
|
| 99 |
-
**input_types["required"],
|
| 100 |
-
**(input_types["optional"] if "optional" in input_types else {})
|
| 101 |
-
}
|
| 102 |
-
params = {}
|
| 103 |
-
for name, input_type in input_types.items():
|
| 104 |
-
if name == "image":
|
| 105 |
-
params[name] = image
|
| 106 |
-
continue
|
| 107 |
-
|
| 108 |
-
if name == "resolution":
|
| 109 |
-
params[name] = resolution
|
| 110 |
-
continue
|
| 111 |
-
|
| 112 |
-
if len(input_type) == 2 and ("default" in input_type[1]):
|
| 113 |
-
params[name] = input_type[1]["default"]
|
| 114 |
-
continue
|
| 115 |
-
|
| 116 |
-
default_values = { "INT": 0, "FLOAT": 0.0 }
|
| 117 |
-
if input_type[0] in default_values:
|
| 118 |
-
params[name] = default_values[input_type[0]]
|
| 119 |
-
|
| 120 |
-
return getattr(aux_class(), aux_class.FUNCTION)(**params)
|
| 121 |
-
|
| 122 |
-
class ControlNetAuxSimpleAddText:
|
| 123 |
-
@classmethod
|
| 124 |
-
def INPUT_TYPES(s):
|
| 125 |
-
return dict(
|
| 126 |
-
required=dict(image=INPUT.IMAGE(), text=INPUT.STRING())
|
| 127 |
-
)
|
| 128 |
-
|
| 129 |
-
RETURN_TYPES = ("IMAGE",)
|
| 130 |
-
FUNCTION = "execute"
|
| 131 |
-
CATEGORY = "ControlNet Preprocessors"
|
| 132 |
-
def execute(self, image, text):
|
| 133 |
-
from PIL import Image, ImageDraw, ImageFont
|
| 134 |
-
import numpy as np
|
| 135 |
-
import torch
|
| 136 |
-
|
| 137 |
-
font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40)
|
| 138 |
-
img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8))
|
| 139 |
-
ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font)
|
| 140 |
-
return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,)
|
| 141 |
-
|
| 142 |
-
class ExecuteAllControlNetPreprocessors:
|
| 143 |
-
@classmethod
|
| 144 |
-
def INPUT_TYPES(s):
|
| 145 |
-
return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
|
| 146 |
-
RETURN_TYPES = ("IMAGE",)
|
| 147 |
-
FUNCTION = "execute"
|
| 148 |
-
|
| 149 |
-
CATEGORY = "ControlNet Preprocessors"
|
| 150 |
-
|
| 151 |
-
def execute(self, image, resolution=512):
|
| 152 |
-
try:
|
| 153 |
-
from comfy_execution.graph_utils import GraphBuilder
|
| 154 |
-
except:
|
| 155 |
-
raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature")
|
| 156 |
-
|
| 157 |
-
graph = GraphBuilder()
|
| 158 |
-
curr_outputs = []
|
| 159 |
-
for preprocc in PREPROCESSOR_OPTIONS:
|
| 160 |
-
preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution)
|
| 161 |
-
hint_img = preprocc_node.out(0)
|
| 162 |
-
add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc)
|
| 163 |
-
curr_outputs.append(add_text_node.out(0))
|
| 164 |
-
|
| 165 |
-
while len(curr_outputs) > 1:
|
| 166 |
-
_outputs = []
|
| 167 |
-
for i in range(0, len(curr_outputs), 2):
|
| 168 |
-
if i+1 < len(curr_outputs):
|
| 169 |
-
image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1])
|
| 170 |
-
_outputs.append(image_batch.out(0))
|
| 171 |
-
else:
|
| 172 |
-
_outputs.append(curr_outputs[i])
|
| 173 |
-
curr_outputs = _outputs
|
| 174 |
-
|
| 175 |
-
return {
|
| 176 |
-
"result": (curr_outputs[0],),
|
| 177 |
-
"expand": graph.finalize(),
|
| 178 |
-
}
|
| 179 |
-
|
| 180 |
-
class ControlNetPreprocessorSelector:
|
| 181 |
-
@classmethod
|
| 182 |
-
def INPUT_TYPES(s):
|
| 183 |
-
return {
|
| 184 |
-
"required": {
|
| 185 |
-
"preprocessor": (PREPROCESSOR_OPTIONS,),
|
| 186 |
-
}
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
RETURN_TYPES = (PREPROCESSOR_OPTIONS,)
|
| 190 |
-
RETURN_NAMES = ("preprocessor",)
|
| 191 |
-
FUNCTION = "get_preprocessor"
|
| 192 |
-
|
| 193 |
-
CATEGORY = "ControlNet Preprocessors"
|
| 194 |
-
|
| 195 |
-
def get_preprocessor(self, preprocessor: str):
|
| 196 |
-
return (preprocessor,)
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
NODE_CLASS_MAPPINGS = {
|
| 200 |
-
**AUX_NODE_MAPPINGS,
|
| 201 |
-
"AIO_Preprocessor": AIO_Preprocessor,
|
| 202 |
-
"ControlNetPreprocessorSelector": ControlNetPreprocessorSelector,
|
| 203 |
-
**HIE_NODE_CLASS_MAPPINGS,
|
| 204 |
-
"ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors,
|
| 205 |
-
"ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText
|
| 206 |
-
}
|
| 207 |
-
|
| 208 |
-
NODE_DISPLAY_NAME_MAPPINGS = {
|
| 209 |
-
**AUX_DISPLAY_NAME_MAPPINGS,
|
| 210 |
-
"AIO_Preprocessor": "AIO Aux Preprocessor",
|
| 211 |
-
"ControlNetPreprocessorSelector": "Preprocessor Selector",
|
| 212 |
-
**HIE_NODE_DISPLAY_NAME_MAPPINGS,
|
| 213 |
-
"ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors"
|
| 214 |
-
}
|
|
|
|
| 1 |
+
#Dummy file ensuring this package will be recognized
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
custom_albumentations/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2017 Buslaev Alexander, Alexander Parinov, Vladimir Iglovikov
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
custom_albumentations/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
__version__ = "1.3.1"
|
| 4 |
+
|
| 5 |
+
from .augmentations import *
|
| 6 |
+
from .core.composition import *
|
| 7 |
+
from .core.serialization import *
|
| 8 |
+
from .core.transforms_interface import *
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from .imgaug.transforms import * # type: ignore
|
| 12 |
+
except ImportError:
|
| 13 |
+
# imgaug is not installed by default, so we import stubs.
|
| 14 |
+
# Run `pip install -U albumentations[imgaug] if you need augmentations from imgaug.`
|
| 15 |
+
from .imgaug.stubs import * # type: ignore
|
custom_albumentations/augmentations/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Common classes
|
| 2 |
+
from .blur.functional import *
|
| 3 |
+
from .blur.transforms import *
|
| 4 |
+
from .crops.functional import *
|
| 5 |
+
from .crops.transforms import *
|
| 6 |
+
|
| 7 |
+
# New transformations goes to individual files listed below
|
| 8 |
+
from .domain_adaptation import *
|
| 9 |
+
from .dropout.channel_dropout import *
|
| 10 |
+
from .dropout.coarse_dropout import *
|
| 11 |
+
from .dropout.cutout import *
|
| 12 |
+
from .dropout.functional import *
|
| 13 |
+
from .dropout.grid_dropout import *
|
| 14 |
+
from .dropout.mask_dropout import *
|
| 15 |
+
from .functional import *
|
| 16 |
+
from .geometric.functional import *
|
| 17 |
+
from .geometric.resize import *
|
| 18 |
+
from .geometric.rotate import *
|
| 19 |
+
from .geometric.transforms import *
|
| 20 |
+
from .transforms import *
|
| 21 |
+
from .utils import *
|
custom_albumentations/augmentations/blur/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functional import *
|
| 2 |
+
from .transforms import *
|
custom_albumentations/augmentations/blur/functional.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from itertools import product
|
| 2 |
+
from math import ceil
|
| 3 |
+
from typing import Sequence, Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from custom_albumentations.augmentations.functional import convolve
|
| 9 |
+
from custom_albumentations.augmentations.geometric.functional import scale
|
| 10 |
+
from custom_albumentations.augmentations.utils import (
|
| 11 |
+
_maybe_process_in_chunks,
|
| 12 |
+
clipped,
|
| 13 |
+
preserve_shape,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
__all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@preserve_shape
|
| 20 |
+
def blur(img: np.ndarray, ksize: int) -> np.ndarray:
|
| 21 |
+
blur_fn = _maybe_process_in_chunks(cv2.blur, ksize=(ksize, ksize))
|
| 22 |
+
return blur_fn(img)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@preserve_shape
|
| 26 |
+
def median_blur(img: np.ndarray, ksize: int) -> np.ndarray:
|
| 27 |
+
if img.dtype == np.float32 and ksize not in {3, 5}:
|
| 28 |
+
raise ValueError(f"Invalid ksize value {ksize}. For a float32 image the only valid ksize values are 3 and 5")
|
| 29 |
+
|
| 30 |
+
blur_fn = _maybe_process_in_chunks(cv2.medianBlur, ksize=ksize)
|
| 31 |
+
return blur_fn(img)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@preserve_shape
|
| 35 |
+
def gaussian_blur(img: np.ndarray, ksize: int, sigma: float = 0) -> np.ndarray:
|
| 36 |
+
# When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
|
| 37 |
+
blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma)
|
| 38 |
+
return blur_fn(img)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
@preserve_shape
|
| 42 |
+
def glass_blur(
|
| 43 |
+
img: np.ndarray, sigma: float, max_delta: int, iterations: int, dxy: np.ndarray, mode: str
|
| 44 |
+
) -> np.ndarray:
|
| 45 |
+
x = cv2.GaussianBlur(np.array(img), sigmaX=sigma, ksize=(0, 0))
|
| 46 |
+
|
| 47 |
+
if mode == "fast":
|
| 48 |
+
hs = np.arange(img.shape[0] - max_delta, max_delta, -1)
|
| 49 |
+
ws = np.arange(img.shape[1] - max_delta, max_delta, -1)
|
| 50 |
+
h: Union[int, np.ndarray] = np.tile(hs, ws.shape[0])
|
| 51 |
+
w: Union[int, np.ndarray] = np.repeat(ws, hs.shape[0])
|
| 52 |
+
|
| 53 |
+
for i in range(iterations):
|
| 54 |
+
dy = dxy[:, i, 0]
|
| 55 |
+
dx = dxy[:, i, 1]
|
| 56 |
+
x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w]
|
| 57 |
+
|
| 58 |
+
elif mode == "exact":
|
| 59 |
+
for ind, (i, h, w) in enumerate(
|
| 60 |
+
product(
|
| 61 |
+
range(iterations),
|
| 62 |
+
range(img.shape[0] - max_delta, max_delta, -1),
|
| 63 |
+
range(img.shape[1] - max_delta, max_delta, -1),
|
| 64 |
+
)
|
| 65 |
+
):
|
| 66 |
+
ind = ind if ind < len(dxy) else ind % len(dxy)
|
| 67 |
+
dy = dxy[ind, i, 0]
|
| 68 |
+
dx = dxy[ind, i, 1]
|
| 69 |
+
x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w]
|
| 70 |
+
else:
|
| 71 |
+
ValueError(f"Unsupported mode `{mode}`. Supports only `fast` and `exact`.")
|
| 72 |
+
|
| 73 |
+
return cv2.GaussianBlur(x, sigmaX=sigma, ksize=(0, 0))
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def defocus(img: np.ndarray, radius: int, alias_blur: float) -> np.ndarray:
|
| 77 |
+
length = np.arange(-max(8, radius), max(8, radius) + 1)
|
| 78 |
+
ksize = 3 if radius <= 8 else 5
|
| 79 |
+
|
| 80 |
+
x, y = np.meshgrid(length, length)
|
| 81 |
+
aliased_disk = np.array((x**2 + y**2) <= radius**2, dtype=np.float32)
|
| 82 |
+
aliased_disk /= np.sum(aliased_disk)
|
| 83 |
+
|
| 84 |
+
kernel = gaussian_blur(aliased_disk, ksize, sigma=alias_blur)
|
| 85 |
+
return convolve(img, kernel=kernel)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def central_zoom(img: np.ndarray, zoom_factor: int) -> np.ndarray:
|
| 89 |
+
h, w = img.shape[:2]
|
| 90 |
+
h_ch, w_ch = ceil(h / zoom_factor), ceil(w / zoom_factor)
|
| 91 |
+
h_top, w_top = (h - h_ch) // 2, (w - w_ch) // 2
|
| 92 |
+
|
| 93 |
+
img = scale(img[h_top : h_top + h_ch, w_top : w_top + w_ch], zoom_factor, cv2.INTER_LINEAR)
|
| 94 |
+
h_trim_top, w_trim_top = (img.shape[0] - h) // 2, (img.shape[1] - w) // 2
|
| 95 |
+
return img[h_trim_top : h_trim_top + h, w_trim_top : w_trim_top + w]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@clipped
|
| 99 |
+
def zoom_blur(img: np.ndarray, zoom_factors: Union[np.ndarray, Sequence[int]]) -> np.ndarray:
|
| 100 |
+
out = np.zeros_like(img, dtype=np.float32)
|
| 101 |
+
for zoom_factor in zoom_factors:
|
| 102 |
+
out += central_zoom(img, zoom_factor)
|
| 103 |
+
|
| 104 |
+
img = ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype)
|
| 105 |
+
|
| 106 |
+
return img
|
custom_albumentations/augmentations/blur/transforms.py
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Any, Dict, List, Sequence, Tuple
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from custom_albumentations import random_utils
|
| 9 |
+
from custom_albumentations.augmentations import functional as FMain
|
| 10 |
+
from custom_albumentations.augmentations.blur import functional as F
|
| 11 |
+
from custom_albumentations.core.transforms_interface import (
|
| 12 |
+
ImageOnlyTransform,
|
| 13 |
+
ScaleFloatType,
|
| 14 |
+
ScaleIntType,
|
| 15 |
+
to_tuple,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
__all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Blur(ImageOnlyTransform):
|
| 22 |
+
"""Blur the input image using a random-sized kernel.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
blur_limit (int, (int, int)): maximum kernel size for blurring the input image.
|
| 26 |
+
Should be in range [3, inf). Default: (3, 7).
|
| 27 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 28 |
+
|
| 29 |
+
Targets:
|
| 30 |
+
image
|
| 31 |
+
|
| 32 |
+
Image types:
|
| 33 |
+
uint8, float32
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5):
|
| 37 |
+
super().__init__(always_apply, p)
|
| 38 |
+
self.blur_limit = to_tuple(blur_limit, 3)
|
| 39 |
+
|
| 40 |
+
def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray:
|
| 41 |
+
return F.blur(img, ksize)
|
| 42 |
+
|
| 43 |
+
def get_params(self) -> Dict[str, Any]:
|
| 44 |
+
return {"ksize": int(random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2))))}
|
| 45 |
+
|
| 46 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 47 |
+
return ("blur_limit",)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class MotionBlur(Blur):
|
| 51 |
+
"""Apply motion blur to the input image using a random-sized kernel.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
blur_limit (int): maximum kernel size for blurring the input image.
|
| 55 |
+
Should be in range [3, inf). Default: (3, 7).
|
| 56 |
+
allow_shifted (bool): if set to true creates non shifted kernels only,
|
| 57 |
+
otherwise creates randomly shifted kernels. Default: True.
|
| 58 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 59 |
+
|
| 60 |
+
Targets:
|
| 61 |
+
image
|
| 62 |
+
|
| 63 |
+
Image types:
|
| 64 |
+
uint8, float32
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(
|
| 68 |
+
self,
|
| 69 |
+
blur_limit: ScaleIntType = 7,
|
| 70 |
+
allow_shifted: bool = True,
|
| 71 |
+
always_apply: bool = False,
|
| 72 |
+
p: float = 0.5,
|
| 73 |
+
):
|
| 74 |
+
super().__init__(blur_limit=blur_limit, always_apply=always_apply, p=p)
|
| 75 |
+
self.allow_shifted = allow_shifted
|
| 76 |
+
|
| 77 |
+
if not allow_shifted and self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1:
|
| 78 |
+
raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}")
|
| 79 |
+
|
| 80 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 81 |
+
return super().get_transform_init_args_names() + ("allow_shifted",)
|
| 82 |
+
|
| 83 |
+
def apply(self, img: np.ndarray, kernel: np.ndarray = None, **params) -> np.ndarray: # type: ignore
|
| 84 |
+
return FMain.convolve(img, kernel=kernel)
|
| 85 |
+
|
| 86 |
+
def get_params(self) -> Dict[str, Any]:
|
| 87 |
+
ksize = random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2)))
|
| 88 |
+
if ksize <= 2:
|
| 89 |
+
raise ValueError("ksize must be > 2. Got: {}".format(ksize))
|
| 90 |
+
kernel = np.zeros((ksize, ksize), dtype=np.uint8)
|
| 91 |
+
x1, x2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
|
| 92 |
+
if x1 == x2:
|
| 93 |
+
y1, y2 = random.sample(range(ksize), 2)
|
| 94 |
+
else:
|
| 95 |
+
y1, y2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
|
| 96 |
+
|
| 97 |
+
def make_odd_val(v1, v2):
|
| 98 |
+
len_v = abs(v1 - v2) + 1
|
| 99 |
+
if len_v % 2 != 1:
|
| 100 |
+
if v2 > v1:
|
| 101 |
+
v2 -= 1
|
| 102 |
+
else:
|
| 103 |
+
v1 -= 1
|
| 104 |
+
return v1, v2
|
| 105 |
+
|
| 106 |
+
if not self.allow_shifted:
|
| 107 |
+
x1, x2 = make_odd_val(x1, x2)
|
| 108 |
+
y1, y2 = make_odd_val(y1, y2)
|
| 109 |
+
|
| 110 |
+
xc = (x1 + x2) / 2
|
| 111 |
+
yc = (y1 + y2) / 2
|
| 112 |
+
|
| 113 |
+
center = ksize / 2 - 0.5
|
| 114 |
+
dx = xc - center
|
| 115 |
+
dy = yc - center
|
| 116 |
+
x1, x2 = [int(i - dx) for i in [x1, x2]]
|
| 117 |
+
y1, y2 = [int(i - dy) for i in [y1, y2]]
|
| 118 |
+
|
| 119 |
+
cv2.line(kernel, (x1, y1), (x2, y2), 1, thickness=1)
|
| 120 |
+
|
| 121 |
+
# Normalize kernel
|
| 122 |
+
return {"kernel": kernel.astype(np.float32) / np.sum(kernel)}
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class MedianBlur(Blur):
|
| 126 |
+
"""Blur the input image using a median filter with a random aperture linear size.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
blur_limit (int): maximum aperture linear size for blurring the input image.
|
| 130 |
+
Must be odd and in range [3, inf). Default: (3, 7).
|
| 131 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 132 |
+
|
| 133 |
+
Targets:
|
| 134 |
+
image
|
| 135 |
+
|
| 136 |
+
Image types:
|
| 137 |
+
uint8, float32
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5):
|
| 141 |
+
super().__init__(blur_limit, always_apply, p)
|
| 142 |
+
|
| 143 |
+
if self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1:
|
| 144 |
+
raise ValueError("MedianBlur supports only odd blur limits.")
|
| 145 |
+
|
| 146 |
+
def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray:
|
| 147 |
+
return F.median_blur(img, ksize)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class GaussianBlur(ImageOnlyTransform):
|
| 151 |
+
"""Blur the input image using a Gaussian filter with a random kernel size.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image.
|
| 155 |
+
Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
|
| 156 |
+
as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
|
| 157 |
+
If set single value `blur_limit` will be in range (0, blur_limit).
|
| 158 |
+
Default: (3, 7).
|
| 159 |
+
sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf).
|
| 160 |
+
If set single value `sigma_limit` will be in range (0, sigma_limit).
|
| 161 |
+
If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
|
| 162 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 163 |
+
|
| 164 |
+
Targets:
|
| 165 |
+
image
|
| 166 |
+
|
| 167 |
+
Image types:
|
| 168 |
+
uint8, float32
|
| 169 |
+
"""
|
| 170 |
+
|
| 171 |
+
def __init__(
|
| 172 |
+
self,
|
| 173 |
+
blur_limit: ScaleIntType = (3, 7),
|
| 174 |
+
sigma_limit: ScaleFloatType = 0,
|
| 175 |
+
always_apply: bool = False,
|
| 176 |
+
p: float = 0.5,
|
| 177 |
+
):
|
| 178 |
+
super().__init__(always_apply, p)
|
| 179 |
+
self.blur_limit = to_tuple(blur_limit, 0)
|
| 180 |
+
self.sigma_limit = to_tuple(sigma_limit if sigma_limit is not None else 0, 0)
|
| 181 |
+
|
| 182 |
+
if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0:
|
| 183 |
+
self.blur_limit = 3, max(3, self.blur_limit[1])
|
| 184 |
+
warnings.warn(
|
| 185 |
+
"blur_limit and sigma_limit minimum value can not be both equal to 0. "
|
| 186 |
+
"blur_limit minimum value changed to 3."
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
|
| 190 |
+
self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
|
| 191 |
+
):
|
| 192 |
+
raise ValueError("GaussianBlur supports only odd blur limits.")
|
| 193 |
+
|
| 194 |
+
def apply(self, img: np.ndarray, ksize: int = 3, sigma: float = 0, **params) -> np.ndarray:
|
| 195 |
+
return F.gaussian_blur(img, ksize, sigma=sigma)
|
| 196 |
+
|
| 197 |
+
def get_params(self) -> Dict[str, float]:
|
| 198 |
+
ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1)
|
| 199 |
+
if ksize != 0 and ksize % 2 != 1:
|
| 200 |
+
ksize = (ksize + 1) % (self.blur_limit[1] + 1)
|
| 201 |
+
|
| 202 |
+
return {"ksize": ksize, "sigma": random.uniform(*self.sigma_limit)}
|
| 203 |
+
|
| 204 |
+
def get_transform_init_args_names(self) -> Tuple[str, str]:
|
| 205 |
+
return ("blur_limit", "sigma_limit")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
class GlassBlur(Blur):
|
| 209 |
+
"""Apply glass noise to the input image.
|
| 210 |
+
|
| 211 |
+
Args:
|
| 212 |
+
sigma (float): standard deviation for Gaussian kernel.
|
| 213 |
+
max_delta (int): max distance between pixels which are swapped.
|
| 214 |
+
iterations (int): number of repeats.
|
| 215 |
+
Should be in range [1, inf). Default: (2).
|
| 216 |
+
mode (str): mode of computation: fast or exact. Default: "fast".
|
| 217 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 218 |
+
|
| 219 |
+
Targets:
|
| 220 |
+
image
|
| 221 |
+
|
| 222 |
+
Image types:
|
| 223 |
+
uint8, float32
|
| 224 |
+
|
| 225 |
+
Reference:
|
| 226 |
+
| https://arxiv.org/abs/1903.12261
|
| 227 |
+
| https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(
|
| 231 |
+
self,
|
| 232 |
+
sigma: float = 0.7,
|
| 233 |
+
max_delta: int = 4,
|
| 234 |
+
iterations: int = 2,
|
| 235 |
+
always_apply: bool = False,
|
| 236 |
+
mode: str = "fast",
|
| 237 |
+
p: float = 0.5,
|
| 238 |
+
):
|
| 239 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 240 |
+
if iterations < 1:
|
| 241 |
+
raise ValueError(f"Iterations should be more or equal to 1, but we got {iterations}")
|
| 242 |
+
|
| 243 |
+
if mode not in ["fast", "exact"]:
|
| 244 |
+
raise ValueError(f"Mode should be 'fast' or 'exact', but we got {mode}")
|
| 245 |
+
|
| 246 |
+
self.sigma = sigma
|
| 247 |
+
self.max_delta = max_delta
|
| 248 |
+
self.iterations = iterations
|
| 249 |
+
self.mode = mode
|
| 250 |
+
|
| 251 |
+
def apply(self, img: np.ndarray, dxy: np.ndarray = None, **params) -> np.ndarray: # type: ignore
|
| 252 |
+
assert dxy is not None
|
| 253 |
+
return F.glass_blur(img, self.sigma, self.max_delta, self.iterations, dxy, self.mode)
|
| 254 |
+
|
| 255 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, np.ndarray]:
|
| 256 |
+
img = params["image"]
|
| 257 |
+
|
| 258 |
+
# generate array containing all necessary values for transformations
|
| 259 |
+
width_pixels = img.shape[0] - self.max_delta * 2
|
| 260 |
+
height_pixels = img.shape[1] - self.max_delta * 2
|
| 261 |
+
total_pixels = width_pixels * height_pixels
|
| 262 |
+
dxy = random_utils.randint(-self.max_delta, self.max_delta, size=(total_pixels, self.iterations, 2))
|
| 263 |
+
|
| 264 |
+
return {"dxy": dxy}
|
| 265 |
+
|
| 266 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str]:
|
| 267 |
+
return ("sigma", "max_delta", "iterations")
|
| 268 |
+
|
| 269 |
+
@property
|
| 270 |
+
def targets_as_params(self) -> List[str]:
|
| 271 |
+
return ["image"]
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class AdvancedBlur(ImageOnlyTransform):
|
| 275 |
+
"""Blur the input image using a Generalized Normal filter with a randomly selected parameters.
|
| 276 |
+
This transform also adds multiplicative noise to generated kernel before convolution.
|
| 277 |
+
|
| 278 |
+
Args:
|
| 279 |
+
blur_limit: maximum Gaussian kernel size for blurring the input image.
|
| 280 |
+
Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
|
| 281 |
+
as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
|
| 282 |
+
If set single value `blur_limit` will be in range (0, blur_limit).
|
| 283 |
+
Default: (3, 7).
|
| 284 |
+
sigmaX_limit: Gaussian kernel standard deviation. Must be in range [0, inf).
|
| 285 |
+
If set single value `sigmaX_limit` will be in range (0, sigma_limit).
|
| 286 |
+
If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
|
| 287 |
+
sigmaY_limit: Same as `sigmaY_limit` for another dimension.
|
| 288 |
+
rotate_limit: Range from which a random angle used to rotate Gaussian kernel is picked.
|
| 289 |
+
If limit is a single int an angle is picked from (-rotate_limit, rotate_limit). Default: (-90, 90).
|
| 290 |
+
beta_limit: Distribution shape parameter, 1 is the normal distribution. Values below 1.0 make distribution
|
| 291 |
+
tails heavier than normal, values above 1.0 make it lighter than normal. Default: (0.5, 8.0).
|
| 292 |
+
noise_limit: Multiplicative factor that control strength of kernel noise. Must be positive and preferably
|
| 293 |
+
centered around 1.0. If set single value `noise_limit` will be in range (0, noise_limit).
|
| 294 |
+
Default: (0.75, 1.25).
|
| 295 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 296 |
+
|
| 297 |
+
Reference:
|
| 298 |
+
https://arxiv.org/abs/2107.10833
|
| 299 |
+
|
| 300 |
+
Targets:
|
| 301 |
+
image
|
| 302 |
+
Image types:
|
| 303 |
+
uint8, float32
|
| 304 |
+
"""
|
| 305 |
+
|
| 306 |
+
def __init__(
|
| 307 |
+
self,
|
| 308 |
+
blur_limit: ScaleIntType = (3, 7),
|
| 309 |
+
sigmaX_limit: ScaleFloatType = (0.2, 1.0),
|
| 310 |
+
sigmaY_limit: ScaleFloatType = (0.2, 1.0),
|
| 311 |
+
rotate_limit: ScaleIntType = 90,
|
| 312 |
+
beta_limit: ScaleFloatType = (0.5, 8.0),
|
| 313 |
+
noise_limit: ScaleFloatType = (0.9, 1.1),
|
| 314 |
+
always_apply: bool = False,
|
| 315 |
+
p: float = 0.5,
|
| 316 |
+
):
|
| 317 |
+
super().__init__(always_apply, p)
|
| 318 |
+
self.blur_limit = to_tuple(blur_limit, 3)
|
| 319 |
+
self.sigmaX_limit = self.__check_values(to_tuple(sigmaX_limit, 0.0), name="sigmaX_limit")
|
| 320 |
+
self.sigmaY_limit = self.__check_values(to_tuple(sigmaY_limit, 0.0), name="sigmaY_limit")
|
| 321 |
+
self.rotate_limit = to_tuple(rotate_limit)
|
| 322 |
+
self.beta_limit = to_tuple(beta_limit, low=0.0)
|
| 323 |
+
self.noise_limit = self.__check_values(to_tuple(noise_limit, 0.0), name="noise_limit")
|
| 324 |
+
|
| 325 |
+
if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
|
| 326 |
+
self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
|
| 327 |
+
):
|
| 328 |
+
raise ValueError("AdvancedBlur supports only odd blur limits.")
|
| 329 |
+
|
| 330 |
+
if self.sigmaX_limit[0] == 0 and self.sigmaY_limit[0] == 0:
|
| 331 |
+
raise ValueError("sigmaX_limit and sigmaY_limit minimum value can not be both equal to 0.")
|
| 332 |
+
|
| 333 |
+
if not (self.beta_limit[0] < 1.0 < self.beta_limit[1]):
|
| 334 |
+
raise ValueError("Beta limit is expected to include 1.0")
|
| 335 |
+
|
| 336 |
+
@staticmethod
|
| 337 |
+
def __check_values(
|
| 338 |
+
value: Sequence[float], name: str, bounds: Tuple[float, float] = (0, float("inf"))
|
| 339 |
+
) -> Sequence[float]:
|
| 340 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 341 |
+
raise ValueError(f"{name} values should be between {bounds}")
|
| 342 |
+
return value
|
| 343 |
+
|
| 344 |
+
def apply(self, img: np.ndarray, kernel: np.ndarray = np.array(None), **params) -> np.ndarray:
|
| 345 |
+
return FMain.convolve(img, kernel=kernel)
|
| 346 |
+
|
| 347 |
+
def get_params(self) -> Dict[str, np.ndarray]:
|
| 348 |
+
ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2)
|
| 349 |
+
sigmaX = random.uniform(*self.sigmaX_limit)
|
| 350 |
+
sigmaY = random.uniform(*self.sigmaY_limit)
|
| 351 |
+
angle = np.deg2rad(random.uniform(*self.rotate_limit))
|
| 352 |
+
|
| 353 |
+
# Split into 2 cases to avoid selection of narrow kernels (beta > 1) too often.
|
| 354 |
+
if random.random() < 0.5:
|
| 355 |
+
beta = random.uniform(self.beta_limit[0], 1)
|
| 356 |
+
else:
|
| 357 |
+
beta = random.uniform(1, self.beta_limit[1])
|
| 358 |
+
|
| 359 |
+
noise_matrix = random_utils.uniform(self.noise_limit[0], self.noise_limit[1], size=[ksize, ksize])
|
| 360 |
+
|
| 361 |
+
# Generate mesh grid centered at zero.
|
| 362 |
+
ax = np.arange(-ksize // 2 + 1.0, ksize // 2 + 1.0)
|
| 363 |
+
# Shape (ksize, ksize, 2)
|
| 364 |
+
grid = np.stack(np.meshgrid(ax, ax), axis=-1)
|
| 365 |
+
|
| 366 |
+
# Calculate rotated sigma matrix
|
| 367 |
+
d_matrix = np.array([[sigmaX**2, 0], [0, sigmaY**2]])
|
| 368 |
+
u_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
|
| 369 |
+
sigma_matrix = np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
|
| 370 |
+
|
| 371 |
+
inverse_sigma = np.linalg.inv(sigma_matrix)
|
| 372 |
+
# Described in "Parameter Estimation For Multivariate Generalized Gaussian Distributions"
|
| 373 |
+
kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
|
| 374 |
+
# Add noise
|
| 375 |
+
kernel = kernel * noise_matrix
|
| 376 |
+
|
| 377 |
+
# Normalize kernel
|
| 378 |
+
kernel = kernel.astype(np.float32) / np.sum(kernel)
|
| 379 |
+
return {"kernel": kernel}
|
| 380 |
+
|
| 381 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]:
|
| 382 |
+
return (
|
| 383 |
+
"blur_limit",
|
| 384 |
+
"sigmaX_limit",
|
| 385 |
+
"sigmaY_limit",
|
| 386 |
+
"rotate_limit",
|
| 387 |
+
"beta_limit",
|
| 388 |
+
"noise_limit",
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class Defocus(ImageOnlyTransform):
|
| 393 |
+
"""
|
| 394 |
+
Apply defocus transform. See https://arxiv.org/abs/1903.12261.
|
| 395 |
+
|
| 396 |
+
Args:
|
| 397 |
+
radius ((int, int) or int): range for radius of defocusing.
|
| 398 |
+
If limit is a single int, the range will be [1, limit]. Default: (3, 10).
|
| 399 |
+
alias_blur ((float, float) or float): range for alias_blur of defocusing (sigma of gaussian blur).
|
| 400 |
+
If limit is a single float, the range will be (0, limit). Default: (0.1, 0.5).
|
| 401 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 402 |
+
|
| 403 |
+
Targets:
|
| 404 |
+
image
|
| 405 |
+
|
| 406 |
+
Image types:
|
| 407 |
+
Any
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
def __init__(
|
| 411 |
+
self,
|
| 412 |
+
radius: ScaleIntType = (3, 10),
|
| 413 |
+
alias_blur: ScaleFloatType = (0.1, 0.5),
|
| 414 |
+
always_apply: bool = False,
|
| 415 |
+
p: float = 0.5,
|
| 416 |
+
):
|
| 417 |
+
super().__init__(always_apply, p)
|
| 418 |
+
self.radius = to_tuple(radius, low=1)
|
| 419 |
+
self.alias_blur = to_tuple(alias_blur, low=0)
|
| 420 |
+
|
| 421 |
+
if self.radius[0] <= 0:
|
| 422 |
+
raise ValueError("Parameter radius must be positive")
|
| 423 |
+
|
| 424 |
+
if self.alias_blur[0] < 0:
|
| 425 |
+
raise ValueError("Parameter alias_blur must be non-negative")
|
| 426 |
+
|
| 427 |
+
def apply(self, img: np.ndarray, radius: int = 3, alias_blur: float = 0.5, **params) -> np.ndarray:
|
| 428 |
+
return F.defocus(img, radius, alias_blur)
|
| 429 |
+
|
| 430 |
+
def get_params(self) -> Dict[str, Any]:
|
| 431 |
+
return {
|
| 432 |
+
"radius": random_utils.randint(self.radius[0], self.radius[1] + 1),
|
| 433 |
+
"alias_blur": random_utils.uniform(self.alias_blur[0], self.alias_blur[1]),
|
| 434 |
+
}
|
| 435 |
+
|
| 436 |
+
def get_transform_init_args_names(self) -> Tuple[str, str]:
|
| 437 |
+
return ("radius", "alias_blur")
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class ZoomBlur(ImageOnlyTransform):
|
| 441 |
+
"""
|
| 442 |
+
Apply zoom blur transform. See https://arxiv.org/abs/1903.12261.
|
| 443 |
+
|
| 444 |
+
Args:
|
| 445 |
+
max_factor ((float, float) or float): range for max factor for blurring.
|
| 446 |
+
If max_factor is a single float, the range will be (1, limit). Default: (1, 1.31).
|
| 447 |
+
All max_factor values should be larger than 1.
|
| 448 |
+
step_factor ((float, float) or float): If single float will be used as step parameter for np.arange.
|
| 449 |
+
If tuple of float step_factor will be in range `[step_factor[0], step_factor[1])`. Default: (0.01, 0.03).
|
| 450 |
+
All step_factor values should be positive.
|
| 451 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 452 |
+
|
| 453 |
+
Targets:
|
| 454 |
+
image
|
| 455 |
+
|
| 456 |
+
Image types:
|
| 457 |
+
Any
|
| 458 |
+
"""
|
| 459 |
+
|
| 460 |
+
def __init__(
|
| 461 |
+
self,
|
| 462 |
+
max_factor: ScaleFloatType = 1.31,
|
| 463 |
+
step_factor: ScaleFloatType = (0.01, 0.03),
|
| 464 |
+
always_apply: bool = False,
|
| 465 |
+
p: float = 0.5,
|
| 466 |
+
):
|
| 467 |
+
super().__init__(always_apply, p)
|
| 468 |
+
self.max_factor = to_tuple(max_factor, low=1.0)
|
| 469 |
+
self.step_factor = to_tuple(step_factor, step_factor)
|
| 470 |
+
|
| 471 |
+
if self.max_factor[0] < 1:
|
| 472 |
+
raise ValueError("Max factor must be larger or equal 1")
|
| 473 |
+
if self.step_factor[0] <= 0:
|
| 474 |
+
raise ValueError("Step factor must be positive")
|
| 475 |
+
|
| 476 |
+
def apply(self, img: np.ndarray, zoom_factors: np.ndarray = np.array(None), **params) -> np.ndarray:
|
| 477 |
+
assert zoom_factors is not None
|
| 478 |
+
return F.zoom_blur(img, zoom_factors)
|
| 479 |
+
|
| 480 |
+
def get_params(self) -> Dict[str, Any]:
|
| 481 |
+
max_factor = random.uniform(self.max_factor[0], self.max_factor[1])
|
| 482 |
+
step_factor = random.uniform(self.step_factor[0], self.step_factor[1])
|
| 483 |
+
return {"zoom_factors": np.arange(1.0, max_factor, step_factor)}
|
| 484 |
+
|
| 485 |
+
def get_transform_init_args_names(self) -> Tuple[str, str]:
|
| 486 |
+
return ("max_factor", "step_factor")
|
custom_albumentations/augmentations/crops/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functional import *
|
| 2 |
+
from .transforms import *
|
custom_albumentations/augmentations/crops/functional.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, Sequence, Tuple
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from custom_albumentations.augmentations.utils import (
|
| 7 |
+
_maybe_process_in_chunks,
|
| 8 |
+
preserve_channel_dim,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
from ...core.bbox_utils import denormalize_bbox, normalize_bbox
|
| 12 |
+
from ...core.transforms_interface import BoxInternalType, KeypointInternalType
|
| 13 |
+
from ..geometric import functional as FGeometric
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"get_random_crop_coords",
|
| 17 |
+
"random_crop",
|
| 18 |
+
"crop_bbox_by_coords",
|
| 19 |
+
"bbox_random_crop",
|
| 20 |
+
"crop_keypoint_by_coords",
|
| 21 |
+
"keypoint_random_crop",
|
| 22 |
+
"get_center_crop_coords",
|
| 23 |
+
"center_crop",
|
| 24 |
+
"bbox_center_crop",
|
| 25 |
+
"keypoint_center_crop",
|
| 26 |
+
"crop",
|
| 27 |
+
"bbox_crop",
|
| 28 |
+
"clamping_crop",
|
| 29 |
+
"crop_and_pad",
|
| 30 |
+
"crop_and_pad_bbox",
|
| 31 |
+
"crop_and_pad_keypoint",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, w_start: float):
|
| 36 |
+
# h_start is [0, 1) and should map to [0, (height - crop_height)] (note inclusive)
|
| 37 |
+
# This is conceptually equivalent to mapping onto `range(0, (height - crop_height + 1))`
|
| 38 |
+
# See: https://github.com/albumentations-team/albumentations/pull/1080
|
| 39 |
+
y1 = int((height - crop_height + 1) * h_start)
|
| 40 |
+
y2 = y1 + crop_height
|
| 41 |
+
x1 = int((width - crop_width + 1) * w_start)
|
| 42 |
+
x2 = x1 + crop_width
|
| 43 |
+
return x1, y1, x2, y2
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float):
|
| 47 |
+
height, width = img.shape[:2]
|
| 48 |
+
if height < crop_height or width < crop_width:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"Requested crop size ({crop_height}, {crop_width}) is "
|
| 51 |
+
"larger than the image size ({height}, {width})".format(
|
| 52 |
+
crop_height=crop_height, crop_width=crop_width, height=height, width=width
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start)
|
| 56 |
+
img = img[y1:y2, x1:x2]
|
| 57 |
+
return img
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def crop_bbox_by_coords(
|
| 61 |
+
bbox: BoxInternalType,
|
| 62 |
+
crop_coords: Tuple[int, int, int, int],
|
| 63 |
+
crop_height: int,
|
| 64 |
+
crop_width: int,
|
| 65 |
+
rows: int,
|
| 66 |
+
cols: int,
|
| 67 |
+
):
|
| 68 |
+
"""Crop a bounding box using the provided coordinates of bottom-left and top-right corners in pixels and the
|
| 69 |
+
required height and width of the crop.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
bbox (tuple): A cropped box `(x_min, y_min, x_max, y_max)`.
|
| 73 |
+
crop_coords (tuple): Crop coordinates `(x1, y1, x2, y2)`.
|
| 74 |
+
crop_height (int):
|
| 75 |
+
crop_width (int):
|
| 76 |
+
rows (int): Image rows.
|
| 77 |
+
cols (int): Image cols.
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`.
|
| 81 |
+
|
| 82 |
+
"""
|
| 83 |
+
bbox = denormalize_bbox(bbox, rows, cols)
|
| 84 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 85 |
+
x1, y1, _, _ = crop_coords
|
| 86 |
+
cropped_bbox = x_min - x1, y_min - y1, x_max - x1, y_max - y1
|
| 87 |
+
return normalize_bbox(cropped_bbox, crop_height, crop_width)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def bbox_random_crop(
|
| 91 |
+
bbox: BoxInternalType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
|
| 92 |
+
):
|
| 93 |
+
crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start)
|
| 94 |
+
return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def crop_keypoint_by_coords(
|
| 98 |
+
keypoint: KeypointInternalType, crop_coords: Tuple[int, int, int, int]
|
| 99 |
+
): # skipcq: PYL-W0613
|
| 100 |
+
"""Crop a keypoint using the provided coordinates of bottom-left and top-right corners in pixels and the
|
| 101 |
+
required height and width of the crop.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
keypoint (tuple): A keypoint `(x, y, angle, scale)`.
|
| 105 |
+
crop_coords (tuple): Crop box coords `(x1, x2, y1, y2)`.
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
A keypoint `(x, y, angle, scale)`.
|
| 109 |
+
|
| 110 |
+
"""
|
| 111 |
+
x, y, angle, scale = keypoint[:4]
|
| 112 |
+
x1, y1, _, _ = crop_coords
|
| 113 |
+
return x - x1, y - y1, angle, scale
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def keypoint_random_crop(
|
| 117 |
+
keypoint: KeypointInternalType,
|
| 118 |
+
crop_height: int,
|
| 119 |
+
crop_width: int,
|
| 120 |
+
h_start: float,
|
| 121 |
+
w_start: float,
|
| 122 |
+
rows: int,
|
| 123 |
+
cols: int,
|
| 124 |
+
):
|
| 125 |
+
"""Keypoint random crop.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
keypoint: (tuple): A keypoint `(x, y, angle, scale)`.
|
| 129 |
+
crop_height (int): Crop height.
|
| 130 |
+
crop_width (int): Crop width.
|
| 131 |
+
h_start (int): Crop height start.
|
| 132 |
+
w_start (int): Crop width start.
|
| 133 |
+
rows (int): Image height.
|
| 134 |
+
cols (int): Image width.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
A keypoint `(x, y, angle, scale)`.
|
| 138 |
+
|
| 139 |
+
"""
|
| 140 |
+
crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start)
|
| 141 |
+
return crop_keypoint_by_coords(keypoint, crop_coords)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
|
| 145 |
+
y1 = (height - crop_height) // 2
|
| 146 |
+
y2 = y1 + crop_height
|
| 147 |
+
x1 = (width - crop_width) // 2
|
| 148 |
+
x2 = x1 + crop_width
|
| 149 |
+
return x1, y1, x2, y2
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
|
| 153 |
+
height, width = img.shape[:2]
|
| 154 |
+
if height < crop_height or width < crop_width:
|
| 155 |
+
raise ValueError(
|
| 156 |
+
"Requested crop size ({crop_height}, {crop_width}) is "
|
| 157 |
+
"larger than the image size ({height}, {width})".format(
|
| 158 |
+
crop_height=crop_height, crop_width=crop_width, height=height, width=width
|
| 159 |
+
)
|
| 160 |
+
)
|
| 161 |
+
x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
|
| 162 |
+
img = img[y1:y2, x1:x2]
|
| 163 |
+
return img
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def bbox_center_crop(bbox: BoxInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
|
| 167 |
+
crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
|
| 168 |
+
return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def keypoint_center_crop(keypoint: KeypointInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
|
| 172 |
+
"""Keypoint center crop.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
keypoint (tuple): A keypoint `(x, y, angle, scale)`.
|
| 176 |
+
crop_height (int): Crop height.
|
| 177 |
+
crop_width (int): Crop width.
|
| 178 |
+
rows (int): Image height.
|
| 179 |
+
cols (int): Image width.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
tuple: A keypoint `(x, y, angle, scale)`.
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
|
| 186 |
+
return crop_keypoint_by_coords(keypoint, crop_coords)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int):
|
| 190 |
+
height, width = img.shape[:2]
|
| 191 |
+
if x_max <= x_min or y_max <= y_min:
|
| 192 |
+
raise ValueError(
|
| 193 |
+
"We should have x_min < x_max and y_min < y_max. But we got"
|
| 194 |
+
" (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})".format(
|
| 195 |
+
x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max
|
| 196 |
+
)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
if x_min < 0 or x_max > width or y_min < 0 or y_max > height:
|
| 200 |
+
raise ValueError(
|
| 201 |
+
"Values for crop should be non negative and equal or smaller than image sizes"
|
| 202 |
+
"(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}, "
|
| 203 |
+
"height = {height}, width = {width})".format(
|
| 204 |
+
x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, height=height, width=width
|
| 205 |
+
)
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return img[y_min:y_max, x_min:x_max]
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def bbox_crop(bbox: BoxInternalType, x_min: int, y_min: int, x_max: int, y_max: int, rows: int, cols: int):
|
| 212 |
+
"""Crop a bounding box.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 216 |
+
x_min (int):
|
| 217 |
+
y_min (int):
|
| 218 |
+
x_max (int):
|
| 219 |
+
y_max (int):
|
| 220 |
+
rows (int): Image rows.
|
| 221 |
+
cols (int): Image cols.
|
| 222 |
+
|
| 223 |
+
Returns:
|
| 224 |
+
tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`.
|
| 225 |
+
|
| 226 |
+
"""
|
| 227 |
+
crop_coords = x_min, y_min, x_max, y_max
|
| 228 |
+
crop_height = y_max - y_min
|
| 229 |
+
crop_width = x_max - x_min
|
| 230 |
+
return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def clamping_crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int):
|
| 234 |
+
h, w = img.shape[:2]
|
| 235 |
+
if x_min < 0:
|
| 236 |
+
x_min = 0
|
| 237 |
+
if y_min < 0:
|
| 238 |
+
y_min = 0
|
| 239 |
+
if y_max >= h:
|
| 240 |
+
y_max = h - 1
|
| 241 |
+
if x_max >= w:
|
| 242 |
+
x_max = w - 1
|
| 243 |
+
return img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
@preserve_channel_dim
|
| 247 |
+
def crop_and_pad(
|
| 248 |
+
img: np.ndarray,
|
| 249 |
+
crop_params: Optional[Sequence[int]],
|
| 250 |
+
pad_params: Optional[Sequence[int]],
|
| 251 |
+
pad_value: Optional[float],
|
| 252 |
+
rows: int,
|
| 253 |
+
cols: int,
|
| 254 |
+
interpolation: int,
|
| 255 |
+
pad_mode: int,
|
| 256 |
+
keep_size: bool,
|
| 257 |
+
) -> np.ndarray:
|
| 258 |
+
if crop_params is not None and any(i != 0 for i in crop_params):
|
| 259 |
+
img = crop(img, *crop_params)
|
| 260 |
+
if pad_params is not None and any(i != 0 for i in pad_params):
|
| 261 |
+
img = FGeometric.pad_with_params(
|
| 262 |
+
img, pad_params[0], pad_params[1], pad_params[2], pad_params[3], border_mode=pad_mode, value=pad_value
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
if keep_size:
|
| 266 |
+
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(cols, rows), interpolation=interpolation)
|
| 267 |
+
img = resize_fn(img)
|
| 268 |
+
|
| 269 |
+
return img
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def crop_and_pad_bbox(
|
| 273 |
+
bbox: BoxInternalType,
|
| 274 |
+
crop_params: Optional[Sequence[int]],
|
| 275 |
+
pad_params: Optional[Sequence[int]],
|
| 276 |
+
rows,
|
| 277 |
+
cols,
|
| 278 |
+
result_rows,
|
| 279 |
+
result_cols,
|
| 280 |
+
) -> BoxInternalType:
|
| 281 |
+
x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4]
|
| 282 |
+
|
| 283 |
+
if crop_params is not None:
|
| 284 |
+
crop_x, crop_y = crop_params[:2]
|
| 285 |
+
x1, y1, x2, y2 = x1 - crop_x, y1 - crop_y, x2 - crop_x, y2 - crop_y
|
| 286 |
+
if pad_params is not None:
|
| 287 |
+
top, bottom, left, right = pad_params
|
| 288 |
+
x1, y1, x2, y2 = x1 + left, y1 + top, x2 + left, y2 + top
|
| 289 |
+
|
| 290 |
+
return normalize_bbox((x1, y1, x2, y2), result_rows, result_cols)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def crop_and_pad_keypoint(
|
| 294 |
+
keypoint: KeypointInternalType,
|
| 295 |
+
crop_params: Optional[Sequence[int]],
|
| 296 |
+
pad_params: Optional[Sequence[int]],
|
| 297 |
+
rows: int,
|
| 298 |
+
cols: int,
|
| 299 |
+
result_rows: int,
|
| 300 |
+
result_cols: int,
|
| 301 |
+
keep_size: bool,
|
| 302 |
+
) -> KeypointInternalType:
|
| 303 |
+
x, y, angle, scale = keypoint[:4]
|
| 304 |
+
|
| 305 |
+
if crop_params is not None:
|
| 306 |
+
crop_x1, crop_y1, crop_x2, crop_y2 = crop_params
|
| 307 |
+
x, y = x - crop_x1, y - crop_y1
|
| 308 |
+
if pad_params is not None:
|
| 309 |
+
top, bottom, left, right = pad_params
|
| 310 |
+
x, y = x + left, y + top
|
| 311 |
+
|
| 312 |
+
if keep_size and (result_cols != cols or result_rows != rows):
|
| 313 |
+
scale_x = cols / result_cols
|
| 314 |
+
scale_y = rows / result_rows
|
| 315 |
+
return FGeometric.keypoint_scale((x, y, angle, scale), scale_x, scale_y)
|
| 316 |
+
|
| 317 |
+
return x, y, angle, scale
|
custom_albumentations/augmentations/crops/transforms.py
ADDED
|
@@ -0,0 +1,943 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from custom_albumentations.core.bbox_utils import union_of_bboxes
|
| 9 |
+
|
| 10 |
+
from ...core.transforms_interface import (
|
| 11 |
+
BoxInternalType,
|
| 12 |
+
DualTransform,
|
| 13 |
+
KeypointInternalType,
|
| 14 |
+
to_tuple,
|
| 15 |
+
)
|
| 16 |
+
from ..geometric import functional as FGeometric
|
| 17 |
+
from . import functional as F
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"RandomCrop",
|
| 21 |
+
"CenterCrop",
|
| 22 |
+
"Crop",
|
| 23 |
+
"CropNonEmptyMaskIfExists",
|
| 24 |
+
"RandomSizedCrop",
|
| 25 |
+
"RandomResizedCrop",
|
| 26 |
+
"RandomCropNearBBox",
|
| 27 |
+
"RandomSizedBBoxSafeCrop",
|
| 28 |
+
"CropAndPad",
|
| 29 |
+
"RandomCropFromBorders",
|
| 30 |
+
"BBoxSafeRandomCrop",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class RandomCrop(DualTransform):
|
| 35 |
+
"""Crop a random part of the input.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
height (int): height of the crop.
|
| 39 |
+
width (int): width of the crop.
|
| 40 |
+
p (float): probability of applying the transform. Default: 1.
|
| 41 |
+
|
| 42 |
+
Targets:
|
| 43 |
+
image, mask, bboxes, keypoints
|
| 44 |
+
|
| 45 |
+
Image types:
|
| 46 |
+
uint8, float32
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
def __init__(self, height, width, always_apply=False, p=1.0):
|
| 50 |
+
super().__init__(always_apply, p)
|
| 51 |
+
self.height = height
|
| 52 |
+
self.width = width
|
| 53 |
+
|
| 54 |
+
def apply(self, img, h_start=0, w_start=0, **params):
|
| 55 |
+
return F.random_crop(img, self.height, self.width, h_start, w_start)
|
| 56 |
+
|
| 57 |
+
def get_params(self):
|
| 58 |
+
return {"h_start": random.random(), "w_start": random.random()}
|
| 59 |
+
|
| 60 |
+
def apply_to_bbox(self, bbox, **params):
|
| 61 |
+
return F.bbox_random_crop(bbox, self.height, self.width, **params)
|
| 62 |
+
|
| 63 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 64 |
+
return F.keypoint_random_crop(keypoint, self.height, self.width, **params)
|
| 65 |
+
|
| 66 |
+
def get_transform_init_args_names(self):
|
| 67 |
+
return ("height", "width")
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class CenterCrop(DualTransform):
|
| 71 |
+
"""Crop the central part of the input.
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
height (int): height of the crop.
|
| 75 |
+
width (int): width of the crop.
|
| 76 |
+
p (float): probability of applying the transform. Default: 1.
|
| 77 |
+
|
| 78 |
+
Targets:
|
| 79 |
+
image, mask, bboxes, keypoints
|
| 80 |
+
|
| 81 |
+
Image types:
|
| 82 |
+
uint8, float32
|
| 83 |
+
|
| 84 |
+
Note:
|
| 85 |
+
It is recommended to use uint8 images as input.
|
| 86 |
+
Otherwise the operation will require internal conversion
|
| 87 |
+
float32 -> uint8 -> float32 that causes worse performance.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def __init__(self, height, width, always_apply=False, p=1.0):
|
| 91 |
+
super(CenterCrop, self).__init__(always_apply, p)
|
| 92 |
+
self.height = height
|
| 93 |
+
self.width = width
|
| 94 |
+
|
| 95 |
+
def apply(self, img, **params):
|
| 96 |
+
return F.center_crop(img, self.height, self.width)
|
| 97 |
+
|
| 98 |
+
def apply_to_bbox(self, bbox, **params):
|
| 99 |
+
return F.bbox_center_crop(bbox, self.height, self.width, **params)
|
| 100 |
+
|
| 101 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 102 |
+
return F.keypoint_center_crop(keypoint, self.height, self.width, **params)
|
| 103 |
+
|
| 104 |
+
def get_transform_init_args_names(self):
|
| 105 |
+
return ("height", "width")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class Crop(DualTransform):
|
| 109 |
+
"""Crop region from image.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
x_min (int): Minimum upper left x coordinate.
|
| 113 |
+
y_min (int): Minimum upper left y coordinate.
|
| 114 |
+
x_max (int): Maximum lower right x coordinate.
|
| 115 |
+
y_max (int): Maximum lower right y coordinate.
|
| 116 |
+
|
| 117 |
+
Targets:
|
| 118 |
+
image, mask, bboxes, keypoints
|
| 119 |
+
|
| 120 |
+
Image types:
|
| 121 |
+
uint8, float32
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(self, x_min=0, y_min=0, x_max=1024, y_max=1024, always_apply=False, p=1.0):
|
| 125 |
+
super(Crop, self).__init__(always_apply, p)
|
| 126 |
+
self.x_min = x_min
|
| 127 |
+
self.y_min = y_min
|
| 128 |
+
self.x_max = x_max
|
| 129 |
+
self.y_max = y_max
|
| 130 |
+
|
| 131 |
+
def apply(self, img, **params):
|
| 132 |
+
return F.crop(img, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
|
| 133 |
+
|
| 134 |
+
def apply_to_bbox(self, bbox, **params):
|
| 135 |
+
return F.bbox_crop(bbox, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max, **params)
|
| 136 |
+
|
| 137 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 138 |
+
return F.crop_keypoint_by_coords(keypoint, crop_coords=(self.x_min, self.y_min, self.x_max, self.y_max))
|
| 139 |
+
|
| 140 |
+
def get_transform_init_args_names(self):
|
| 141 |
+
return ("x_min", "y_min", "x_max", "y_max")
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class CropNonEmptyMaskIfExists(DualTransform):
|
| 145 |
+
"""Crop area with mask if mask is non-empty, else make random crop.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
height (int): vertical size of crop in pixels
|
| 149 |
+
width (int): horizontal size of crop in pixels
|
| 150 |
+
ignore_values (list of int): values to ignore in mask, `0` values are always ignored
|
| 151 |
+
(e.g. if background value is 5 set `ignore_values=[5]` to ignore)
|
| 152 |
+
ignore_channels (list of int): channels to ignore in mask
|
| 153 |
+
(e.g. if background is a first channel set `ignore_channels=[0]` to ignore)
|
| 154 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 155 |
+
|
| 156 |
+
Targets:
|
| 157 |
+
image, mask, bboxes, keypoints
|
| 158 |
+
|
| 159 |
+
Image types:
|
| 160 |
+
uint8, float32
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
def __init__(self, height, width, ignore_values=None, ignore_channels=None, always_apply=False, p=1.0):
|
| 164 |
+
super(CropNonEmptyMaskIfExists, self).__init__(always_apply, p)
|
| 165 |
+
|
| 166 |
+
if ignore_values is not None and not isinstance(ignore_values, list):
|
| 167 |
+
raise ValueError("Expected `ignore_values` of type `list`, got `{}`".format(type(ignore_values)))
|
| 168 |
+
if ignore_channels is not None and not isinstance(ignore_channels, list):
|
| 169 |
+
raise ValueError("Expected `ignore_channels` of type `list`, got `{}`".format(type(ignore_channels)))
|
| 170 |
+
|
| 171 |
+
self.height = height
|
| 172 |
+
self.width = width
|
| 173 |
+
self.ignore_values = ignore_values
|
| 174 |
+
self.ignore_channels = ignore_channels
|
| 175 |
+
|
| 176 |
+
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 177 |
+
return F.crop(img, x_min, y_min, x_max, y_max)
|
| 178 |
+
|
| 179 |
+
def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 180 |
+
return F.bbox_crop(
|
| 181 |
+
bbox, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, rows=params["rows"], cols=params["cols"]
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 185 |
+
return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
|
| 186 |
+
|
| 187 |
+
def _preprocess_mask(self, mask):
|
| 188 |
+
mask_height, mask_width = mask.shape[:2]
|
| 189 |
+
|
| 190 |
+
if self.ignore_values is not None:
|
| 191 |
+
ignore_values_np = np.array(self.ignore_values)
|
| 192 |
+
mask = np.where(np.isin(mask, ignore_values_np), 0, mask)
|
| 193 |
+
|
| 194 |
+
if mask.ndim == 3 and self.ignore_channels is not None:
|
| 195 |
+
target_channels = np.array([ch for ch in range(mask.shape[-1]) if ch not in self.ignore_channels])
|
| 196 |
+
mask = np.take(mask, target_channels, axis=-1)
|
| 197 |
+
|
| 198 |
+
if self.height > mask_height or self.width > mask_width:
|
| 199 |
+
raise ValueError(
|
| 200 |
+
"Crop size ({},{}) is larger than image ({},{})".format(
|
| 201 |
+
self.height, self.width, mask_height, mask_width
|
| 202 |
+
)
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
return mask
|
| 206 |
+
|
| 207 |
+
def update_params(self, params, **kwargs):
|
| 208 |
+
super().update_params(params, **kwargs)
|
| 209 |
+
if "mask" in kwargs:
|
| 210 |
+
mask = self._preprocess_mask(kwargs["mask"])
|
| 211 |
+
elif "masks" in kwargs and len(kwargs["masks"]):
|
| 212 |
+
masks = kwargs["masks"]
|
| 213 |
+
mask = self._preprocess_mask(np.copy(masks[0])) # need copy as we perform in-place mod afterwards
|
| 214 |
+
for m in masks[1:]:
|
| 215 |
+
mask |= self._preprocess_mask(m)
|
| 216 |
+
else:
|
| 217 |
+
raise RuntimeError("Can not find mask for CropNonEmptyMaskIfExists")
|
| 218 |
+
|
| 219 |
+
mask_height, mask_width = mask.shape[:2]
|
| 220 |
+
|
| 221 |
+
if mask.any():
|
| 222 |
+
mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
|
| 223 |
+
non_zero_yx = np.argwhere(mask)
|
| 224 |
+
y, x = random.choice(non_zero_yx)
|
| 225 |
+
x_min = x - random.randint(0, self.width - 1)
|
| 226 |
+
y_min = y - random.randint(0, self.height - 1)
|
| 227 |
+
x_min = np.clip(x_min, 0, mask_width - self.width)
|
| 228 |
+
y_min = np.clip(y_min, 0, mask_height - self.height)
|
| 229 |
+
else:
|
| 230 |
+
x_min = random.randint(0, mask_width - self.width)
|
| 231 |
+
y_min = random.randint(0, mask_height - self.height)
|
| 232 |
+
|
| 233 |
+
x_max = x_min + self.width
|
| 234 |
+
y_max = y_min + self.height
|
| 235 |
+
|
| 236 |
+
params.update({"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max})
|
| 237 |
+
return params
|
| 238 |
+
|
| 239 |
+
def get_transform_init_args_names(self):
|
| 240 |
+
return ("height", "width", "ignore_values", "ignore_channels")
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class _BaseRandomSizedCrop(DualTransform):
|
| 244 |
+
# Base class for RandomSizedCrop and RandomResizedCrop
|
| 245 |
+
|
| 246 |
+
def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
|
| 247 |
+
super(_BaseRandomSizedCrop, self).__init__(always_apply, p)
|
| 248 |
+
self.height = height
|
| 249 |
+
self.width = width
|
| 250 |
+
self.interpolation = interpolation
|
| 251 |
+
|
| 252 |
+
def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
|
| 253 |
+
crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
|
| 254 |
+
return FGeometric.resize(crop, self.height, self.width, interpolation)
|
| 255 |
+
|
| 256 |
+
def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
|
| 257 |
+
return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)
|
| 258 |
+
|
| 259 |
+
def apply_to_keypoint(self, keypoint, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
|
| 260 |
+
keypoint = F.keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, rows, cols)
|
| 261 |
+
scale_x = self.width / crop_width
|
| 262 |
+
scale_y = self.height / crop_height
|
| 263 |
+
keypoint = FGeometric.keypoint_scale(keypoint, scale_x, scale_y)
|
| 264 |
+
return keypoint
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
class RandomSizedCrop(_BaseRandomSizedCrop):
|
| 268 |
+
"""Crop a random part of the input and rescale it to some size.
|
| 269 |
+
|
| 270 |
+
Args:
|
| 271 |
+
min_max_height ((int, int)): crop size limits.
|
| 272 |
+
height (int): height after crop and resize.
|
| 273 |
+
width (int): width after crop and resize.
|
| 274 |
+
w2h_ratio (float): aspect ratio of crop.
|
| 275 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 276 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 277 |
+
Default: cv2.INTER_LINEAR.
|
| 278 |
+
p (float): probability of applying the transform. Default: 1.
|
| 279 |
+
|
| 280 |
+
Targets:
|
| 281 |
+
image, mask, bboxes, keypoints
|
| 282 |
+
|
| 283 |
+
Image types:
|
| 284 |
+
uint8, float32
|
| 285 |
+
"""
|
| 286 |
+
|
| 287 |
+
def __init__(
|
| 288 |
+
self, min_max_height, height, width, w2h_ratio=1.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0
|
| 289 |
+
):
|
| 290 |
+
super(RandomSizedCrop, self).__init__(
|
| 291 |
+
height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p
|
| 292 |
+
)
|
| 293 |
+
self.min_max_height = min_max_height
|
| 294 |
+
self.w2h_ratio = w2h_ratio
|
| 295 |
+
|
| 296 |
+
def get_params(self):
|
| 297 |
+
crop_height = random.randint(self.min_max_height[0], self.min_max_height[1])
|
| 298 |
+
return {
|
| 299 |
+
"h_start": random.random(),
|
| 300 |
+
"w_start": random.random(),
|
| 301 |
+
"crop_height": crop_height,
|
| 302 |
+
"crop_width": int(crop_height * self.w2h_ratio),
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
def get_transform_init_args_names(self):
|
| 306 |
+
return "min_max_height", "height", "width", "w2h_ratio", "interpolation"
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class RandomResizedCrop(_BaseRandomSizedCrop):
|
| 310 |
+
"""Torchvision's variant of crop a random part of the input and rescale it to some size.
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
height (int): height after crop and resize.
|
| 314 |
+
width (int): width after crop and resize.
|
| 315 |
+
scale ((float, float)): range of size of the origin size cropped
|
| 316 |
+
ratio ((float, float)): range of aspect ratio of the origin aspect ratio cropped
|
| 317 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 318 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 319 |
+
Default: cv2.INTER_LINEAR.
|
| 320 |
+
p (float): probability of applying the transform. Default: 1.
|
| 321 |
+
|
| 322 |
+
Targets:
|
| 323 |
+
image, mask, bboxes, keypoints
|
| 324 |
+
|
| 325 |
+
Image types:
|
| 326 |
+
uint8, float32
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def __init__(
|
| 330 |
+
self,
|
| 331 |
+
height,
|
| 332 |
+
width,
|
| 333 |
+
scale=(0.08, 1.0),
|
| 334 |
+
ratio=(0.75, 1.3333333333333333),
|
| 335 |
+
interpolation=cv2.INTER_LINEAR,
|
| 336 |
+
always_apply=False,
|
| 337 |
+
p=1.0,
|
| 338 |
+
):
|
| 339 |
+
super(RandomResizedCrop, self).__init__(
|
| 340 |
+
height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p
|
| 341 |
+
)
|
| 342 |
+
self.scale = scale
|
| 343 |
+
self.ratio = ratio
|
| 344 |
+
|
| 345 |
+
def get_params_dependent_on_targets(self, params):
|
| 346 |
+
img = params["image"]
|
| 347 |
+
area = img.shape[0] * img.shape[1]
|
| 348 |
+
|
| 349 |
+
for _attempt in range(10):
|
| 350 |
+
target_area = random.uniform(*self.scale) * area
|
| 351 |
+
log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
|
| 352 |
+
aspect_ratio = math.exp(random.uniform(*log_ratio))
|
| 353 |
+
|
| 354 |
+
w = int(round(math.sqrt(target_area * aspect_ratio))) # skipcq: PTC-W0028
|
| 355 |
+
h = int(round(math.sqrt(target_area / aspect_ratio))) # skipcq: PTC-W0028
|
| 356 |
+
|
| 357 |
+
if 0 < w <= img.shape[1] and 0 < h <= img.shape[0]:
|
| 358 |
+
i = random.randint(0, img.shape[0] - h)
|
| 359 |
+
j = random.randint(0, img.shape[1] - w)
|
| 360 |
+
return {
|
| 361 |
+
"crop_height": h,
|
| 362 |
+
"crop_width": w,
|
| 363 |
+
"h_start": i * 1.0 / (img.shape[0] - h + 1e-10),
|
| 364 |
+
"w_start": j * 1.0 / (img.shape[1] - w + 1e-10),
|
| 365 |
+
}
|
| 366 |
+
|
| 367 |
+
# Fallback to central crop
|
| 368 |
+
in_ratio = img.shape[1] / img.shape[0]
|
| 369 |
+
if in_ratio < min(self.ratio):
|
| 370 |
+
w = img.shape[1]
|
| 371 |
+
h = int(round(w / min(self.ratio)))
|
| 372 |
+
elif in_ratio > max(self.ratio):
|
| 373 |
+
h = img.shape[0]
|
| 374 |
+
w = int(round(h * max(self.ratio)))
|
| 375 |
+
else: # whole image
|
| 376 |
+
w = img.shape[1]
|
| 377 |
+
h = img.shape[0]
|
| 378 |
+
i = (img.shape[0] - h) // 2
|
| 379 |
+
j = (img.shape[1] - w) // 2
|
| 380 |
+
return {
|
| 381 |
+
"crop_height": h,
|
| 382 |
+
"crop_width": w,
|
| 383 |
+
"h_start": i * 1.0 / (img.shape[0] - h + 1e-10),
|
| 384 |
+
"w_start": j * 1.0 / (img.shape[1] - w + 1e-10),
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
def get_params(self):
|
| 388 |
+
return {}
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def targets_as_params(self):
|
| 392 |
+
return ["image"]
|
| 393 |
+
|
| 394 |
+
def get_transform_init_args_names(self):
|
| 395 |
+
return "height", "width", "scale", "ratio", "interpolation"
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class RandomCropNearBBox(DualTransform):
|
| 399 |
+
"""Crop bbox from image with random shift by x,y coordinates
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
max_part_shift (float, (float, float)): Max shift in `height` and `width` dimensions relative
|
| 403 |
+
to `cropping_bbox` dimension.
|
| 404 |
+
If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift).
|
| 405 |
+
Default (0.3, 0.3).
|
| 406 |
+
cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox`
|
| 407 |
+
p (float): probability of applying the transform. Default: 1.
|
| 408 |
+
|
| 409 |
+
Targets:
|
| 410 |
+
image, mask, bboxes, keypoints
|
| 411 |
+
|
| 412 |
+
Image types:
|
| 413 |
+
uint8, float32
|
| 414 |
+
|
| 415 |
+
Examples:
|
| 416 |
+
>>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')],
|
| 417 |
+
>>> bbox_params=BboxParams("pascal_voc"))
|
| 418 |
+
>>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20])
|
| 419 |
+
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
def __init__(
|
| 423 |
+
self,
|
| 424 |
+
max_part_shift: Union[float, Tuple[float, float]] = (0.3, 0.3),
|
| 425 |
+
cropping_box_key: str = "cropping_bbox",
|
| 426 |
+
always_apply: bool = False,
|
| 427 |
+
p: float = 1.0,
|
| 428 |
+
):
|
| 429 |
+
super(RandomCropNearBBox, self).__init__(always_apply, p)
|
| 430 |
+
self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift)
|
| 431 |
+
self.cropping_bbox_key = cropping_box_key
|
| 432 |
+
|
| 433 |
+
if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1:
|
| 434 |
+
raise ValueError("Invalid max_part_shift. Got: {}".format(max_part_shift))
|
| 435 |
+
|
| 436 |
+
def apply(
|
| 437 |
+
self, img: np.ndarray, x_min: int = 0, x_max: int = 0, y_min: int = 0, y_max: int = 0, **params
|
| 438 |
+
) -> np.ndarray:
|
| 439 |
+
return F.clamping_crop(img, x_min, y_min, x_max, y_max)
|
| 440 |
+
|
| 441 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, int]:
|
| 442 |
+
bbox = params[self.cropping_bbox_key]
|
| 443 |
+
h_max_shift = round((bbox[3] - bbox[1]) * self.max_part_shift[0])
|
| 444 |
+
w_max_shift = round((bbox[2] - bbox[0]) * self.max_part_shift[1])
|
| 445 |
+
|
| 446 |
+
x_min = bbox[0] - random.randint(-w_max_shift, w_max_shift)
|
| 447 |
+
x_max = bbox[2] + random.randint(-w_max_shift, w_max_shift)
|
| 448 |
+
|
| 449 |
+
y_min = bbox[1] - random.randint(-h_max_shift, h_max_shift)
|
| 450 |
+
y_max = bbox[3] + random.randint(-h_max_shift, h_max_shift)
|
| 451 |
+
|
| 452 |
+
x_min = max(0, x_min)
|
| 453 |
+
y_min = max(0, y_min)
|
| 454 |
+
|
| 455 |
+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
|
| 456 |
+
|
| 457 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 458 |
+
return F.bbox_crop(bbox, **params)
|
| 459 |
+
|
| 460 |
+
def apply_to_keypoint(
|
| 461 |
+
self,
|
| 462 |
+
keypoint: Tuple[float, float, float, float],
|
| 463 |
+
x_min: int = 0,
|
| 464 |
+
x_max: int = 0,
|
| 465 |
+
y_min: int = 0,
|
| 466 |
+
y_max: int = 0,
|
| 467 |
+
**params
|
| 468 |
+
) -> Tuple[float, float, float, float]:
|
| 469 |
+
return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
|
| 470 |
+
|
| 471 |
+
@property
|
| 472 |
+
def targets_as_params(self) -> List[str]:
|
| 473 |
+
return [self.cropping_bbox_key]
|
| 474 |
+
|
| 475 |
+
def get_transform_init_args_names(self) -> Tuple[str]:
|
| 476 |
+
return ("max_part_shift",)
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class BBoxSafeRandomCrop(DualTransform):
|
| 480 |
+
"""Crop a random part of the input without loss of bboxes.
|
| 481 |
+
Args:
|
| 482 |
+
erosion_rate (float): erosion rate applied on input image height before crop.
|
| 483 |
+
p (float): probability of applying the transform. Default: 1.
|
| 484 |
+
Targets:
|
| 485 |
+
image, mask, bboxes
|
| 486 |
+
Image types:
|
| 487 |
+
uint8, float32
|
| 488 |
+
"""
|
| 489 |
+
|
| 490 |
+
def __init__(self, erosion_rate=0.0, always_apply=False, p=1.0):
|
| 491 |
+
super(BBoxSafeRandomCrop, self).__init__(always_apply, p)
|
| 492 |
+
self.erosion_rate = erosion_rate
|
| 493 |
+
|
| 494 |
+
def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, **params):
|
| 495 |
+
return F.random_crop(img, crop_height, crop_width, h_start, w_start)
|
| 496 |
+
|
| 497 |
+
def get_params_dependent_on_targets(self, params):
|
| 498 |
+
img_h, img_w = params["image"].shape[:2]
|
| 499 |
+
if len(params["bboxes"]) == 0: # less likely, this class is for use with bboxes.
|
| 500 |
+
erosive_h = int(img_h * (1.0 - self.erosion_rate))
|
| 501 |
+
crop_height = img_h if erosive_h >= img_h else random.randint(erosive_h, img_h)
|
| 502 |
+
return {
|
| 503 |
+
"h_start": random.random(),
|
| 504 |
+
"w_start": random.random(),
|
| 505 |
+
"crop_height": crop_height,
|
| 506 |
+
"crop_width": int(crop_height * img_w / img_h),
|
| 507 |
+
}
|
| 508 |
+
# get union of all bboxes
|
| 509 |
+
x, y, x2, y2 = union_of_bboxes(
|
| 510 |
+
width=img_w, height=img_h, bboxes=params["bboxes"], erosion_rate=self.erosion_rate
|
| 511 |
+
)
|
| 512 |
+
# find bigger region
|
| 513 |
+
bx, by = x * random.random(), y * random.random()
|
| 514 |
+
bx2, by2 = x2 + (1 - x2) * random.random(), y2 + (1 - y2) * random.random()
|
| 515 |
+
bw, bh = bx2 - bx, by2 - by
|
| 516 |
+
crop_height = img_h if bh >= 1.0 else int(img_h * bh)
|
| 517 |
+
crop_width = img_w if bw >= 1.0 else int(img_w * bw)
|
| 518 |
+
h_start = np.clip(0.0 if bh >= 1.0 else by / (1.0 - bh), 0.0, 1.0)
|
| 519 |
+
w_start = np.clip(0.0 if bw >= 1.0 else bx / (1.0 - bw), 0.0, 1.0)
|
| 520 |
+
return {"h_start": h_start, "w_start": w_start, "crop_height": crop_height, "crop_width": crop_width}
|
| 521 |
+
|
| 522 |
+
def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
|
| 523 |
+
return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)
|
| 524 |
+
|
| 525 |
+
@property
|
| 526 |
+
def targets_as_params(self):
|
| 527 |
+
return ["image", "bboxes"]
|
| 528 |
+
|
| 529 |
+
def get_transform_init_args_names(self):
|
| 530 |
+
return ("erosion_rate",)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop):
|
| 534 |
+
"""Crop a random part of the input and rescale it to some size without loss of bboxes.
|
| 535 |
+
Args:
|
| 536 |
+
height (int): height after crop and resize.
|
| 537 |
+
width (int): width after crop and resize.
|
| 538 |
+
erosion_rate (float): erosion rate applied on input image height before crop.
|
| 539 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 540 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 541 |
+
Default: cv2.INTER_LINEAR.
|
| 542 |
+
p (float): probability of applying the transform. Default: 1.
|
| 543 |
+
Targets:
|
| 544 |
+
image, mask, bboxes
|
| 545 |
+
Image types:
|
| 546 |
+
uint8, float32
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
def __init__(self, height, width, erosion_rate=0.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
|
| 550 |
+
super(RandomSizedBBoxSafeCrop, self).__init__(erosion_rate, always_apply, p)
|
| 551 |
+
self.height = height
|
| 552 |
+
self.width = width
|
| 553 |
+
self.interpolation = interpolation
|
| 554 |
+
|
| 555 |
+
def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
|
| 556 |
+
crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
|
| 557 |
+
return FGeometric.resize(crop, self.height, self.width, interpolation)
|
| 558 |
+
|
| 559 |
+
def get_transform_init_args_names(self):
|
| 560 |
+
return super().get_transform_init_args_names() + ("height", "width", "interpolation")
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
class CropAndPad(DualTransform):
|
| 564 |
+
"""Crop and pad images by pixel amounts or fractions of image sizes.
|
| 565 |
+
Cropping removes pixels at the sides (i.e. extracts a subimage from a given full image).
|
| 566 |
+
Padding adds pixels to the sides (e.g. black pixels).
|
| 567 |
+
This transformation will never crop images below a height or width of ``1``.
|
| 568 |
+
|
| 569 |
+
Note:
|
| 570 |
+
This transformation automatically resizes images back to their original size. To deactivate this, add the
|
| 571 |
+
parameter ``keep_size=False``.
|
| 572 |
+
|
| 573 |
+
Args:
|
| 574 |
+
px (int or tuple):
|
| 575 |
+
The number of pixels to crop (negative values) or pad (positive values)
|
| 576 |
+
on each side of the image. Either this or the parameter `percent` may
|
| 577 |
+
be set, not both at the same time.
|
| 578 |
+
* If ``None``, then pixel-based cropping/padding will not be used.
|
| 579 |
+
* If ``int``, then that exact number of pixels will always be cropped/padded.
|
| 580 |
+
* If a ``tuple`` of two ``int`` s with values ``a`` and ``b``,
|
| 581 |
+
then each side will be cropped/padded by a random amount sampled
|
| 582 |
+
uniformly per image and side from the interval ``[a, b]``. If
|
| 583 |
+
however `sample_independently` is set to ``False``, only one
|
| 584 |
+
value will be sampled per image and used for all sides.
|
| 585 |
+
* If a ``tuple`` of four entries, then the entries represent top,
|
| 586 |
+
right, bottom, left. Each entry may be a single ``int`` (always
|
| 587 |
+
crop/pad by exactly that value), a ``tuple`` of two ``int`` s
|
| 588 |
+
``a`` and ``b`` (crop/pad by an amount within ``[a, b]``), a
|
| 589 |
+
``list`` of ``int`` s (crop/pad by a random value that is
|
| 590 |
+
contained in the ``list``).
|
| 591 |
+
percent (float or tuple):
|
| 592 |
+
The number of pixels to crop (negative values) or pad (positive values)
|
| 593 |
+
on each side of the image given as a *fraction* of the image
|
| 594 |
+
height/width. E.g. if this is set to ``-0.1``, the transformation will
|
| 595 |
+
always crop away ``10%`` of the image's height at both the top and the
|
| 596 |
+
bottom (both ``10%`` each), as well as ``10%`` of the width at the
|
| 597 |
+
right and left.
|
| 598 |
+
Expected value range is ``(-1.0, inf)``.
|
| 599 |
+
Either this or the parameter `px` may be set, not both
|
| 600 |
+
at the same time.
|
| 601 |
+
* If ``None``, then fraction-based cropping/padding will not be
|
| 602 |
+
used.
|
| 603 |
+
* If ``float``, then that fraction will always be cropped/padded.
|
| 604 |
+
* If a ``tuple`` of two ``float`` s with values ``a`` and ``b``,
|
| 605 |
+
then each side will be cropped/padded by a random fraction
|
| 606 |
+
sampled uniformly per image and side from the interval
|
| 607 |
+
``[a, b]``. If however `sample_independently` is set to
|
| 608 |
+
``False``, only one value will be sampled per image and used for
|
| 609 |
+
all sides.
|
| 610 |
+
* If a ``tuple`` of four entries, then the entries represent top,
|
| 611 |
+
right, bottom, left. Each entry may be a single ``float``
|
| 612 |
+
(always crop/pad by exactly that percent value), a ``tuple`` of
|
| 613 |
+
two ``float`` s ``a`` and ``b`` (crop/pad by a fraction from
|
| 614 |
+
``[a, b]``), a ``list`` of ``float`` s (crop/pad by a random
|
| 615 |
+
value that is contained in the list).
|
| 616 |
+
pad_mode (int): OpenCV border mode.
|
| 617 |
+
pad_cval (number, Sequence[number]):
|
| 618 |
+
The constant value to use if the pad mode is ``BORDER_CONSTANT``.
|
| 619 |
+
* If ``number``, then that value will be used.
|
| 620 |
+
* If a ``tuple`` of two ``number`` s and at least one of them is
|
| 621 |
+
a ``float``, then a random number will be uniformly sampled per
|
| 622 |
+
image from the continuous interval ``[a, b]`` and used as the
|
| 623 |
+
value. If both ``number`` s are ``int`` s, the interval is
|
| 624 |
+
discrete.
|
| 625 |
+
* If a ``list`` of ``number``, then a random value will be chosen
|
| 626 |
+
from the elements of the ``list`` and used as the value.
|
| 627 |
+
pad_cval_mask (number, Sequence[number]): Same as pad_cval but only for masks.
|
| 628 |
+
keep_size (bool):
|
| 629 |
+
After cropping and padding, the result image will usually have a
|
| 630 |
+
different height/width compared to the original input image. If this
|
| 631 |
+
parameter is set to ``True``, then the cropped/padded image will be
|
| 632 |
+
resized to the input image's size, i.e. the output shape is always identical to the input shape.
|
| 633 |
+
sample_independently (bool):
|
| 634 |
+
If ``False`` *and* the values for `px`/`percent` result in exactly
|
| 635 |
+
*one* probability distribution for all image sides, only one single
|
| 636 |
+
value will be sampled from that probability distribution and used for
|
| 637 |
+
all sides. I.e. the crop/pad amount then is the same for all sides.
|
| 638 |
+
If ``True``, four values will be sampled independently, one per side.
|
| 639 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 640 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 641 |
+
Default: cv2.INTER_LINEAR.
|
| 642 |
+
|
| 643 |
+
Targets:
|
| 644 |
+
image, mask, bboxes, keypoints
|
| 645 |
+
|
| 646 |
+
Image types:
|
| 647 |
+
any
|
| 648 |
+
"""
|
| 649 |
+
|
| 650 |
+
def __init__(
|
| 651 |
+
self,
|
| 652 |
+
px: Optional[Union[int, Sequence[float], Sequence[Tuple]]] = None,
|
| 653 |
+
percent: Optional[Union[float, Sequence[float], Sequence[Tuple]]] = None,
|
| 654 |
+
pad_mode: int = cv2.BORDER_CONSTANT,
|
| 655 |
+
pad_cval: Union[float, Sequence[float]] = 0,
|
| 656 |
+
pad_cval_mask: Union[float, Sequence[float]] = 0,
|
| 657 |
+
keep_size: bool = True,
|
| 658 |
+
sample_independently: bool = True,
|
| 659 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 660 |
+
always_apply: bool = False,
|
| 661 |
+
p: float = 1.0,
|
| 662 |
+
):
|
| 663 |
+
super().__init__(always_apply, p)
|
| 664 |
+
|
| 665 |
+
if px is None and percent is None:
|
| 666 |
+
raise ValueError("px and percent are empty!")
|
| 667 |
+
if px is not None and percent is not None:
|
| 668 |
+
raise ValueError("Only px or percent may be set!")
|
| 669 |
+
|
| 670 |
+
self.px = px
|
| 671 |
+
self.percent = percent
|
| 672 |
+
|
| 673 |
+
self.pad_mode = pad_mode
|
| 674 |
+
self.pad_cval = pad_cval
|
| 675 |
+
self.pad_cval_mask = pad_cval_mask
|
| 676 |
+
|
| 677 |
+
self.keep_size = keep_size
|
| 678 |
+
self.sample_independently = sample_independently
|
| 679 |
+
|
| 680 |
+
self.interpolation = interpolation
|
| 681 |
+
|
| 682 |
+
def apply(
|
| 683 |
+
self,
|
| 684 |
+
img: np.ndarray,
|
| 685 |
+
crop_params: Sequence[int] = (),
|
| 686 |
+
pad_params: Sequence[int] = (),
|
| 687 |
+
pad_value: Union[int, float] = 0,
|
| 688 |
+
rows: int = 0,
|
| 689 |
+
cols: int = 0,
|
| 690 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 691 |
+
**params
|
| 692 |
+
) -> np.ndarray:
|
| 693 |
+
return F.crop_and_pad(
|
| 694 |
+
img, crop_params, pad_params, pad_value, rows, cols, interpolation, self.pad_mode, self.keep_size
|
| 695 |
+
)
|
| 696 |
+
|
| 697 |
+
def apply_to_mask(
|
| 698 |
+
self,
|
| 699 |
+
img: np.ndarray,
|
| 700 |
+
crop_params: Optional[Sequence[int]] = None,
|
| 701 |
+
pad_params: Optional[Sequence[int]] = None,
|
| 702 |
+
pad_value_mask: Optional[float] = None,
|
| 703 |
+
rows: int = 0,
|
| 704 |
+
cols: int = 0,
|
| 705 |
+
interpolation: int = cv2.INTER_NEAREST,
|
| 706 |
+
**params
|
| 707 |
+
) -> np.ndarray:
|
| 708 |
+
return F.crop_and_pad(
|
| 709 |
+
img, crop_params, pad_params, pad_value_mask, rows, cols, interpolation, self.pad_mode, self.keep_size
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
def apply_to_bbox(
|
| 713 |
+
self,
|
| 714 |
+
bbox: BoxInternalType,
|
| 715 |
+
crop_params: Optional[Sequence[int]] = None,
|
| 716 |
+
pad_params: Optional[Sequence[int]] = None,
|
| 717 |
+
rows: int = 0,
|
| 718 |
+
cols: int = 0,
|
| 719 |
+
result_rows: int = 0,
|
| 720 |
+
result_cols: int = 0,
|
| 721 |
+
**params
|
| 722 |
+
) -> BoxInternalType:
|
| 723 |
+
return F.crop_and_pad_bbox(bbox, crop_params, pad_params, rows, cols, result_rows, result_cols)
|
| 724 |
+
|
| 725 |
+
def apply_to_keypoint(
|
| 726 |
+
self,
|
| 727 |
+
keypoint: KeypointInternalType,
|
| 728 |
+
crop_params: Optional[Sequence[int]] = None,
|
| 729 |
+
pad_params: Optional[Sequence[int]] = None,
|
| 730 |
+
rows: int = 0,
|
| 731 |
+
cols: int = 0,
|
| 732 |
+
result_rows: int = 0,
|
| 733 |
+
result_cols: int = 0,
|
| 734 |
+
**params
|
| 735 |
+
) -> KeypointInternalType:
|
| 736 |
+
return F.crop_and_pad_keypoint(
|
| 737 |
+
keypoint, crop_params, pad_params, rows, cols, result_rows, result_cols, self.keep_size
|
| 738 |
+
)
|
| 739 |
+
|
| 740 |
+
@property
|
| 741 |
+
def targets_as_params(self) -> List[str]:
|
| 742 |
+
return ["image"]
|
| 743 |
+
|
| 744 |
+
@staticmethod
|
| 745 |
+
def __prevent_zero(val1: int, val2: int, max_val: int) -> Tuple[int, int]:
|
| 746 |
+
regain = abs(max_val) + 1
|
| 747 |
+
regain1 = regain // 2
|
| 748 |
+
regain2 = regain // 2
|
| 749 |
+
if regain1 + regain2 < regain:
|
| 750 |
+
regain1 += 1
|
| 751 |
+
|
| 752 |
+
if regain1 > val1:
|
| 753 |
+
diff = regain1 - val1
|
| 754 |
+
regain1 = val1
|
| 755 |
+
regain2 += diff
|
| 756 |
+
elif regain2 > val2:
|
| 757 |
+
diff = regain2 - val2
|
| 758 |
+
regain2 = val2
|
| 759 |
+
regain1 += diff
|
| 760 |
+
|
| 761 |
+
val1 = val1 - regain1
|
| 762 |
+
val2 = val2 - regain2
|
| 763 |
+
|
| 764 |
+
return val1, val2
|
| 765 |
+
|
| 766 |
+
@staticmethod
|
| 767 |
+
def _prevent_zero(crop_params: List[int], height: int, width: int) -> Sequence[int]:
|
| 768 |
+
top, right, bottom, left = crop_params
|
| 769 |
+
|
| 770 |
+
remaining_height = height - (top + bottom)
|
| 771 |
+
remaining_width = width - (left + right)
|
| 772 |
+
|
| 773 |
+
if remaining_height < 1:
|
| 774 |
+
top, bottom = CropAndPad.__prevent_zero(top, bottom, height)
|
| 775 |
+
if remaining_width < 1:
|
| 776 |
+
left, right = CropAndPad.__prevent_zero(left, right, width)
|
| 777 |
+
|
| 778 |
+
return [max(top, 0), max(right, 0), max(bottom, 0), max(left, 0)]
|
| 779 |
+
|
| 780 |
+
def get_params_dependent_on_targets(self, params) -> dict:
|
| 781 |
+
height, width = params["image"].shape[:2]
|
| 782 |
+
|
| 783 |
+
if self.px is not None:
|
| 784 |
+
params = self._get_px_params()
|
| 785 |
+
else:
|
| 786 |
+
params = self._get_percent_params()
|
| 787 |
+
params[0] = int(params[0] * height)
|
| 788 |
+
params[1] = int(params[1] * width)
|
| 789 |
+
params[2] = int(params[2] * height)
|
| 790 |
+
params[3] = int(params[3] * width)
|
| 791 |
+
|
| 792 |
+
pad_params = [max(i, 0) for i in params]
|
| 793 |
+
|
| 794 |
+
crop_params = self._prevent_zero([-min(i, 0) for i in params], height, width)
|
| 795 |
+
|
| 796 |
+
top, right, bottom, left = crop_params
|
| 797 |
+
crop_params = [left, top, width - right, height - bottom]
|
| 798 |
+
result_rows = crop_params[3] - crop_params[1]
|
| 799 |
+
result_cols = crop_params[2] - crop_params[0]
|
| 800 |
+
if result_cols == width and result_rows == height:
|
| 801 |
+
crop_params = []
|
| 802 |
+
|
| 803 |
+
top, right, bottom, left = pad_params
|
| 804 |
+
pad_params = [top, bottom, left, right]
|
| 805 |
+
if any(pad_params):
|
| 806 |
+
result_rows += top + bottom
|
| 807 |
+
result_cols += left + right
|
| 808 |
+
else:
|
| 809 |
+
pad_params = []
|
| 810 |
+
|
| 811 |
+
return {
|
| 812 |
+
"crop_params": crop_params or None,
|
| 813 |
+
"pad_params": pad_params or None,
|
| 814 |
+
"pad_value": None if pad_params is None else self._get_pad_value(self.pad_cval),
|
| 815 |
+
"pad_value_mask": None if pad_params is None else self._get_pad_value(self.pad_cval_mask),
|
| 816 |
+
"result_rows": result_rows,
|
| 817 |
+
"result_cols": result_cols,
|
| 818 |
+
}
|
| 819 |
+
|
| 820 |
+
def _get_px_params(self) -> List[int]:
|
| 821 |
+
if self.px is None:
|
| 822 |
+
raise ValueError("px is not set")
|
| 823 |
+
|
| 824 |
+
if isinstance(self.px, int):
|
| 825 |
+
params = [self.px] * 4
|
| 826 |
+
elif len(self.px) == 2:
|
| 827 |
+
if self.sample_independently:
|
| 828 |
+
params = [random.randrange(*self.px) for _ in range(4)]
|
| 829 |
+
else:
|
| 830 |
+
px = random.randrange(*self.px)
|
| 831 |
+
params = [px] * 4
|
| 832 |
+
else:
|
| 833 |
+
params = [i if isinstance(i, int) else random.randrange(*i) for i in self.px] # type: ignore
|
| 834 |
+
|
| 835 |
+
return params # [top, right, bottom, left]
|
| 836 |
+
|
| 837 |
+
def _get_percent_params(self) -> List[float]:
|
| 838 |
+
if self.percent is None:
|
| 839 |
+
raise ValueError("percent is not set")
|
| 840 |
+
|
| 841 |
+
if isinstance(self.percent, float):
|
| 842 |
+
params = [self.percent] * 4
|
| 843 |
+
elif len(self.percent) == 2:
|
| 844 |
+
if self.sample_independently:
|
| 845 |
+
params = [random.uniform(*self.percent) for _ in range(4)]
|
| 846 |
+
else:
|
| 847 |
+
px = random.uniform(*self.percent)
|
| 848 |
+
params = [px] * 4
|
| 849 |
+
else:
|
| 850 |
+
params = [i if isinstance(i, (int, float)) else random.uniform(*i) for i in self.percent]
|
| 851 |
+
|
| 852 |
+
return params # params = [top, right, bottom, left]
|
| 853 |
+
|
| 854 |
+
@staticmethod
|
| 855 |
+
def _get_pad_value(pad_value: Union[float, Sequence[float]]) -> Union[int, float]:
|
| 856 |
+
if isinstance(pad_value, (int, float)):
|
| 857 |
+
return pad_value
|
| 858 |
+
|
| 859 |
+
if len(pad_value) == 2:
|
| 860 |
+
a, b = pad_value
|
| 861 |
+
if isinstance(a, int) and isinstance(b, int):
|
| 862 |
+
return random.randint(a, b)
|
| 863 |
+
|
| 864 |
+
return random.uniform(a, b)
|
| 865 |
+
|
| 866 |
+
return random.choice(pad_value)
|
| 867 |
+
|
| 868 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 869 |
+
return (
|
| 870 |
+
"px",
|
| 871 |
+
"percent",
|
| 872 |
+
"pad_mode",
|
| 873 |
+
"pad_cval",
|
| 874 |
+
"pad_cval_mask",
|
| 875 |
+
"keep_size",
|
| 876 |
+
"sample_independently",
|
| 877 |
+
"interpolation",
|
| 878 |
+
)
|
| 879 |
+
|
| 880 |
+
|
| 881 |
+
class RandomCropFromBorders(DualTransform):
|
| 882 |
+
"""Crop bbox from image randomly cut parts from borders without resize at the end
|
| 883 |
+
|
| 884 |
+
Args:
|
| 885 |
+
crop_left (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
|
| 886 |
+
from left side in range [0, crop_left * width)
|
| 887 |
+
crop_right (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
|
| 888 |
+
from right side in range [(1 - crop_right) * width, width)
|
| 889 |
+
crop_top (float): singlefloat value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
|
| 890 |
+
from top side in range [0, crop_top * height)
|
| 891 |
+
crop_bottom (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
|
| 892 |
+
from bottom side in range [(1 - crop_bottom) * height, height)
|
| 893 |
+
p (float): probability of applying the transform. Default: 1.
|
| 894 |
+
|
| 895 |
+
Targets:
|
| 896 |
+
image, mask, bboxes, keypoints
|
| 897 |
+
|
| 898 |
+
Image types:
|
| 899 |
+
uint8, float32
|
| 900 |
+
"""
|
| 901 |
+
|
| 902 |
+
def __init__(
|
| 903 |
+
self,
|
| 904 |
+
crop_left=0.1,
|
| 905 |
+
crop_right=0.1,
|
| 906 |
+
crop_top=0.1,
|
| 907 |
+
crop_bottom=0.1,
|
| 908 |
+
always_apply=False,
|
| 909 |
+
p=1.0,
|
| 910 |
+
):
|
| 911 |
+
super(RandomCropFromBorders, self).__init__(always_apply, p)
|
| 912 |
+
self.crop_left = crop_left
|
| 913 |
+
self.crop_right = crop_right
|
| 914 |
+
self.crop_top = crop_top
|
| 915 |
+
self.crop_bottom = crop_bottom
|
| 916 |
+
|
| 917 |
+
def get_params_dependent_on_targets(self, params):
|
| 918 |
+
img = params["image"]
|
| 919 |
+
x_min = random.randint(0, int(self.crop_left * img.shape[1]))
|
| 920 |
+
x_max = random.randint(max(x_min + 1, int((1 - self.crop_right) * img.shape[1])), img.shape[1])
|
| 921 |
+
y_min = random.randint(0, int(self.crop_top * img.shape[0]))
|
| 922 |
+
y_max = random.randint(max(y_min + 1, int((1 - self.crop_bottom) * img.shape[0])), img.shape[0])
|
| 923 |
+
return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
|
| 924 |
+
|
| 925 |
+
def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 926 |
+
return F.clamping_crop(img, x_min, y_min, x_max, y_max)
|
| 927 |
+
|
| 928 |
+
def apply_to_mask(self, mask, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 929 |
+
return F.clamping_crop(mask, x_min, y_min, x_max, y_max)
|
| 930 |
+
|
| 931 |
+
def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 932 |
+
rows, cols = params["rows"], params["cols"]
|
| 933 |
+
return F.bbox_crop(bbox, x_min, y_min, x_max, y_max, rows, cols)
|
| 934 |
+
|
| 935 |
+
def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params):
|
| 936 |
+
return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
|
| 937 |
+
|
| 938 |
+
@property
|
| 939 |
+
def targets_as_params(self):
|
| 940 |
+
return ["image"]
|
| 941 |
+
|
| 942 |
+
def get_transform_init_args_names(self):
|
| 943 |
+
return "crop_left", "crop_right", "crop_top", "crop_bottom"
|
custom_albumentations/augmentations/domain_adaptation.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Callable, Literal, Sequence, Tuple
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from custom_qudida import DomainAdapter
|
| 7 |
+
from skimage.exposure import match_histograms
|
| 8 |
+
from sklearn.decomposition import PCA
|
| 9 |
+
from sklearn.preprocessing import MinMaxScaler, StandardScaler
|
| 10 |
+
|
| 11 |
+
from custom_albumentations.augmentations.utils import (
|
| 12 |
+
clipped,
|
| 13 |
+
get_opencv_dtype_from_numpy,
|
| 14 |
+
is_grayscale_image,
|
| 15 |
+
is_multispectral_image,
|
| 16 |
+
preserve_shape,
|
| 17 |
+
read_rgb_image,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
from ..core.transforms_interface import ImageOnlyTransform, ScaleFloatType, to_tuple
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"HistogramMatching",
|
| 24 |
+
"FDA",
|
| 25 |
+
"PixelDistributionAdaptation",
|
| 26 |
+
"fourier_domain_adaptation",
|
| 27 |
+
"apply_histogram",
|
| 28 |
+
"adapt_pixel_distribution",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@clipped
|
| 33 |
+
@preserve_shape
|
| 34 |
+
def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray:
|
| 35 |
+
"""
|
| 36 |
+
Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
img: source image
|
| 40 |
+
target_img: target image for domain adaptation
|
| 41 |
+
beta: coefficient from source paper
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
transformed image
|
| 45 |
+
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
img = np.squeeze(img)
|
| 49 |
+
target_img = np.squeeze(target_img)
|
| 50 |
+
|
| 51 |
+
if target_img.shape != img.shape:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"The source and target images must have the same shape,"
|
| 54 |
+
" but got {} and {} respectively.".format(img.shape, target_img.shape)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# get fft of both source and target
|
| 58 |
+
fft_src = np.fft.fft2(img.astype(np.float32), axes=(0, 1))
|
| 59 |
+
fft_trg = np.fft.fft2(target_img.astype(np.float32), axes=(0, 1))
|
| 60 |
+
|
| 61 |
+
# extract amplitude and phase of both fft-s
|
| 62 |
+
amplitude_src, phase_src = np.abs(fft_src), np.angle(fft_src)
|
| 63 |
+
amplitude_trg = np.abs(fft_trg)
|
| 64 |
+
|
| 65 |
+
# mutate the amplitude part of source with target
|
| 66 |
+
amplitude_src = np.fft.fftshift(amplitude_src, axes=(0, 1))
|
| 67 |
+
amplitude_trg = np.fft.fftshift(amplitude_trg, axes=(0, 1))
|
| 68 |
+
height, width = amplitude_src.shape[:2]
|
| 69 |
+
border = np.floor(min(height, width) * beta).astype(int)
|
| 70 |
+
center_y, center_x = np.floor([height / 2.0, width / 2.0]).astype(int)
|
| 71 |
+
|
| 72 |
+
y1, y2 = center_y - border, center_y + border + 1
|
| 73 |
+
x1, x2 = center_x - border, center_x + border + 1
|
| 74 |
+
|
| 75 |
+
amplitude_src[y1:y2, x1:x2] = amplitude_trg[y1:y2, x1:x2]
|
| 76 |
+
amplitude_src = np.fft.ifftshift(amplitude_src, axes=(0, 1))
|
| 77 |
+
|
| 78 |
+
# get mutated image
|
| 79 |
+
src_image_transformed = np.fft.ifft2(amplitude_src * np.exp(1j * phase_src), axes=(0, 1))
|
| 80 |
+
src_image_transformed = np.real(src_image_transformed)
|
| 81 |
+
|
| 82 |
+
return src_image_transformed
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@preserve_shape
|
| 86 |
+
def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray:
|
| 87 |
+
if img.dtype != reference_image.dtype:
|
| 88 |
+
raise RuntimeError(
|
| 89 |
+
f"Dtype of image and reference image must be the same. Got {img.dtype} and {reference_image.dtype}"
|
| 90 |
+
)
|
| 91 |
+
if img.shape[:2] != reference_image.shape[:2]:
|
| 92 |
+
reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0]))
|
| 93 |
+
|
| 94 |
+
img, reference_image = np.squeeze(img), np.squeeze(reference_image)
|
| 95 |
+
|
| 96 |
+
try:
|
| 97 |
+
matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == 3 else None)
|
| 98 |
+
except TypeError:
|
| 99 |
+
matched = match_histograms(img, reference_image, multichannel=True) # case for scikit-image<0.19.1
|
| 100 |
+
img = cv2.addWeighted(
|
| 101 |
+
matched,
|
| 102 |
+
blend_ratio,
|
| 103 |
+
img,
|
| 104 |
+
1 - blend_ratio,
|
| 105 |
+
0,
|
| 106 |
+
dtype=get_opencv_dtype_from_numpy(img.dtype),
|
| 107 |
+
)
|
| 108 |
+
return img
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@preserve_shape
|
| 112 |
+
def adapt_pixel_distribution(
|
| 113 |
+
img: np.ndarray, ref: np.ndarray, transform_type: str = "pca", weight: float = 0.5
|
| 114 |
+
) -> np.ndarray:
|
| 115 |
+
initial_type = img.dtype
|
| 116 |
+
transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]()
|
| 117 |
+
adapter = DomainAdapter(transformer=transformer, ref_img=ref)
|
| 118 |
+
result = adapter(img).astype("float32")
|
| 119 |
+
blended = (img.astype("float32") * (1 - weight) + result * weight).astype(initial_type)
|
| 120 |
+
return blended
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class HistogramMatching(ImageOnlyTransform):
|
| 124 |
+
"""
|
| 125 |
+
Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches
|
| 126 |
+
the histogram of the reference image. If the images have multiple channels, the matching is done independently
|
| 127 |
+
for each channel, as long as the number of channels is equal in the input image and the reference.
|
| 128 |
+
|
| 129 |
+
Histogram matching can be used as a lightweight normalisation for image processing,
|
| 130 |
+
such as feature matching, especially in circumstances where the images have been taken from different
|
| 131 |
+
sources or in different conditions (i.e. lighting).
|
| 132 |
+
|
| 133 |
+
See:
|
| 134 |
+
https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_histogram_matching.html
|
| 135 |
+
|
| 136 |
+
Args:
|
| 137 |
+
reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
|
| 138 |
+
it expects a sequence of paths to images.
|
| 139 |
+
blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original
|
| 140 |
+
with random blend factor for increased diversity of generated images.
|
| 141 |
+
read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
|
| 142 |
+
and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
|
| 143 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 144 |
+
|
| 145 |
+
Targets:
|
| 146 |
+
image
|
| 147 |
+
|
| 148 |
+
Image types:
|
| 149 |
+
uint8, uint16, float32
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
reference_images: Sequence[Any],
|
| 155 |
+
blend_ratio: Tuple[float, float] = (0.5, 1.0),
|
| 156 |
+
read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
|
| 157 |
+
always_apply: bool = False,
|
| 158 |
+
p: float = 0.5,
|
| 159 |
+
):
|
| 160 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 161 |
+
self.reference_images = reference_images
|
| 162 |
+
self.read_fn = read_fn
|
| 163 |
+
self.blend_ratio = blend_ratio
|
| 164 |
+
|
| 165 |
+
def apply(self, img, reference_image=None, blend_ratio=0.5, **params):
|
| 166 |
+
return apply_histogram(img, reference_image, blend_ratio)
|
| 167 |
+
|
| 168 |
+
def get_params(self):
|
| 169 |
+
return {
|
| 170 |
+
"reference_image": self.read_fn(random.choice(self.reference_images)),
|
| 171 |
+
"blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]),
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
def get_transform_init_args_names(self):
|
| 175 |
+
return ("reference_images", "blend_ratio", "read_fn")
|
| 176 |
+
|
| 177 |
+
def _to_dict(self):
|
| 178 |
+
raise NotImplementedError("HistogramMatching can not be serialized.")
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class FDA(ImageOnlyTransform):
|
| 182 |
+
"""
|
| 183 |
+
Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA
|
| 184 |
+
Simple "style transfer".
|
| 185 |
+
|
| 186 |
+
Args:
|
| 187 |
+
reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
|
| 188 |
+
it expects a sequence of paths to images.
|
| 189 |
+
beta_limit (float or tuple of float): coefficient beta from paper. Recommended less 0.3.
|
| 190 |
+
read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
|
| 191 |
+
and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
|
| 192 |
+
|
| 193 |
+
Targets:
|
| 194 |
+
image
|
| 195 |
+
|
| 196 |
+
Image types:
|
| 197 |
+
uint8, float32
|
| 198 |
+
|
| 199 |
+
Reference:
|
| 200 |
+
https://github.com/YanchaoYang/FDA
|
| 201 |
+
https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf
|
| 202 |
+
|
| 203 |
+
Example:
|
| 204 |
+
>>> import numpy as np
|
| 205 |
+
>>> import custom_albumentations as albumentations as A
|
| 206 |
+
>>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
|
| 207 |
+
>>> target_image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
|
| 208 |
+
>>> aug = A.Compose([A.FDA([target_image], p=1, read_fn=lambda x: x)])
|
| 209 |
+
>>> result = aug(image=image)
|
| 210 |
+
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
reference_images: Sequence[Any],
|
| 216 |
+
beta_limit: ScaleFloatType = 0.1,
|
| 217 |
+
read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
|
| 218 |
+
always_apply: bool = False,
|
| 219 |
+
p: float = 0.5,
|
| 220 |
+
):
|
| 221 |
+
super(FDA, self).__init__(always_apply=always_apply, p=p)
|
| 222 |
+
self.reference_images = reference_images
|
| 223 |
+
self.read_fn = read_fn
|
| 224 |
+
self.beta_limit = to_tuple(beta_limit, low=0)
|
| 225 |
+
|
| 226 |
+
def apply(self, img, target_image=None, beta=0.1, **params):
|
| 227 |
+
return fourier_domain_adaptation(img=img, target_img=target_image, beta=beta)
|
| 228 |
+
|
| 229 |
+
def get_params_dependent_on_targets(self, params):
|
| 230 |
+
img = params["image"]
|
| 231 |
+
target_img = self.read_fn(random.choice(self.reference_images))
|
| 232 |
+
target_img = cv2.resize(target_img, dsize=(img.shape[1], img.shape[0]))
|
| 233 |
+
|
| 234 |
+
return {"target_image": target_img}
|
| 235 |
+
|
| 236 |
+
def get_params(self):
|
| 237 |
+
return {"beta": random.uniform(self.beta_limit[0], self.beta_limit[1])}
|
| 238 |
+
|
| 239 |
+
@property
|
| 240 |
+
def targets_as_params(self):
|
| 241 |
+
return ["image"]
|
| 242 |
+
|
| 243 |
+
def get_transform_init_args_names(self):
|
| 244 |
+
return ("reference_images", "beta_limit", "read_fn")
|
| 245 |
+
|
| 246 |
+
def _to_dict(self):
|
| 247 |
+
raise NotImplementedError("FDA can not be serialized.")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class PixelDistributionAdaptation(ImageOnlyTransform):
|
| 251 |
+
"""
|
| 252 |
+
Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler
|
| 253 |
+
or MinMaxScaler) on both original and reference image, transforms original image with transform trained on this
|
| 254 |
+
image and then performs inverse transformation using transform fitted on reference image.
|
| 255 |
+
|
| 256 |
+
Args:
|
| 257 |
+
reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
|
| 258 |
+
it expects a sequence of paths to images.
|
| 259 |
+
blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original
|
| 260 |
+
with random blend factor for increased diversity of generated images.
|
| 261 |
+
read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
|
| 262 |
+
and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
|
| 263 |
+
transform_type (str): type of transform; "pca", "standard", "minmax" are allowed.
|
| 264 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 265 |
+
|
| 266 |
+
Targets:
|
| 267 |
+
image
|
| 268 |
+
|
| 269 |
+
Image types:
|
| 270 |
+
uint8, float32
|
| 271 |
+
|
| 272 |
+
See also: https://github.com/arsenyinfo/qudida
|
| 273 |
+
"""
|
| 274 |
+
|
| 275 |
+
def __init__(
|
| 276 |
+
self,
|
| 277 |
+
reference_images: Sequence[Any],
|
| 278 |
+
blend_ratio: Tuple[float, float] = (0.25, 1.0),
|
| 279 |
+
read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
|
| 280 |
+
transform_type: Literal["pca", "standard", "minmax"] = "pca",
|
| 281 |
+
always_apply: bool = False,
|
| 282 |
+
p: float = 0.5,
|
| 283 |
+
):
|
| 284 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 285 |
+
self.reference_images = reference_images
|
| 286 |
+
self.read_fn = read_fn
|
| 287 |
+
self.blend_ratio = blend_ratio
|
| 288 |
+
expected_transformers = ("pca", "standard", "minmax")
|
| 289 |
+
if transform_type not in expected_transformers:
|
| 290 |
+
raise ValueError(f"Got unexpected transform_type {transform_type}. Expected one of {expected_transformers}")
|
| 291 |
+
self.transform_type = transform_type
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def _validate_shape(img: np.ndarray):
|
| 295 |
+
if is_grayscale_image(img) or is_multispectral_image(img):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"Unexpected image shape: expected 3 dimensions, got {len(img.shape)}."
|
| 298 |
+
f"Is it a grayscale or multispectral image? It's not supported for now."
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def ensure_uint8(self, img: np.ndarray) -> Tuple[np.ndarray, bool]:
|
| 302 |
+
if img.dtype == np.float32:
|
| 303 |
+
if img.min() < 0 or img.max() > 1:
|
| 304 |
+
message = (
|
| 305 |
+
"PixelDistributionAdaptation uses uint8 under the hood, so float32 should be converted,"
|
| 306 |
+
"Can not do it automatically when the image is out of [0..1] range."
|
| 307 |
+
)
|
| 308 |
+
raise TypeError(message)
|
| 309 |
+
return (img * 255).astype("uint8"), True
|
| 310 |
+
return img, False
|
| 311 |
+
|
| 312 |
+
def apply(self, img, reference_image, blend_ratio, **params):
|
| 313 |
+
self._validate_shape(img)
|
| 314 |
+
reference_image, _ = self.ensure_uint8(reference_image)
|
| 315 |
+
img, needs_reconvert = self.ensure_uint8(img)
|
| 316 |
+
|
| 317 |
+
adapted = adapt_pixel_distribution(
|
| 318 |
+
img=img,
|
| 319 |
+
ref=reference_image,
|
| 320 |
+
weight=blend_ratio,
|
| 321 |
+
transform_type=self.transform_type,
|
| 322 |
+
)
|
| 323 |
+
if needs_reconvert:
|
| 324 |
+
adapted = adapted.astype("float32") * (1 / 255)
|
| 325 |
+
return adapted
|
| 326 |
+
|
| 327 |
+
def get_params(self):
|
| 328 |
+
return {
|
| 329 |
+
"reference_image": self.read_fn(random.choice(self.reference_images)),
|
| 330 |
+
"blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]),
|
| 331 |
+
}
|
| 332 |
+
|
| 333 |
+
def get_transform_init_args_names(self):
|
| 334 |
+
return ("reference_images", "blend_ratio", "read_fn", "transform_type")
|
| 335 |
+
|
| 336 |
+
def _to_dict(self):
|
| 337 |
+
raise NotImplementedError("PixelDistributionAdaptation can not be serialized.")
|
custom_albumentations/augmentations/dropout/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .channel_dropout import *
|
| 2 |
+
from .coarse_dropout import *
|
| 3 |
+
from .cutout import *
|
| 4 |
+
from .grid_dropout import *
|
| 5 |
+
from .mask_dropout import *
|
custom_albumentations/augmentations/dropout/channel_dropout.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Mapping, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from custom_albumentations.core.transforms_interface import ImageOnlyTransform
|
| 7 |
+
|
| 8 |
+
from .functional import channel_dropout
|
| 9 |
+
|
| 10 |
+
__all__ = ["ChannelDropout"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ChannelDropout(ImageOnlyTransform):
|
| 14 |
+
"""Randomly Drop Channels in the input Image.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
channel_drop_range (int, int): range from which we choose the number of channels to drop.
|
| 18 |
+
fill_value (int, float): pixel value for the dropped channel.
|
| 19 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 20 |
+
|
| 21 |
+
Targets:
|
| 22 |
+
image
|
| 23 |
+
|
| 24 |
+
Image types:
|
| 25 |
+
uint8, uint16, unit32, float32
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(
|
| 29 |
+
self,
|
| 30 |
+
channel_drop_range: Tuple[int, int] = (1, 1),
|
| 31 |
+
fill_value: Union[int, float] = 0,
|
| 32 |
+
always_apply: bool = False,
|
| 33 |
+
p: float = 0.5,
|
| 34 |
+
):
|
| 35 |
+
super(ChannelDropout, self).__init__(always_apply, p)
|
| 36 |
+
|
| 37 |
+
self.channel_drop_range = channel_drop_range
|
| 38 |
+
|
| 39 |
+
self.min_channels = channel_drop_range[0]
|
| 40 |
+
self.max_channels = channel_drop_range[1]
|
| 41 |
+
|
| 42 |
+
if not 1 <= self.min_channels <= self.max_channels:
|
| 43 |
+
raise ValueError("Invalid channel_drop_range. Got: {}".format(channel_drop_range))
|
| 44 |
+
|
| 45 |
+
self.fill_value = fill_value
|
| 46 |
+
|
| 47 |
+
def apply(self, img: np.ndarray, channels_to_drop: Tuple[int, ...] = (0,), **params) -> np.ndarray:
|
| 48 |
+
return channel_dropout(img, channels_to_drop, self.fill_value)
|
| 49 |
+
|
| 50 |
+
def get_params_dependent_on_targets(self, params: Mapping[str, Any]):
|
| 51 |
+
img = params["image"]
|
| 52 |
+
|
| 53 |
+
num_channels = img.shape[-1]
|
| 54 |
+
|
| 55 |
+
if len(img.shape) == 2 or num_channels == 1:
|
| 56 |
+
raise NotImplementedError("Images has one channel. ChannelDropout is not defined.")
|
| 57 |
+
|
| 58 |
+
if self.max_channels >= num_channels:
|
| 59 |
+
raise ValueError("Can not drop all channels in ChannelDropout.")
|
| 60 |
+
|
| 61 |
+
num_drop_channels = random.randint(self.min_channels, self.max_channels)
|
| 62 |
+
|
| 63 |
+
channels_to_drop = random.sample(range(num_channels), k=num_drop_channels)
|
| 64 |
+
|
| 65 |
+
return {"channels_to_drop": channels_to_drop}
|
| 66 |
+
|
| 67 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 68 |
+
return "channel_drop_range", "fill_value"
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def targets_as_params(self):
|
| 72 |
+
return ["image"]
|
custom_albumentations/augmentations/dropout/coarse_dropout.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Iterable, List, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ...core.transforms_interface import DualTransform, KeypointType
|
| 7 |
+
from .functional import cutout
|
| 8 |
+
|
| 9 |
+
__all__ = ["CoarseDropout"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CoarseDropout(DualTransform):
|
| 13 |
+
"""CoarseDropout of the rectangular regions in the image.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
max_holes (int): Maximum number of regions to zero out.
|
| 17 |
+
max_height (int, float): Maximum height of the hole.
|
| 18 |
+
If float, it is calculated as a fraction of the image height.
|
| 19 |
+
max_width (int, float): Maximum width of the hole.
|
| 20 |
+
If float, it is calculated as a fraction of the image width.
|
| 21 |
+
min_holes (int): Minimum number of regions to zero out. If `None`,
|
| 22 |
+
`min_holes` is be set to `max_holes`. Default: `None`.
|
| 23 |
+
min_height (int, float): Minimum height of the hole. Default: None. If `None`,
|
| 24 |
+
`min_height` is set to `max_height`. Default: `None`.
|
| 25 |
+
If float, it is calculated as a fraction of the image height.
|
| 26 |
+
min_width (int, float): Minimum width of the hole. If `None`, `min_height` is
|
| 27 |
+
set to `max_width`. Default: `None`.
|
| 28 |
+
If float, it is calculated as a fraction of the image width.
|
| 29 |
+
|
| 30 |
+
fill_value (int, float, list of int, list of float): value for dropped pixels.
|
| 31 |
+
mask_fill_value (int, float, list of int, list of float): fill value for dropped pixels
|
| 32 |
+
in mask. If `None` - mask is not affected. Default: `None`.
|
| 33 |
+
|
| 34 |
+
Targets:
|
| 35 |
+
image, mask, keypoints
|
| 36 |
+
|
| 37 |
+
Image types:
|
| 38 |
+
uint8, float32
|
| 39 |
+
|
| 40 |
+
Reference:
|
| 41 |
+
| https://arxiv.org/abs/1708.04552
|
| 42 |
+
| https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
|
| 43 |
+
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
max_holes: int = 8,
|
| 49 |
+
max_height: int = 8,
|
| 50 |
+
max_width: int = 8,
|
| 51 |
+
min_holes: Optional[int] = None,
|
| 52 |
+
min_height: Optional[int] = None,
|
| 53 |
+
min_width: Optional[int] = None,
|
| 54 |
+
fill_value: int = 0,
|
| 55 |
+
mask_fill_value: Optional[int] = None,
|
| 56 |
+
always_apply: bool = False,
|
| 57 |
+
p: float = 0.5,
|
| 58 |
+
):
|
| 59 |
+
super(CoarseDropout, self).__init__(always_apply, p)
|
| 60 |
+
self.max_holes = max_holes
|
| 61 |
+
self.max_height = max_height
|
| 62 |
+
self.max_width = max_width
|
| 63 |
+
self.min_holes = min_holes if min_holes is not None else max_holes
|
| 64 |
+
self.min_height = min_height if min_height is not None else max_height
|
| 65 |
+
self.min_width = min_width if min_width is not None else max_width
|
| 66 |
+
self.fill_value = fill_value
|
| 67 |
+
self.mask_fill_value = mask_fill_value
|
| 68 |
+
if not 0 < self.min_holes <= self.max_holes:
|
| 69 |
+
raise ValueError("Invalid combination of min_holes and max_holes. Got: {}".format([min_holes, max_holes]))
|
| 70 |
+
|
| 71 |
+
self.check_range(self.max_height)
|
| 72 |
+
self.check_range(self.min_height)
|
| 73 |
+
self.check_range(self.max_width)
|
| 74 |
+
self.check_range(self.min_width)
|
| 75 |
+
|
| 76 |
+
if not 0 < self.min_height <= self.max_height:
|
| 77 |
+
raise ValueError(
|
| 78 |
+
"Invalid combination of min_height and max_height. Got: {}".format([min_height, max_height])
|
| 79 |
+
)
|
| 80 |
+
if not 0 < self.min_width <= self.max_width:
|
| 81 |
+
raise ValueError("Invalid combination of min_width and max_width. Got: {}".format([min_width, max_width]))
|
| 82 |
+
|
| 83 |
+
def check_range(self, dimension):
|
| 84 |
+
if isinstance(dimension, float) and not 0 <= dimension < 1.0:
|
| 85 |
+
raise ValueError(
|
| 86 |
+
"Invalid value {}. If using floats, the value should be in the range [0.0, 1.0)".format(dimension)
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def apply(
|
| 90 |
+
self,
|
| 91 |
+
img: np.ndarray,
|
| 92 |
+
fill_value: Union[int, float] = 0,
|
| 93 |
+
holes: Iterable[Tuple[int, int, int, int]] = (),
|
| 94 |
+
**params
|
| 95 |
+
) -> np.ndarray:
|
| 96 |
+
return cutout(img, holes, fill_value)
|
| 97 |
+
|
| 98 |
+
def apply_to_mask(
|
| 99 |
+
self,
|
| 100 |
+
img: np.ndarray,
|
| 101 |
+
mask_fill_value: Union[int, float] = 0,
|
| 102 |
+
holes: Iterable[Tuple[int, int, int, int]] = (),
|
| 103 |
+
**params
|
| 104 |
+
) -> np.ndarray:
|
| 105 |
+
if mask_fill_value is None:
|
| 106 |
+
return img
|
| 107 |
+
return cutout(img, holes, mask_fill_value)
|
| 108 |
+
|
| 109 |
+
def get_params_dependent_on_targets(self, params):
|
| 110 |
+
img = params["image"]
|
| 111 |
+
height, width = img.shape[:2]
|
| 112 |
+
|
| 113 |
+
holes = []
|
| 114 |
+
for _n in range(random.randint(self.min_holes, self.max_holes)):
|
| 115 |
+
if all(
|
| 116 |
+
[
|
| 117 |
+
isinstance(self.min_height, int),
|
| 118 |
+
isinstance(self.min_width, int),
|
| 119 |
+
isinstance(self.max_height, int),
|
| 120 |
+
isinstance(self.max_width, int),
|
| 121 |
+
]
|
| 122 |
+
):
|
| 123 |
+
hole_height = random.randint(self.min_height, self.max_height)
|
| 124 |
+
hole_width = random.randint(self.min_width, self.max_width)
|
| 125 |
+
elif all(
|
| 126 |
+
[
|
| 127 |
+
isinstance(self.min_height, float),
|
| 128 |
+
isinstance(self.min_width, float),
|
| 129 |
+
isinstance(self.max_height, float),
|
| 130 |
+
isinstance(self.max_width, float),
|
| 131 |
+
]
|
| 132 |
+
):
|
| 133 |
+
hole_height = int(height * random.uniform(self.min_height, self.max_height))
|
| 134 |
+
hole_width = int(width * random.uniform(self.min_width, self.max_width))
|
| 135 |
+
else:
|
| 136 |
+
raise ValueError(
|
| 137 |
+
"Min width, max width, \
|
| 138 |
+
min height and max height \
|
| 139 |
+
should all either be ints or floats. \
|
| 140 |
+
Got: {} respectively".format(
|
| 141 |
+
[
|
| 142 |
+
type(self.min_width),
|
| 143 |
+
type(self.max_width),
|
| 144 |
+
type(self.min_height),
|
| 145 |
+
type(self.max_height),
|
| 146 |
+
]
|
| 147 |
+
)
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
y1 = random.randint(0, height - hole_height)
|
| 151 |
+
x1 = random.randint(0, width - hole_width)
|
| 152 |
+
y2 = y1 + hole_height
|
| 153 |
+
x2 = x1 + hole_width
|
| 154 |
+
holes.append((x1, y1, x2, y2))
|
| 155 |
+
|
| 156 |
+
return {"holes": holes}
|
| 157 |
+
|
| 158 |
+
@property
|
| 159 |
+
def targets_as_params(self):
|
| 160 |
+
return ["image"]
|
| 161 |
+
|
| 162 |
+
def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool:
|
| 163 |
+
x1, y1, x2, y2 = hole
|
| 164 |
+
x, y = keypoint[:2]
|
| 165 |
+
return x1 <= x < x2 and y1 <= y < y2
|
| 166 |
+
|
| 167 |
+
def apply_to_keypoints(
|
| 168 |
+
self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params
|
| 169 |
+
) -> List[KeypointType]:
|
| 170 |
+
result = set(keypoints)
|
| 171 |
+
for hole in holes:
|
| 172 |
+
for kp in keypoints:
|
| 173 |
+
if self._keypoint_in_hole(kp, hole):
|
| 174 |
+
result.discard(kp)
|
| 175 |
+
return list(result)
|
| 176 |
+
|
| 177 |
+
def get_transform_init_args_names(self):
|
| 178 |
+
return (
|
| 179 |
+
"max_holes",
|
| 180 |
+
"max_height",
|
| 181 |
+
"max_width",
|
| 182 |
+
"min_holes",
|
| 183 |
+
"min_height",
|
| 184 |
+
"min_width",
|
| 185 |
+
"fill_value",
|
| 186 |
+
"mask_fill_value",
|
| 187 |
+
)
|
custom_albumentations/augmentations/dropout/cutout.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import warnings
|
| 3 |
+
from typing import Any, Dict, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from custom_albumentations.core.transforms_interface import ImageOnlyTransform
|
| 8 |
+
|
| 9 |
+
from .functional import cutout
|
| 10 |
+
|
| 11 |
+
__all__ = ["Cutout"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Cutout(ImageOnlyTransform):
|
| 15 |
+
"""CoarseDropout of the square regions in the image.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
num_holes (int): number of regions to zero out
|
| 19 |
+
max_h_size (int): maximum height of the hole
|
| 20 |
+
max_w_size (int): maximum width of the hole
|
| 21 |
+
fill_value (int, float, list of int, list of float): value for dropped pixels.
|
| 22 |
+
|
| 23 |
+
Targets:
|
| 24 |
+
image
|
| 25 |
+
|
| 26 |
+
Image types:
|
| 27 |
+
uint8, float32
|
| 28 |
+
|
| 29 |
+
Reference:
|
| 30 |
+
| https://arxiv.org/abs/1708.04552
|
| 31 |
+
| https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
|
| 32 |
+
| https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(
|
| 36 |
+
self,
|
| 37 |
+
num_holes: int = 8,
|
| 38 |
+
max_h_size: int = 8,
|
| 39 |
+
max_w_size: int = 8,
|
| 40 |
+
fill_value: Union[int, float] = 0,
|
| 41 |
+
always_apply: bool = False,
|
| 42 |
+
p: float = 0.5,
|
| 43 |
+
):
|
| 44 |
+
super(Cutout, self).__init__(always_apply, p)
|
| 45 |
+
self.num_holes = num_holes
|
| 46 |
+
self.max_h_size = max_h_size
|
| 47 |
+
self.max_w_size = max_w_size
|
| 48 |
+
self.fill_value = fill_value
|
| 49 |
+
warnings.warn(
|
| 50 |
+
f"{self.__class__.__name__} has been deprecated. Please use CoarseDropout",
|
| 51 |
+
FutureWarning,
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
def apply(self, img: np.ndarray, fill_value: Union[int, float] = 0, holes=(), **params):
|
| 55 |
+
return cutout(img, holes, fill_value)
|
| 56 |
+
|
| 57 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 58 |
+
img = params["image"]
|
| 59 |
+
height, width = img.shape[:2]
|
| 60 |
+
|
| 61 |
+
holes = []
|
| 62 |
+
for _n in range(self.num_holes):
|
| 63 |
+
y = random.randint(0, height)
|
| 64 |
+
x = random.randint(0, width)
|
| 65 |
+
|
| 66 |
+
y1 = np.clip(y - self.max_h_size // 2, 0, height)
|
| 67 |
+
y2 = np.clip(y1 + self.max_h_size, 0, height)
|
| 68 |
+
x1 = np.clip(x - self.max_w_size // 2, 0, width)
|
| 69 |
+
x2 = np.clip(x1 + self.max_w_size, 0, width)
|
| 70 |
+
holes.append((x1, y1, x2, y2))
|
| 71 |
+
|
| 72 |
+
return {"holes": holes}
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def targets_as_params(self):
|
| 76 |
+
return ["image"]
|
| 77 |
+
|
| 78 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 79 |
+
return ("num_holes", "max_h_size", "max_w_size")
|
custom_albumentations/augmentations/dropout/functional.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterable, List, Tuple, Union
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from custom_albumentations.augmentations.utils import preserve_shape
|
| 6 |
+
|
| 7 |
+
__all__ = ["cutout", "channel_dropout"]
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@preserve_shape
|
| 11 |
+
def channel_dropout(
|
| 12 |
+
img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: Union[int, float] = 0
|
| 13 |
+
) -> np.ndarray:
|
| 14 |
+
if len(img.shape) == 2 or img.shape[2] == 1:
|
| 15 |
+
raise NotImplementedError("Only one channel. ChannelDropout is not defined.")
|
| 16 |
+
|
| 17 |
+
img = img.copy()
|
| 18 |
+
img[..., channels_to_drop] = fill_value
|
| 19 |
+
return img
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def cutout(
|
| 23 |
+
img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: Union[int, float] = 0
|
| 24 |
+
) -> np.ndarray:
|
| 25 |
+
# Make a copy of the input image since we don't want to modify it directly
|
| 26 |
+
img = img.copy()
|
| 27 |
+
for x1, y1, x2, y2 in holes:
|
| 28 |
+
img[y1:y2, x1:x2] = fill_value
|
| 29 |
+
return img
|
custom_albumentations/augmentations/dropout/grid_dropout.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Iterable, Optional, Tuple
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from ...core.transforms_interface import DualTransform
|
| 7 |
+
from . import functional as F
|
| 8 |
+
|
| 9 |
+
__all__ = ["GridDropout"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GridDropout(DualTransform):
|
| 13 |
+
"""GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions).
|
| 17 |
+
Must be between 0 and 1. Default: 0.5.
|
| 18 |
+
unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge.
|
| 19 |
+
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
|
| 20 |
+
unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge.
|
| 21 |
+
If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
|
| 22 |
+
holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2.
|
| 23 |
+
If 'None', grid unit width is set as image_width//10. Default: `None`.
|
| 24 |
+
holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2.
|
| 25 |
+
If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller.
|
| 26 |
+
shift_x (int): offsets of the grid start in x direction from (0,0) coordinate.
|
| 27 |
+
Clipped between 0 and grid unit_width - hole_width. Default: 0.
|
| 28 |
+
shift_y (int): offsets of the grid start in y direction from (0,0) coordinate.
|
| 29 |
+
Clipped between 0 and grid unit height - hole_height. Default: 0.
|
| 30 |
+
random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
|
| 31 |
+
If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`.
|
| 32 |
+
fill_value (int): value for the dropped pixels. Default = 0
|
| 33 |
+
mask_fill_value (int): value for the dropped pixels in mask.
|
| 34 |
+
If `None`, transformation is not applied to the mask. Default: `None`.
|
| 35 |
+
|
| 36 |
+
Targets:
|
| 37 |
+
image, mask
|
| 38 |
+
|
| 39 |
+
Image types:
|
| 40 |
+
uint8, float32
|
| 41 |
+
|
| 42 |
+
References:
|
| 43 |
+
https://arxiv.org/abs/2001.04086
|
| 44 |
+
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __init__(
|
| 48 |
+
self,
|
| 49 |
+
ratio: float = 0.5,
|
| 50 |
+
unit_size_min: Optional[int] = None,
|
| 51 |
+
unit_size_max: Optional[int] = None,
|
| 52 |
+
holes_number_x: Optional[int] = None,
|
| 53 |
+
holes_number_y: Optional[int] = None,
|
| 54 |
+
shift_x: int = 0,
|
| 55 |
+
shift_y: int = 0,
|
| 56 |
+
random_offset: bool = False,
|
| 57 |
+
fill_value: int = 0,
|
| 58 |
+
mask_fill_value: Optional[int] = None,
|
| 59 |
+
always_apply: bool = False,
|
| 60 |
+
p: float = 0.5,
|
| 61 |
+
):
|
| 62 |
+
super(GridDropout, self).__init__(always_apply, p)
|
| 63 |
+
self.ratio = ratio
|
| 64 |
+
self.unit_size_min = unit_size_min
|
| 65 |
+
self.unit_size_max = unit_size_max
|
| 66 |
+
self.holes_number_x = holes_number_x
|
| 67 |
+
self.holes_number_y = holes_number_y
|
| 68 |
+
self.shift_x = shift_x
|
| 69 |
+
self.shift_y = shift_y
|
| 70 |
+
self.random_offset = random_offset
|
| 71 |
+
self.fill_value = fill_value
|
| 72 |
+
self.mask_fill_value = mask_fill_value
|
| 73 |
+
if not 0 < self.ratio <= 1:
|
| 74 |
+
raise ValueError("ratio must be between 0 and 1.")
|
| 75 |
+
|
| 76 |
+
def apply(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray:
|
| 77 |
+
return F.cutout(img, holes, self.fill_value)
|
| 78 |
+
|
| 79 |
+
def apply_to_mask(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray:
|
| 80 |
+
if self.mask_fill_value is None:
|
| 81 |
+
return img
|
| 82 |
+
|
| 83 |
+
return F.cutout(img, holes, self.mask_fill_value)
|
| 84 |
+
|
| 85 |
+
def get_params_dependent_on_targets(self, params):
|
| 86 |
+
img = params["image"]
|
| 87 |
+
height, width = img.shape[:2]
|
| 88 |
+
# set grid using unit size limits
|
| 89 |
+
if self.unit_size_min and self.unit_size_max:
|
| 90 |
+
if not 2 <= self.unit_size_min <= self.unit_size_max:
|
| 91 |
+
raise ValueError("Max unit size should be >= min size, both at least 2 pixels.")
|
| 92 |
+
if self.unit_size_max > min(height, width):
|
| 93 |
+
raise ValueError("Grid size limits must be within the shortest image edge.")
|
| 94 |
+
unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1)
|
| 95 |
+
unit_height = unit_width
|
| 96 |
+
else:
|
| 97 |
+
# set grid using holes numbers
|
| 98 |
+
if self.holes_number_x is None:
|
| 99 |
+
unit_width = max(2, width // 10)
|
| 100 |
+
else:
|
| 101 |
+
if not 1 <= self.holes_number_x <= width // 2:
|
| 102 |
+
raise ValueError("The hole_number_x must be between 1 and image width//2.")
|
| 103 |
+
unit_width = width // self.holes_number_x
|
| 104 |
+
if self.holes_number_y is None:
|
| 105 |
+
unit_height = max(min(unit_width, height), 2)
|
| 106 |
+
else:
|
| 107 |
+
if not 1 <= self.holes_number_y <= height // 2:
|
| 108 |
+
raise ValueError("The hole_number_y must be between 1 and image height//2.")
|
| 109 |
+
unit_height = height // self.holes_number_y
|
| 110 |
+
|
| 111 |
+
hole_width = int(unit_width * self.ratio)
|
| 112 |
+
hole_height = int(unit_height * self.ratio)
|
| 113 |
+
# min 1 pixel and max unit length - 1
|
| 114 |
+
hole_width = min(max(hole_width, 1), unit_width - 1)
|
| 115 |
+
hole_height = min(max(hole_height, 1), unit_height - 1)
|
| 116 |
+
# set offset of the grid
|
| 117 |
+
if self.shift_x is None:
|
| 118 |
+
shift_x = 0
|
| 119 |
+
else:
|
| 120 |
+
shift_x = min(max(0, self.shift_x), unit_width - hole_width)
|
| 121 |
+
if self.shift_y is None:
|
| 122 |
+
shift_y = 0
|
| 123 |
+
else:
|
| 124 |
+
shift_y = min(max(0, self.shift_y), unit_height - hole_height)
|
| 125 |
+
if self.random_offset:
|
| 126 |
+
shift_x = random.randint(0, unit_width - hole_width)
|
| 127 |
+
shift_y = random.randint(0, unit_height - hole_height)
|
| 128 |
+
holes = []
|
| 129 |
+
for i in range(width // unit_width + 1):
|
| 130 |
+
for j in range(height // unit_height + 1):
|
| 131 |
+
x1 = min(shift_x + unit_width * i, width)
|
| 132 |
+
y1 = min(shift_y + unit_height * j, height)
|
| 133 |
+
x2 = min(x1 + hole_width, width)
|
| 134 |
+
y2 = min(y1 + hole_height, height)
|
| 135 |
+
holes.append((x1, y1, x2, y2))
|
| 136 |
+
|
| 137 |
+
return {"holes": holes}
|
| 138 |
+
|
| 139 |
+
@property
|
| 140 |
+
def targets_as_params(self):
|
| 141 |
+
return ["image"]
|
| 142 |
+
|
| 143 |
+
def get_transform_init_args_names(self):
|
| 144 |
+
return (
|
| 145 |
+
"ratio",
|
| 146 |
+
"unit_size_min",
|
| 147 |
+
"unit_size_max",
|
| 148 |
+
"holes_number_x",
|
| 149 |
+
"holes_number_y",
|
| 150 |
+
"shift_x",
|
| 151 |
+
"shift_y",
|
| 152 |
+
"random_offset",
|
| 153 |
+
"fill_value",
|
| 154 |
+
"mask_fill_value",
|
| 155 |
+
)
|
custom_albumentations/augmentations/dropout/mask_dropout.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from skimage.measure import label
|
| 7 |
+
|
| 8 |
+
from ...core.transforms_interface import DualTransform, to_tuple
|
| 9 |
+
|
| 10 |
+
__all__ = ["MaskDropout"]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MaskDropout(DualTransform):
|
| 14 |
+
"""
|
| 15 |
+
Image & mask augmentation that zero out mask and image regions corresponding
|
| 16 |
+
to randomly chosen object instance from mask.
|
| 17 |
+
|
| 18 |
+
Mask must be single-channel image, zero values treated as background.
|
| 19 |
+
Image can be any number of channels.
|
| 20 |
+
|
| 21 |
+
Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
|
| 25 |
+
image_fill_value: Fill value to use when filling image.
|
| 26 |
+
Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
|
| 27 |
+
mask_fill_value: Fill value to use when filling mask.
|
| 28 |
+
|
| 29 |
+
Targets:
|
| 30 |
+
image, mask
|
| 31 |
+
|
| 32 |
+
Image types:
|
| 33 |
+
uint8, float32
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
max_objects: int = 1,
|
| 39 |
+
image_fill_value: Union[int, float, str] = 0,
|
| 40 |
+
mask_fill_value: Union[int, float] = 0,
|
| 41 |
+
always_apply: bool = False,
|
| 42 |
+
p: float = 0.5,
|
| 43 |
+
):
|
| 44 |
+
super(MaskDropout, self).__init__(always_apply, p)
|
| 45 |
+
self.max_objects = to_tuple(max_objects, 1)
|
| 46 |
+
self.image_fill_value = image_fill_value
|
| 47 |
+
self.mask_fill_value = mask_fill_value
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def targets_as_params(self):
|
| 51 |
+
return ["mask"]
|
| 52 |
+
|
| 53 |
+
def get_params_dependent_on_targets(self, params) -> Dict[str, Any]:
|
| 54 |
+
mask = params["mask"]
|
| 55 |
+
|
| 56 |
+
label_image, num_labels = label(mask, return_num=True)
|
| 57 |
+
|
| 58 |
+
if num_labels == 0:
|
| 59 |
+
dropout_mask = None
|
| 60 |
+
else:
|
| 61 |
+
objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1]))
|
| 62 |
+
objects_to_drop = min(num_labels, objects_to_drop)
|
| 63 |
+
|
| 64 |
+
if objects_to_drop == num_labels:
|
| 65 |
+
dropout_mask = mask > 0
|
| 66 |
+
else:
|
| 67 |
+
labels_index = random.sample(range(1, num_labels + 1), objects_to_drop)
|
| 68 |
+
dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool)
|
| 69 |
+
for label_index in labels_index:
|
| 70 |
+
dropout_mask |= label_image == label_index
|
| 71 |
+
|
| 72 |
+
params.update({"dropout_mask": dropout_mask})
|
| 73 |
+
return params
|
| 74 |
+
|
| 75 |
+
def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
|
| 76 |
+
if dropout_mask is None:
|
| 77 |
+
return img
|
| 78 |
+
|
| 79 |
+
if self.image_fill_value == "inpaint":
|
| 80 |
+
dropout_mask = dropout_mask.astype(np.uint8)
|
| 81 |
+
_, _, w, h = cv2.boundingRect(dropout_mask)
|
| 82 |
+
radius = min(3, max(w, h) // 2)
|
| 83 |
+
img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
|
| 84 |
+
else:
|
| 85 |
+
img = img.copy()
|
| 86 |
+
img[dropout_mask] = self.image_fill_value
|
| 87 |
+
|
| 88 |
+
return img
|
| 89 |
+
|
| 90 |
+
def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
|
| 91 |
+
if dropout_mask is None:
|
| 92 |
+
return img
|
| 93 |
+
|
| 94 |
+
img = img.copy()
|
| 95 |
+
img[dropout_mask] = self.mask_fill_value
|
| 96 |
+
return img
|
| 97 |
+
|
| 98 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 99 |
+
return "max_objects", "image_fill_value", "mask_fill_value"
|
custom_albumentations/augmentations/functional.py
ADDED
|
@@ -0,0 +1,1380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Sequence, Union
|
| 4 |
+
from warnings import warn
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import skimage
|
| 9 |
+
|
| 10 |
+
from custom_albumentations import random_utils
|
| 11 |
+
from custom_albumentations.augmentations.utils import (
|
| 12 |
+
MAX_VALUES_BY_DTYPE,
|
| 13 |
+
_maybe_process_in_chunks,
|
| 14 |
+
clip,
|
| 15 |
+
clipped,
|
| 16 |
+
ensure_contiguous,
|
| 17 |
+
is_grayscale_image,
|
| 18 |
+
is_rgb_image,
|
| 19 |
+
non_rgb_warning,
|
| 20 |
+
preserve_channel_dim,
|
| 21 |
+
preserve_shape,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"add_fog",
|
| 26 |
+
"add_rain",
|
| 27 |
+
"add_shadow",
|
| 28 |
+
"add_gravel",
|
| 29 |
+
"add_snow",
|
| 30 |
+
"add_sun_flare",
|
| 31 |
+
"add_weighted",
|
| 32 |
+
"adjust_brightness_torchvision",
|
| 33 |
+
"adjust_contrast_torchvision",
|
| 34 |
+
"adjust_hue_torchvision",
|
| 35 |
+
"adjust_saturation_torchvision",
|
| 36 |
+
"brightness_contrast_adjust",
|
| 37 |
+
"channel_shuffle",
|
| 38 |
+
"clahe",
|
| 39 |
+
"convolve",
|
| 40 |
+
"downscale",
|
| 41 |
+
"equalize",
|
| 42 |
+
"fancy_pca",
|
| 43 |
+
"from_float",
|
| 44 |
+
"gamma_transform",
|
| 45 |
+
"gauss_noise",
|
| 46 |
+
"image_compression",
|
| 47 |
+
"invert",
|
| 48 |
+
"iso_noise",
|
| 49 |
+
"linear_transformation_rgb",
|
| 50 |
+
"move_tone_curve",
|
| 51 |
+
"multiply",
|
| 52 |
+
"noop",
|
| 53 |
+
"normalize",
|
| 54 |
+
"posterize",
|
| 55 |
+
"shift_hsv",
|
| 56 |
+
"shift_rgb",
|
| 57 |
+
"solarize",
|
| 58 |
+
"superpixels",
|
| 59 |
+
"swap_tiles_on_image",
|
| 60 |
+
"to_float",
|
| 61 |
+
"to_gray",
|
| 62 |
+
"gray_to_rgb",
|
| 63 |
+
"unsharp_mask",
|
| 64 |
+
]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def normalize_cv2(img, mean, denominator):
|
| 68 |
+
if mean.shape and len(mean) != 4 and mean.shape != img.shape:
|
| 69 |
+
mean = np.array(mean.tolist() + [0] * (4 - len(mean)), dtype=np.float64)
|
| 70 |
+
if not denominator.shape:
|
| 71 |
+
denominator = np.array([denominator.tolist()] * 4, dtype=np.float64)
|
| 72 |
+
elif len(denominator) != 4 and denominator.shape != img.shape:
|
| 73 |
+
denominator = np.array(denominator.tolist() + [1] * (4 - len(denominator)), dtype=np.float64)
|
| 74 |
+
|
| 75 |
+
img = np.ascontiguousarray(img.astype("float32"))
|
| 76 |
+
cv2.subtract(img, mean.astype(np.float64), img)
|
| 77 |
+
cv2.multiply(img, denominator.astype(np.float64), img)
|
| 78 |
+
return img
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def normalize_numpy(img, mean, denominator):
|
| 82 |
+
img = img.astype(np.float32)
|
| 83 |
+
img -= mean
|
| 84 |
+
img *= denominator
|
| 85 |
+
return img
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def normalize(img, mean, std, max_pixel_value=255.0):
|
| 89 |
+
mean = np.array(mean, dtype=np.float32)
|
| 90 |
+
mean *= max_pixel_value
|
| 91 |
+
|
| 92 |
+
std = np.array(std, dtype=np.float32)
|
| 93 |
+
std *= max_pixel_value
|
| 94 |
+
|
| 95 |
+
denominator = np.reciprocal(std, dtype=np.float32)
|
| 96 |
+
|
| 97 |
+
if img.ndim == 3 and img.shape[-1] == 3:
|
| 98 |
+
return normalize_cv2(img, mean, denominator)
|
| 99 |
+
return normalize_numpy(img, mean, denominator)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift):
|
| 103 |
+
dtype = img.dtype
|
| 104 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
| 105 |
+
hue, sat, val = cv2.split(img)
|
| 106 |
+
|
| 107 |
+
if hue_shift != 0:
|
| 108 |
+
lut_hue = np.arange(0, 256, dtype=np.int16)
|
| 109 |
+
lut_hue = np.mod(lut_hue + hue_shift, 180).astype(dtype)
|
| 110 |
+
hue = cv2.LUT(hue, lut_hue)
|
| 111 |
+
|
| 112 |
+
if sat_shift != 0:
|
| 113 |
+
lut_sat = np.arange(0, 256, dtype=np.int16)
|
| 114 |
+
lut_sat = np.clip(lut_sat + sat_shift, 0, 255).astype(dtype)
|
| 115 |
+
sat = cv2.LUT(sat, lut_sat)
|
| 116 |
+
|
| 117 |
+
if val_shift != 0:
|
| 118 |
+
lut_val = np.arange(0, 256, dtype=np.int16)
|
| 119 |
+
lut_val = np.clip(lut_val + val_shift, 0, 255).astype(dtype)
|
| 120 |
+
val = cv2.LUT(val, lut_val)
|
| 121 |
+
|
| 122 |
+
img = cv2.merge((hue, sat, val)).astype(dtype)
|
| 123 |
+
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
|
| 124 |
+
return img
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift):
|
| 128 |
+
dtype = img.dtype
|
| 129 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
| 130 |
+
hue, sat, val = cv2.split(img)
|
| 131 |
+
|
| 132 |
+
if hue_shift != 0:
|
| 133 |
+
hue = cv2.add(hue, hue_shift)
|
| 134 |
+
hue = np.mod(hue, 360) # OpenCV fails with negative values
|
| 135 |
+
|
| 136 |
+
if sat_shift != 0:
|
| 137 |
+
sat = clip(cv2.add(sat, sat_shift), dtype, 1.0)
|
| 138 |
+
|
| 139 |
+
if val_shift != 0:
|
| 140 |
+
val = clip(cv2.add(val, val_shift), dtype, 1.0)
|
| 141 |
+
|
| 142 |
+
img = cv2.merge((hue, sat, val))
|
| 143 |
+
img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
|
| 144 |
+
return img
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@preserve_shape
|
| 148 |
+
def shift_hsv(img, hue_shift, sat_shift, val_shift):
|
| 149 |
+
if hue_shift == 0 and sat_shift == 0 and val_shift == 0:
|
| 150 |
+
return img
|
| 151 |
+
|
| 152 |
+
is_gray = is_grayscale_image(img)
|
| 153 |
+
if is_gray:
|
| 154 |
+
if hue_shift != 0 or sat_shift != 0:
|
| 155 |
+
hue_shift = 0
|
| 156 |
+
sat_shift = 0
|
| 157 |
+
warn(
|
| 158 |
+
"HueSaturationValue: hue_shift and sat_shift are not applicable to grayscale image. "
|
| 159 |
+
"Set them to 0 or use RGB image"
|
| 160 |
+
)
|
| 161 |
+
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 162 |
+
|
| 163 |
+
if img.dtype == np.uint8:
|
| 164 |
+
img = _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift)
|
| 165 |
+
else:
|
| 166 |
+
img = _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift)
|
| 167 |
+
|
| 168 |
+
if is_gray:
|
| 169 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 170 |
+
|
| 171 |
+
return img
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def solarize(img, threshold=128):
|
| 175 |
+
"""Invert all pixel values above a threshold.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
img (numpy.ndarray): The image to solarize.
|
| 179 |
+
threshold (int): All pixels above this greyscale level are inverted.
|
| 180 |
+
|
| 181 |
+
Returns:
|
| 182 |
+
numpy.ndarray: Solarized image.
|
| 183 |
+
|
| 184 |
+
"""
|
| 185 |
+
dtype = img.dtype
|
| 186 |
+
max_val = MAX_VALUES_BY_DTYPE[dtype]
|
| 187 |
+
|
| 188 |
+
if dtype == np.dtype("uint8"):
|
| 189 |
+
lut = [(i if i < threshold else max_val - i) for i in range(max_val + 1)]
|
| 190 |
+
|
| 191 |
+
prev_shape = img.shape
|
| 192 |
+
img = cv2.LUT(img, np.array(lut, dtype=dtype))
|
| 193 |
+
|
| 194 |
+
if len(prev_shape) != len(img.shape):
|
| 195 |
+
img = np.expand_dims(img, -1)
|
| 196 |
+
return img
|
| 197 |
+
|
| 198 |
+
result_img = img.copy()
|
| 199 |
+
cond = img >= threshold
|
| 200 |
+
result_img[cond] = max_val - result_img[cond]
|
| 201 |
+
return result_img
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@preserve_shape
|
| 205 |
+
def posterize(img, bits):
|
| 206 |
+
"""Reduce the number of bits for each color channel.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
img (numpy.ndarray): image to posterize.
|
| 210 |
+
bits (int): number of high bits. Must be in range [0, 8]
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
numpy.ndarray: Image with reduced color channels.
|
| 214 |
+
|
| 215 |
+
"""
|
| 216 |
+
bits = np.uint8(bits)
|
| 217 |
+
|
| 218 |
+
if img.dtype != np.uint8:
|
| 219 |
+
raise TypeError("Image must have uint8 channel type")
|
| 220 |
+
if np.any((bits < 0) | (bits > 8)):
|
| 221 |
+
raise ValueError("bits must be in range [0, 8]")
|
| 222 |
+
|
| 223 |
+
if not bits.shape or len(bits) == 1:
|
| 224 |
+
if bits == 0:
|
| 225 |
+
return np.zeros_like(img)
|
| 226 |
+
if bits == 8:
|
| 227 |
+
return img.copy()
|
| 228 |
+
|
| 229 |
+
lut = np.arange(0, 256, dtype=np.uint8)
|
| 230 |
+
mask = ~np.uint8(2 ** (8 - bits) - 1)
|
| 231 |
+
lut &= mask
|
| 232 |
+
|
| 233 |
+
return cv2.LUT(img, lut)
|
| 234 |
+
|
| 235 |
+
if not is_rgb_image(img):
|
| 236 |
+
raise TypeError("If bits is iterable image must be RGB")
|
| 237 |
+
|
| 238 |
+
result_img = np.empty_like(img)
|
| 239 |
+
for i, channel_bits in enumerate(bits):
|
| 240 |
+
if channel_bits == 0:
|
| 241 |
+
result_img[..., i] = np.zeros_like(img[..., i])
|
| 242 |
+
elif channel_bits == 8:
|
| 243 |
+
result_img[..., i] = img[..., i].copy()
|
| 244 |
+
else:
|
| 245 |
+
lut = np.arange(0, 256, dtype=np.uint8)
|
| 246 |
+
mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
|
| 247 |
+
lut &= mask
|
| 248 |
+
|
| 249 |
+
result_img[..., i] = cv2.LUT(img[..., i], lut)
|
| 250 |
+
|
| 251 |
+
return result_img
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _equalize_pil(img, mask=None):
|
| 255 |
+
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
|
| 256 |
+
h = [_f for _f in histogram if _f]
|
| 257 |
+
|
| 258 |
+
if len(h) <= 1:
|
| 259 |
+
return img.copy()
|
| 260 |
+
|
| 261 |
+
step = np.sum(h[:-1]) // 255
|
| 262 |
+
if not step:
|
| 263 |
+
return img.copy()
|
| 264 |
+
|
| 265 |
+
lut = np.empty(256, dtype=np.uint8)
|
| 266 |
+
n = step // 2
|
| 267 |
+
for i in range(256):
|
| 268 |
+
lut[i] = min(n // step, 255)
|
| 269 |
+
n += histogram[i]
|
| 270 |
+
|
| 271 |
+
return cv2.LUT(img, np.array(lut))
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def _equalize_cv(img, mask=None):
|
| 275 |
+
if mask is None:
|
| 276 |
+
return cv2.equalizeHist(img)
|
| 277 |
+
|
| 278 |
+
histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
|
| 279 |
+
i = 0
|
| 280 |
+
for val in histogram:
|
| 281 |
+
if val > 0:
|
| 282 |
+
break
|
| 283 |
+
i += 1
|
| 284 |
+
i = min(i, 255)
|
| 285 |
+
|
| 286 |
+
total = np.sum(histogram)
|
| 287 |
+
if histogram[i] == total:
|
| 288 |
+
return np.full_like(img, i)
|
| 289 |
+
|
| 290 |
+
scale = 255.0 / (total - histogram[i])
|
| 291 |
+
_sum = 0
|
| 292 |
+
|
| 293 |
+
lut = np.zeros(256, dtype=np.uint8)
|
| 294 |
+
i += 1
|
| 295 |
+
for i in range(i, len(histogram)):
|
| 296 |
+
_sum += histogram[i]
|
| 297 |
+
lut[i] = clip(round(_sum * scale), np.dtype("uint8"), 255)
|
| 298 |
+
|
| 299 |
+
return cv2.LUT(img, lut)
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@preserve_channel_dim
|
| 303 |
+
def equalize(img, mask=None, mode="cv", by_channels=True):
|
| 304 |
+
"""Equalize the image histogram.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
img (numpy.ndarray): RGB or grayscale image.
|
| 308 |
+
mask (numpy.ndarray): An optional mask. If given, only the pixels selected by
|
| 309 |
+
the mask are included in the analysis. Maybe 1 channel or 3 channel array.
|
| 310 |
+
mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
|
| 311 |
+
by_channels (bool): If True, use equalization by channels separately,
|
| 312 |
+
else convert image to YCbCr representation and use equalization by `Y` channel.
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
numpy.ndarray: Equalized image.
|
| 316 |
+
|
| 317 |
+
"""
|
| 318 |
+
if img.dtype != np.uint8:
|
| 319 |
+
raise TypeError("Image must have uint8 channel type")
|
| 320 |
+
|
| 321 |
+
modes = ["cv", "pil"]
|
| 322 |
+
|
| 323 |
+
if mode not in modes:
|
| 324 |
+
raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode))
|
| 325 |
+
if mask is not None:
|
| 326 |
+
if is_rgb_image(mask) and is_grayscale_image(img):
|
| 327 |
+
raise ValueError("Wrong mask shape. Image shape: {}. " "Mask shape: {}".format(img.shape, mask.shape))
|
| 328 |
+
if not by_channels and not is_grayscale_image(mask):
|
| 329 |
+
raise ValueError(
|
| 330 |
+
"When by_channels=False only 1-channel mask supports. " "Mask shape: {}".format(mask.shape)
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
if mode == "pil":
|
| 334 |
+
function = _equalize_pil
|
| 335 |
+
else:
|
| 336 |
+
function = _equalize_cv
|
| 337 |
+
|
| 338 |
+
if mask is not None:
|
| 339 |
+
mask = mask.astype(np.uint8)
|
| 340 |
+
|
| 341 |
+
if is_grayscale_image(img):
|
| 342 |
+
return function(img, mask)
|
| 343 |
+
|
| 344 |
+
if not by_channels:
|
| 345 |
+
result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
|
| 346 |
+
result_img[..., 0] = function(result_img[..., 0], mask)
|
| 347 |
+
return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB)
|
| 348 |
+
|
| 349 |
+
result_img = np.empty_like(img)
|
| 350 |
+
for i in range(3):
|
| 351 |
+
if mask is None:
|
| 352 |
+
_mask = None
|
| 353 |
+
elif is_grayscale_image(mask):
|
| 354 |
+
_mask = mask
|
| 355 |
+
else:
|
| 356 |
+
_mask = mask[..., i]
|
| 357 |
+
|
| 358 |
+
result_img[..., i] = function(img[..., i], _mask)
|
| 359 |
+
|
| 360 |
+
return result_img
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@preserve_shape
|
| 364 |
+
def move_tone_curve(img, low_y, high_y):
|
| 365 |
+
"""Rescales the relationship between bright and dark areas of the image by manipulating its tone curve.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
img (numpy.ndarray): RGB or grayscale image.
|
| 369 |
+
low_y (float): y-position of a Bezier control point used
|
| 370 |
+
to adjust the tone curve, must be in range [0, 1]
|
| 371 |
+
high_y (float): y-position of a Bezier control point used
|
| 372 |
+
to adjust image tone curve, must be in range [0, 1]
|
| 373 |
+
"""
|
| 374 |
+
input_dtype = img.dtype
|
| 375 |
+
|
| 376 |
+
if low_y < 0 or low_y > 1:
|
| 377 |
+
raise ValueError("low_shift must be in range [0, 1]")
|
| 378 |
+
if high_y < 0 or high_y > 1:
|
| 379 |
+
raise ValueError("high_shift must be in range [0, 1]")
|
| 380 |
+
|
| 381 |
+
if input_dtype != np.uint8:
|
| 382 |
+
raise ValueError("Unsupported image type {}".format(input_dtype))
|
| 383 |
+
|
| 384 |
+
t = np.linspace(0.0, 1.0, 256)
|
| 385 |
+
|
| 386 |
+
# Defines responze of a four-point bezier curve
|
| 387 |
+
def evaluate_bez(t):
|
| 388 |
+
return 3 * (1 - t) ** 2 * t * low_y + 3 * (1 - t) * t**2 * high_y + t**3
|
| 389 |
+
|
| 390 |
+
evaluate_bez = np.vectorize(evaluate_bez)
|
| 391 |
+
remapping = np.rint(evaluate_bez(t) * 255).astype(np.uint8)
|
| 392 |
+
|
| 393 |
+
lut_fn = _maybe_process_in_chunks(cv2.LUT, lut=remapping)
|
| 394 |
+
img = lut_fn(img)
|
| 395 |
+
return img
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
@clipped
|
| 399 |
+
def _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift):
|
| 400 |
+
if r_shift == g_shift == b_shift:
|
| 401 |
+
return img + r_shift
|
| 402 |
+
|
| 403 |
+
result_img = np.empty_like(img)
|
| 404 |
+
shifts = [r_shift, g_shift, b_shift]
|
| 405 |
+
for i, shift in enumerate(shifts):
|
| 406 |
+
result_img[..., i] = img[..., i] + shift
|
| 407 |
+
|
| 408 |
+
return result_img
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def _shift_image_uint8(img, value):
|
| 412 |
+
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
|
| 413 |
+
|
| 414 |
+
lut = np.arange(0, max_value + 1).astype("float32")
|
| 415 |
+
lut += value
|
| 416 |
+
|
| 417 |
+
lut = np.clip(lut, 0, max_value).astype(img.dtype)
|
| 418 |
+
return cv2.LUT(img, lut)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
@preserve_shape
|
| 422 |
+
def _shift_rgb_uint8(img, r_shift, g_shift, b_shift):
|
| 423 |
+
if r_shift == g_shift == b_shift:
|
| 424 |
+
h, w, c = img.shape
|
| 425 |
+
img = img.reshape([h, w * c])
|
| 426 |
+
|
| 427 |
+
return _shift_image_uint8(img, r_shift)
|
| 428 |
+
|
| 429 |
+
result_img = np.empty_like(img)
|
| 430 |
+
shifts = [r_shift, g_shift, b_shift]
|
| 431 |
+
for i, shift in enumerate(shifts):
|
| 432 |
+
result_img[..., i] = _shift_image_uint8(img[..., i], shift)
|
| 433 |
+
|
| 434 |
+
return result_img
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def shift_rgb(img, r_shift, g_shift, b_shift):
|
| 438 |
+
if img.dtype == np.uint8:
|
| 439 |
+
return _shift_rgb_uint8(img, r_shift, g_shift, b_shift)
|
| 440 |
+
|
| 441 |
+
return _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
@clipped
|
| 445 |
+
def linear_transformation_rgb(img, transformation_matrix):
|
| 446 |
+
result_img = cv2.transform(img, transformation_matrix)
|
| 447 |
+
|
| 448 |
+
return result_img
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
@preserve_channel_dim
|
| 452 |
+
def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
|
| 453 |
+
if img.dtype != np.uint8:
|
| 454 |
+
raise TypeError("clahe supports only uint8 inputs")
|
| 455 |
+
|
| 456 |
+
clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
|
| 457 |
+
|
| 458 |
+
if len(img.shape) == 2 or img.shape[2] == 1:
|
| 459 |
+
img = clahe_mat.apply(img)
|
| 460 |
+
else:
|
| 461 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
|
| 462 |
+
img[:, :, 0] = clahe_mat.apply(img[:, :, 0])
|
| 463 |
+
img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
|
| 464 |
+
|
| 465 |
+
return img
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@preserve_shape
|
| 469 |
+
def convolve(img, kernel):
|
| 470 |
+
conv_fn = _maybe_process_in_chunks(cv2.filter2D, ddepth=-1, kernel=kernel)
|
| 471 |
+
return conv_fn(img)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
@preserve_shape
|
| 475 |
+
def image_compression(img, quality, image_type):
|
| 476 |
+
if image_type in [".jpeg", ".jpg"]:
|
| 477 |
+
quality_flag = cv2.IMWRITE_JPEG_QUALITY
|
| 478 |
+
elif image_type == ".webp":
|
| 479 |
+
quality_flag = cv2.IMWRITE_WEBP_QUALITY
|
| 480 |
+
else:
|
| 481 |
+
NotImplementedError("Only '.jpg' and '.webp' compression transforms are implemented. ")
|
| 482 |
+
|
| 483 |
+
input_dtype = img.dtype
|
| 484 |
+
needs_float = False
|
| 485 |
+
|
| 486 |
+
if input_dtype == np.float32:
|
| 487 |
+
warn(
|
| 488 |
+
"Image compression augmentation "
|
| 489 |
+
"is most effective with uint8 inputs, "
|
| 490 |
+
"{} is used as input.".format(input_dtype),
|
| 491 |
+
UserWarning,
|
| 492 |
+
)
|
| 493 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 494 |
+
needs_float = True
|
| 495 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 496 |
+
raise ValueError("Unexpected dtype {} for image augmentation".format(input_dtype))
|
| 497 |
+
|
| 498 |
+
_, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
|
| 499 |
+
img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)
|
| 500 |
+
|
| 501 |
+
if needs_float:
|
| 502 |
+
img = to_float(img, max_value=255)
|
| 503 |
+
return img
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@preserve_shape
|
| 507 |
+
def add_snow(img, snow_point, brightness_coeff):
|
| 508 |
+
"""Bleaches out pixels, imitation snow.
|
| 509 |
+
|
| 510 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
img (numpy.ndarray): Image.
|
| 514 |
+
snow_point: Number of show points.
|
| 515 |
+
brightness_coeff: Brightness coefficient.
|
| 516 |
+
|
| 517 |
+
Returns:
|
| 518 |
+
numpy.ndarray: Image.
|
| 519 |
+
|
| 520 |
+
"""
|
| 521 |
+
non_rgb_warning(img)
|
| 522 |
+
|
| 523 |
+
input_dtype = img.dtype
|
| 524 |
+
needs_float = False
|
| 525 |
+
|
| 526 |
+
snow_point *= 127.5 # = 255 / 2
|
| 527 |
+
snow_point += 85 # = 255 / 3
|
| 528 |
+
|
| 529 |
+
if input_dtype == np.float32:
|
| 530 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 531 |
+
needs_float = True
|
| 532 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 533 |
+
raise ValueError("Unexpected dtype {} for RandomSnow augmentation".format(input_dtype))
|
| 534 |
+
|
| 535 |
+
image_HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
|
| 536 |
+
image_HLS = np.array(image_HLS, dtype=np.float32)
|
| 537 |
+
|
| 538 |
+
image_HLS[:, :, 1][image_HLS[:, :, 1] < snow_point] *= brightness_coeff
|
| 539 |
+
|
| 540 |
+
image_HLS[:, :, 1] = clip(image_HLS[:, :, 1], np.uint8, 255)
|
| 541 |
+
|
| 542 |
+
image_HLS = np.array(image_HLS, dtype=np.uint8)
|
| 543 |
+
|
| 544 |
+
image_RGB = cv2.cvtColor(image_HLS, cv2.COLOR_HLS2RGB)
|
| 545 |
+
|
| 546 |
+
if needs_float:
|
| 547 |
+
image_RGB = to_float(image_RGB, max_value=255)
|
| 548 |
+
|
| 549 |
+
return image_RGB
|
| 550 |
+
|
| 551 |
+
|
| 552 |
+
@preserve_shape
|
| 553 |
+
def add_rain(
|
| 554 |
+
img,
|
| 555 |
+
slant,
|
| 556 |
+
drop_length,
|
| 557 |
+
drop_width,
|
| 558 |
+
drop_color,
|
| 559 |
+
blur_value,
|
| 560 |
+
brightness_coefficient,
|
| 561 |
+
rain_drops,
|
| 562 |
+
):
|
| 563 |
+
"""
|
| 564 |
+
|
| 565 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 566 |
+
|
| 567 |
+
Args:
|
| 568 |
+
img (numpy.ndarray): Image.
|
| 569 |
+
slant (int):
|
| 570 |
+
drop_length:
|
| 571 |
+
drop_width:
|
| 572 |
+
drop_color:
|
| 573 |
+
blur_value (int): Rainy view are blurry.
|
| 574 |
+
brightness_coefficient (float): Rainy days are usually shady.
|
| 575 |
+
rain_drops:
|
| 576 |
+
|
| 577 |
+
Returns:
|
| 578 |
+
numpy.ndarray: Image.
|
| 579 |
+
|
| 580 |
+
"""
|
| 581 |
+
non_rgb_warning(img)
|
| 582 |
+
|
| 583 |
+
input_dtype = img.dtype
|
| 584 |
+
needs_float = False
|
| 585 |
+
|
| 586 |
+
if input_dtype == np.float32:
|
| 587 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 588 |
+
needs_float = True
|
| 589 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 590 |
+
raise ValueError("Unexpected dtype {} for RandomRain augmentation".format(input_dtype))
|
| 591 |
+
|
| 592 |
+
image = img.copy()
|
| 593 |
+
|
| 594 |
+
for rain_drop_x0, rain_drop_y0 in rain_drops:
|
| 595 |
+
rain_drop_x1 = rain_drop_x0 + slant
|
| 596 |
+
rain_drop_y1 = rain_drop_y0 + drop_length
|
| 597 |
+
|
| 598 |
+
cv2.line(
|
| 599 |
+
image,
|
| 600 |
+
(rain_drop_x0, rain_drop_y0),
|
| 601 |
+
(rain_drop_x1, rain_drop_y1),
|
| 602 |
+
drop_color,
|
| 603 |
+
drop_width,
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
image = cv2.blur(image, (blur_value, blur_value)) # rainy view are blurry
|
| 607 |
+
image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
|
| 608 |
+
image_hsv[:, :, 2] *= brightness_coefficient
|
| 609 |
+
|
| 610 |
+
image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
|
| 611 |
+
|
| 612 |
+
if needs_float:
|
| 613 |
+
image_rgb = to_float(image_rgb, max_value=255)
|
| 614 |
+
|
| 615 |
+
return image_rgb
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@preserve_shape
|
| 619 |
+
def add_fog(img, fog_coef, alpha_coef, haze_list):
|
| 620 |
+
"""Add fog to the image.
|
| 621 |
+
|
| 622 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
img (numpy.ndarray): Image.
|
| 626 |
+
fog_coef (float): Fog coefficient.
|
| 627 |
+
alpha_coef (float): Alpha coefficient.
|
| 628 |
+
haze_list (list):
|
| 629 |
+
|
| 630 |
+
Returns:
|
| 631 |
+
numpy.ndarray: Image.
|
| 632 |
+
|
| 633 |
+
"""
|
| 634 |
+
non_rgb_warning(img)
|
| 635 |
+
|
| 636 |
+
input_dtype = img.dtype
|
| 637 |
+
needs_float = False
|
| 638 |
+
|
| 639 |
+
if input_dtype == np.float32:
|
| 640 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 641 |
+
needs_float = True
|
| 642 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 643 |
+
raise ValueError("Unexpected dtype {} for RandomFog augmentation".format(input_dtype))
|
| 644 |
+
|
| 645 |
+
width = img.shape[1]
|
| 646 |
+
|
| 647 |
+
hw = max(int(width // 3 * fog_coef), 10)
|
| 648 |
+
|
| 649 |
+
for haze_points in haze_list:
|
| 650 |
+
x, y = haze_points
|
| 651 |
+
overlay = img.copy()
|
| 652 |
+
output = img.copy()
|
| 653 |
+
alpha = alpha_coef * fog_coef
|
| 654 |
+
rad = hw // 2
|
| 655 |
+
point = (x + hw // 2, y + hw // 2)
|
| 656 |
+
cv2.circle(overlay, point, int(rad), (255, 255, 255), -1)
|
| 657 |
+
cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)
|
| 658 |
+
|
| 659 |
+
img = output.copy()
|
| 660 |
+
|
| 661 |
+
image_rgb = cv2.blur(img, (hw // 10, hw // 10))
|
| 662 |
+
|
| 663 |
+
if needs_float:
|
| 664 |
+
image_rgb = to_float(image_rgb, max_value=255)
|
| 665 |
+
|
| 666 |
+
return image_rgb
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
@preserve_shape
|
| 670 |
+
def add_sun_flare(img, flare_center_x, flare_center_y, src_radius, src_color, circles):
|
| 671 |
+
"""Add sun flare.
|
| 672 |
+
|
| 673 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 674 |
+
|
| 675 |
+
Args:
|
| 676 |
+
img (numpy.ndarray):
|
| 677 |
+
flare_center_x (float):
|
| 678 |
+
flare_center_y (float):
|
| 679 |
+
src_radius:
|
| 680 |
+
src_color (int, int, int):
|
| 681 |
+
circles (list):
|
| 682 |
+
|
| 683 |
+
Returns:
|
| 684 |
+
numpy.ndarray:
|
| 685 |
+
|
| 686 |
+
"""
|
| 687 |
+
non_rgb_warning(img)
|
| 688 |
+
|
| 689 |
+
input_dtype = img.dtype
|
| 690 |
+
needs_float = False
|
| 691 |
+
|
| 692 |
+
if input_dtype == np.float32:
|
| 693 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 694 |
+
needs_float = True
|
| 695 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 696 |
+
raise ValueError("Unexpected dtype {} for RandomSunFlareaugmentation".format(input_dtype))
|
| 697 |
+
|
| 698 |
+
overlay = img.copy()
|
| 699 |
+
output = img.copy()
|
| 700 |
+
|
| 701 |
+
for alpha, (x, y), rad3, (r_color, g_color, b_color) in circles:
|
| 702 |
+
cv2.circle(overlay, (x, y), rad3, (r_color, g_color, b_color), -1)
|
| 703 |
+
|
| 704 |
+
cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)
|
| 705 |
+
|
| 706 |
+
point = (int(flare_center_x), int(flare_center_y))
|
| 707 |
+
|
| 708 |
+
overlay = output.copy()
|
| 709 |
+
num_times = src_radius // 10
|
| 710 |
+
alpha = np.linspace(0.0, 1, num=num_times)
|
| 711 |
+
rad = np.linspace(1, src_radius, num=num_times)
|
| 712 |
+
for i in range(num_times):
|
| 713 |
+
cv2.circle(overlay, point, int(rad[i]), src_color, -1)
|
| 714 |
+
alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
|
| 715 |
+
cv2.addWeighted(overlay, alp, output, 1 - alp, 0, output)
|
| 716 |
+
|
| 717 |
+
image_rgb = output
|
| 718 |
+
|
| 719 |
+
if needs_float:
|
| 720 |
+
image_rgb = to_float(image_rgb, max_value=255)
|
| 721 |
+
|
| 722 |
+
return image_rgb
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
@ensure_contiguous
|
| 726 |
+
@preserve_shape
|
| 727 |
+
def add_shadow(img, vertices_list):
|
| 728 |
+
"""Add shadows to the image.
|
| 729 |
+
|
| 730 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 731 |
+
|
| 732 |
+
Args:
|
| 733 |
+
img (numpy.ndarray):
|
| 734 |
+
vertices_list (list):
|
| 735 |
+
|
| 736 |
+
Returns:
|
| 737 |
+
numpy.ndarray:
|
| 738 |
+
|
| 739 |
+
"""
|
| 740 |
+
non_rgb_warning(img)
|
| 741 |
+
input_dtype = img.dtype
|
| 742 |
+
needs_float = False
|
| 743 |
+
|
| 744 |
+
if input_dtype == np.float32:
|
| 745 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 746 |
+
needs_float = True
|
| 747 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 748 |
+
raise ValueError("Unexpected dtype {} for RandomShadow augmentation".format(input_dtype))
|
| 749 |
+
|
| 750 |
+
image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
|
| 751 |
+
mask = np.zeros_like(img)
|
| 752 |
+
|
| 753 |
+
# adding all shadow polygons on empty mask, single 255 denotes only red channel
|
| 754 |
+
for vertices in vertices_list:
|
| 755 |
+
cv2.fillPoly(mask, vertices, 255)
|
| 756 |
+
|
| 757 |
+
# if red channel is hot, image's "Lightness" channel's brightness is lowered
|
| 758 |
+
red_max_value_ind = mask[:, :, 0] == 255
|
| 759 |
+
image_hls[:, :, 1][red_max_value_ind] = image_hls[:, :, 1][red_max_value_ind] * 0.5
|
| 760 |
+
|
| 761 |
+
image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
|
| 762 |
+
|
| 763 |
+
if needs_float:
|
| 764 |
+
image_rgb = to_float(image_rgb, max_value=255)
|
| 765 |
+
|
| 766 |
+
return image_rgb
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
@ensure_contiguous
|
| 770 |
+
@preserve_shape
|
| 771 |
+
def add_gravel(img: np.ndarray, gravels: list):
|
| 772 |
+
"""Add gravel to the image.
|
| 773 |
+
|
| 774 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 775 |
+
|
| 776 |
+
Args:
|
| 777 |
+
img (numpy.ndarray): image to add gravel to
|
| 778 |
+
gravels (list): list of gravel parameters. (float, float, float, float):
|
| 779 |
+
(top-left x, top-left y, bottom-right x, bottom right y)
|
| 780 |
+
|
| 781 |
+
Returns:
|
| 782 |
+
numpy.ndarray:
|
| 783 |
+
"""
|
| 784 |
+
non_rgb_warning(img)
|
| 785 |
+
input_dtype = img.dtype
|
| 786 |
+
needs_float = False
|
| 787 |
+
|
| 788 |
+
if input_dtype == np.float32:
|
| 789 |
+
img = from_float(img, dtype=np.dtype("uint8"))
|
| 790 |
+
needs_float = True
|
| 791 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 792 |
+
raise ValueError("Unexpected dtype {} for AddGravel augmentation".format(input_dtype))
|
| 793 |
+
|
| 794 |
+
image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
|
| 795 |
+
|
| 796 |
+
for gravel in gravels:
|
| 797 |
+
y1, y2, x1, x2, sat = gravel
|
| 798 |
+
image_hls[x1:x2, y1:y2, 1] = sat
|
| 799 |
+
|
| 800 |
+
image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
|
| 801 |
+
|
| 802 |
+
if needs_float:
|
| 803 |
+
image_rgb = to_float(image_rgb, max_value=255)
|
| 804 |
+
|
| 805 |
+
return image_rgb
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def invert(img: np.ndarray) -> np.ndarray:
|
| 809 |
+
# Supports all the valid dtypes
|
| 810 |
+
# clips the img to avoid unexpected behaviour.
|
| 811 |
+
return MAX_VALUES_BY_DTYPE[img.dtype] - img
|
| 812 |
+
|
| 813 |
+
|
| 814 |
+
def channel_shuffle(img, channels_shuffled):
|
| 815 |
+
img = img[..., channels_shuffled]
|
| 816 |
+
return img
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
@preserve_shape
|
| 820 |
+
def gamma_transform(img, gamma):
|
| 821 |
+
if img.dtype == np.uint8:
|
| 822 |
+
table = (np.arange(0, 256.0 / 255, 1.0 / 255) ** gamma) * 255
|
| 823 |
+
img = cv2.LUT(img, table.astype(np.uint8))
|
| 824 |
+
else:
|
| 825 |
+
img = np.power(img, gamma)
|
| 826 |
+
|
| 827 |
+
return img
|
| 828 |
+
|
| 829 |
+
|
| 830 |
+
@clipped
|
| 831 |
+
def gauss_noise(image, gauss):
|
| 832 |
+
image = image.astype("float32")
|
| 833 |
+
return image + gauss
|
| 834 |
+
|
| 835 |
+
|
| 836 |
+
@clipped
|
| 837 |
+
def _brightness_contrast_adjust_non_uint(img, alpha=1, beta=0, beta_by_max=False):
|
| 838 |
+
dtype = img.dtype
|
| 839 |
+
img = img.astype("float32")
|
| 840 |
+
|
| 841 |
+
if alpha != 1:
|
| 842 |
+
img *= alpha
|
| 843 |
+
if beta != 0:
|
| 844 |
+
if beta_by_max:
|
| 845 |
+
max_value = MAX_VALUES_BY_DTYPE[dtype]
|
| 846 |
+
img += beta * max_value
|
| 847 |
+
else:
|
| 848 |
+
img += beta * np.mean(img)
|
| 849 |
+
return img
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
@preserve_shape
|
| 853 |
+
def _brightness_contrast_adjust_uint(img, alpha=1, beta=0, beta_by_max=False):
|
| 854 |
+
dtype = np.dtype("uint8")
|
| 855 |
+
|
| 856 |
+
max_value = MAX_VALUES_BY_DTYPE[dtype]
|
| 857 |
+
|
| 858 |
+
lut = np.arange(0, max_value + 1).astype("float32")
|
| 859 |
+
|
| 860 |
+
if alpha != 1:
|
| 861 |
+
lut *= alpha
|
| 862 |
+
if beta != 0:
|
| 863 |
+
if beta_by_max:
|
| 864 |
+
lut += beta * max_value
|
| 865 |
+
else:
|
| 866 |
+
lut += (alpha * beta) * np.mean(img)
|
| 867 |
+
|
| 868 |
+
lut = np.clip(lut, 0, max_value).astype(dtype)
|
| 869 |
+
img = cv2.LUT(img, lut)
|
| 870 |
+
return img
|
| 871 |
+
|
| 872 |
+
|
| 873 |
+
def brightness_contrast_adjust(img, alpha=1, beta=0, beta_by_max=False):
|
| 874 |
+
if img.dtype == np.uint8:
|
| 875 |
+
return _brightness_contrast_adjust_uint(img, alpha, beta, beta_by_max)
|
| 876 |
+
|
| 877 |
+
return _brightness_contrast_adjust_non_uint(img, alpha, beta, beta_by_max)
|
| 878 |
+
|
| 879 |
+
|
| 880 |
+
@clipped
|
| 881 |
+
def iso_noise(image, color_shift=0.05, intensity=0.5, random_state=None, **kwargs):
|
| 882 |
+
"""
|
| 883 |
+
Apply poisson noise to image to simulate camera sensor noise.
|
| 884 |
+
|
| 885 |
+
Args:
|
| 886 |
+
image (numpy.ndarray): Input image, currently, only RGB, uint8 images are supported.
|
| 887 |
+
color_shift (float):
|
| 888 |
+
intensity (float): Multiplication factor for noise values. Values of ~0.5 are produce noticeable,
|
| 889 |
+
yet acceptable level of noise.
|
| 890 |
+
random_state:
|
| 891 |
+
**kwargs:
|
| 892 |
+
|
| 893 |
+
Returns:
|
| 894 |
+
numpy.ndarray: Noised image
|
| 895 |
+
|
| 896 |
+
"""
|
| 897 |
+
if image.dtype != np.uint8:
|
| 898 |
+
raise TypeError("Image must have uint8 channel type")
|
| 899 |
+
if not is_rgb_image(image):
|
| 900 |
+
raise TypeError("Image must be RGB")
|
| 901 |
+
|
| 902 |
+
one_over_255 = float(1.0 / 255.0)
|
| 903 |
+
image = np.multiply(image, one_over_255, dtype=np.float32)
|
| 904 |
+
hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS)
|
| 905 |
+
_, stddev = cv2.meanStdDev(hls)
|
| 906 |
+
|
| 907 |
+
luminance_noise = random_utils.poisson(stddev[1] * intensity * 255, size=hls.shape[:2], random_state=random_state)
|
| 908 |
+
color_noise = random_utils.normal(0, color_shift * 360 * intensity, size=hls.shape[:2], random_state=random_state)
|
| 909 |
+
|
| 910 |
+
hue = hls[..., 0]
|
| 911 |
+
hue += color_noise
|
| 912 |
+
hue[hue < 0] += 360
|
| 913 |
+
hue[hue > 360] -= 360
|
| 914 |
+
|
| 915 |
+
luminance = hls[..., 1]
|
| 916 |
+
luminance += (luminance_noise / 255) * (1.0 - luminance)
|
| 917 |
+
|
| 918 |
+
image = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB) * 255
|
| 919 |
+
return image.astype(np.uint8)
|
| 920 |
+
|
| 921 |
+
|
| 922 |
+
def to_gray(img):
|
| 923 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 924 |
+
return cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
|
| 925 |
+
|
| 926 |
+
|
| 927 |
+
def gray_to_rgb(img):
|
| 928 |
+
return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
|
| 929 |
+
|
| 930 |
+
|
| 931 |
+
@preserve_shape
|
| 932 |
+
def downscale(img, scale, down_interpolation=cv2.INTER_AREA, up_interpolation=cv2.INTER_LINEAR):
|
| 933 |
+
h, w = img.shape[:2]
|
| 934 |
+
|
| 935 |
+
need_cast = (
|
| 936 |
+
up_interpolation != cv2.INTER_NEAREST or down_interpolation != cv2.INTER_NEAREST
|
| 937 |
+
) and img.dtype == np.uint8
|
| 938 |
+
if need_cast:
|
| 939 |
+
img = to_float(img)
|
| 940 |
+
downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=down_interpolation)
|
| 941 |
+
upscaled = cv2.resize(downscaled, (w, h), interpolation=up_interpolation)
|
| 942 |
+
if need_cast:
|
| 943 |
+
upscaled = from_float(np.clip(upscaled, 0, 1), dtype=np.dtype("uint8"))
|
| 944 |
+
return upscaled
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def to_float(img, max_value=None):
|
| 948 |
+
if max_value is None:
|
| 949 |
+
try:
|
| 950 |
+
max_value = MAX_VALUES_BY_DTYPE[img.dtype]
|
| 951 |
+
except KeyError:
|
| 952 |
+
raise RuntimeError(
|
| 953 |
+
"Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by "
|
| 954 |
+
"passing the max_value argument".format(img.dtype)
|
| 955 |
+
)
|
| 956 |
+
return img.astype("float32") / max_value
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
def from_float(img, dtype, max_value=None):
|
| 960 |
+
if max_value is None:
|
| 961 |
+
try:
|
| 962 |
+
max_value = MAX_VALUES_BY_DTYPE[dtype]
|
| 963 |
+
except KeyError:
|
| 964 |
+
raise RuntimeError(
|
| 965 |
+
"Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by "
|
| 966 |
+
"passing the max_value argument".format(dtype)
|
| 967 |
+
)
|
| 968 |
+
return (img * max_value).astype(dtype)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
def noop(input_obj, **params): # skipcq: PYL-W0613
|
| 972 |
+
return input_obj
|
| 973 |
+
|
| 974 |
+
|
| 975 |
+
def swap_tiles_on_image(image, tiles):
|
| 976 |
+
"""
|
| 977 |
+
Swap tiles on image.
|
| 978 |
+
|
| 979 |
+
Args:
|
| 980 |
+
image (np.ndarray): Input image.
|
| 981 |
+
tiles (np.ndarray): array of tuples(
|
| 982 |
+
current_left_up_corner_row, current_left_up_corner_col,
|
| 983 |
+
old_left_up_corner_row, old_left_up_corner_col,
|
| 984 |
+
height_tile, width_tile)
|
| 985 |
+
|
| 986 |
+
Returns:
|
| 987 |
+
np.ndarray: Output image.
|
| 988 |
+
|
| 989 |
+
"""
|
| 990 |
+
new_image = image.copy()
|
| 991 |
+
|
| 992 |
+
for tile in tiles:
|
| 993 |
+
new_image[tile[0] : tile[0] + tile[4], tile[1] : tile[1] + tile[5]] = image[
|
| 994 |
+
tile[2] : tile[2] + tile[4], tile[3] : tile[3] + tile[5]
|
| 995 |
+
]
|
| 996 |
+
|
| 997 |
+
return new_image
|
| 998 |
+
|
| 999 |
+
|
| 1000 |
+
@clipped
|
| 1001 |
+
def _multiply_uint8(img, multiplier):
|
| 1002 |
+
img = img.astype(np.float32)
|
| 1003 |
+
return np.multiply(img, multiplier)
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
@preserve_shape
|
| 1007 |
+
def _multiply_uint8_optimized(img, multiplier):
|
| 1008 |
+
if is_grayscale_image(img) or len(multiplier) == 1:
|
| 1009 |
+
multiplier = multiplier[0]
|
| 1010 |
+
lut = np.arange(0, 256, dtype=np.float32)
|
| 1011 |
+
lut *= multiplier
|
| 1012 |
+
lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype])
|
| 1013 |
+
func = _maybe_process_in_chunks(cv2.LUT, lut=lut)
|
| 1014 |
+
return func(img)
|
| 1015 |
+
|
| 1016 |
+
channels = img.shape[-1]
|
| 1017 |
+
lut = [np.arange(0, 256, dtype=np.float32)] * channels
|
| 1018 |
+
lut = np.stack(lut, axis=-1)
|
| 1019 |
+
|
| 1020 |
+
lut *= multiplier
|
| 1021 |
+
lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype])
|
| 1022 |
+
|
| 1023 |
+
images = []
|
| 1024 |
+
for i in range(channels):
|
| 1025 |
+
func = _maybe_process_in_chunks(cv2.LUT, lut=lut[:, i])
|
| 1026 |
+
images.append(func(img[:, :, i]))
|
| 1027 |
+
return np.stack(images, axis=-1)
|
| 1028 |
+
|
| 1029 |
+
|
| 1030 |
+
@clipped
|
| 1031 |
+
def _multiply_non_uint8(img, multiplier):
|
| 1032 |
+
return img * multiplier
|
| 1033 |
+
|
| 1034 |
+
|
| 1035 |
+
def multiply(img, multiplier):
|
| 1036 |
+
"""
|
| 1037 |
+
Args:
|
| 1038 |
+
img (numpy.ndarray): Image.
|
| 1039 |
+
multiplier (numpy.ndarray): Multiplier coefficient.
|
| 1040 |
+
|
| 1041 |
+
Returns:
|
| 1042 |
+
numpy.ndarray: Image multiplied by `multiplier` coefficient.
|
| 1043 |
+
|
| 1044 |
+
"""
|
| 1045 |
+
if img.dtype == np.uint8:
|
| 1046 |
+
if len(multiplier.shape) == 1:
|
| 1047 |
+
return _multiply_uint8_optimized(img, multiplier)
|
| 1048 |
+
|
| 1049 |
+
return _multiply_uint8(img, multiplier)
|
| 1050 |
+
|
| 1051 |
+
return _multiply_non_uint8(img, multiplier)
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def bbox_from_mask(mask):
|
| 1055 |
+
"""Create bounding box from binary mask (fast version)
|
| 1056 |
+
|
| 1057 |
+
Args:
|
| 1058 |
+
mask (numpy.ndarray): binary mask.
|
| 1059 |
+
|
| 1060 |
+
Returns:
|
| 1061 |
+
tuple: A bounding box tuple `(x_min, y_min, x_max, y_max)`.
|
| 1062 |
+
|
| 1063 |
+
"""
|
| 1064 |
+
rows = np.any(mask, axis=1)
|
| 1065 |
+
if not rows.any():
|
| 1066 |
+
return -1, -1, -1, -1
|
| 1067 |
+
cols = np.any(mask, axis=0)
|
| 1068 |
+
y_min, y_max = np.where(rows)[0][[0, -1]]
|
| 1069 |
+
x_min, x_max = np.where(cols)[0][[0, -1]]
|
| 1070 |
+
return x_min, y_min, x_max + 1, y_max + 1
|
| 1071 |
+
|
| 1072 |
+
|
| 1073 |
+
def mask_from_bbox(img, bbox):
|
| 1074 |
+
"""Create binary mask from bounding box
|
| 1075 |
+
|
| 1076 |
+
Args:
|
| 1077 |
+
img (numpy.ndarray): input image
|
| 1078 |
+
bbox: A bounding box tuple `(x_min, y_min, x_max, y_max)`
|
| 1079 |
+
|
| 1080 |
+
Returns:
|
| 1081 |
+
mask (numpy.ndarray): binary mask
|
| 1082 |
+
|
| 1083 |
+
"""
|
| 1084 |
+
|
| 1085 |
+
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
| 1086 |
+
x_min, y_min, x_max, y_max = bbox
|
| 1087 |
+
mask[y_min:y_max, x_min:x_max] = 1
|
| 1088 |
+
return mask
|
| 1089 |
+
|
| 1090 |
+
|
| 1091 |
+
def fancy_pca(img, alpha=0.1):
|
| 1092 |
+
"""Perform 'Fancy PCA' augmentation from:
|
| 1093 |
+
http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
|
| 1094 |
+
|
| 1095 |
+
Args:
|
| 1096 |
+
img (numpy.ndarray): numpy array with (h, w, rgb) shape, as ints between 0-255
|
| 1097 |
+
alpha (float): how much to perturb/scale the eigen vecs and vals
|
| 1098 |
+
the paper used std=0.1
|
| 1099 |
+
|
| 1100 |
+
Returns:
|
| 1101 |
+
numpy.ndarray: numpy image-like array as uint8 range(0, 255)
|
| 1102 |
+
|
| 1103 |
+
"""
|
| 1104 |
+
if not is_rgb_image(img) or img.dtype != np.uint8:
|
| 1105 |
+
raise TypeError("Image must be RGB image in uint8 format.")
|
| 1106 |
+
|
| 1107 |
+
orig_img = img.astype(float).copy()
|
| 1108 |
+
|
| 1109 |
+
img = img / 255.0 # rescale to 0 to 1 range
|
| 1110 |
+
|
| 1111 |
+
# flatten image to columns of RGB
|
| 1112 |
+
img_rs = img.reshape(-1, 3)
|
| 1113 |
+
# img_rs shape (640000, 3)
|
| 1114 |
+
|
| 1115 |
+
# center mean
|
| 1116 |
+
img_centered = img_rs - np.mean(img_rs, axis=0)
|
| 1117 |
+
|
| 1118 |
+
# paper says 3x3 covariance matrix
|
| 1119 |
+
img_cov = np.cov(img_centered, rowvar=False)
|
| 1120 |
+
|
| 1121 |
+
# eigen values and eigen vectors
|
| 1122 |
+
eig_vals, eig_vecs = np.linalg.eigh(img_cov)
|
| 1123 |
+
|
| 1124 |
+
# sort values and vector
|
| 1125 |
+
sort_perm = eig_vals[::-1].argsort()
|
| 1126 |
+
eig_vals[::-1].sort()
|
| 1127 |
+
eig_vecs = eig_vecs[:, sort_perm]
|
| 1128 |
+
|
| 1129 |
+
# get [p1, p2, p3]
|
| 1130 |
+
m1 = np.column_stack((eig_vecs))
|
| 1131 |
+
|
| 1132 |
+
# get 3x1 matrix of eigen values multiplied by random variable draw from normal
|
| 1133 |
+
# distribution with mean of 0 and standard deviation of 0.1
|
| 1134 |
+
m2 = np.zeros((3, 1))
|
| 1135 |
+
# according to the paper alpha should only be draw once per augmentation (not once per channel)
|
| 1136 |
+
# alpha = np.random.normal(0, alpha_std)
|
| 1137 |
+
|
| 1138 |
+
# broad cast to speed things up
|
| 1139 |
+
m2[:, 0] = alpha * eig_vals[:]
|
| 1140 |
+
|
| 1141 |
+
# this is the vector that we're going to add to each pixel in a moment
|
| 1142 |
+
add_vect = np.matrix(m1) * np.matrix(m2)
|
| 1143 |
+
|
| 1144 |
+
for idx in range(3): # RGB
|
| 1145 |
+
orig_img[..., idx] += add_vect[idx] * 255
|
| 1146 |
+
|
| 1147 |
+
# for image processing it was found that working with float 0.0 to 1.0
|
| 1148 |
+
# was easier than integers between 0-255
|
| 1149 |
+
# orig_img /= 255.0
|
| 1150 |
+
orig_img = np.clip(orig_img, 0.0, 255.0)
|
| 1151 |
+
|
| 1152 |
+
# orig_img *= 255
|
| 1153 |
+
orig_img = orig_img.astype(np.uint8)
|
| 1154 |
+
|
| 1155 |
+
return orig_img
|
| 1156 |
+
|
| 1157 |
+
|
| 1158 |
+
def _adjust_brightness_torchvision_uint8(img, factor):
|
| 1159 |
+
lut = np.arange(0, 256) * factor
|
| 1160 |
+
lut = np.clip(lut, 0, 255).astype(np.uint8)
|
| 1161 |
+
return cv2.LUT(img, lut)
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
@preserve_shape
|
| 1165 |
+
def adjust_brightness_torchvision(img, factor):
|
| 1166 |
+
if factor == 0:
|
| 1167 |
+
return np.zeros_like(img)
|
| 1168 |
+
elif factor == 1:
|
| 1169 |
+
return img
|
| 1170 |
+
|
| 1171 |
+
if img.dtype == np.uint8:
|
| 1172 |
+
return _adjust_brightness_torchvision_uint8(img, factor)
|
| 1173 |
+
|
| 1174 |
+
return clip(img * factor, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype])
|
| 1175 |
+
|
| 1176 |
+
|
| 1177 |
+
def _adjust_contrast_torchvision_uint8(img, factor, mean):
|
| 1178 |
+
lut = np.arange(0, 256) * factor
|
| 1179 |
+
lut = lut + mean * (1 - factor)
|
| 1180 |
+
lut = clip(lut, img.dtype, 255)
|
| 1181 |
+
|
| 1182 |
+
return cv2.LUT(img, lut)
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
@preserve_shape
|
| 1186 |
+
def adjust_contrast_torchvision(img, factor):
|
| 1187 |
+
if factor == 1:
|
| 1188 |
+
return img
|
| 1189 |
+
|
| 1190 |
+
if is_grayscale_image(img):
|
| 1191 |
+
mean = img.mean()
|
| 1192 |
+
else:
|
| 1193 |
+
mean = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean()
|
| 1194 |
+
|
| 1195 |
+
if factor == 0:
|
| 1196 |
+
if img.dtype != np.float32:
|
| 1197 |
+
mean = int(mean + 0.5)
|
| 1198 |
+
return np.full_like(img, mean, dtype=img.dtype)
|
| 1199 |
+
|
| 1200 |
+
if img.dtype == np.uint8:
|
| 1201 |
+
return _adjust_contrast_torchvision_uint8(img, factor, mean)
|
| 1202 |
+
|
| 1203 |
+
return clip(
|
| 1204 |
+
img.astype(np.float32) * factor + mean * (1 - factor),
|
| 1205 |
+
img.dtype,
|
| 1206 |
+
MAX_VALUES_BY_DTYPE[img.dtype],
|
| 1207 |
+
)
|
| 1208 |
+
|
| 1209 |
+
|
| 1210 |
+
@preserve_shape
|
| 1211 |
+
def adjust_saturation_torchvision(img, factor, gamma=0):
|
| 1212 |
+
if factor == 1:
|
| 1213 |
+
return img
|
| 1214 |
+
|
| 1215 |
+
if is_grayscale_image(img):
|
| 1216 |
+
gray = img
|
| 1217 |
+
return gray
|
| 1218 |
+
else:
|
| 1219 |
+
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
|
| 1220 |
+
gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
|
| 1221 |
+
|
| 1222 |
+
if factor == 0:
|
| 1223 |
+
return gray
|
| 1224 |
+
|
| 1225 |
+
result = cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma)
|
| 1226 |
+
if img.dtype == np.uint8:
|
| 1227 |
+
return result
|
| 1228 |
+
|
| 1229 |
+
# OpenCV does not clip values for float dtype
|
| 1230 |
+
return clip(result, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype])
|
| 1231 |
+
|
| 1232 |
+
|
| 1233 |
+
def _adjust_hue_torchvision_uint8(img, factor):
|
| 1234 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
| 1235 |
+
|
| 1236 |
+
lut = np.arange(0, 256, dtype=np.int16)
|
| 1237 |
+
lut = np.mod(lut + 180 * factor, 180).astype(np.uint8)
|
| 1238 |
+
img[..., 0] = cv2.LUT(img[..., 0], lut)
|
| 1239 |
+
|
| 1240 |
+
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
|
| 1241 |
+
|
| 1242 |
+
|
| 1243 |
+
def adjust_hue_torchvision(img, factor):
|
| 1244 |
+
if is_grayscale_image(img):
|
| 1245 |
+
return img
|
| 1246 |
+
|
| 1247 |
+
if factor == 0:
|
| 1248 |
+
return img
|
| 1249 |
+
|
| 1250 |
+
if img.dtype == np.uint8:
|
| 1251 |
+
return _adjust_hue_torchvision_uint8(img, factor)
|
| 1252 |
+
|
| 1253 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
|
| 1254 |
+
img[..., 0] = np.mod(img[..., 0] + factor * 360, 360)
|
| 1255 |
+
return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
|
| 1256 |
+
|
| 1257 |
+
|
| 1258 |
+
@preserve_shape
|
| 1259 |
+
def superpixels(
|
| 1260 |
+
image: np.ndarray, n_segments: int, replace_samples: Sequence[bool], max_size: Optional[int], interpolation: int
|
| 1261 |
+
) -> np.ndarray:
|
| 1262 |
+
if not np.any(replace_samples):
|
| 1263 |
+
return image
|
| 1264 |
+
|
| 1265 |
+
orig_shape = image.shape
|
| 1266 |
+
if max_size is not None:
|
| 1267 |
+
size = max(image.shape[:2])
|
| 1268 |
+
if size > max_size:
|
| 1269 |
+
scale = max_size / size
|
| 1270 |
+
height, width = image.shape[:2]
|
| 1271 |
+
new_height, new_width = int(height * scale), int(width * scale)
|
| 1272 |
+
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(new_width, new_height), interpolation=interpolation)
|
| 1273 |
+
image = resize_fn(image)
|
| 1274 |
+
|
| 1275 |
+
segments = skimage.segmentation.slic(
|
| 1276 |
+
image, n_segments=n_segments, compactness=10, channel_axis=-1 if image.ndim > 2 else None
|
| 1277 |
+
)
|
| 1278 |
+
|
| 1279 |
+
min_value = 0
|
| 1280 |
+
max_value = MAX_VALUES_BY_DTYPE[image.dtype]
|
| 1281 |
+
image = np.copy(image)
|
| 1282 |
+
if image.ndim == 2:
|
| 1283 |
+
image = image.reshape(*image.shape, 1)
|
| 1284 |
+
nb_channels = image.shape[2]
|
| 1285 |
+
for c in range(nb_channels):
|
| 1286 |
+
# segments+1 here because otherwise regionprops always misses the last label
|
| 1287 |
+
regions = skimage.measure.regionprops(segments + 1, intensity_image=image[..., c])
|
| 1288 |
+
for ridx, region in enumerate(regions):
|
| 1289 |
+
# with mod here, because slic can sometimes create more superpixel than requested.
|
| 1290 |
+
# replace_samples then does not have enough values, so we just start over with the first one again.
|
| 1291 |
+
if replace_samples[ridx % len(replace_samples)]:
|
| 1292 |
+
mean_intensity = region.mean_intensity
|
| 1293 |
+
image_sp_c = image[..., c]
|
| 1294 |
+
|
| 1295 |
+
if image_sp_c.dtype.kind in ["i", "u", "b"]:
|
| 1296 |
+
# After rounding the value can end up slightly outside of the value_range. Hence, we need to clip.
|
| 1297 |
+
# We do clip via min(max(...)) instead of np.clip because
|
| 1298 |
+
# the latter one does not seem to keep dtypes for dtypes with large itemsizes (e.g. uint64).
|
| 1299 |
+
value: Union[int, float]
|
| 1300 |
+
value = int(np.round(mean_intensity))
|
| 1301 |
+
value = min(max(value, min_value), max_value)
|
| 1302 |
+
else:
|
| 1303 |
+
value = mean_intensity
|
| 1304 |
+
|
| 1305 |
+
image_sp_c[segments == ridx] = value
|
| 1306 |
+
|
| 1307 |
+
if orig_shape != image.shape:
|
| 1308 |
+
resize_fn = _maybe_process_in_chunks(
|
| 1309 |
+
cv2.resize, dsize=(orig_shape[1], orig_shape[0]), interpolation=interpolation
|
| 1310 |
+
)
|
| 1311 |
+
image = resize_fn(image)
|
| 1312 |
+
|
| 1313 |
+
return image
|
| 1314 |
+
|
| 1315 |
+
|
| 1316 |
+
@clipped
|
| 1317 |
+
def add_weighted(img1, alpha, img2, beta):
|
| 1318 |
+
return img1.astype(float) * alpha + img2.astype(float) * beta
|
| 1319 |
+
|
| 1320 |
+
|
| 1321 |
+
@clipped
|
| 1322 |
+
@preserve_shape
|
| 1323 |
+
def unsharp_mask(image: np.ndarray, ksize: int, sigma: float = 0.0, alpha: float = 0.2, threshold: int = 10):
|
| 1324 |
+
blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma)
|
| 1325 |
+
|
| 1326 |
+
input_dtype = image.dtype
|
| 1327 |
+
if input_dtype == np.uint8:
|
| 1328 |
+
image = to_float(image)
|
| 1329 |
+
elif input_dtype not in (np.uint8, np.float32):
|
| 1330 |
+
raise ValueError("Unexpected dtype {} for UnsharpMask augmentation".format(input_dtype))
|
| 1331 |
+
|
| 1332 |
+
blur = blur_fn(image)
|
| 1333 |
+
residual = image - blur
|
| 1334 |
+
|
| 1335 |
+
# Do not sharpen noise
|
| 1336 |
+
mask = np.abs(residual) * 255 > threshold
|
| 1337 |
+
mask = mask.astype("float32")
|
| 1338 |
+
|
| 1339 |
+
sharp = image + alpha * residual
|
| 1340 |
+
# Avoid color noise artefacts.
|
| 1341 |
+
sharp = np.clip(sharp, 0, 1)
|
| 1342 |
+
|
| 1343 |
+
soft_mask = blur_fn(mask)
|
| 1344 |
+
output = soft_mask * sharp + (1 - soft_mask) * image
|
| 1345 |
+
return from_float(output, dtype=input_dtype)
|
| 1346 |
+
|
| 1347 |
+
|
| 1348 |
+
@preserve_shape
|
| 1349 |
+
def pixel_dropout(image: np.ndarray, drop_mask: np.ndarray, drop_value: Union[float, Sequence[float]]) -> np.ndarray:
|
| 1350 |
+
if isinstance(drop_value, (int, float)) and drop_value == 0:
|
| 1351 |
+
drop_values = np.zeros_like(image)
|
| 1352 |
+
else:
|
| 1353 |
+
drop_values = np.full_like(image, drop_value) # type: ignore
|
| 1354 |
+
return np.where(drop_mask, drop_values, image)
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
@clipped
|
| 1358 |
+
@preserve_shape
|
| 1359 |
+
def spatter(
|
| 1360 |
+
img: np.ndarray,
|
| 1361 |
+
non_mud: Optional[np.ndarray],
|
| 1362 |
+
mud: Optional[np.ndarray],
|
| 1363 |
+
rain: Optional[np.ndarray],
|
| 1364 |
+
mode: str,
|
| 1365 |
+
) -> np.ndarray:
|
| 1366 |
+
non_rgb_warning(img)
|
| 1367 |
+
|
| 1368 |
+
coef = MAX_VALUES_BY_DTYPE[img.dtype]
|
| 1369 |
+
img = img.astype(np.float32) * (1 / coef)
|
| 1370 |
+
|
| 1371 |
+
if mode == "rain":
|
| 1372 |
+
assert rain is not None
|
| 1373 |
+
img = img + rain
|
| 1374 |
+
elif mode == "mud":
|
| 1375 |
+
assert non_mud is not None and mud is not None
|
| 1376 |
+
img = img * non_mud + mud
|
| 1377 |
+
else:
|
| 1378 |
+
raise ValueError("Unsupported spatter mode: " + str(mode))
|
| 1379 |
+
|
| 1380 |
+
return img * 255
|
custom_albumentations/augmentations/geometric/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .functional import *
|
| 2 |
+
from .resize import *
|
| 3 |
+
from .rotate import *
|
| 4 |
+
from .transforms import *
|
custom_albumentations/augmentations/geometric/functional.py
ADDED
|
@@ -0,0 +1,1300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
import skimage.transform
|
| 7 |
+
from scipy.ndimage import gaussian_filter
|
| 8 |
+
|
| 9 |
+
from custom_albumentations.augmentations.utils import (
|
| 10 |
+
_maybe_process_in_chunks,
|
| 11 |
+
angle_2pi_range,
|
| 12 |
+
clipped,
|
| 13 |
+
preserve_channel_dim,
|
| 14 |
+
preserve_shape,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
from ... import random_utils
|
| 18 |
+
from ...core.bbox_utils import denormalize_bbox, normalize_bbox
|
| 19 |
+
from ...core.transforms_interface import (
|
| 20 |
+
BoxInternalType,
|
| 21 |
+
FillValueType,
|
| 22 |
+
ImageColorType,
|
| 23 |
+
KeypointInternalType,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
__all__ = [
|
| 27 |
+
"optical_distortion",
|
| 28 |
+
"elastic_transform_approx",
|
| 29 |
+
"grid_distortion",
|
| 30 |
+
"pad",
|
| 31 |
+
"pad_with_params",
|
| 32 |
+
"bbox_rot90",
|
| 33 |
+
"keypoint_rot90",
|
| 34 |
+
"rotate",
|
| 35 |
+
"bbox_rotate",
|
| 36 |
+
"keypoint_rotate",
|
| 37 |
+
"shift_scale_rotate",
|
| 38 |
+
"keypoint_shift_scale_rotate",
|
| 39 |
+
"bbox_shift_scale_rotate",
|
| 40 |
+
"elastic_transform",
|
| 41 |
+
"resize",
|
| 42 |
+
"scale",
|
| 43 |
+
"keypoint_scale",
|
| 44 |
+
"py3round",
|
| 45 |
+
"_func_max_size",
|
| 46 |
+
"longest_max_size",
|
| 47 |
+
"smallest_max_size",
|
| 48 |
+
"perspective",
|
| 49 |
+
"perspective_bbox",
|
| 50 |
+
"rotation2DMatrixToEulerAngles",
|
| 51 |
+
"perspective_keypoint",
|
| 52 |
+
"_is_identity_matrix",
|
| 53 |
+
"warp_affine",
|
| 54 |
+
"keypoint_affine",
|
| 55 |
+
"bbox_affine",
|
| 56 |
+
"safe_rotate",
|
| 57 |
+
"bbox_safe_rotate",
|
| 58 |
+
"keypoint_safe_rotate",
|
| 59 |
+
"piecewise_affine",
|
| 60 |
+
"to_distance_maps",
|
| 61 |
+
"from_distance_maps",
|
| 62 |
+
"keypoint_piecewise_affine",
|
| 63 |
+
"bbox_piecewise_affine",
|
| 64 |
+
"bbox_flip",
|
| 65 |
+
"bbox_hflip",
|
| 66 |
+
"bbox_transpose",
|
| 67 |
+
"bbox_vflip",
|
| 68 |
+
"hflip",
|
| 69 |
+
"hflip_cv2",
|
| 70 |
+
"transpose",
|
| 71 |
+
"keypoint_flip",
|
| 72 |
+
"keypoint_hflip",
|
| 73 |
+
"keypoint_transpose",
|
| 74 |
+
"keypoint_vflip",
|
| 75 |
+
]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def bbox_rot90(bbox: BoxInternalType, factor: int, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
|
| 79 |
+
"""Rotates a bounding box by 90 degrees CCW (see np.rot90)
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
bbox: A bounding box tuple (x_min, y_min, x_max, y_max).
|
| 83 |
+
factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90.
|
| 84 |
+
rows: Image rows.
|
| 85 |
+
cols: Image cols.
|
| 86 |
+
|
| 87 |
+
Returns:
|
| 88 |
+
tuple: A bounding box tuple (x_min, y_min, x_max, y_max).
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
if factor not in {0, 1, 2, 3}:
|
| 92 |
+
raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
|
| 93 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 94 |
+
if factor == 1:
|
| 95 |
+
bbox = y_min, 1 - x_max, y_max, 1 - x_min
|
| 96 |
+
elif factor == 2:
|
| 97 |
+
bbox = 1 - x_max, 1 - y_max, 1 - x_min, 1 - y_min
|
| 98 |
+
elif factor == 3:
|
| 99 |
+
bbox = 1 - y_max, x_min, 1 - y_min, x_max
|
| 100 |
+
return bbox
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@angle_2pi_range
|
| 104 |
+
def keypoint_rot90(keypoint: KeypointInternalType, factor: int, rows: int, cols: int, **params) -> KeypointInternalType:
|
| 105 |
+
"""Rotates a keypoint by 90 degrees CCW (see np.rot90)
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 109 |
+
factor: Number of CCW rotations. Must be in range [0;3] See np.rot90.
|
| 110 |
+
rows: Image height.
|
| 111 |
+
cols: Image width.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
tuple: A keypoint `(x, y, angle, scale)`.
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: if factor not in set {0, 1, 2, 3}
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
x, y, angle, scale = keypoint[:4]
|
| 121 |
+
|
| 122 |
+
if factor not in {0, 1, 2, 3}:
|
| 123 |
+
raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
|
| 124 |
+
|
| 125 |
+
if factor == 1:
|
| 126 |
+
x, y, angle = y, (cols - 1) - x, angle - math.pi / 2
|
| 127 |
+
elif factor == 2:
|
| 128 |
+
x, y, angle = (cols - 1) - x, (rows - 1) - y, angle - math.pi
|
| 129 |
+
elif factor == 3:
|
| 130 |
+
x, y, angle = (rows - 1) - y, x, angle + math.pi / 2
|
| 131 |
+
|
| 132 |
+
return x, y, angle, scale
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
@preserve_channel_dim
|
| 136 |
+
def rotate(
|
| 137 |
+
img: np.ndarray,
|
| 138 |
+
angle: float,
|
| 139 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 140 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 141 |
+
value: Optional[ImageColorType] = None,
|
| 142 |
+
):
|
| 143 |
+
height, width = img.shape[:2]
|
| 144 |
+
# for images we use additional shifts of (0.5, 0.5) as otherwise
|
| 145 |
+
# we get an ugly black border for 90deg rotations
|
| 146 |
+
matrix = cv2.getRotationMatrix2D((width / 2 - 0.5, height / 2 - 0.5), angle, 1.0)
|
| 147 |
+
|
| 148 |
+
warp_fn = _maybe_process_in_chunks(
|
| 149 |
+
cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
|
| 150 |
+
)
|
| 151 |
+
return warp_fn(img)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, cols: int) -> BoxInternalType:
|
| 155 |
+
"""Rotates a bounding box by angle degrees.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 159 |
+
angle: Angle of rotation in degrees.
|
| 160 |
+
method: Rotation method used. Should be one of: "largest_box", "ellipse". Default: "largest_box".
|
| 161 |
+
rows: Image rows.
|
| 162 |
+
cols: Image cols.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 166 |
+
|
| 167 |
+
References:
|
| 168 |
+
https://arxiv.org/abs/2109.13488
|
| 169 |
+
|
| 170 |
+
"""
|
| 171 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 172 |
+
scale = cols / float(rows)
|
| 173 |
+
if method == "largest_box":
|
| 174 |
+
x = np.array([x_min, x_max, x_max, x_min]) - 0.5
|
| 175 |
+
y = np.array([y_min, y_min, y_max, y_max]) - 0.5
|
| 176 |
+
elif method == "ellipse":
|
| 177 |
+
w = (x_max - x_min) / 2
|
| 178 |
+
h = (y_max - y_min) / 2
|
| 179 |
+
data = np.arange(0, 360, dtype=np.float32)
|
| 180 |
+
x = w * np.sin(np.radians(data)) + (w + x_min - 0.5)
|
| 181 |
+
y = h * np.cos(np.radians(data)) + (h + y_min - 0.5)
|
| 182 |
+
else:
|
| 183 |
+
raise ValueError(f"Method {method} is not a valid rotation method.")
|
| 184 |
+
angle = np.deg2rad(angle)
|
| 185 |
+
x_t = (np.cos(angle) * x * scale + np.sin(angle) * y) / scale
|
| 186 |
+
y_t = -np.sin(angle) * x * scale + np.cos(angle) * y
|
| 187 |
+
x_t = x_t + 0.5
|
| 188 |
+
y_t = y_t + 0.5
|
| 189 |
+
|
| 190 |
+
x_min, x_max = min(x_t), max(x_t)
|
| 191 |
+
y_min, y_max = min(y_t), max(y_t)
|
| 192 |
+
|
| 193 |
+
return x_min, y_min, x_max, y_max
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
@angle_2pi_range
|
| 197 |
+
def keypoint_rotate(keypoint, angle, rows, cols, **params):
|
| 198 |
+
"""Rotate a keypoint by angle.
|
| 199 |
+
|
| 200 |
+
Args:
|
| 201 |
+
keypoint (tuple): A keypoint `(x, y, angle, scale)`.
|
| 202 |
+
angle (float): Rotation angle.
|
| 203 |
+
rows (int): Image height.
|
| 204 |
+
cols (int): Image width.
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
tuple: A keypoint `(x, y, angle, scale)`.
|
| 208 |
+
|
| 209 |
+
"""
|
| 210 |
+
center = (cols - 1) * 0.5, (rows - 1) * 0.5
|
| 211 |
+
matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
|
| 212 |
+
x, y, a, s = keypoint[:4]
|
| 213 |
+
x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
|
| 214 |
+
return x, y, a + math.radians(angle), s
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
@preserve_channel_dim
|
| 218 |
+
def shift_scale_rotate(
|
| 219 |
+
img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=None
|
| 220 |
+
):
|
| 221 |
+
height, width = img.shape[:2]
|
| 222 |
+
# for images we use additional shifts of (0.5, 0.5) as otherwise
|
| 223 |
+
# we get an ugly black border for 90deg rotations
|
| 224 |
+
center = (width / 2 - 0.5, height / 2 - 0.5)
|
| 225 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 226 |
+
matrix[0, 2] += dx * width
|
| 227 |
+
matrix[1, 2] += dy * height
|
| 228 |
+
|
| 229 |
+
warp_affine_fn = _maybe_process_in_chunks(
|
| 230 |
+
cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
|
| 231 |
+
)
|
| 232 |
+
return warp_affine_fn(img)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@angle_2pi_range
|
| 236 |
+
def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params):
|
| 237 |
+
(
|
| 238 |
+
x,
|
| 239 |
+
y,
|
| 240 |
+
a,
|
| 241 |
+
s,
|
| 242 |
+
) = keypoint[:4]
|
| 243 |
+
height, width = rows, cols
|
| 244 |
+
center = (cols - 1) * 0.5, (rows - 1) * 0.5
|
| 245 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 246 |
+
matrix[0, 2] += dx * width
|
| 247 |
+
matrix[1, 2] += dy * height
|
| 248 |
+
|
| 249 |
+
x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
|
| 250 |
+
angle = a + math.radians(angle)
|
| 251 |
+
scale = s * scale
|
| 252 |
+
|
| 253 |
+
return x, y, angle, scale
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rotate_method, rows, cols, **kwargs): # skipcq: PYL-W0613
|
| 257 |
+
"""Rotates, shifts and scales a bounding box. Rotation is made by angle degrees,
|
| 258 |
+
scaling is made by scale factor and shifting is made by dx and dy.
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
Args:
|
| 262 |
+
bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 263 |
+
angle (int): Angle of rotation in degrees.
|
| 264 |
+
scale (int): Scale factor.
|
| 265 |
+
dx (int): Shift along x-axis in pixel units.
|
| 266 |
+
dy (int): Shift along y-axis in pixel units.
|
| 267 |
+
rotate_method(str): Rotation method used. Should be one of: "largest_box", "ellipse".
|
| 268 |
+
Default: "largest_box".
|
| 269 |
+
rows (int): Image rows.
|
| 270 |
+
cols (int): Image cols.
|
| 271 |
+
|
| 272 |
+
Returns:
|
| 273 |
+
A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 274 |
+
|
| 275 |
+
"""
|
| 276 |
+
height, width = rows, cols
|
| 277 |
+
center = (width / 2, height / 2)
|
| 278 |
+
if rotate_method == "ellipse":
|
| 279 |
+
x_min, y_min, x_max, y_max = bbox_rotate(bbox, angle, rotate_method, rows, cols)
|
| 280 |
+
matrix = cv2.getRotationMatrix2D(center, 0, scale)
|
| 281 |
+
else:
|
| 282 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 283 |
+
matrix = cv2.getRotationMatrix2D(center, angle, scale)
|
| 284 |
+
matrix[0, 2] += dx * width
|
| 285 |
+
matrix[1, 2] += dy * height
|
| 286 |
+
x = np.array([x_min, x_max, x_max, x_min])
|
| 287 |
+
y = np.array([y_min, y_min, y_max, y_max])
|
| 288 |
+
ones = np.ones(shape=(len(x)))
|
| 289 |
+
points_ones = np.vstack([x, y, ones]).transpose()
|
| 290 |
+
points_ones[:, 0] *= width
|
| 291 |
+
points_ones[:, 1] *= height
|
| 292 |
+
tr_points = matrix.dot(points_ones.T).T
|
| 293 |
+
tr_points[:, 0] /= width
|
| 294 |
+
tr_points[:, 1] /= height
|
| 295 |
+
|
| 296 |
+
x_min, x_max = min(tr_points[:, 0]), max(tr_points[:, 0])
|
| 297 |
+
y_min, y_max = min(tr_points[:, 1]), max(tr_points[:, 1])
|
| 298 |
+
|
| 299 |
+
return x_min, y_min, x_max, y_max
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
@preserve_shape
|
| 303 |
+
def elastic_transform(
|
| 304 |
+
img: np.ndarray,
|
| 305 |
+
alpha: float,
|
| 306 |
+
sigma: float,
|
| 307 |
+
alpha_affine: float,
|
| 308 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 309 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 310 |
+
value: Optional[ImageColorType] = None,
|
| 311 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 312 |
+
approximate: bool = False,
|
| 313 |
+
same_dxdy: bool = False,
|
| 314 |
+
):
|
| 315 |
+
"""Elastic deformation of images as described in [Simard2003]_ (with modifications).
|
| 316 |
+
Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
|
| 317 |
+
|
| 318 |
+
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
| 319 |
+
Convolutional Neural Networks applied to Visual Document Analysis", in
|
| 320 |
+
Proc. of the International Conference on Document Analysis and
|
| 321 |
+
Recognition, 2003.
|
| 322 |
+
"""
|
| 323 |
+
height, width = img.shape[:2]
|
| 324 |
+
|
| 325 |
+
# Random affine
|
| 326 |
+
center_square = np.array((height, width), dtype=np.float32) // 2
|
| 327 |
+
square_size = min((height, width)) // 3
|
| 328 |
+
alpha = float(alpha)
|
| 329 |
+
sigma = float(sigma)
|
| 330 |
+
alpha_affine = float(alpha_affine)
|
| 331 |
+
|
| 332 |
+
pts1 = np.array(
|
| 333 |
+
[
|
| 334 |
+
center_square + square_size,
|
| 335 |
+
[center_square[0] + square_size, center_square[1] - square_size],
|
| 336 |
+
center_square - square_size,
|
| 337 |
+
],
|
| 338 |
+
dtype=np.float32,
|
| 339 |
+
)
|
| 340 |
+
pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype(
|
| 341 |
+
np.float32
|
| 342 |
+
)
|
| 343 |
+
matrix = cv2.getAffineTransform(pts1, pts2)
|
| 344 |
+
|
| 345 |
+
warp_fn = _maybe_process_in_chunks(
|
| 346 |
+
cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
|
| 347 |
+
)
|
| 348 |
+
img = warp_fn(img)
|
| 349 |
+
|
| 350 |
+
if approximate:
|
| 351 |
+
# Approximate computation smooth displacement map with a large enough kernel.
|
| 352 |
+
# On large images (512+) this is approximately 2X times faster
|
| 353 |
+
dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
|
| 354 |
+
cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
|
| 355 |
+
dx *= alpha
|
| 356 |
+
if same_dxdy:
|
| 357 |
+
# Speed up even more
|
| 358 |
+
dy = dx
|
| 359 |
+
else:
|
| 360 |
+
dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
|
| 361 |
+
cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
|
| 362 |
+
dy *= alpha
|
| 363 |
+
else:
|
| 364 |
+
dx = np.float32(
|
| 365 |
+
gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha
|
| 366 |
+
)
|
| 367 |
+
if same_dxdy:
|
| 368 |
+
# Speed up
|
| 369 |
+
dy = dx
|
| 370 |
+
else:
|
| 371 |
+
dy = np.float32(
|
| 372 |
+
gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
x, y = np.meshgrid(np.arange(width), np.arange(height))
|
| 376 |
+
|
| 377 |
+
map_x = np.float32(x + dx)
|
| 378 |
+
map_y = np.float32(y + dy)
|
| 379 |
+
|
| 380 |
+
remap_fn = _maybe_process_in_chunks(
|
| 381 |
+
cv2.remap, map1=map_x, map2=map_y, interpolation=interpolation, borderMode=border_mode, borderValue=value
|
| 382 |
+
)
|
| 383 |
+
return remap_fn(img)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
@preserve_channel_dim
|
| 387 |
+
def resize(img, height, width, interpolation=cv2.INTER_LINEAR):
|
| 388 |
+
img_height, img_width = img.shape[:2]
|
| 389 |
+
if height == img_height and width == img_width:
|
| 390 |
+
return img
|
| 391 |
+
resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
|
| 392 |
+
return resize_fn(img)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
@preserve_channel_dim
|
| 396 |
+
def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
|
| 397 |
+
height, width = img.shape[:2]
|
| 398 |
+
new_height, new_width = int(height * scale), int(width * scale)
|
| 399 |
+
return resize(img, new_height, new_width, interpolation)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def keypoint_scale(keypoint: KeypointInternalType, scale_x: float, scale_y: float) -> KeypointInternalType:
|
| 403 |
+
"""Scales a keypoint by scale_x and scale_y.
|
| 404 |
+
|
| 405 |
+
Args:
|
| 406 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 407 |
+
scale_x: Scale coefficient x-axis.
|
| 408 |
+
scale_y: Scale coefficient y-axis.
|
| 409 |
+
|
| 410 |
+
Returns:
|
| 411 |
+
A keypoint `(x, y, angle, scale)`.
|
| 412 |
+
|
| 413 |
+
"""
|
| 414 |
+
x, y, angle, scale = keypoint[:4]
|
| 415 |
+
return x * scale_x, y * scale_y, angle, scale * max(scale_x, scale_y)
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def py3round(number):
|
| 419 |
+
"""Unified rounding in all python versions."""
|
| 420 |
+
if abs(round(number) - number) == 0.5:
|
| 421 |
+
return int(2.0 * round(number / 2.0))
|
| 422 |
+
|
| 423 |
+
return int(round(number))
|
| 424 |
+
|
| 425 |
+
|
| 426 |
+
def _func_max_size(img, max_size, interpolation, func):
|
| 427 |
+
height, width = img.shape[:2]
|
| 428 |
+
|
| 429 |
+
scale = max_size / float(func(width, height))
|
| 430 |
+
|
| 431 |
+
if scale != 1.0:
|
| 432 |
+
new_height, new_width = tuple(py3round(dim * scale) for dim in (height, width))
|
| 433 |
+
img = resize(img, height=new_height, width=new_width, interpolation=interpolation)
|
| 434 |
+
return img
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
@preserve_channel_dim
|
| 438 |
+
def longest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray:
|
| 439 |
+
return _func_max_size(img, max_size, interpolation, max)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
@preserve_channel_dim
|
| 443 |
+
def smallest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray:
|
| 444 |
+
return _func_max_size(img, max_size, interpolation, min)
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
@preserve_channel_dim
|
| 448 |
+
def perspective(
|
| 449 |
+
img: np.ndarray,
|
| 450 |
+
matrix: np.ndarray,
|
| 451 |
+
max_width: int,
|
| 452 |
+
max_height: int,
|
| 453 |
+
border_val: Union[int, float, List[int], List[float], np.ndarray],
|
| 454 |
+
border_mode: int,
|
| 455 |
+
keep_size: bool,
|
| 456 |
+
interpolation: int,
|
| 457 |
+
):
|
| 458 |
+
h, w = img.shape[:2]
|
| 459 |
+
perspective_func = _maybe_process_in_chunks(
|
| 460 |
+
cv2.warpPerspective,
|
| 461 |
+
M=matrix,
|
| 462 |
+
dsize=(max_width, max_height),
|
| 463 |
+
borderMode=border_mode,
|
| 464 |
+
borderValue=border_val,
|
| 465 |
+
flags=interpolation,
|
| 466 |
+
)
|
| 467 |
+
warped = perspective_func(img)
|
| 468 |
+
|
| 469 |
+
if keep_size:
|
| 470 |
+
return resize(warped, h, w, interpolation=interpolation)
|
| 471 |
+
|
| 472 |
+
return warped
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def perspective_bbox(
|
| 476 |
+
bbox: BoxInternalType,
|
| 477 |
+
height: int,
|
| 478 |
+
width: int,
|
| 479 |
+
matrix: np.ndarray,
|
| 480 |
+
max_width: int,
|
| 481 |
+
max_height: int,
|
| 482 |
+
keep_size: bool,
|
| 483 |
+
) -> BoxInternalType:
|
| 484 |
+
x1, y1, x2, y2 = denormalize_bbox(bbox, height, width)[:4]
|
| 485 |
+
|
| 486 |
+
points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32)
|
| 487 |
+
|
| 488 |
+
x1, y1, x2, y2 = float("inf"), float("inf"), 0, 0
|
| 489 |
+
for pt in points:
|
| 490 |
+
pt = perspective_keypoint(pt.tolist() + [0, 0], height, width, matrix, max_width, max_height, keep_size)
|
| 491 |
+
x, y = pt[:2]
|
| 492 |
+
x1 = min(x1, x)
|
| 493 |
+
x2 = max(x2, x)
|
| 494 |
+
y1 = min(y1, y)
|
| 495 |
+
y2 = max(y2, y)
|
| 496 |
+
|
| 497 |
+
return normalize_bbox((x1, y1, x2, y2), height if keep_size else max_height, width if keep_size else max_width)
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
def rotation2DMatrixToEulerAngles(matrix: np.ndarray, y_up: bool = False) -> float:
|
| 501 |
+
"""
|
| 502 |
+
Args:
|
| 503 |
+
matrix (np.ndarray): Rotation matrix
|
| 504 |
+
y_up (bool): is Y axis looks up or down
|
| 505 |
+
"""
|
| 506 |
+
if y_up:
|
| 507 |
+
return np.arctan2(matrix[1, 0], matrix[0, 0])
|
| 508 |
+
return np.arctan2(-matrix[1, 0], matrix[0, 0])
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
@angle_2pi_range
|
| 512 |
+
def perspective_keypoint(
|
| 513 |
+
keypoint: KeypointInternalType,
|
| 514 |
+
height: int,
|
| 515 |
+
width: int,
|
| 516 |
+
matrix: np.ndarray,
|
| 517 |
+
max_width: int,
|
| 518 |
+
max_height: int,
|
| 519 |
+
keep_size: bool,
|
| 520 |
+
) -> KeypointInternalType:
|
| 521 |
+
x, y, angle, scale = keypoint
|
| 522 |
+
|
| 523 |
+
keypoint_vector = np.array([x, y], dtype=np.float32).reshape([1, 1, 2])
|
| 524 |
+
|
| 525 |
+
x, y = cv2.perspectiveTransform(keypoint_vector, matrix)[0, 0]
|
| 526 |
+
angle += rotation2DMatrixToEulerAngles(matrix[:2, :2], y_up=True)
|
| 527 |
+
|
| 528 |
+
scale_x = np.sign(matrix[0, 0]) * np.sqrt(matrix[0, 0] ** 2 + matrix[0, 1] ** 2)
|
| 529 |
+
scale_y = np.sign(matrix[1, 1]) * np.sqrt(matrix[1, 0] ** 2 + matrix[1, 1] ** 2)
|
| 530 |
+
scale *= max(scale_x, scale_y)
|
| 531 |
+
|
| 532 |
+
if keep_size:
|
| 533 |
+
scale_x = width / max_width
|
| 534 |
+
scale_y = height / max_height
|
| 535 |
+
return keypoint_scale((x, y, angle, scale), scale_x, scale_y)
|
| 536 |
+
|
| 537 |
+
return x, y, angle, scale
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
def _is_identity_matrix(matrix: skimage.transform.ProjectiveTransform) -> bool:
|
| 541 |
+
return np.allclose(matrix.params, np.eye(3, dtype=np.float32))
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
@preserve_channel_dim
|
| 545 |
+
def warp_affine(
|
| 546 |
+
image: np.ndarray,
|
| 547 |
+
matrix: skimage.transform.ProjectiveTransform,
|
| 548 |
+
interpolation: int,
|
| 549 |
+
cval: Union[int, float, Sequence[int], Sequence[float]],
|
| 550 |
+
mode: int,
|
| 551 |
+
output_shape: Sequence[int],
|
| 552 |
+
) -> np.ndarray:
|
| 553 |
+
if _is_identity_matrix(matrix):
|
| 554 |
+
return image
|
| 555 |
+
|
| 556 |
+
dsize = int(np.round(output_shape[1])), int(np.round(output_shape[0]))
|
| 557 |
+
warp_fn = _maybe_process_in_chunks(
|
| 558 |
+
cv2.warpAffine, M=matrix.params[:2], dsize=dsize, flags=interpolation, borderMode=mode, borderValue=cval
|
| 559 |
+
)
|
| 560 |
+
tmp = warp_fn(image)
|
| 561 |
+
return tmp
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
@angle_2pi_range
|
| 565 |
+
def keypoint_affine(
|
| 566 |
+
keypoint: KeypointInternalType,
|
| 567 |
+
matrix: skimage.transform.ProjectiveTransform,
|
| 568 |
+
scale: dict,
|
| 569 |
+
) -> KeypointInternalType:
|
| 570 |
+
if _is_identity_matrix(matrix):
|
| 571 |
+
return keypoint
|
| 572 |
+
|
| 573 |
+
x, y, a, s = keypoint[:4]
|
| 574 |
+
x, y = cv2.transform(np.array([[[x, y]]]), matrix.params[:2]).squeeze()
|
| 575 |
+
a += rotation2DMatrixToEulerAngles(matrix.params[:2])
|
| 576 |
+
s *= np.max([scale["x"], scale["y"]])
|
| 577 |
+
return x, y, a, s
|
| 578 |
+
|
| 579 |
+
|
| 580 |
+
def bbox_affine(
|
| 581 |
+
bbox: BoxInternalType,
|
| 582 |
+
matrix: skimage.transform.ProjectiveTransform,
|
| 583 |
+
rotate_method: str,
|
| 584 |
+
rows: int,
|
| 585 |
+
cols: int,
|
| 586 |
+
output_shape: Sequence[int],
|
| 587 |
+
) -> BoxInternalType:
|
| 588 |
+
if _is_identity_matrix(matrix):
|
| 589 |
+
return bbox
|
| 590 |
+
x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
|
| 591 |
+
if rotate_method == "largest_box":
|
| 592 |
+
points = np.array(
|
| 593 |
+
[
|
| 594 |
+
[x_min, y_min],
|
| 595 |
+
[x_max, y_min],
|
| 596 |
+
[x_max, y_max],
|
| 597 |
+
[x_min, y_max],
|
| 598 |
+
]
|
| 599 |
+
)
|
| 600 |
+
elif rotate_method == "ellipse":
|
| 601 |
+
w = (x_max - x_min) / 2
|
| 602 |
+
h = (y_max - y_min) / 2
|
| 603 |
+
data = np.arange(0, 360, dtype=np.float32)
|
| 604 |
+
x = w * np.sin(np.radians(data)) + (w + x_min - 0.5)
|
| 605 |
+
y = h * np.cos(np.radians(data)) + (h + y_min - 0.5)
|
| 606 |
+
points = np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])
|
| 607 |
+
else:
|
| 608 |
+
raise ValueError(f"Method {rotate_method} is not a valid rotation method.")
|
| 609 |
+
points = skimage.transform.matrix_transform(points, matrix.params)
|
| 610 |
+
x_min = np.min(points[:, 0])
|
| 611 |
+
x_max = np.max(points[:, 0])
|
| 612 |
+
y_min = np.min(points[:, 1])
|
| 613 |
+
y_max = np.max(points[:, 1])
|
| 614 |
+
|
| 615 |
+
return normalize_bbox((x_min, y_min, x_max, y_max), output_shape[0], output_shape[1])
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
@preserve_channel_dim
|
| 619 |
+
def safe_rotate(
|
| 620 |
+
img: np.ndarray,
|
| 621 |
+
matrix: np.ndarray,
|
| 622 |
+
interpolation: int,
|
| 623 |
+
value: FillValueType = None,
|
| 624 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 625 |
+
) -> np.ndarray:
|
| 626 |
+
h, w = img.shape[:2]
|
| 627 |
+
warp_fn = _maybe_process_in_chunks(
|
| 628 |
+
cv2.warpAffine,
|
| 629 |
+
M=matrix,
|
| 630 |
+
dsize=(w, h),
|
| 631 |
+
flags=interpolation,
|
| 632 |
+
borderMode=border_mode,
|
| 633 |
+
borderValue=value,
|
| 634 |
+
)
|
| 635 |
+
return warp_fn(img)
|
| 636 |
+
|
| 637 |
+
|
| 638 |
+
def bbox_safe_rotate(bbox: BoxInternalType, matrix: np.ndarray, cols: int, rows: int) -> BoxInternalType:
|
| 639 |
+
x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4]
|
| 640 |
+
points = np.array(
|
| 641 |
+
[
|
| 642 |
+
[x1, y1, 1],
|
| 643 |
+
[x2, y1, 1],
|
| 644 |
+
[x2, y2, 1],
|
| 645 |
+
[x1, y2, 1],
|
| 646 |
+
]
|
| 647 |
+
)
|
| 648 |
+
points = points @ matrix.T
|
| 649 |
+
x1 = points[:, 0].min()
|
| 650 |
+
x2 = points[:, 0].max()
|
| 651 |
+
y1 = points[:, 1].min()
|
| 652 |
+
y2 = points[:, 1].max()
|
| 653 |
+
|
| 654 |
+
def fix_point(pt1: float, pt2: float, max_val: float) -> Tuple[float, float]:
|
| 655 |
+
# In my opinion, these errors should be very low, around 1-2 pixels.
|
| 656 |
+
if pt1 < 0:
|
| 657 |
+
return 0, pt2 + pt1
|
| 658 |
+
if pt2 > max_val:
|
| 659 |
+
return pt1 - (pt2 - max_val), max_val
|
| 660 |
+
return pt1, pt2
|
| 661 |
+
|
| 662 |
+
x1, x2 = fix_point(x1, x2, cols)
|
| 663 |
+
y1, y2 = fix_point(y1, y2, rows)
|
| 664 |
+
|
| 665 |
+
return normalize_bbox((x1, y1, x2, y2), rows, cols)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def keypoint_safe_rotate(
|
| 669 |
+
keypoint: KeypointInternalType,
|
| 670 |
+
matrix: np.ndarray,
|
| 671 |
+
angle: float,
|
| 672 |
+
scale_x: float,
|
| 673 |
+
scale_y: float,
|
| 674 |
+
cols: int,
|
| 675 |
+
rows: int,
|
| 676 |
+
) -> KeypointInternalType:
|
| 677 |
+
x, y, a, s = keypoint[:4]
|
| 678 |
+
point = np.array([[x, y, 1]])
|
| 679 |
+
x, y = (point @ matrix.T)[0]
|
| 680 |
+
|
| 681 |
+
# To avoid problems with float errors
|
| 682 |
+
x = np.clip(x, 0, cols - 1)
|
| 683 |
+
y = np.clip(y, 0, rows - 1)
|
| 684 |
+
|
| 685 |
+
a += angle
|
| 686 |
+
s *= max(scale_x, scale_y)
|
| 687 |
+
return x, y, a, s
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
@clipped
|
| 691 |
+
def piecewise_affine(
|
| 692 |
+
img: np.ndarray,
|
| 693 |
+
matrix: Optional[skimage.transform.PiecewiseAffineTransform],
|
| 694 |
+
interpolation: int,
|
| 695 |
+
mode: str,
|
| 696 |
+
cval: float,
|
| 697 |
+
) -> np.ndarray:
|
| 698 |
+
if matrix is None:
|
| 699 |
+
return img
|
| 700 |
+
return skimage.transform.warp(
|
| 701 |
+
img, matrix, order=interpolation, mode=mode, cval=cval, preserve_range=True, output_shape=img.shape
|
| 702 |
+
)
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def to_distance_maps(
|
| 706 |
+
keypoints: Sequence[Tuple[float, float]], height: int, width: int, inverted: bool = False
|
| 707 |
+
) -> np.ndarray:
|
| 708 |
+
"""Generate a ``(H,W,N)`` array of distance maps for ``N`` keypoints.
|
| 709 |
+
|
| 710 |
+
The ``n``-th distance map contains at every location ``(y, x)`` the
|
| 711 |
+
euclidean distance to the ``n``-th keypoint.
|
| 712 |
+
|
| 713 |
+
This function can be used as a helper when augmenting keypoints with a
|
| 714 |
+
method that only supports the augmentation of images.
|
| 715 |
+
|
| 716 |
+
Args:
|
| 717 |
+
keypoint: keypoint coordinates
|
| 718 |
+
height: image height
|
| 719 |
+
width: image width
|
| 720 |
+
inverted (bool): If ``True``, inverted distance maps are returned where each
|
| 721 |
+
distance value d is replaced by ``d/(d+1)``, i.e. the distance
|
| 722 |
+
maps have values in the range ``(0.0, 1.0]`` with ``1.0`` denoting
|
| 723 |
+
exactly the position of the respective keypoint.
|
| 724 |
+
|
| 725 |
+
Returns:
|
| 726 |
+
(H, W, N) ndarray
|
| 727 |
+
A ``float32`` array containing ``N`` distance maps for ``N``
|
| 728 |
+
keypoints. Each location ``(y, x, n)`` in the array denotes the
|
| 729 |
+
euclidean distance at ``(y, x)`` to the ``n``-th keypoint.
|
| 730 |
+
If `inverted` is ``True``, the distance ``d`` is replaced
|
| 731 |
+
by ``d/(d+1)``. The height and width of the array match the
|
| 732 |
+
height and width in ``KeypointsOnImage.shape``.
|
| 733 |
+
"""
|
| 734 |
+
distance_maps = np.zeros((height, width, len(keypoints)), dtype=np.float32)
|
| 735 |
+
|
| 736 |
+
yy = np.arange(0, height)
|
| 737 |
+
xx = np.arange(0, width)
|
| 738 |
+
grid_xx, grid_yy = np.meshgrid(xx, yy)
|
| 739 |
+
|
| 740 |
+
for i, (x, y) in enumerate(keypoints):
|
| 741 |
+
distance_maps[:, :, i] = (grid_xx - x) ** 2 + (grid_yy - y) ** 2
|
| 742 |
+
|
| 743 |
+
distance_maps = np.sqrt(distance_maps)
|
| 744 |
+
if inverted:
|
| 745 |
+
return 1 / (distance_maps + 1)
|
| 746 |
+
return distance_maps
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def from_distance_maps(
|
| 750 |
+
distance_maps: np.ndarray,
|
| 751 |
+
inverted: bool,
|
| 752 |
+
if_not_found_coords: Optional[Union[Sequence[int], dict]],
|
| 753 |
+
threshold: Optional[float] = None,
|
| 754 |
+
) -> List[Tuple[float, float]]:
|
| 755 |
+
"""Convert outputs of ``to_distance_maps()`` to ``KeypointsOnImage``.
|
| 756 |
+
This is the inverse of `to_distance_maps`.
|
| 757 |
+
|
| 758 |
+
Args:
|
| 759 |
+
distance_maps (np.ndarray): The distance maps. ``N`` is the number of keypoints.
|
| 760 |
+
inverted (bool): Whether the given distance maps were generated in inverted mode
|
| 761 |
+
(i.e. :func:`KeypointsOnImage.to_distance_maps` was called with ``inverted=True``) or in non-inverted mode.
|
| 762 |
+
if_not_found_coords (tuple, list, dict or None, optional):
|
| 763 |
+
Coordinates to use for keypoints that cannot be found in `distance_maps`.
|
| 764 |
+
|
| 765 |
+
* If this is a ``list``/``tuple``, it must contain two ``int`` values.
|
| 766 |
+
* If it is a ``dict``, it must contain the keys ``x`` and ``y`` with each containing one ``int`` value.
|
| 767 |
+
* If this is ``None``, then the keypoint will not be added.
|
| 768 |
+
threshold (float): The search for keypoints works by searching for the
|
| 769 |
+
argmin (non-inverted) or argmax (inverted) in each channel. This
|
| 770 |
+
parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit
|
| 771 |
+
as a keypoint. Use ``None`` to use no min/max.
|
| 772 |
+
nb_channels (None, int): Number of channels of the image on which the keypoints are placed.
|
| 773 |
+
Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set
|
| 774 |
+
to ``(height, width)``, otherwise ``(height, width, nb_channels)``.
|
| 775 |
+
"""
|
| 776 |
+
if distance_maps.ndim != 3:
|
| 777 |
+
raise ValueError(
|
| 778 |
+
f"Expected three-dimensional input, "
|
| 779 |
+
f"got {distance_maps.ndim} dimensions and shape {distance_maps.shape}."
|
| 780 |
+
)
|
| 781 |
+
height, width, nb_keypoints = distance_maps.shape
|
| 782 |
+
|
| 783 |
+
drop_if_not_found = False
|
| 784 |
+
if if_not_found_coords is None:
|
| 785 |
+
drop_if_not_found = True
|
| 786 |
+
if_not_found_x = -1
|
| 787 |
+
if_not_found_y = -1
|
| 788 |
+
elif isinstance(if_not_found_coords, (tuple, list)):
|
| 789 |
+
if len(if_not_found_coords) != 2:
|
| 790 |
+
raise ValueError(
|
| 791 |
+
f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, "
|
| 792 |
+
f"got {len(if_not_found_coords)}."
|
| 793 |
+
)
|
| 794 |
+
if_not_found_x = if_not_found_coords[0]
|
| 795 |
+
if_not_found_y = if_not_found_coords[1]
|
| 796 |
+
elif isinstance(if_not_found_coords, dict):
|
| 797 |
+
if_not_found_x = if_not_found_coords["x"]
|
| 798 |
+
if_not_found_y = if_not_found_coords["y"]
|
| 799 |
+
else:
|
| 800 |
+
raise ValueError(
|
| 801 |
+
f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}."
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
keypoints = []
|
| 805 |
+
for i in range(nb_keypoints):
|
| 806 |
+
if inverted:
|
| 807 |
+
hitidx_flat = np.argmax(distance_maps[..., i])
|
| 808 |
+
else:
|
| 809 |
+
hitidx_flat = np.argmin(distance_maps[..., i])
|
| 810 |
+
hitidx_ndim = np.unravel_index(hitidx_flat, (height, width))
|
| 811 |
+
if not inverted and threshold is not None:
|
| 812 |
+
found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] < threshold
|
| 813 |
+
elif inverted and threshold is not None:
|
| 814 |
+
found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] >= threshold
|
| 815 |
+
else:
|
| 816 |
+
found = True
|
| 817 |
+
if found:
|
| 818 |
+
keypoints.append((float(hitidx_ndim[1]), float(hitidx_ndim[0])))
|
| 819 |
+
else:
|
| 820 |
+
if not drop_if_not_found:
|
| 821 |
+
keypoints.append((if_not_found_x, if_not_found_y))
|
| 822 |
+
|
| 823 |
+
return keypoints
|
| 824 |
+
|
| 825 |
+
|
| 826 |
+
def keypoint_piecewise_affine(
|
| 827 |
+
keypoint: KeypointInternalType,
|
| 828 |
+
matrix: Optional[skimage.transform.PiecewiseAffineTransform],
|
| 829 |
+
h: int,
|
| 830 |
+
w: int,
|
| 831 |
+
keypoints_threshold: float,
|
| 832 |
+
) -> KeypointInternalType:
|
| 833 |
+
if matrix is None:
|
| 834 |
+
return keypoint
|
| 835 |
+
x, y, a, s = keypoint[:4]
|
| 836 |
+
dist_maps = to_distance_maps([(x, y)], h, w, True)
|
| 837 |
+
dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0)
|
| 838 |
+
x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)[0]
|
| 839 |
+
return x, y, a, s
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
def bbox_piecewise_affine(
|
| 843 |
+
bbox: BoxInternalType,
|
| 844 |
+
matrix: Optional[skimage.transform.PiecewiseAffineTransform],
|
| 845 |
+
h: int,
|
| 846 |
+
w: int,
|
| 847 |
+
keypoints_threshold: float,
|
| 848 |
+
) -> BoxInternalType:
|
| 849 |
+
if matrix is None:
|
| 850 |
+
return bbox
|
| 851 |
+
x1, y1, x2, y2 = denormalize_bbox(bbox, h, w)[:4]
|
| 852 |
+
keypoints = [
|
| 853 |
+
(x1, y1),
|
| 854 |
+
(x2, y1),
|
| 855 |
+
(x2, y2),
|
| 856 |
+
(x1, y2),
|
| 857 |
+
]
|
| 858 |
+
dist_maps = to_distance_maps(keypoints, h, w, True)
|
| 859 |
+
dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0)
|
| 860 |
+
keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)
|
| 861 |
+
keypoints = [i for i in keypoints if 0 <= i[0] < w and 0 <= i[1] < h]
|
| 862 |
+
keypoints_arr = np.array(keypoints)
|
| 863 |
+
x1 = keypoints_arr[:, 0].min()
|
| 864 |
+
y1 = keypoints_arr[:, 1].min()
|
| 865 |
+
x2 = keypoints_arr[:, 0].max()
|
| 866 |
+
y2 = keypoints_arr[:, 1].max()
|
| 867 |
+
return normalize_bbox((x1, y1, x2, y2), h, w)
|
| 868 |
+
|
| 869 |
+
|
| 870 |
+
def vflip(img: np.ndarray) -> np.ndarray:
|
| 871 |
+
return np.ascontiguousarray(img[::-1, ...])
|
| 872 |
+
|
| 873 |
+
|
| 874 |
+
def hflip(img: np.ndarray) -> np.ndarray:
|
| 875 |
+
return np.ascontiguousarray(img[:, ::-1, ...])
|
| 876 |
+
|
| 877 |
+
|
| 878 |
+
def hflip_cv2(img: np.ndarray) -> np.ndarray:
|
| 879 |
+
return cv2.flip(img, 1)
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
@preserve_shape
|
| 883 |
+
def random_flip(img: np.ndarray, code: int) -> np.ndarray:
|
| 884 |
+
return cv2.flip(img, code)
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def transpose(img: np.ndarray) -> np.ndarray:
|
| 888 |
+
return img.transpose(1, 0, 2) if len(img.shape) > 2 else img.transpose(1, 0)
|
| 889 |
+
|
| 890 |
+
|
| 891 |
+
def rot90(img: np.ndarray, factor: int) -> np.ndarray:
|
| 892 |
+
img = np.rot90(img, factor)
|
| 893 |
+
return np.ascontiguousarray(img)
|
| 894 |
+
|
| 895 |
+
|
| 896 |
+
def bbox_vflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
|
| 897 |
+
"""Flip a bounding box vertically around the x-axis.
|
| 898 |
+
|
| 899 |
+
Args:
|
| 900 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 901 |
+
rows: Image rows.
|
| 902 |
+
cols: Image cols.
|
| 903 |
+
|
| 904 |
+
Returns:
|
| 905 |
+
tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 906 |
+
|
| 907 |
+
"""
|
| 908 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 909 |
+
return x_min, 1 - y_max, x_max, 1 - y_min
|
| 910 |
+
|
| 911 |
+
|
| 912 |
+
def bbox_hflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
|
| 913 |
+
"""Flip a bounding box horizontally around the y-axis.
|
| 914 |
+
|
| 915 |
+
Args:
|
| 916 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 917 |
+
rows: Image rows.
|
| 918 |
+
cols: Image cols.
|
| 919 |
+
|
| 920 |
+
Returns:
|
| 921 |
+
A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 922 |
+
|
| 923 |
+
"""
|
| 924 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 925 |
+
return 1 - x_max, y_min, 1 - x_min, y_max
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def bbox_flip(bbox: BoxInternalType, d: int, rows: int, cols: int) -> BoxInternalType:
|
| 929 |
+
"""Flip a bounding box either vertically, horizontally or both depending on the value of `d`.
|
| 930 |
+
|
| 931 |
+
Args:
|
| 932 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 933 |
+
d: dimension. 0 for vertical flip, 1 for horizontal, -1 for transpose
|
| 934 |
+
rows: Image rows.
|
| 935 |
+
cols: Image cols.
|
| 936 |
+
|
| 937 |
+
Returns:
|
| 938 |
+
A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 939 |
+
|
| 940 |
+
Raises:
|
| 941 |
+
ValueError: if value of `d` is not -1, 0 or 1.
|
| 942 |
+
|
| 943 |
+
"""
|
| 944 |
+
if d == 0:
|
| 945 |
+
bbox = bbox_vflip(bbox, rows, cols)
|
| 946 |
+
elif d == 1:
|
| 947 |
+
bbox = bbox_hflip(bbox, rows, cols)
|
| 948 |
+
elif d == -1:
|
| 949 |
+
bbox = bbox_hflip(bbox, rows, cols)
|
| 950 |
+
bbox = bbox_vflip(bbox, rows, cols)
|
| 951 |
+
else:
|
| 952 |
+
raise ValueError("Invalid d value {}. Valid values are -1, 0 and 1".format(d))
|
| 953 |
+
return bbox
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
def bbox_transpose(
|
| 957 |
+
bbox: KeypointInternalType, axis: int, rows: int, cols: int
|
| 958 |
+
) -> KeypointInternalType: # skipcq: PYL-W0613
|
| 959 |
+
"""Transposes a bounding box along given axis.
|
| 960 |
+
|
| 961 |
+
Args:
|
| 962 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 963 |
+
axis: 0 - main axis, 1 - secondary axis.
|
| 964 |
+
rows: Image rows.
|
| 965 |
+
cols: Image cols.
|
| 966 |
+
|
| 967 |
+
Returns:
|
| 968 |
+
A bounding box tuple `(x_min, y_min, x_max, y_max)`.
|
| 969 |
+
|
| 970 |
+
Raises:
|
| 971 |
+
ValueError: If axis not equal to 0 or 1.
|
| 972 |
+
|
| 973 |
+
"""
|
| 974 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 975 |
+
if axis not in {0, 1}:
|
| 976 |
+
raise ValueError("Axis must be either 0 or 1.")
|
| 977 |
+
if axis == 0:
|
| 978 |
+
bbox = (y_min, x_min, y_max, x_max)
|
| 979 |
+
if axis == 1:
|
| 980 |
+
bbox = (1 - y_max, 1 - x_max, 1 - y_min, 1 - x_min)
|
| 981 |
+
return bbox
|
| 982 |
+
|
| 983 |
+
|
| 984 |
+
@angle_2pi_range
|
| 985 |
+
def keypoint_vflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType:
|
| 986 |
+
"""Flip a keypoint vertically around the x-axis.
|
| 987 |
+
|
| 988 |
+
Args:
|
| 989 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 990 |
+
rows: Image height.
|
| 991 |
+
cols: Image width.
|
| 992 |
+
|
| 993 |
+
Returns:
|
| 994 |
+
tuple: A keypoint `(x, y, angle, scale)`.
|
| 995 |
+
|
| 996 |
+
"""
|
| 997 |
+
x, y, angle, scale = keypoint[:4]
|
| 998 |
+
angle = -angle
|
| 999 |
+
return x, (rows - 1) - y, angle, scale
|
| 1000 |
+
|
| 1001 |
+
|
| 1002 |
+
@angle_2pi_range
|
| 1003 |
+
def keypoint_hflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType:
|
| 1004 |
+
"""Flip a keypoint horizontally around the y-axis.
|
| 1005 |
+
|
| 1006 |
+
Args:
|
| 1007 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 1008 |
+
rows: Image height.
|
| 1009 |
+
cols: Image width.
|
| 1010 |
+
|
| 1011 |
+
Returns:
|
| 1012 |
+
A keypoint `(x, y, angle, scale)`.
|
| 1013 |
+
|
| 1014 |
+
"""
|
| 1015 |
+
x, y, angle, scale = keypoint[:4]
|
| 1016 |
+
angle = math.pi - angle
|
| 1017 |
+
return (cols - 1) - x, y, angle, scale
|
| 1018 |
+
|
| 1019 |
+
|
| 1020 |
+
def keypoint_flip(keypoint: KeypointInternalType, d: int, rows: int, cols: int) -> KeypointInternalType:
|
| 1021 |
+
"""Flip a keypoint either vertically, horizontally or both depending on the value of `d`.
|
| 1022 |
+
|
| 1023 |
+
Args:
|
| 1024 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 1025 |
+
d: Number of flip. Must be -1, 0 or 1:
|
| 1026 |
+
* 0 - vertical flip,
|
| 1027 |
+
* 1 - horizontal flip,
|
| 1028 |
+
* -1 - vertical and horizontal flip.
|
| 1029 |
+
rows: Image height.
|
| 1030 |
+
cols: Image width.
|
| 1031 |
+
|
| 1032 |
+
Returns:
|
| 1033 |
+
A keypoint `(x, y, angle, scale)`.
|
| 1034 |
+
|
| 1035 |
+
Raises:
|
| 1036 |
+
ValueError: if value of `d` is not -1, 0 or 1.
|
| 1037 |
+
|
| 1038 |
+
"""
|
| 1039 |
+
if d == 0:
|
| 1040 |
+
keypoint = keypoint_vflip(keypoint, rows, cols)
|
| 1041 |
+
elif d == 1:
|
| 1042 |
+
keypoint = keypoint_hflip(keypoint, rows, cols)
|
| 1043 |
+
elif d == -1:
|
| 1044 |
+
keypoint = keypoint_hflip(keypoint, rows, cols)
|
| 1045 |
+
keypoint = keypoint_vflip(keypoint, rows, cols)
|
| 1046 |
+
else:
|
| 1047 |
+
raise ValueError(f"Invalid d value {d}. Valid values are -1, 0 and 1")
|
| 1048 |
+
return keypoint
|
| 1049 |
+
|
| 1050 |
+
|
| 1051 |
+
def keypoint_transpose(keypoint: KeypointInternalType) -> KeypointInternalType:
|
| 1052 |
+
"""Rotate a keypoint by angle.
|
| 1053 |
+
|
| 1054 |
+
Args:
|
| 1055 |
+
keypoint: A keypoint `(x, y, angle, scale)`.
|
| 1056 |
+
|
| 1057 |
+
Returns:
|
| 1058 |
+
A keypoint `(x, y, angle, scale)`.
|
| 1059 |
+
|
| 1060 |
+
"""
|
| 1061 |
+
x, y, angle, scale = keypoint[:4]
|
| 1062 |
+
|
| 1063 |
+
if angle <= np.pi:
|
| 1064 |
+
angle = np.pi - angle
|
| 1065 |
+
else:
|
| 1066 |
+
angle = 3 * np.pi - angle
|
| 1067 |
+
|
| 1068 |
+
return y, x, angle, scale
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
@preserve_channel_dim
|
| 1072 |
+
def pad(
|
| 1073 |
+
img: np.ndarray,
|
| 1074 |
+
min_height: int,
|
| 1075 |
+
min_width: int,
|
| 1076 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1077 |
+
value: Optional[ImageColorType] = None,
|
| 1078 |
+
) -> np.ndarray:
|
| 1079 |
+
height, width = img.shape[:2]
|
| 1080 |
+
|
| 1081 |
+
if height < min_height:
|
| 1082 |
+
h_pad_top = int((min_height - height) / 2.0)
|
| 1083 |
+
h_pad_bottom = min_height - height - h_pad_top
|
| 1084 |
+
else:
|
| 1085 |
+
h_pad_top = 0
|
| 1086 |
+
h_pad_bottom = 0
|
| 1087 |
+
|
| 1088 |
+
if width < min_width:
|
| 1089 |
+
w_pad_left = int((min_width - width) / 2.0)
|
| 1090 |
+
w_pad_right = min_width - width - w_pad_left
|
| 1091 |
+
else:
|
| 1092 |
+
w_pad_left = 0
|
| 1093 |
+
w_pad_right = 0
|
| 1094 |
+
|
| 1095 |
+
img = pad_with_params(img, h_pad_top, h_pad_bottom, w_pad_left, w_pad_right, border_mode, value)
|
| 1096 |
+
|
| 1097 |
+
if img.shape[:2] != (max(min_height, height), max(min_width, width)):
|
| 1098 |
+
raise RuntimeError(
|
| 1099 |
+
"Invalid result shape. Got: {}. Expected: {}".format(
|
| 1100 |
+
img.shape[:2], (max(min_height, height), max(min_width, width))
|
| 1101 |
+
)
|
| 1102 |
+
)
|
| 1103 |
+
|
| 1104 |
+
return img
|
| 1105 |
+
|
| 1106 |
+
|
| 1107 |
+
@preserve_channel_dim
|
| 1108 |
+
def pad_with_params(
|
| 1109 |
+
img: np.ndarray,
|
| 1110 |
+
h_pad_top: int,
|
| 1111 |
+
h_pad_bottom: int,
|
| 1112 |
+
w_pad_left: int,
|
| 1113 |
+
w_pad_right: int,
|
| 1114 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1115 |
+
value: Optional[ImageColorType] = None,
|
| 1116 |
+
) -> np.ndarray:
|
| 1117 |
+
pad_fn = _maybe_process_in_chunks(
|
| 1118 |
+
cv2.copyMakeBorder,
|
| 1119 |
+
top=h_pad_top,
|
| 1120 |
+
bottom=h_pad_bottom,
|
| 1121 |
+
left=w_pad_left,
|
| 1122 |
+
right=w_pad_right,
|
| 1123 |
+
borderType=border_mode,
|
| 1124 |
+
value=value,
|
| 1125 |
+
)
|
| 1126 |
+
return pad_fn(img)
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
@preserve_shape
|
| 1130 |
+
def optical_distortion(
|
| 1131 |
+
img: np.ndarray,
|
| 1132 |
+
k: int = 0,
|
| 1133 |
+
dx: int = 0,
|
| 1134 |
+
dy: int = 0,
|
| 1135 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 1136 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1137 |
+
value: Optional[ImageColorType] = None,
|
| 1138 |
+
) -> np.ndarray:
|
| 1139 |
+
"""Barrel / pincushion distortion. Unconventional augment.
|
| 1140 |
+
|
| 1141 |
+
Reference:
|
| 1142 |
+
| https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion
|
| 1143 |
+
| https://stackoverflow.com/questions/10364201/image-transformation-in-opencv
|
| 1144 |
+
| https://stackoverflow.com/questions/2477774/correcting-fisheye-distortion-programmatically
|
| 1145 |
+
| http://www.coldvision.io/2017/03/02/advanced-lane-finding-using-opencv/
|
| 1146 |
+
"""
|
| 1147 |
+
height, width = img.shape[:2]
|
| 1148 |
+
|
| 1149 |
+
fx = width
|
| 1150 |
+
fy = height
|
| 1151 |
+
|
| 1152 |
+
cx = width * 0.5 + dx
|
| 1153 |
+
cy = height * 0.5 + dy
|
| 1154 |
+
|
| 1155 |
+
camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
|
| 1156 |
+
|
| 1157 |
+
distortion = np.array([k, k, 0, 0, 0], dtype=np.float32)
|
| 1158 |
+
map1, map2 = cv2.initUndistortRectifyMap(
|
| 1159 |
+
camera_matrix, distortion, None, None, (width, height), cv2.CV_32FC1 # type: ignore[attr-defined]
|
| 1160 |
+
)
|
| 1161 |
+
return cv2.remap(img, map1, map2, interpolation=interpolation, borderMode=border_mode, borderValue=value)
|
| 1162 |
+
|
| 1163 |
+
|
| 1164 |
+
@preserve_shape
|
| 1165 |
+
def grid_distortion(
|
| 1166 |
+
img: np.ndarray,
|
| 1167 |
+
num_steps: int = 10,
|
| 1168 |
+
xsteps: Tuple = (),
|
| 1169 |
+
ysteps: Tuple = (),
|
| 1170 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 1171 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1172 |
+
value: Optional[ImageColorType] = None,
|
| 1173 |
+
) -> np.ndarray:
|
| 1174 |
+
"""Perform a grid distortion of an input image.
|
| 1175 |
+
|
| 1176 |
+
Reference:
|
| 1177 |
+
http://pythology.blogspot.sg/2014/03/interpolation-on-regular-distorted-grid.html
|
| 1178 |
+
"""
|
| 1179 |
+
height, width = img.shape[:2]
|
| 1180 |
+
|
| 1181 |
+
x_step = width // num_steps
|
| 1182 |
+
xx = np.zeros(width, np.float32)
|
| 1183 |
+
prev = 0
|
| 1184 |
+
for idx in range(num_steps + 1):
|
| 1185 |
+
x = idx * x_step
|
| 1186 |
+
start = int(x)
|
| 1187 |
+
end = int(x) + x_step
|
| 1188 |
+
if end > width:
|
| 1189 |
+
end = width
|
| 1190 |
+
cur = width
|
| 1191 |
+
else:
|
| 1192 |
+
cur = prev + x_step * xsteps[idx]
|
| 1193 |
+
|
| 1194 |
+
xx[start:end] = np.linspace(prev, cur, end - start)
|
| 1195 |
+
prev = cur
|
| 1196 |
+
|
| 1197 |
+
y_step = height // num_steps
|
| 1198 |
+
yy = np.zeros(height, np.float32)
|
| 1199 |
+
prev = 0
|
| 1200 |
+
for idx in range(num_steps + 1):
|
| 1201 |
+
y = idx * y_step
|
| 1202 |
+
start = int(y)
|
| 1203 |
+
end = int(y) + y_step
|
| 1204 |
+
if end > height:
|
| 1205 |
+
end = height
|
| 1206 |
+
cur = height
|
| 1207 |
+
else:
|
| 1208 |
+
cur = prev + y_step * ysteps[idx]
|
| 1209 |
+
|
| 1210 |
+
yy[start:end] = np.linspace(prev, cur, end - start)
|
| 1211 |
+
prev = cur
|
| 1212 |
+
|
| 1213 |
+
map_x, map_y = np.meshgrid(xx, yy)
|
| 1214 |
+
map_x = map_x.astype(np.float32)
|
| 1215 |
+
map_y = map_y.astype(np.float32)
|
| 1216 |
+
|
| 1217 |
+
remap_fn = _maybe_process_in_chunks(
|
| 1218 |
+
cv2.remap,
|
| 1219 |
+
map1=map_x,
|
| 1220 |
+
map2=map_y,
|
| 1221 |
+
interpolation=interpolation,
|
| 1222 |
+
borderMode=border_mode,
|
| 1223 |
+
borderValue=value,
|
| 1224 |
+
)
|
| 1225 |
+
return remap_fn(img)
|
| 1226 |
+
|
| 1227 |
+
|
| 1228 |
+
@preserve_shape
|
| 1229 |
+
def elastic_transform_approx(
|
| 1230 |
+
img: np.ndarray,
|
| 1231 |
+
alpha: float,
|
| 1232 |
+
sigma: float,
|
| 1233 |
+
alpha_affine: float,
|
| 1234 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 1235 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1236 |
+
value: Optional[ImageColorType] = None,
|
| 1237 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 1238 |
+
) -> np.ndarray:
|
| 1239 |
+
"""Elastic deformation of images as described in [Simard2003]_ (with modifications for speed).
|
| 1240 |
+
Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
|
| 1241 |
+
|
| 1242 |
+
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
| 1243 |
+
Convolutional Neural Networks applied to Visual Document Analysis", in
|
| 1244 |
+
Proc. of the International Conference on Document Analysis and
|
| 1245 |
+
Recognition, 2003.
|
| 1246 |
+
"""
|
| 1247 |
+
height, width = img.shape[:2]
|
| 1248 |
+
|
| 1249 |
+
# Random affine
|
| 1250 |
+
center_square = np.array((height, width), dtype=np.float32) // 2
|
| 1251 |
+
square_size = min((height, width)) // 3
|
| 1252 |
+
alpha = float(alpha)
|
| 1253 |
+
sigma = float(sigma)
|
| 1254 |
+
alpha_affine = float(alpha_affine)
|
| 1255 |
+
|
| 1256 |
+
pts1 = np.array(
|
| 1257 |
+
[
|
| 1258 |
+
center_square + square_size,
|
| 1259 |
+
[center_square[0] + square_size, center_square[1] - square_size],
|
| 1260 |
+
center_square - square_size,
|
| 1261 |
+
],
|
| 1262 |
+
dtype=np.float32,
|
| 1263 |
+
)
|
| 1264 |
+
pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype(
|
| 1265 |
+
np.float32
|
| 1266 |
+
)
|
| 1267 |
+
matrix = cv2.getAffineTransform(pts1, pts2)
|
| 1268 |
+
|
| 1269 |
+
warp_fn = _maybe_process_in_chunks(
|
| 1270 |
+
cv2.warpAffine,
|
| 1271 |
+
M=matrix,
|
| 1272 |
+
dsize=(width, height),
|
| 1273 |
+
flags=interpolation,
|
| 1274 |
+
borderMode=border_mode,
|
| 1275 |
+
borderValue=value,
|
| 1276 |
+
)
|
| 1277 |
+
img = warp_fn(img)
|
| 1278 |
+
|
| 1279 |
+
dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
|
| 1280 |
+
cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
|
| 1281 |
+
dx *= alpha
|
| 1282 |
+
|
| 1283 |
+
dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
|
| 1284 |
+
cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
|
| 1285 |
+
dy *= alpha
|
| 1286 |
+
|
| 1287 |
+
x, y = np.meshgrid(np.arange(width), np.arange(height))
|
| 1288 |
+
|
| 1289 |
+
map_x = np.float32(x + dx)
|
| 1290 |
+
map_y = np.float32(y + dy)
|
| 1291 |
+
|
| 1292 |
+
remap_fn = _maybe_process_in_chunks(
|
| 1293 |
+
cv2.remap,
|
| 1294 |
+
map1=map_x,
|
| 1295 |
+
map2=map_y,
|
| 1296 |
+
interpolation=interpolation,
|
| 1297 |
+
borderMode=border_mode,
|
| 1298 |
+
borderValue=value,
|
| 1299 |
+
)
|
| 1300 |
+
return remap_fn(img)
|
custom_albumentations/augmentations/geometric/resize.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
from typing import Dict, Sequence, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from ...core.transforms_interface import (
|
| 8 |
+
BoxInternalType,
|
| 9 |
+
DualTransform,
|
| 10 |
+
KeypointInternalType,
|
| 11 |
+
to_tuple,
|
| 12 |
+
)
|
| 13 |
+
from . import functional as F
|
| 14 |
+
|
| 15 |
+
__all__ = ["RandomScale", "LongestMaxSize", "SmallestMaxSize", "Resize"]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class RandomScale(DualTransform):
|
| 19 |
+
"""Randomly resize the input. Output image size is different from the input image size.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
|
| 23 |
+
range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
|
| 24 |
+
If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
|
| 25 |
+
Default: (-0.1, 0.1).
|
| 26 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 27 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 28 |
+
Default: cv2.INTER_LINEAR.
|
| 29 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 30 |
+
|
| 31 |
+
Targets:
|
| 32 |
+
image, mask, bboxes, keypoints
|
| 33 |
+
|
| 34 |
+
Image types:
|
| 35 |
+
uint8, float32
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self, scale_limit=0.1, interpolation=cv2.INTER_LINEAR, always_apply=False, p=0.5):
|
| 39 |
+
super(RandomScale, self).__init__(always_apply, p)
|
| 40 |
+
self.scale_limit = to_tuple(scale_limit, bias=1.0)
|
| 41 |
+
self.interpolation = interpolation
|
| 42 |
+
|
| 43 |
+
def get_params(self):
|
| 44 |
+
return {"scale": random.uniform(self.scale_limit[0], self.scale_limit[1])}
|
| 45 |
+
|
| 46 |
+
def apply(self, img, scale=0, interpolation=cv2.INTER_LINEAR, **params):
|
| 47 |
+
return F.scale(img, scale, interpolation)
|
| 48 |
+
|
| 49 |
+
def apply_to_bbox(self, bbox, **params):
|
| 50 |
+
# Bounding box coordinates are scale invariant
|
| 51 |
+
return bbox
|
| 52 |
+
|
| 53 |
+
def apply_to_keypoint(self, keypoint, scale=0, **params):
|
| 54 |
+
return F.keypoint_scale(keypoint, scale, scale)
|
| 55 |
+
|
| 56 |
+
def get_transform_init_args(self):
|
| 57 |
+
return {"interpolation": self.interpolation, "scale_limit": to_tuple(self.scale_limit, bias=-1.0)}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class LongestMaxSize(DualTransform):
|
| 61 |
+
"""Rescale an image so that maximum side is equal to max_size, keeping the aspect ratio of the initial image.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
max_size (int, list of int): maximum size of the image after the transformation. When using a list, max size
|
| 65 |
+
will be randomly selected from the values in the list.
|
| 66 |
+
interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR.
|
| 67 |
+
p (float): probability of applying the transform. Default: 1.
|
| 68 |
+
|
| 69 |
+
Targets:
|
| 70 |
+
image, mask, bboxes, keypoints
|
| 71 |
+
|
| 72 |
+
Image types:
|
| 73 |
+
uint8, float32
|
| 74 |
+
"""
|
| 75 |
+
|
| 76 |
+
def __init__(
|
| 77 |
+
self,
|
| 78 |
+
max_size: Union[int, Sequence[int]] = 1024,
|
| 79 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 80 |
+
always_apply: bool = False,
|
| 81 |
+
p: float = 1,
|
| 82 |
+
):
|
| 83 |
+
super(LongestMaxSize, self).__init__(always_apply, p)
|
| 84 |
+
self.interpolation = interpolation
|
| 85 |
+
self.max_size = max_size
|
| 86 |
+
|
| 87 |
+
def apply(
|
| 88 |
+
self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params
|
| 89 |
+
) -> np.ndarray:
|
| 90 |
+
return F.longest_max_size(img, max_size=max_size, interpolation=interpolation)
|
| 91 |
+
|
| 92 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 93 |
+
# Bounding box coordinates are scale invariant
|
| 94 |
+
return bbox
|
| 95 |
+
|
| 96 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType:
|
| 97 |
+
height = params["rows"]
|
| 98 |
+
width = params["cols"]
|
| 99 |
+
|
| 100 |
+
scale = max_size / max([height, width])
|
| 101 |
+
return F.keypoint_scale(keypoint, scale, scale)
|
| 102 |
+
|
| 103 |
+
def get_params(self) -> Dict[str, int]:
|
| 104 |
+
return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)}
|
| 105 |
+
|
| 106 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 107 |
+
return ("max_size", "interpolation")
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class SmallestMaxSize(DualTransform):
|
| 111 |
+
"""Rescale an image so that minimum side is equal to max_size, keeping the aspect ratio of the initial image.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
max_size (int, list of int): maximum size of smallest side of the image after the transformation. When using a
|
| 115 |
+
list, max size will be randomly selected from the values in the list.
|
| 116 |
+
interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR.
|
| 117 |
+
p (float): probability of applying the transform. Default: 1.
|
| 118 |
+
|
| 119 |
+
Targets:
|
| 120 |
+
image, mask, bboxes, keypoints
|
| 121 |
+
|
| 122 |
+
Image types:
|
| 123 |
+
uint8, float32
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
max_size: Union[int, Sequence[int]] = 1024,
|
| 129 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 130 |
+
always_apply: bool = False,
|
| 131 |
+
p: float = 1,
|
| 132 |
+
):
|
| 133 |
+
super(SmallestMaxSize, self).__init__(always_apply, p)
|
| 134 |
+
self.interpolation = interpolation
|
| 135 |
+
self.max_size = max_size
|
| 136 |
+
|
| 137 |
+
def apply(
|
| 138 |
+
self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params
|
| 139 |
+
) -> np.ndarray:
|
| 140 |
+
return F.smallest_max_size(img, max_size=max_size, interpolation=interpolation)
|
| 141 |
+
|
| 142 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 143 |
+
return bbox
|
| 144 |
+
|
| 145 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType:
|
| 146 |
+
height = params["rows"]
|
| 147 |
+
width = params["cols"]
|
| 148 |
+
|
| 149 |
+
scale = max_size / min([height, width])
|
| 150 |
+
return F.keypoint_scale(keypoint, scale, scale)
|
| 151 |
+
|
| 152 |
+
def get_params(self) -> Dict[str, int]:
|
| 153 |
+
return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)}
|
| 154 |
+
|
| 155 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 156 |
+
return ("max_size", "interpolation")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class Resize(DualTransform):
|
| 160 |
+
"""Resize the input to the given height and width.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
height (int): desired height of the output.
|
| 164 |
+
width (int): desired width of the output.
|
| 165 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 166 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 167 |
+
Default: cv2.INTER_LINEAR.
|
| 168 |
+
p (float): probability of applying the transform. Default: 1.
|
| 169 |
+
|
| 170 |
+
Targets:
|
| 171 |
+
image, mask, bboxes, keypoints
|
| 172 |
+
|
| 173 |
+
Image types:
|
| 174 |
+
uint8, float32
|
| 175 |
+
"""
|
| 176 |
+
|
| 177 |
+
def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
|
| 178 |
+
super(Resize, self).__init__(always_apply, p)
|
| 179 |
+
self.height = height
|
| 180 |
+
self.width = width
|
| 181 |
+
self.interpolation = interpolation
|
| 182 |
+
|
| 183 |
+
def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
|
| 184 |
+
return F.resize(img, height=self.height, width=self.width, interpolation=interpolation)
|
| 185 |
+
|
| 186 |
+
def apply_to_bbox(self, bbox, **params):
|
| 187 |
+
# Bounding box coordinates are scale invariant
|
| 188 |
+
return bbox
|
| 189 |
+
|
| 190 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 191 |
+
height = params["rows"]
|
| 192 |
+
width = params["cols"]
|
| 193 |
+
scale_x = self.width / width
|
| 194 |
+
scale_y = self.height / height
|
| 195 |
+
return F.keypoint_scale(keypoint, scale_x, scale_y)
|
| 196 |
+
|
| 197 |
+
def get_transform_init_args_names(self):
|
| 198 |
+
return ("height", "width", "interpolation")
|
custom_albumentations/augmentations/geometric/rotate.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from ...core.transforms_interface import (
|
| 9 |
+
BoxInternalType,
|
| 10 |
+
DualTransform,
|
| 11 |
+
FillValueType,
|
| 12 |
+
KeypointInternalType,
|
| 13 |
+
to_tuple,
|
| 14 |
+
)
|
| 15 |
+
from ..crops import functional as FCrops
|
| 16 |
+
from . import functional as F
|
| 17 |
+
|
| 18 |
+
__all__ = ["Rotate", "RandomRotate90", "SafeRotate"]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class RandomRotate90(DualTransform):
|
| 22 |
+
"""Randomly rotate the input by 90 degrees zero or more times.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 26 |
+
|
| 27 |
+
Targets:
|
| 28 |
+
image, mask, bboxes, keypoints
|
| 29 |
+
|
| 30 |
+
Image types:
|
| 31 |
+
uint8, float32
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def apply(self, img, factor=0, **params):
|
| 35 |
+
"""
|
| 36 |
+
Args:
|
| 37 |
+
factor (int): number of times the input will be rotated by 90 degrees.
|
| 38 |
+
"""
|
| 39 |
+
return np.ascontiguousarray(np.rot90(img, factor))
|
| 40 |
+
|
| 41 |
+
def get_params(self):
|
| 42 |
+
# Random int in the range [0, 3]
|
| 43 |
+
return {"factor": random.randint(0, 3)}
|
| 44 |
+
|
| 45 |
+
def apply_to_bbox(self, bbox, factor=0, **params):
|
| 46 |
+
return F.bbox_rot90(bbox, factor, **params)
|
| 47 |
+
|
| 48 |
+
def apply_to_keypoint(self, keypoint, factor=0, **params):
|
| 49 |
+
return F.keypoint_rot90(keypoint, factor, **params)
|
| 50 |
+
|
| 51 |
+
def get_transform_init_args_names(self):
|
| 52 |
+
return ()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Rotate(DualTransform):
|
| 56 |
+
"""Rotate the input by an angle selected randomly from the uniform distribution.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
limit ((int, int) or int): range from which a random angle is picked. If limit is a single int
|
| 60 |
+
an angle is picked from (-limit, limit). Default: (-90, 90)
|
| 61 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 62 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 63 |
+
Default: cv2.INTER_LINEAR.
|
| 64 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 65 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 66 |
+
Default: cv2.BORDER_REFLECT_101
|
| 67 |
+
value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 68 |
+
mask_value (int, float,
|
| 69 |
+
list of ints,
|
| 70 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 71 |
+
rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
|
| 72 |
+
Default: "largest_box"
|
| 73 |
+
crop_border (bool): If True would make a largest possible crop within rotated image
|
| 74 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 75 |
+
|
| 76 |
+
Targets:
|
| 77 |
+
image, mask, bboxes, keypoints
|
| 78 |
+
|
| 79 |
+
Image types:
|
| 80 |
+
uint8, float32
|
| 81 |
+
"""
|
| 82 |
+
|
| 83 |
+
def __init__(
|
| 84 |
+
self,
|
| 85 |
+
limit=90,
|
| 86 |
+
interpolation=cv2.INTER_LINEAR,
|
| 87 |
+
border_mode=cv2.BORDER_REFLECT_101,
|
| 88 |
+
value=None,
|
| 89 |
+
mask_value=None,
|
| 90 |
+
rotate_method="largest_box",
|
| 91 |
+
crop_border=False,
|
| 92 |
+
always_apply=False,
|
| 93 |
+
p=0.5,
|
| 94 |
+
):
|
| 95 |
+
super(Rotate, self).__init__(always_apply, p)
|
| 96 |
+
self.limit = to_tuple(limit)
|
| 97 |
+
self.interpolation = interpolation
|
| 98 |
+
self.border_mode = border_mode
|
| 99 |
+
self.value = value
|
| 100 |
+
self.mask_value = mask_value
|
| 101 |
+
self.rotate_method = rotate_method
|
| 102 |
+
self.crop_border = crop_border
|
| 103 |
+
|
| 104 |
+
if rotate_method not in ["largest_box", "ellipse"]:
|
| 105 |
+
raise ValueError(f"Rotation method {self.rotate_method} is not valid.")
|
| 106 |
+
|
| 107 |
+
def apply(
|
| 108 |
+
self, img, angle=0, interpolation=cv2.INTER_LINEAR, x_min=None, x_max=None, y_min=None, y_max=None, **params
|
| 109 |
+
):
|
| 110 |
+
img_out = F.rotate(img, angle, interpolation, self.border_mode, self.value)
|
| 111 |
+
if self.crop_border:
|
| 112 |
+
img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max)
|
| 113 |
+
return img_out
|
| 114 |
+
|
| 115 |
+
def apply_to_mask(self, img, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, **params):
|
| 116 |
+
img_out = F.rotate(img, angle, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
|
| 117 |
+
if self.crop_border:
|
| 118 |
+
img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max)
|
| 119 |
+
return img_out
|
| 120 |
+
|
| 121 |
+
def apply_to_bbox(self, bbox, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params):
|
| 122 |
+
bbox_out = F.bbox_rotate(bbox, angle, self.rotate_method, rows, cols)
|
| 123 |
+
if self.crop_border:
|
| 124 |
+
bbox_out = FCrops.bbox_crop(bbox_out, x_min, y_min, x_max, y_max, rows, cols)
|
| 125 |
+
return bbox_out
|
| 126 |
+
|
| 127 |
+
def apply_to_keypoint(
|
| 128 |
+
self, keypoint, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params
|
| 129 |
+
):
|
| 130 |
+
keypoint_out = F.keypoint_rotate(keypoint, angle, rows, cols, **params)
|
| 131 |
+
if self.crop_border:
|
| 132 |
+
keypoint_out = FCrops.crop_keypoint_by_coords(keypoint_out, (x_min, y_min, x_max, y_max))
|
| 133 |
+
return keypoint_out
|
| 134 |
+
|
| 135 |
+
@staticmethod
|
| 136 |
+
def _rotated_rect_with_max_area(h, w, angle):
|
| 137 |
+
"""
|
| 138 |
+
Given a rectangle of size wxh that has been rotated by 'angle' (in
|
| 139 |
+
degrees), computes the width and height of the largest possible
|
| 140 |
+
axis-aligned rectangle (maximal area) within the rotated rectangle.
|
| 141 |
+
|
| 142 |
+
Code from: https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
angle = math.radians(angle)
|
| 146 |
+
width_is_longer = w >= h
|
| 147 |
+
side_long, side_short = (w, h) if width_is_longer else (h, w)
|
| 148 |
+
|
| 149 |
+
# since the solutions for angle, -angle and 180-angle are all the same,
|
| 150 |
+
# it is sufficient to look at the first quadrant and the absolute values of sin,cos:
|
| 151 |
+
sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
|
| 152 |
+
if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
|
| 153 |
+
# half constrained case: two crop corners touch the longer side,
|
| 154 |
+
# the other two corners are on the mid-line parallel to the longer line
|
| 155 |
+
x = 0.5 * side_short
|
| 156 |
+
wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
|
| 157 |
+
else:
|
| 158 |
+
# fully constrained case: crop touches all 4 sides
|
| 159 |
+
cos_2a = cos_a * cos_a - sin_a * sin_a
|
| 160 |
+
wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
|
| 161 |
+
|
| 162 |
+
return dict(
|
| 163 |
+
x_min=max(0, int(w / 2 - wr / 2)),
|
| 164 |
+
x_max=min(w, int(w / 2 + wr / 2)),
|
| 165 |
+
y_min=max(0, int(h / 2 - hr / 2)),
|
| 166 |
+
y_max=min(h, int(h / 2 + hr / 2)),
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
@property
|
| 170 |
+
def targets_as_params(self) -> List[str]:
|
| 171 |
+
return ["image"]
|
| 172 |
+
|
| 173 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 174 |
+
out_params = {"angle": random.uniform(self.limit[0], self.limit[1])}
|
| 175 |
+
if self.crop_border:
|
| 176 |
+
h, w = params["image"].shape[:2]
|
| 177 |
+
out_params.update(self._rotated_rect_with_max_area(h, w, out_params["angle"]))
|
| 178 |
+
return out_params
|
| 179 |
+
|
| 180 |
+
def get_transform_init_args_names(self):
|
| 181 |
+
return ("limit", "interpolation", "border_mode", "value", "mask_value", "rotate_method", "crop_border")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class SafeRotate(DualTransform):
|
| 185 |
+
"""Rotate the input inside the input's frame by an angle selected randomly from the uniform distribution.
|
| 186 |
+
|
| 187 |
+
The resulting image may have artifacts in it. After rotation, the image may have a different aspect ratio, and
|
| 188 |
+
after resizing, it returns to its original shape with the original aspect ratio of the image. For these reason we
|
| 189 |
+
may see some artifacts.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
limit ((int, int) or int): range from which a random angle is picked. If limit is a single int
|
| 193 |
+
an angle is picked from (-limit, limit). Default: (-90, 90)
|
| 194 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 195 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 196 |
+
Default: cv2.INTER_LINEAR.
|
| 197 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 198 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 199 |
+
Default: cv2.BORDER_REFLECT_101
|
| 200 |
+
value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 201 |
+
mask_value (int, float,
|
| 202 |
+
list of ints,
|
| 203 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 204 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 205 |
+
|
| 206 |
+
Targets:
|
| 207 |
+
image, mask, bboxes, keypoints
|
| 208 |
+
|
| 209 |
+
Image types:
|
| 210 |
+
uint8, float32
|
| 211 |
+
"""
|
| 212 |
+
|
| 213 |
+
def __init__(
|
| 214 |
+
self,
|
| 215 |
+
limit: Union[float, Tuple[float, float]] = 90,
|
| 216 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 217 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 218 |
+
value: FillValueType = None,
|
| 219 |
+
mask_value: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
|
| 220 |
+
always_apply: bool = False,
|
| 221 |
+
p: float = 0.5,
|
| 222 |
+
):
|
| 223 |
+
super(SafeRotate, self).__init__(always_apply, p)
|
| 224 |
+
self.limit = to_tuple(limit)
|
| 225 |
+
self.interpolation = interpolation
|
| 226 |
+
self.border_mode = border_mode
|
| 227 |
+
self.value = value
|
| 228 |
+
self.mask_value = mask_value
|
| 229 |
+
|
| 230 |
+
def apply(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray:
|
| 231 |
+
return F.safe_rotate(img, matrix, self.interpolation, self.value, self.border_mode)
|
| 232 |
+
|
| 233 |
+
def apply_to_mask(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray:
|
| 234 |
+
return F.safe_rotate(img, matrix, cv2.INTER_NEAREST, self.mask_value, self.border_mode)
|
| 235 |
+
|
| 236 |
+
def apply_to_bbox(self, bbox: BoxInternalType, cols: int = 0, rows: int = 0, **params) -> BoxInternalType:
|
| 237 |
+
return F.bbox_safe_rotate(bbox, params["matrix"], cols, rows)
|
| 238 |
+
|
| 239 |
+
def apply_to_keypoint(
|
| 240 |
+
self,
|
| 241 |
+
keypoint: KeypointInternalType,
|
| 242 |
+
angle: float = 0,
|
| 243 |
+
scale_x: float = 0,
|
| 244 |
+
scale_y: float = 0,
|
| 245 |
+
cols: int = 0,
|
| 246 |
+
rows: int = 0,
|
| 247 |
+
**params
|
| 248 |
+
) -> KeypointInternalType:
|
| 249 |
+
return F.keypoint_safe_rotate(keypoint, params["matrix"], angle, scale_x, scale_y, cols, rows)
|
| 250 |
+
|
| 251 |
+
@property
|
| 252 |
+
def targets_as_params(self) -> List[str]:
|
| 253 |
+
return ["image"]
|
| 254 |
+
|
| 255 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 256 |
+
angle = random.uniform(self.limit[0], self.limit[1])
|
| 257 |
+
|
| 258 |
+
image = params["image"]
|
| 259 |
+
h, w = image.shape[:2]
|
| 260 |
+
|
| 261 |
+
# https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
|
| 262 |
+
image_center = (w / 2, h / 2)
|
| 263 |
+
|
| 264 |
+
# Rotation Matrix
|
| 265 |
+
rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
| 266 |
+
|
| 267 |
+
# rotation calculates the cos and sin, taking absolutes of those.
|
| 268 |
+
abs_cos = abs(rotation_mat[0, 0])
|
| 269 |
+
abs_sin = abs(rotation_mat[0, 1])
|
| 270 |
+
|
| 271 |
+
# find the new width and height bounds
|
| 272 |
+
new_w = math.ceil(h * abs_sin + w * abs_cos)
|
| 273 |
+
new_h = math.ceil(h * abs_cos + w * abs_sin)
|
| 274 |
+
|
| 275 |
+
scale_x = w / new_w
|
| 276 |
+
scale_y = h / new_h
|
| 277 |
+
|
| 278 |
+
# Shift the image to create padding
|
| 279 |
+
rotation_mat[0, 2] += new_w / 2 - image_center[0]
|
| 280 |
+
rotation_mat[1, 2] += new_h / 2 - image_center[1]
|
| 281 |
+
|
| 282 |
+
# Rescale to original size
|
| 283 |
+
scale_mat = np.diag(np.ones(3))
|
| 284 |
+
scale_mat[0, 0] *= scale_x
|
| 285 |
+
scale_mat[1, 1] *= scale_y
|
| 286 |
+
_tmp = np.diag(np.ones(3))
|
| 287 |
+
_tmp[:2] = rotation_mat
|
| 288 |
+
_tmp = scale_mat @ _tmp
|
| 289 |
+
rotation_mat = _tmp[:2]
|
| 290 |
+
|
| 291 |
+
return {"matrix": rotation_mat, "angle": angle, "scale_x": scale_x, "scale_y": scale_y}
|
| 292 |
+
|
| 293 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str]:
|
| 294 |
+
return ("limit", "interpolation", "border_mode", "value", "mask_value")
|
custom_albumentations/augmentations/geometric/transforms.py
ADDED
|
@@ -0,0 +1,1499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from enum import Enum
|
| 4 |
+
from typing import Dict, Optional, Sequence, Tuple, Union
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import skimage.transform
|
| 9 |
+
|
| 10 |
+
from custom_albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox
|
| 11 |
+
|
| 12 |
+
from ... import random_utils
|
| 13 |
+
from ...core.transforms_interface import (
|
| 14 |
+
BoxInternalType,
|
| 15 |
+
DualTransform,
|
| 16 |
+
ImageColorType,
|
| 17 |
+
KeypointInternalType,
|
| 18 |
+
ScaleFloatType,
|
| 19 |
+
to_tuple,
|
| 20 |
+
)
|
| 21 |
+
from ..functional import bbox_from_mask
|
| 22 |
+
from . import functional as F
|
| 23 |
+
|
| 24 |
+
__all__ = [
|
| 25 |
+
"ShiftScaleRotate",
|
| 26 |
+
"ElasticTransform",
|
| 27 |
+
"Perspective",
|
| 28 |
+
"Affine",
|
| 29 |
+
"PiecewiseAffine",
|
| 30 |
+
"VerticalFlip",
|
| 31 |
+
"HorizontalFlip",
|
| 32 |
+
"Flip",
|
| 33 |
+
"Transpose",
|
| 34 |
+
"OpticalDistortion",
|
| 35 |
+
"GridDistortion",
|
| 36 |
+
"PadIfNeeded",
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ShiftScaleRotate(DualTransform):
|
| 41 |
+
"""Randomly apply affine transforms: translate, scale and rotate the input.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit
|
| 45 |
+
is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and
|
| 46 |
+
upper bounds should lie in range [0, 1]. Default: (-0.0625, 0.0625).
|
| 47 |
+
scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
|
| 48 |
+
range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
|
| 49 |
+
If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
|
| 50 |
+
Default: (-0.1, 0.1).
|
| 51 |
+
rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the
|
| 52 |
+
range will be (-rotate_limit, rotate_limit). Default: (-45, 45).
|
| 53 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 54 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 55 |
+
Default: cv2.INTER_LINEAR.
|
| 56 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 57 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 58 |
+
Default: cv2.BORDER_REFLECT_101
|
| 59 |
+
value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 60 |
+
mask_value (int, float,
|
| 61 |
+
list of int,
|
| 62 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 63 |
+
shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value
|
| 64 |
+
instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value,
|
| 65 |
+
the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in
|
| 66 |
+
the range [0, 1]. Default: None.
|
| 67 |
+
shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value
|
| 68 |
+
instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value,
|
| 69 |
+
the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie
|
| 70 |
+
in the range [0, 1]. Default: None.
|
| 71 |
+
rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
|
| 72 |
+
Default: "largest_box"
|
| 73 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 74 |
+
|
| 75 |
+
Targets:
|
| 76 |
+
image, mask, keypoints
|
| 77 |
+
|
| 78 |
+
Image types:
|
| 79 |
+
uint8, float32
|
| 80 |
+
"""
|
| 81 |
+
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
shift_limit=0.0625,
|
| 85 |
+
scale_limit=0.1,
|
| 86 |
+
rotate_limit=45,
|
| 87 |
+
interpolation=cv2.INTER_LINEAR,
|
| 88 |
+
border_mode=cv2.BORDER_REFLECT_101,
|
| 89 |
+
value=None,
|
| 90 |
+
mask_value=None,
|
| 91 |
+
shift_limit_x=None,
|
| 92 |
+
shift_limit_y=None,
|
| 93 |
+
rotate_method="largest_box",
|
| 94 |
+
always_apply=False,
|
| 95 |
+
p=0.5,
|
| 96 |
+
):
|
| 97 |
+
super(ShiftScaleRotate, self).__init__(always_apply, p)
|
| 98 |
+
self.shift_limit_x = to_tuple(shift_limit_x if shift_limit_x is not None else shift_limit)
|
| 99 |
+
self.shift_limit_y = to_tuple(shift_limit_y if shift_limit_y is not None else shift_limit)
|
| 100 |
+
self.scale_limit = to_tuple(scale_limit, bias=1.0)
|
| 101 |
+
self.rotate_limit = to_tuple(rotate_limit)
|
| 102 |
+
self.interpolation = interpolation
|
| 103 |
+
self.border_mode = border_mode
|
| 104 |
+
self.value = value
|
| 105 |
+
self.mask_value = mask_value
|
| 106 |
+
self.rotate_method = rotate_method
|
| 107 |
+
|
| 108 |
+
if self.rotate_method not in ["largest_box", "ellipse"]:
|
| 109 |
+
raise ValueError(f"Rotation method {self.rotate_method} is not valid.")
|
| 110 |
+
|
| 111 |
+
def apply(self, img, angle=0, scale=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params):
|
| 112 |
+
return F.shift_scale_rotate(img, angle, scale, dx, dy, interpolation, self.border_mode, self.value)
|
| 113 |
+
|
| 114 |
+
def apply_to_mask(self, img, angle=0, scale=0, dx=0, dy=0, **params):
|
| 115 |
+
return F.shift_scale_rotate(img, angle, scale, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
|
| 116 |
+
|
| 117 |
+
def apply_to_keypoint(self, keypoint, angle=0, scale=0, dx=0, dy=0, rows=0, cols=0, **params):
|
| 118 |
+
return F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols)
|
| 119 |
+
|
| 120 |
+
def get_params(self):
|
| 121 |
+
return {
|
| 122 |
+
"angle": random.uniform(self.rotate_limit[0], self.rotate_limit[1]),
|
| 123 |
+
"scale": random.uniform(self.scale_limit[0], self.scale_limit[1]),
|
| 124 |
+
"dx": random.uniform(self.shift_limit_x[0], self.shift_limit_x[1]),
|
| 125 |
+
"dy": random.uniform(self.shift_limit_y[0], self.shift_limit_y[1]),
|
| 126 |
+
}
|
| 127 |
+
|
| 128 |
+
def apply_to_bbox(self, bbox, angle, scale, dx, dy, **params):
|
| 129 |
+
return F.bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, self.rotate_method, **params)
|
| 130 |
+
|
| 131 |
+
def get_transform_init_args(self):
|
| 132 |
+
return {
|
| 133 |
+
"shift_limit_x": self.shift_limit_x,
|
| 134 |
+
"shift_limit_y": self.shift_limit_y,
|
| 135 |
+
"scale_limit": to_tuple(self.scale_limit, bias=-1.0),
|
| 136 |
+
"rotate_limit": self.rotate_limit,
|
| 137 |
+
"interpolation": self.interpolation,
|
| 138 |
+
"border_mode": self.border_mode,
|
| 139 |
+
"value": self.value,
|
| 140 |
+
"mask_value": self.mask_value,
|
| 141 |
+
"rotate_method": self.rotate_method,
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class ElasticTransform(DualTransform):
|
| 146 |
+
"""Elastic deformation of images as described in [Simard2003]_ (with modifications).
|
| 147 |
+
Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
|
| 148 |
+
|
| 149 |
+
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
|
| 150 |
+
Convolutional Neural Networks applied to Visual Document Analysis", in
|
| 151 |
+
Proc. of the International Conference on Document Analysis and
|
| 152 |
+
Recognition, 2003.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
alpha (float):
|
| 156 |
+
sigma (float): Gaussian filter parameter.
|
| 157 |
+
alpha_affine (float): The range will be (-alpha_affine, alpha_affine)
|
| 158 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 159 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 160 |
+
Default: cv2.INTER_LINEAR.
|
| 161 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 162 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 163 |
+
Default: cv2.BORDER_REFLECT_101
|
| 164 |
+
value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 165 |
+
mask_value (int, float,
|
| 166 |
+
list of ints,
|
| 167 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 168 |
+
approximate (boolean): Whether to smooth displacement map with fixed kernel size.
|
| 169 |
+
Enabling this option gives ~2X speedup on large images.
|
| 170 |
+
same_dxdy (boolean): Whether to use same random generated shift for x and y.
|
| 171 |
+
Enabling this option gives ~2X speedup.
|
| 172 |
+
|
| 173 |
+
Targets:
|
| 174 |
+
image, mask, bbox
|
| 175 |
+
|
| 176 |
+
Image types:
|
| 177 |
+
uint8, float32
|
| 178 |
+
"""
|
| 179 |
+
|
| 180 |
+
def __init__(
|
| 181 |
+
self,
|
| 182 |
+
alpha=1,
|
| 183 |
+
sigma=50,
|
| 184 |
+
alpha_affine=50,
|
| 185 |
+
interpolation=cv2.INTER_LINEAR,
|
| 186 |
+
border_mode=cv2.BORDER_REFLECT_101,
|
| 187 |
+
value=None,
|
| 188 |
+
mask_value=None,
|
| 189 |
+
always_apply=False,
|
| 190 |
+
approximate=False,
|
| 191 |
+
same_dxdy=False,
|
| 192 |
+
p=0.5,
|
| 193 |
+
):
|
| 194 |
+
super(ElasticTransform, self).__init__(always_apply, p)
|
| 195 |
+
self.alpha = alpha
|
| 196 |
+
self.alpha_affine = alpha_affine
|
| 197 |
+
self.sigma = sigma
|
| 198 |
+
self.interpolation = interpolation
|
| 199 |
+
self.border_mode = border_mode
|
| 200 |
+
self.value = value
|
| 201 |
+
self.mask_value = mask_value
|
| 202 |
+
self.approximate = approximate
|
| 203 |
+
self.same_dxdy = same_dxdy
|
| 204 |
+
|
| 205 |
+
def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params):
|
| 206 |
+
return F.elastic_transform(
|
| 207 |
+
img,
|
| 208 |
+
self.alpha,
|
| 209 |
+
self.sigma,
|
| 210 |
+
self.alpha_affine,
|
| 211 |
+
interpolation,
|
| 212 |
+
self.border_mode,
|
| 213 |
+
self.value,
|
| 214 |
+
np.random.RandomState(random_state),
|
| 215 |
+
self.approximate,
|
| 216 |
+
self.same_dxdy,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
def apply_to_mask(self, img, random_state=None, **params):
|
| 220 |
+
return F.elastic_transform(
|
| 221 |
+
img,
|
| 222 |
+
self.alpha,
|
| 223 |
+
self.sigma,
|
| 224 |
+
self.alpha_affine,
|
| 225 |
+
cv2.INTER_NEAREST,
|
| 226 |
+
self.border_mode,
|
| 227 |
+
self.mask_value,
|
| 228 |
+
np.random.RandomState(random_state),
|
| 229 |
+
self.approximate,
|
| 230 |
+
self.same_dxdy,
|
| 231 |
+
)
|
| 232 |
+
|
| 233 |
+
def apply_to_bbox(self, bbox, random_state=None, **params):
|
| 234 |
+
rows, cols = params["rows"], params["cols"]
|
| 235 |
+
mask = np.zeros((rows, cols), dtype=np.uint8)
|
| 236 |
+
bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
|
| 237 |
+
x_min, y_min, x_max, y_max = bbox_denorm[:4]
|
| 238 |
+
x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
|
| 239 |
+
mask[y_min:y_max, x_min:x_max] = 1
|
| 240 |
+
mask = F.elastic_transform(
|
| 241 |
+
mask,
|
| 242 |
+
self.alpha,
|
| 243 |
+
self.sigma,
|
| 244 |
+
self.alpha_affine,
|
| 245 |
+
cv2.INTER_NEAREST,
|
| 246 |
+
self.border_mode,
|
| 247 |
+
self.mask_value,
|
| 248 |
+
np.random.RandomState(random_state),
|
| 249 |
+
self.approximate,
|
| 250 |
+
)
|
| 251 |
+
bbox_returned = bbox_from_mask(mask)
|
| 252 |
+
bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
|
| 253 |
+
return bbox_returned
|
| 254 |
+
|
| 255 |
+
def get_params(self):
|
| 256 |
+
return {"random_state": random.randint(0, 10000)}
|
| 257 |
+
|
| 258 |
+
def get_transform_init_args_names(self):
|
| 259 |
+
return (
|
| 260 |
+
"alpha",
|
| 261 |
+
"sigma",
|
| 262 |
+
"alpha_affine",
|
| 263 |
+
"interpolation",
|
| 264 |
+
"border_mode",
|
| 265 |
+
"value",
|
| 266 |
+
"mask_value",
|
| 267 |
+
"approximate",
|
| 268 |
+
"same_dxdy",
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class Perspective(DualTransform):
|
| 273 |
+
"""Perform a random four point perspective transform of the input.
|
| 274 |
+
|
| 275 |
+
Args:
|
| 276 |
+
scale (float or (float, float)): standard deviation of the normal distributions. These are used to sample
|
| 277 |
+
the random distances of the subimage's corners from the full image's corners.
|
| 278 |
+
If scale is a single float value, the range will be (0, scale). Default: (0.05, 0.1).
|
| 279 |
+
keep_size (bool): Whether to resize image’s back to their original size after applying the perspective
|
| 280 |
+
transform. If set to False, the resulting images may end up having different shapes
|
| 281 |
+
and will always be a list, never an array. Default: True
|
| 282 |
+
pad_mode (OpenCV flag): OpenCV border mode.
|
| 283 |
+
pad_val (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 284 |
+
Default: 0
|
| 285 |
+
mask_pad_val (int, float, list of int, list of float): padding value for mask
|
| 286 |
+
if border_mode is cv2.BORDER_CONSTANT. Default: 0
|
| 287 |
+
fit_output (bool): If True, the image plane size and position will be adjusted to still capture
|
| 288 |
+
the whole image after perspective transformation. (Followed by image resizing if keep_size is set to True.)
|
| 289 |
+
Otherwise, parts of the transformed image may be outside of the image plane.
|
| 290 |
+
This setting should not be set to True when using large scale values as it could lead to very large images.
|
| 291 |
+
Default: False
|
| 292 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 293 |
+
|
| 294 |
+
Targets:
|
| 295 |
+
image, mask, keypoints, bboxes
|
| 296 |
+
|
| 297 |
+
Image types:
|
| 298 |
+
uint8, float32
|
| 299 |
+
"""
|
| 300 |
+
|
| 301 |
+
def __init__(
|
| 302 |
+
self,
|
| 303 |
+
scale=(0.05, 0.1),
|
| 304 |
+
keep_size=True,
|
| 305 |
+
pad_mode=cv2.BORDER_CONSTANT,
|
| 306 |
+
pad_val=0,
|
| 307 |
+
mask_pad_val=0,
|
| 308 |
+
fit_output=False,
|
| 309 |
+
interpolation=cv2.INTER_LINEAR,
|
| 310 |
+
always_apply=False,
|
| 311 |
+
p=0.5,
|
| 312 |
+
):
|
| 313 |
+
super().__init__(always_apply, p)
|
| 314 |
+
self.scale = to_tuple(scale, 0)
|
| 315 |
+
self.keep_size = keep_size
|
| 316 |
+
self.pad_mode = pad_mode
|
| 317 |
+
self.pad_val = pad_val
|
| 318 |
+
self.mask_pad_val = mask_pad_val
|
| 319 |
+
self.fit_output = fit_output
|
| 320 |
+
self.interpolation = interpolation
|
| 321 |
+
|
| 322 |
+
def apply(self, img, matrix=None, max_height=None, max_width=None, **params):
|
| 323 |
+
return F.perspective(
|
| 324 |
+
img, matrix, max_width, max_height, self.pad_val, self.pad_mode, self.keep_size, params["interpolation"]
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
def apply_to_bbox(self, bbox, matrix=None, max_height=None, max_width=None, **params):
|
| 328 |
+
return F.perspective_bbox(bbox, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size)
|
| 329 |
+
|
| 330 |
+
def apply_to_keypoint(self, keypoint, matrix=None, max_height=None, max_width=None, **params):
|
| 331 |
+
return F.perspective_keypoint(
|
| 332 |
+
keypoint, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
@property
|
| 336 |
+
def targets_as_params(self):
|
| 337 |
+
return ["image"]
|
| 338 |
+
|
| 339 |
+
def get_params_dependent_on_targets(self, params):
|
| 340 |
+
h, w = params["image"].shape[:2]
|
| 341 |
+
|
| 342 |
+
scale = random_utils.uniform(*self.scale)
|
| 343 |
+
points = random_utils.normal(0, scale, [4, 2])
|
| 344 |
+
points = np.mod(np.abs(points), 0.32)
|
| 345 |
+
|
| 346 |
+
# top left -- no changes needed, just use jitter
|
| 347 |
+
# top right
|
| 348 |
+
points[1, 0] = 1.0 - points[1, 0] # w = 1.0 - jitter
|
| 349 |
+
# bottom right
|
| 350 |
+
points[2] = 1.0 - points[2] # w = 1.0 - jitt
|
| 351 |
+
# bottom left
|
| 352 |
+
points[3, 1] = 1.0 - points[3, 1] # h = 1.0 - jitter
|
| 353 |
+
|
| 354 |
+
points[:, 0] *= w
|
| 355 |
+
points[:, 1] *= h
|
| 356 |
+
|
| 357 |
+
# Obtain a consistent order of the points and unpack them individually.
|
| 358 |
+
# Warning: don't just do (tl, tr, br, bl) = _order_points(...)
|
| 359 |
+
# here, because the reordered points is used further below.
|
| 360 |
+
points = self._order_points(points)
|
| 361 |
+
tl, tr, br, bl = points
|
| 362 |
+
|
| 363 |
+
# compute the width of the new image, which will be the
|
| 364 |
+
# maximum distance between bottom-right and bottom-left
|
| 365 |
+
# x-coordiates or the top-right and top-left x-coordinates
|
| 366 |
+
min_width = None
|
| 367 |
+
max_width = None
|
| 368 |
+
while min_width is None or min_width < 2:
|
| 369 |
+
width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
|
| 370 |
+
width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
|
| 371 |
+
max_width = int(max(width_top, width_bottom))
|
| 372 |
+
min_width = int(min(width_top, width_bottom))
|
| 373 |
+
if min_width < 2:
|
| 374 |
+
step_size = (2 - min_width) / 2
|
| 375 |
+
tl[0] -= step_size
|
| 376 |
+
tr[0] += step_size
|
| 377 |
+
bl[0] -= step_size
|
| 378 |
+
br[0] += step_size
|
| 379 |
+
|
| 380 |
+
# compute the height of the new image, which will be the maximum distance between the top-right
|
| 381 |
+
# and bottom-right y-coordinates or the top-left and bottom-left y-coordinates
|
| 382 |
+
min_height = None
|
| 383 |
+
max_height = None
|
| 384 |
+
while min_height is None or min_height < 2:
|
| 385 |
+
height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
|
| 386 |
+
height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
|
| 387 |
+
max_height = int(max(height_right, height_left))
|
| 388 |
+
min_height = int(min(height_right, height_left))
|
| 389 |
+
if min_height < 2:
|
| 390 |
+
step_size = (2 - min_height) / 2
|
| 391 |
+
tl[1] -= step_size
|
| 392 |
+
tr[1] -= step_size
|
| 393 |
+
bl[1] += step_size
|
| 394 |
+
br[1] += step_size
|
| 395 |
+
|
| 396 |
+
# now that we have the dimensions of the new image, construct
|
| 397 |
+
# the set of destination points to obtain a "birds eye view",
|
| 398 |
+
# (i.e. top-down view) of the image, again specifying points
|
| 399 |
+
# in the top-left, top-right, bottom-right, and bottom-left order
|
| 400 |
+
# do not use width-1 or height-1 here, as for e.g. width=3, height=2
|
| 401 |
+
# the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
|
| 402 |
+
dst = np.array([[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]], dtype=np.float32)
|
| 403 |
+
|
| 404 |
+
# compute the perspective transform matrix and then apply it
|
| 405 |
+
m = cv2.getPerspectiveTransform(points, dst)
|
| 406 |
+
|
| 407 |
+
if self.fit_output:
|
| 408 |
+
m, max_width, max_height = self._expand_transform(m, (h, w))
|
| 409 |
+
|
| 410 |
+
return {"matrix": m, "max_height": max_height, "max_width": max_width, "interpolation": self.interpolation}
|
| 411 |
+
|
| 412 |
+
@classmethod
|
| 413 |
+
def _expand_transform(cls, matrix, shape):
|
| 414 |
+
height, width = shape
|
| 415 |
+
# do not use width-1 or height-1 here, as for e.g. width=3, height=2, max_height
|
| 416 |
+
# the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
|
| 417 |
+
rect = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)
|
| 418 |
+
dst = cv2.perspectiveTransform(np.array([rect]), matrix)[0]
|
| 419 |
+
|
| 420 |
+
# get min x, y over transformed 4 points
|
| 421 |
+
# then modify target points by subtracting these minima => shift to (0, 0)
|
| 422 |
+
dst -= dst.min(axis=0, keepdims=True)
|
| 423 |
+
dst = np.around(dst, decimals=0)
|
| 424 |
+
|
| 425 |
+
matrix_expanded = cv2.getPerspectiveTransform(rect, dst)
|
| 426 |
+
max_width, max_height = dst.max(axis=0)
|
| 427 |
+
return matrix_expanded, int(max_width), int(max_height)
|
| 428 |
+
|
| 429 |
+
@staticmethod
|
| 430 |
+
def _order_points(pts: np.ndarray) -> np.ndarray:
|
| 431 |
+
pts = np.array(sorted(pts, key=lambda x: x[0]))
|
| 432 |
+
left = pts[:2] # points with smallest x coordinate - left points
|
| 433 |
+
right = pts[2:] # points with greatest x coordinate - right points
|
| 434 |
+
|
| 435 |
+
if left[0][1] < left[1][1]:
|
| 436 |
+
tl, bl = left
|
| 437 |
+
else:
|
| 438 |
+
bl, tl = left
|
| 439 |
+
|
| 440 |
+
if right[0][1] < right[1][1]:
|
| 441 |
+
tr, br = right
|
| 442 |
+
else:
|
| 443 |
+
br, tr = right
|
| 444 |
+
|
| 445 |
+
return np.array([tl, tr, br, bl], dtype=np.float32)
|
| 446 |
+
|
| 447 |
+
def get_transform_init_args_names(self):
|
| 448 |
+
return "scale", "keep_size", "pad_mode", "pad_val", "mask_pad_val", "fit_output", "interpolation"
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
class Affine(DualTransform):
|
| 452 |
+
"""Augmentation to apply affine transformations to images.
|
| 453 |
+
This is mostly a wrapper around the corresponding classes and functions in OpenCV.
|
| 454 |
+
|
| 455 |
+
Affine transformations involve:
|
| 456 |
+
|
| 457 |
+
- Translation ("move" image on the x-/y-axis)
|
| 458 |
+
- Rotation
|
| 459 |
+
- Scaling ("zoom" in/out)
|
| 460 |
+
- Shear (move one side of the image, turning a square into a trapezoid)
|
| 461 |
+
|
| 462 |
+
All such transformations can create "new" pixels in the image without a defined content, e.g.
|
| 463 |
+
if the image is translated to the left, pixels are created on the right.
|
| 464 |
+
A method has to be defined to deal with these pixel values.
|
| 465 |
+
The parameters `cval` and `mode` of this class deal with this.
|
| 466 |
+
|
| 467 |
+
Some transformations involve interpolations between several pixels
|
| 468 |
+
of the input image to generate output pixel values. The parameters `interpolation` and
|
| 469 |
+
`mask_interpolation` deals with the method of interpolation used for this.
|
| 470 |
+
|
| 471 |
+
Args:
|
| 472 |
+
scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and
|
| 473 |
+
``0.5`` is zoomed out to ``50`` percent of the original size.
|
| 474 |
+
* If a single number, then that value will be used for all images.
|
| 475 |
+
* If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
|
| 476 |
+
That the same range will be used for both x- and y-axis. To keep the aspect ratio, set
|
| 477 |
+
``keep_ratio=True``, then the same value will be used for both x- and y-axis.
|
| 478 |
+
* If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
|
| 479 |
+
Each of these keys can have the same values as described above.
|
| 480 |
+
Using a dictionary allows to set different values for the two axis and sampling will then happen
|
| 481 |
+
*independently* per axis, resulting in samples that differ between the axes. Note that when
|
| 482 |
+
the ``keep_ratio=True``, the x- and y-axis ranges should be the same.
|
| 483 |
+
translate_percent (None, number, tuple of number or dict): Translation as a fraction of the image height/width
|
| 484 |
+
(x-translation, y-translation), where ``0`` denotes "no change"
|
| 485 |
+
and ``0.5`` denotes "half of the axis size".
|
| 486 |
+
* If ``None`` then equivalent to ``0.0`` unless `translate_px` has a value other than ``None``.
|
| 487 |
+
* If a single number, then that value will be used for all images.
|
| 488 |
+
* If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
|
| 489 |
+
That sampled fraction value will be used identically for both x- and y-axis.
|
| 490 |
+
* If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
|
| 491 |
+
Each of these keys can have the same values as described above.
|
| 492 |
+
Using a dictionary allows to set different values for the two axis and sampling will then happen
|
| 493 |
+
*independently* per axis, resulting in samples that differ between the axes.
|
| 494 |
+
translate_px (None, int, tuple of int or dict): Translation in pixels.
|
| 495 |
+
* If ``None`` then equivalent to ``0`` unless `translate_percent` has a value other than ``None``.
|
| 496 |
+
* If a single int, then that value will be used for all images.
|
| 497 |
+
* If a tuple ``(a, b)``, then a value will be uniformly sampled per image from
|
| 498 |
+
the discrete interval ``[a..b]``. That number will be used identically for both x- and y-axis.
|
| 499 |
+
* If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
|
| 500 |
+
Each of these keys can have the same values as described above.
|
| 501 |
+
Using a dictionary allows to set different values for the two axis and sampling will then happen
|
| 502 |
+
*independently* per axis, resulting in samples that differ between the axes.
|
| 503 |
+
rotate (number or tuple of number): Rotation in degrees (**NOT** radians), i.e. expected value range is
|
| 504 |
+
around ``[-360, 360]``. Rotation happens around the *center* of the image,
|
| 505 |
+
not the top left corner as in some other frameworks.
|
| 506 |
+
* If a number, then that value will be used for all images.
|
| 507 |
+
* If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``
|
| 508 |
+
and used as the rotation value.
|
| 509 |
+
shear (number, tuple of number or dict): Shear in degrees (**NOT** radians), i.e. expected value range is
|
| 510 |
+
around ``[-360, 360]``, with reasonable values being in the range of ``[-45, 45]``.
|
| 511 |
+
* If a number, then that value will be used for all images as
|
| 512 |
+
the shear on the x-axis (no shear on the y-axis will be done).
|
| 513 |
+
* If a tuple ``(a, b)``, then two value will be uniformly sampled per image
|
| 514 |
+
from the interval ``[a, b]`` and be used as the x- and y-shear value.
|
| 515 |
+
* If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
|
| 516 |
+
Each of these keys can have the same values as described above.
|
| 517 |
+
Using a dictionary allows to set different values for the two axis and sampling will then happen
|
| 518 |
+
*independently* per axis, resulting in samples that differ between the axes.
|
| 519 |
+
interpolation (int): OpenCV interpolation flag.
|
| 520 |
+
mask_interpolation (int): OpenCV interpolation flag.
|
| 521 |
+
cval (number or sequence of number): The constant value to use when filling in newly created pixels.
|
| 522 |
+
(E.g. translating by 1px to the right will create a new 1px-wide column of pixels
|
| 523 |
+
on the left of the image).
|
| 524 |
+
The value is only used when `mode=constant`. The expected value range is ``[0, 255]`` for ``uint8`` images.
|
| 525 |
+
cval_mask (number or tuple of number): Same as cval but only for masks.
|
| 526 |
+
mode (int): OpenCV border flag.
|
| 527 |
+
fit_output (bool): If True, the image plane size and position will be adjusted to tightly capture
|
| 528 |
+
the whole image after affine transformation (`translate_percent` and `translate_px` are ignored).
|
| 529 |
+
Otherwise (``False``), parts of the transformed image may end up outside the image plane.
|
| 530 |
+
Fitting the output shape can be useful to avoid corners of the image being outside the image plane
|
| 531 |
+
after applying rotations. Default: False
|
| 532 |
+
keep_ratio (bool): When True, the original aspect ratio will be kept when the random scale is applied.
|
| 533 |
+
Default: False.
|
| 534 |
+
rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or
|
| 535 |
+
"ellipse"[1].
|
| 536 |
+
Default: "largest_box"
|
| 537 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 538 |
+
|
| 539 |
+
Targets:
|
| 540 |
+
image, mask, keypoints, bboxes
|
| 541 |
+
|
| 542 |
+
Image types:
|
| 543 |
+
uint8, float32
|
| 544 |
+
|
| 545 |
+
Reference:
|
| 546 |
+
[1] https://arxiv.org/abs/2109.13488
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
def __init__(
|
| 550 |
+
self,
|
| 551 |
+
scale: Optional[Union[float, Sequence[float], dict]] = None,
|
| 552 |
+
translate_percent: Optional[Union[float, Sequence[float], dict]] = None,
|
| 553 |
+
translate_px: Optional[Union[int, Sequence[int], dict]] = None,
|
| 554 |
+
rotate: Optional[Union[float, Sequence[float]]] = None,
|
| 555 |
+
shear: Optional[Union[float, Sequence[float], dict]] = None,
|
| 556 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 557 |
+
mask_interpolation: int = cv2.INTER_NEAREST,
|
| 558 |
+
cval: Union[int, float, Sequence[int], Sequence[float]] = 0,
|
| 559 |
+
cval_mask: Union[int, float, Sequence[int], Sequence[float]] = 0,
|
| 560 |
+
mode: int = cv2.BORDER_CONSTANT,
|
| 561 |
+
fit_output: bool = False,
|
| 562 |
+
keep_ratio: bool = False,
|
| 563 |
+
rotate_method: str = "largest_box",
|
| 564 |
+
always_apply: bool = False,
|
| 565 |
+
p: float = 0.5,
|
| 566 |
+
):
|
| 567 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 568 |
+
|
| 569 |
+
params = [scale, translate_percent, translate_px, rotate, shear]
|
| 570 |
+
if all([p is None for p in params]):
|
| 571 |
+
scale = {"x": (0.9, 1.1), "y": (0.9, 1.1)}
|
| 572 |
+
translate_percent = {"x": (-0.1, 0.1), "y": (-0.1, 0.1)}
|
| 573 |
+
rotate = (-15, 15)
|
| 574 |
+
shear = {"x": (-10, 10), "y": (-10, 10)}
|
| 575 |
+
else:
|
| 576 |
+
scale = scale if scale is not None else 1.0
|
| 577 |
+
rotate = rotate if rotate is not None else 0.0
|
| 578 |
+
shear = shear if shear is not None else 0.0
|
| 579 |
+
|
| 580 |
+
self.interpolation = interpolation
|
| 581 |
+
self.mask_interpolation = mask_interpolation
|
| 582 |
+
self.cval = cval
|
| 583 |
+
self.cval_mask = cval_mask
|
| 584 |
+
self.mode = mode
|
| 585 |
+
self.scale = self._handle_dict_arg(scale, "scale")
|
| 586 |
+
self.translate_percent, self.translate_px = self._handle_translate_arg(translate_px, translate_percent)
|
| 587 |
+
self.rotate = to_tuple(rotate, rotate)
|
| 588 |
+
self.fit_output = fit_output
|
| 589 |
+
self.shear = self._handle_dict_arg(shear, "shear")
|
| 590 |
+
self.keep_ratio = keep_ratio
|
| 591 |
+
self.rotate_method = rotate_method
|
| 592 |
+
|
| 593 |
+
if self.keep_ratio and self.scale["x"] != self.scale["y"]:
|
| 594 |
+
raise ValueError(
|
| 595 |
+
"When keep_ratio is True, the x and y scale range should be identical. got {}".format(self.scale)
|
| 596 |
+
)
|
| 597 |
+
|
| 598 |
+
def get_transform_init_args_names(self):
|
| 599 |
+
return (
|
| 600 |
+
"interpolation",
|
| 601 |
+
"mask_interpolation",
|
| 602 |
+
"cval",
|
| 603 |
+
"mode",
|
| 604 |
+
"scale",
|
| 605 |
+
"translate_percent",
|
| 606 |
+
"translate_px",
|
| 607 |
+
"rotate",
|
| 608 |
+
"fit_output",
|
| 609 |
+
"shear",
|
| 610 |
+
"cval_mask",
|
| 611 |
+
"keep_ratio",
|
| 612 |
+
"rotate_method",
|
| 613 |
+
)
|
| 614 |
+
|
| 615 |
+
@staticmethod
|
| 616 |
+
def _handle_dict_arg(val: Union[float, Sequence[float], dict], name: str, default: float = 1.0):
|
| 617 |
+
if isinstance(val, dict):
|
| 618 |
+
if "x" not in val and "y" not in val:
|
| 619 |
+
raise ValueError(
|
| 620 |
+
f'Expected {name} dictionary to contain at least key "x" or ' 'key "y". Found neither of them.'
|
| 621 |
+
)
|
| 622 |
+
x = val.get("x", default)
|
| 623 |
+
y = val.get("y", default)
|
| 624 |
+
return {"x": to_tuple(x, x), "y": to_tuple(y, y)}
|
| 625 |
+
return {"x": to_tuple(val, val), "y": to_tuple(val, val)}
|
| 626 |
+
|
| 627 |
+
@classmethod
|
| 628 |
+
def _handle_translate_arg(
|
| 629 |
+
cls,
|
| 630 |
+
translate_px: Optional[Union[float, Sequence[float], dict]],
|
| 631 |
+
translate_percent: Optional[Union[float, Sequence[float], dict]],
|
| 632 |
+
):
|
| 633 |
+
if translate_percent is None and translate_px is None:
|
| 634 |
+
translate_px = 0
|
| 635 |
+
|
| 636 |
+
if translate_percent is not None and translate_px is not None:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
"Expected either translate_percent or translate_px to be " "provided, " "but neither of them was."
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
if translate_percent is not None:
|
| 642 |
+
# translate by percent
|
| 643 |
+
return cls._handle_dict_arg(translate_percent, "translate_percent", default=0.0), translate_px
|
| 644 |
+
|
| 645 |
+
if translate_px is None:
|
| 646 |
+
raise ValueError("translate_px is None.")
|
| 647 |
+
# translate by pixels
|
| 648 |
+
return translate_percent, cls._handle_dict_arg(translate_px, "translate_px")
|
| 649 |
+
|
| 650 |
+
def apply(
|
| 651 |
+
self,
|
| 652 |
+
img: np.ndarray,
|
| 653 |
+
matrix: skimage.transform.ProjectiveTransform = None,
|
| 654 |
+
output_shape: Sequence[int] = (),
|
| 655 |
+
**params
|
| 656 |
+
) -> np.ndarray:
|
| 657 |
+
return F.warp_affine(
|
| 658 |
+
img,
|
| 659 |
+
matrix,
|
| 660 |
+
interpolation=self.interpolation,
|
| 661 |
+
cval=self.cval,
|
| 662 |
+
mode=self.mode,
|
| 663 |
+
output_shape=output_shape,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
def apply_to_mask(
|
| 667 |
+
self,
|
| 668 |
+
img: np.ndarray,
|
| 669 |
+
matrix: skimage.transform.ProjectiveTransform = None,
|
| 670 |
+
output_shape: Sequence[int] = (),
|
| 671 |
+
**params
|
| 672 |
+
) -> np.ndarray:
|
| 673 |
+
return F.warp_affine(
|
| 674 |
+
img,
|
| 675 |
+
matrix,
|
| 676 |
+
interpolation=self.mask_interpolation,
|
| 677 |
+
cval=self.cval_mask,
|
| 678 |
+
mode=self.mode,
|
| 679 |
+
output_shape=output_shape,
|
| 680 |
+
)
|
| 681 |
+
|
| 682 |
+
def apply_to_bbox(
|
| 683 |
+
self,
|
| 684 |
+
bbox: BoxInternalType,
|
| 685 |
+
matrix: skimage.transform.ProjectiveTransform = None,
|
| 686 |
+
rows: int = 0,
|
| 687 |
+
cols: int = 0,
|
| 688 |
+
output_shape: Sequence[int] = (),
|
| 689 |
+
**params
|
| 690 |
+
) -> BoxInternalType:
|
| 691 |
+
return F.bbox_affine(bbox, matrix, self.rotate_method, rows, cols, output_shape)
|
| 692 |
+
|
| 693 |
+
def apply_to_keypoint(
|
| 694 |
+
self,
|
| 695 |
+
keypoint: KeypointInternalType,
|
| 696 |
+
matrix: Optional[skimage.transform.ProjectiveTransform] = None,
|
| 697 |
+
scale: Optional[dict] = None,
|
| 698 |
+
**params
|
| 699 |
+
) -> KeypointInternalType:
|
| 700 |
+
assert scale is not None and matrix is not None
|
| 701 |
+
return F.keypoint_affine(keypoint, matrix=matrix, scale=scale)
|
| 702 |
+
|
| 703 |
+
@property
|
| 704 |
+
def targets_as_params(self):
|
| 705 |
+
return ["image"]
|
| 706 |
+
|
| 707 |
+
def get_params_dependent_on_targets(self, params: dict) -> dict:
|
| 708 |
+
h, w = params["image"].shape[:2]
|
| 709 |
+
|
| 710 |
+
translate: Dict[str, Union[int, float]]
|
| 711 |
+
if self.translate_px is not None:
|
| 712 |
+
translate = {key: random.randint(*value) for key, value in self.translate_px.items()}
|
| 713 |
+
elif self.translate_percent is not None:
|
| 714 |
+
translate = {key: random.uniform(*value) for key, value in self.translate_percent.items()}
|
| 715 |
+
translate["x"] = translate["x"] * w
|
| 716 |
+
translate["y"] = translate["y"] * h
|
| 717 |
+
else:
|
| 718 |
+
translate = {"x": 0, "y": 0}
|
| 719 |
+
|
| 720 |
+
# Look to issue https://github.com/albumentations-team/albumentations/issues/1079
|
| 721 |
+
shear = {key: -random.uniform(*value) for key, value in self.shear.items()}
|
| 722 |
+
scale = {key: random.uniform(*value) for key, value in self.scale.items()}
|
| 723 |
+
if self.keep_ratio:
|
| 724 |
+
scale["y"] = scale["x"]
|
| 725 |
+
|
| 726 |
+
# Look to issue https://github.com/albumentations-team/albumentations/issues/1079
|
| 727 |
+
rotate = -random.uniform(*self.rotate)
|
| 728 |
+
|
| 729 |
+
# for images we use additional shifts of (0.5, 0.5) as otherwise
|
| 730 |
+
# we get an ugly black border for 90deg rotations
|
| 731 |
+
shift_x = w / 2 - 0.5
|
| 732 |
+
shift_y = h / 2 - 0.5
|
| 733 |
+
|
| 734 |
+
matrix_to_topleft = skimage.transform.SimilarityTransform(translation=[-shift_x, -shift_y])
|
| 735 |
+
matrix_shear_y_rot = skimage.transform.AffineTransform(rotation=-np.pi / 2)
|
| 736 |
+
matrix_shear_y = skimage.transform.AffineTransform(shear=np.deg2rad(shear["y"]))
|
| 737 |
+
matrix_shear_y_rot_inv = skimage.transform.AffineTransform(rotation=np.pi / 2)
|
| 738 |
+
matrix_transforms = skimage.transform.AffineTransform(
|
| 739 |
+
scale=(scale["x"], scale["y"]),
|
| 740 |
+
translation=(translate["x"], translate["y"]),
|
| 741 |
+
rotation=np.deg2rad(rotate),
|
| 742 |
+
shear=np.deg2rad(shear["x"]),
|
| 743 |
+
)
|
| 744 |
+
matrix_to_center = skimage.transform.SimilarityTransform(translation=[shift_x, shift_y])
|
| 745 |
+
matrix = (
|
| 746 |
+
matrix_to_topleft
|
| 747 |
+
+ matrix_shear_y_rot
|
| 748 |
+
+ matrix_shear_y
|
| 749 |
+
+ matrix_shear_y_rot_inv
|
| 750 |
+
+ matrix_transforms
|
| 751 |
+
+ matrix_to_center
|
| 752 |
+
)
|
| 753 |
+
if self.fit_output:
|
| 754 |
+
matrix, output_shape = self._compute_affine_warp_output_shape(matrix, params["image"].shape)
|
| 755 |
+
else:
|
| 756 |
+
output_shape = params["image"].shape
|
| 757 |
+
|
| 758 |
+
return {
|
| 759 |
+
"rotate": rotate,
|
| 760 |
+
"scale": scale,
|
| 761 |
+
"matrix": matrix,
|
| 762 |
+
"output_shape": output_shape,
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
@staticmethod
|
| 766 |
+
def _compute_affine_warp_output_shape(
|
| 767 |
+
matrix: skimage.transform.ProjectiveTransform, input_shape: Sequence[int]
|
| 768 |
+
) -> Tuple[skimage.transform.ProjectiveTransform, Sequence[int]]:
|
| 769 |
+
height, width = input_shape[:2]
|
| 770 |
+
|
| 771 |
+
if height == 0 or width == 0:
|
| 772 |
+
return matrix, input_shape
|
| 773 |
+
|
| 774 |
+
# determine shape of output image
|
| 775 |
+
corners = np.array([[0, 0], [0, height - 1], [width - 1, height - 1], [width - 1, 0]])
|
| 776 |
+
corners = matrix(corners)
|
| 777 |
+
minc = corners[:, 0].min()
|
| 778 |
+
minr = corners[:, 1].min()
|
| 779 |
+
maxc = corners[:, 0].max()
|
| 780 |
+
maxr = corners[:, 1].max()
|
| 781 |
+
out_height = maxr - minr + 1
|
| 782 |
+
out_width = maxc - minc + 1
|
| 783 |
+
if len(input_shape) == 3:
|
| 784 |
+
output_shape = np.ceil((out_height, out_width, input_shape[2]))
|
| 785 |
+
else:
|
| 786 |
+
output_shape = np.ceil((out_height, out_width))
|
| 787 |
+
output_shape_tuple = tuple([int(v) for v in output_shape.tolist()])
|
| 788 |
+
# fit output image in new shape
|
| 789 |
+
translation = (-minc, -minr)
|
| 790 |
+
matrix_to_fit = skimage.transform.SimilarityTransform(translation=translation)
|
| 791 |
+
matrix = matrix + matrix_to_fit
|
| 792 |
+
return matrix, output_shape_tuple
|
| 793 |
+
|
| 794 |
+
|
| 795 |
+
class PiecewiseAffine(DualTransform):
|
| 796 |
+
"""Apply affine transformations that differ between local neighbourhoods.
|
| 797 |
+
This augmentation places a regular grid of points on an image and randomly moves the neighbourhood of these point
|
| 798 |
+
around via affine transformations. This leads to local distortions.
|
| 799 |
+
|
| 800 |
+
This is mostly a wrapper around scikit-image's ``PiecewiseAffine``.
|
| 801 |
+
See also ``Affine`` for a similar technique.
|
| 802 |
+
|
| 803 |
+
Note:
|
| 804 |
+
This augmenter is very slow. Try to use ``ElasticTransformation`` instead, which is at least 10x faster.
|
| 805 |
+
|
| 806 |
+
Note:
|
| 807 |
+
For coordinate-based inputs (keypoints, bounding boxes, polygons, ...),
|
| 808 |
+
this augmenter still has to perform an image-based augmentation,
|
| 809 |
+
which will make it significantly slower and not fully correct for such inputs than other transforms.
|
| 810 |
+
|
| 811 |
+
Args:
|
| 812 |
+
scale (float, tuple of float): Each point on the regular grid is moved around via a normal distribution.
|
| 813 |
+
This scale factor is equivalent to the normal distribution's sigma.
|
| 814 |
+
Note that the jitter (how far each point is moved in which direction) is multiplied by the height/width of
|
| 815 |
+
the image if ``absolute_scale=False`` (default), so this scale can be the same for different sized images.
|
| 816 |
+
Recommended values are in the range ``0.01`` to ``0.05`` (weak to strong augmentations).
|
| 817 |
+
* If a single ``float``, then that value will always be used as the scale.
|
| 818 |
+
* If a tuple ``(a, b)`` of ``float`` s, then a random value will
|
| 819 |
+
be uniformly sampled per image from the interval ``[a, b]``.
|
| 820 |
+
nb_rows (int, tuple of int): Number of rows of points that the regular grid should have.
|
| 821 |
+
Must be at least ``2``. For large images, you might want to pick a higher value than ``4``.
|
| 822 |
+
You might have to then adjust scale to lower values.
|
| 823 |
+
* If a single ``int``, then that value will always be used as the number of rows.
|
| 824 |
+
* If a tuple ``(a, b)``, then a value from the discrete interval
|
| 825 |
+
``[a..b]`` will be uniformly sampled per image.
|
| 826 |
+
nb_cols (int, tuple of int): Number of columns. Analogous to `nb_rows`.
|
| 827 |
+
interpolation (int): The order of interpolation. The order has to be in the range 0-5:
|
| 828 |
+
- 0: Nearest-neighbor
|
| 829 |
+
- 1: Bi-linear (default)
|
| 830 |
+
- 2: Bi-quadratic
|
| 831 |
+
- 3: Bi-cubic
|
| 832 |
+
- 4: Bi-quartic
|
| 833 |
+
- 5: Bi-quintic
|
| 834 |
+
mask_interpolation (int): same as interpolation but for mask.
|
| 835 |
+
cval (number): The constant value to use when filling in newly created pixels.
|
| 836 |
+
cval_mask (number): Same as cval but only for masks.
|
| 837 |
+
mode (str): {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}, optional
|
| 838 |
+
Points outside the boundaries of the input are filled according
|
| 839 |
+
to the given mode. Modes match the behaviour of `numpy.pad`.
|
| 840 |
+
absolute_scale (bool): Take `scale` as an absolute value rather than a relative value.
|
| 841 |
+
keypoints_threshold (float): Used as threshold in conversion from distance maps to keypoints.
|
| 842 |
+
The search for keypoints works by searching for the
|
| 843 |
+
argmin (non-inverted) or argmax (inverted) in each channel. This
|
| 844 |
+
parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit
|
| 845 |
+
as a keypoint. Use ``None`` to use no min/max. Default: 0.01
|
| 846 |
+
|
| 847 |
+
Targets:
|
| 848 |
+
image, mask, keypoints, bboxes
|
| 849 |
+
|
| 850 |
+
Image types:
|
| 851 |
+
uint8, float32
|
| 852 |
+
|
| 853 |
+
"""
|
| 854 |
+
|
| 855 |
+
def __init__(
|
| 856 |
+
self,
|
| 857 |
+
scale: ScaleFloatType = (0.03, 0.05),
|
| 858 |
+
nb_rows: Union[int, Sequence[int]] = 4,
|
| 859 |
+
nb_cols: Union[int, Sequence[int]] = 4,
|
| 860 |
+
interpolation: int = 1,
|
| 861 |
+
mask_interpolation: int = 0,
|
| 862 |
+
cval: int = 0,
|
| 863 |
+
cval_mask: int = 0,
|
| 864 |
+
mode: str = "constant",
|
| 865 |
+
absolute_scale: bool = False,
|
| 866 |
+
always_apply: bool = False,
|
| 867 |
+
keypoints_threshold: float = 0.01,
|
| 868 |
+
p: float = 0.5,
|
| 869 |
+
):
|
| 870 |
+
super(PiecewiseAffine, self).__init__(always_apply, p)
|
| 871 |
+
|
| 872 |
+
self.scale = to_tuple(scale, scale)
|
| 873 |
+
self.nb_rows = to_tuple(nb_rows, nb_rows)
|
| 874 |
+
self.nb_cols = to_tuple(nb_cols, nb_cols)
|
| 875 |
+
self.interpolation = interpolation
|
| 876 |
+
self.mask_interpolation = mask_interpolation
|
| 877 |
+
self.cval = cval
|
| 878 |
+
self.cval_mask = cval_mask
|
| 879 |
+
self.mode = mode
|
| 880 |
+
self.absolute_scale = absolute_scale
|
| 881 |
+
self.keypoints_threshold = keypoints_threshold
|
| 882 |
+
|
| 883 |
+
def get_transform_init_args_names(self):
|
| 884 |
+
return (
|
| 885 |
+
"scale",
|
| 886 |
+
"nb_rows",
|
| 887 |
+
"nb_cols",
|
| 888 |
+
"interpolation",
|
| 889 |
+
"mask_interpolation",
|
| 890 |
+
"cval",
|
| 891 |
+
"cval_mask",
|
| 892 |
+
"mode",
|
| 893 |
+
"absolute_scale",
|
| 894 |
+
"keypoints_threshold",
|
| 895 |
+
)
|
| 896 |
+
|
| 897 |
+
@property
|
| 898 |
+
def targets_as_params(self):
|
| 899 |
+
return ["image"]
|
| 900 |
+
|
| 901 |
+
def get_params_dependent_on_targets(self, params) -> dict:
|
| 902 |
+
h, w = params["image"].shape[:2]
|
| 903 |
+
|
| 904 |
+
nb_rows = np.clip(random.randint(*self.nb_rows), 2, None)
|
| 905 |
+
nb_cols = np.clip(random.randint(*self.nb_cols), 2, None)
|
| 906 |
+
nb_cells = nb_cols * nb_rows
|
| 907 |
+
scale = random.uniform(*self.scale)
|
| 908 |
+
|
| 909 |
+
jitter: np.ndarray = random_utils.normal(0, scale, (nb_cells, 2))
|
| 910 |
+
if not np.any(jitter > 0):
|
| 911 |
+
for i in range(10): # See: https://github.com/albumentations-team/albumentations/issues/1442
|
| 912 |
+
jitter = random_utils.normal(0, scale, (nb_cells, 2))
|
| 913 |
+
if np.any(jitter > 0):
|
| 914 |
+
break
|
| 915 |
+
if not np.any(jitter > 0):
|
| 916 |
+
return {"matrix": None}
|
| 917 |
+
|
| 918 |
+
y = np.linspace(0, h, nb_rows)
|
| 919 |
+
x = np.linspace(0, w, nb_cols)
|
| 920 |
+
|
| 921 |
+
# (H, W) and (H, W) for H=rows, W=cols
|
| 922 |
+
xx_src, yy_src = np.meshgrid(x, y)
|
| 923 |
+
|
| 924 |
+
# (1, HW, 2) => (HW, 2) for H=rows, W=cols
|
| 925 |
+
points_src = np.dstack([yy_src.flat, xx_src.flat])[0]
|
| 926 |
+
|
| 927 |
+
if self.absolute_scale:
|
| 928 |
+
jitter[:, 0] = jitter[:, 0] / h if h > 0 else 0.0
|
| 929 |
+
jitter[:, 1] = jitter[:, 1] / w if w > 0 else 0.0
|
| 930 |
+
|
| 931 |
+
jitter[:, 0] = jitter[:, 0] * h
|
| 932 |
+
jitter[:, 1] = jitter[:, 1] * w
|
| 933 |
+
|
| 934 |
+
points_dest = np.copy(points_src)
|
| 935 |
+
points_dest[:, 0] = points_dest[:, 0] + jitter[:, 0]
|
| 936 |
+
points_dest[:, 1] = points_dest[:, 1] + jitter[:, 1]
|
| 937 |
+
|
| 938 |
+
# Restrict all destination points to be inside the image plane.
|
| 939 |
+
# This is necessary, as otherwise keypoints could be augmented
|
| 940 |
+
# outside of the image plane and these would be replaced by
|
| 941 |
+
# (-1, -1), which would not conform with the behaviour of the other augmenters.
|
| 942 |
+
points_dest[:, 0] = np.clip(points_dest[:, 0], 0, h - 1)
|
| 943 |
+
points_dest[:, 1] = np.clip(points_dest[:, 1], 0, w - 1)
|
| 944 |
+
|
| 945 |
+
matrix = skimage.transform.PiecewiseAffineTransform()
|
| 946 |
+
matrix.estimate(points_src[:, ::-1], points_dest[:, ::-1])
|
| 947 |
+
|
| 948 |
+
return {
|
| 949 |
+
"matrix": matrix,
|
| 950 |
+
}
|
| 951 |
+
|
| 952 |
+
def apply(
|
| 953 |
+
self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params
|
| 954 |
+
) -> np.ndarray:
|
| 955 |
+
return F.piecewise_affine(img, matrix, self.interpolation, self.mode, self.cval)
|
| 956 |
+
|
| 957 |
+
def apply_to_mask(
|
| 958 |
+
self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params
|
| 959 |
+
) -> np.ndarray:
|
| 960 |
+
return F.piecewise_affine(img, matrix, self.mask_interpolation, self.mode, self.cval_mask)
|
| 961 |
+
|
| 962 |
+
def apply_to_bbox(
|
| 963 |
+
self,
|
| 964 |
+
bbox: BoxInternalType,
|
| 965 |
+
rows: int = 0,
|
| 966 |
+
cols: int = 0,
|
| 967 |
+
matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None,
|
| 968 |
+
**params
|
| 969 |
+
) -> BoxInternalType:
|
| 970 |
+
return F.bbox_piecewise_affine(bbox, matrix, rows, cols, self.keypoints_threshold)
|
| 971 |
+
|
| 972 |
+
def apply_to_keypoint(
|
| 973 |
+
self,
|
| 974 |
+
keypoint: KeypointInternalType,
|
| 975 |
+
rows: int = 0,
|
| 976 |
+
cols: int = 0,
|
| 977 |
+
matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None,
|
| 978 |
+
**params
|
| 979 |
+
):
|
| 980 |
+
return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols, self.keypoints_threshold)
|
| 981 |
+
|
| 982 |
+
|
| 983 |
+
class PadIfNeeded(DualTransform):
|
| 984 |
+
"""Pad side of the image / max if side is less than desired number.
|
| 985 |
+
|
| 986 |
+
Args:
|
| 987 |
+
min_height (int): minimal result image height.
|
| 988 |
+
min_width (int): minimal result image width.
|
| 989 |
+
pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument.
|
| 990 |
+
pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument.
|
| 991 |
+
position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or
|
| 992 |
+
PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT.
|
| 993 |
+
or PositionType.RANDOM. Default: PositionType.CENTER.
|
| 994 |
+
border_mode (OpenCV flag): OpenCV border mode.
|
| 995 |
+
value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 996 |
+
mask_value (int, float,
|
| 997 |
+
list of int,
|
| 998 |
+
list of float): padding value for mask if border_mode is cv2.BORDER_CONSTANT.
|
| 999 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 1000 |
+
|
| 1001 |
+
Targets:
|
| 1002 |
+
image, mask, bbox, keypoints
|
| 1003 |
+
|
| 1004 |
+
Image types:
|
| 1005 |
+
uint8, float32
|
| 1006 |
+
"""
|
| 1007 |
+
|
| 1008 |
+
class PositionType(Enum):
|
| 1009 |
+
CENTER = "center"
|
| 1010 |
+
TOP_LEFT = "top_left"
|
| 1011 |
+
TOP_RIGHT = "top_right"
|
| 1012 |
+
BOTTOM_LEFT = "bottom_left"
|
| 1013 |
+
BOTTOM_RIGHT = "bottom_right"
|
| 1014 |
+
RANDOM = "random"
|
| 1015 |
+
|
| 1016 |
+
def __init__(
|
| 1017 |
+
self,
|
| 1018 |
+
min_height: Optional[int] = 1024,
|
| 1019 |
+
min_width: Optional[int] = 1024,
|
| 1020 |
+
pad_height_divisor: Optional[int] = None,
|
| 1021 |
+
pad_width_divisor: Optional[int] = None,
|
| 1022 |
+
position: Union[PositionType, str] = PositionType.CENTER,
|
| 1023 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1024 |
+
value: Optional[ImageColorType] = None,
|
| 1025 |
+
mask_value: Optional[ImageColorType] = None,
|
| 1026 |
+
always_apply: bool = False,
|
| 1027 |
+
p: float = 1.0,
|
| 1028 |
+
):
|
| 1029 |
+
if (min_height is None) == (pad_height_divisor is None):
|
| 1030 |
+
raise ValueError("Only one of 'min_height' and 'pad_height_divisor' parameters must be set")
|
| 1031 |
+
|
| 1032 |
+
if (min_width is None) == (pad_width_divisor is None):
|
| 1033 |
+
raise ValueError("Only one of 'min_width' and 'pad_width_divisor' parameters must be set")
|
| 1034 |
+
|
| 1035 |
+
super(PadIfNeeded, self).__init__(always_apply, p)
|
| 1036 |
+
self.min_height = min_height
|
| 1037 |
+
self.min_width = min_width
|
| 1038 |
+
self.pad_width_divisor = pad_width_divisor
|
| 1039 |
+
self.pad_height_divisor = pad_height_divisor
|
| 1040 |
+
self.position = PadIfNeeded.PositionType(position)
|
| 1041 |
+
self.border_mode = border_mode
|
| 1042 |
+
self.value = value
|
| 1043 |
+
self.mask_value = mask_value
|
| 1044 |
+
|
| 1045 |
+
def update_params(self, params, **kwargs):
|
| 1046 |
+
params = super(PadIfNeeded, self).update_params(params, **kwargs)
|
| 1047 |
+
rows = params["rows"]
|
| 1048 |
+
cols = params["cols"]
|
| 1049 |
+
|
| 1050 |
+
if self.min_height is not None:
|
| 1051 |
+
if rows < self.min_height:
|
| 1052 |
+
h_pad_top = int((self.min_height - rows) / 2.0)
|
| 1053 |
+
h_pad_bottom = self.min_height - rows - h_pad_top
|
| 1054 |
+
else:
|
| 1055 |
+
h_pad_top = 0
|
| 1056 |
+
h_pad_bottom = 0
|
| 1057 |
+
else:
|
| 1058 |
+
pad_remained = rows % self.pad_height_divisor
|
| 1059 |
+
pad_rows = self.pad_height_divisor - pad_remained if pad_remained > 0 else 0
|
| 1060 |
+
|
| 1061 |
+
h_pad_top = pad_rows // 2
|
| 1062 |
+
h_pad_bottom = pad_rows - h_pad_top
|
| 1063 |
+
|
| 1064 |
+
if self.min_width is not None:
|
| 1065 |
+
if cols < self.min_width:
|
| 1066 |
+
w_pad_left = int((self.min_width - cols) / 2.0)
|
| 1067 |
+
w_pad_right = self.min_width - cols - w_pad_left
|
| 1068 |
+
else:
|
| 1069 |
+
w_pad_left = 0
|
| 1070 |
+
w_pad_right = 0
|
| 1071 |
+
else:
|
| 1072 |
+
pad_remainder = cols % self.pad_width_divisor
|
| 1073 |
+
pad_cols = self.pad_width_divisor - pad_remainder if pad_remainder > 0 else 0
|
| 1074 |
+
|
| 1075 |
+
w_pad_left = pad_cols // 2
|
| 1076 |
+
w_pad_right = pad_cols - w_pad_left
|
| 1077 |
+
|
| 1078 |
+
h_pad_top, h_pad_bottom, w_pad_left, w_pad_right = self.__update_position_params(
|
| 1079 |
+
h_top=h_pad_top, h_bottom=h_pad_bottom, w_left=w_pad_left, w_right=w_pad_right
|
| 1080 |
+
)
|
| 1081 |
+
|
| 1082 |
+
params.update(
|
| 1083 |
+
{
|
| 1084 |
+
"pad_top": h_pad_top,
|
| 1085 |
+
"pad_bottom": h_pad_bottom,
|
| 1086 |
+
"pad_left": w_pad_left,
|
| 1087 |
+
"pad_right": w_pad_right,
|
| 1088 |
+
}
|
| 1089 |
+
)
|
| 1090 |
+
return params
|
| 1091 |
+
|
| 1092 |
+
def apply(
|
| 1093 |
+
self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params
|
| 1094 |
+
) -> np.ndarray:
|
| 1095 |
+
return F.pad_with_params(
|
| 1096 |
+
img,
|
| 1097 |
+
pad_top,
|
| 1098 |
+
pad_bottom,
|
| 1099 |
+
pad_left,
|
| 1100 |
+
pad_right,
|
| 1101 |
+
border_mode=self.border_mode,
|
| 1102 |
+
value=self.value,
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
def apply_to_mask(
|
| 1106 |
+
self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params
|
| 1107 |
+
) -> np.ndarray:
|
| 1108 |
+
return F.pad_with_params(
|
| 1109 |
+
img,
|
| 1110 |
+
pad_top,
|
| 1111 |
+
pad_bottom,
|
| 1112 |
+
pad_left,
|
| 1113 |
+
pad_right,
|
| 1114 |
+
border_mode=self.border_mode,
|
| 1115 |
+
value=self.mask_value,
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
def apply_to_bbox(
|
| 1119 |
+
self,
|
| 1120 |
+
bbox: BoxInternalType,
|
| 1121 |
+
pad_top: int = 0,
|
| 1122 |
+
pad_bottom: int = 0,
|
| 1123 |
+
pad_left: int = 0,
|
| 1124 |
+
pad_right: int = 0,
|
| 1125 |
+
rows: int = 0,
|
| 1126 |
+
cols: int = 0,
|
| 1127 |
+
**params
|
| 1128 |
+
) -> BoxInternalType:
|
| 1129 |
+
x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
|
| 1130 |
+
bbox = x_min + pad_left, y_min + pad_top, x_max + pad_left, y_max + pad_top
|
| 1131 |
+
return normalize_bbox(bbox, rows + pad_top + pad_bottom, cols + pad_left + pad_right)
|
| 1132 |
+
|
| 1133 |
+
def apply_to_keypoint(
|
| 1134 |
+
self,
|
| 1135 |
+
keypoint: KeypointInternalType,
|
| 1136 |
+
pad_top: int = 0,
|
| 1137 |
+
pad_bottom: int = 0,
|
| 1138 |
+
pad_left: int = 0,
|
| 1139 |
+
pad_right: int = 0,
|
| 1140 |
+
**params
|
| 1141 |
+
) -> KeypointInternalType:
|
| 1142 |
+
x, y, angle, scale = keypoint[:4]
|
| 1143 |
+
return x + pad_left, y + pad_top, angle, scale
|
| 1144 |
+
|
| 1145 |
+
def get_transform_init_args_names(self):
|
| 1146 |
+
return (
|
| 1147 |
+
"min_height",
|
| 1148 |
+
"min_width",
|
| 1149 |
+
"pad_height_divisor",
|
| 1150 |
+
"pad_width_divisor",
|
| 1151 |
+
"border_mode",
|
| 1152 |
+
"value",
|
| 1153 |
+
"mask_value",
|
| 1154 |
+
)
|
| 1155 |
+
|
| 1156 |
+
def __update_position_params(
|
| 1157 |
+
self, h_top: int, h_bottom: int, w_left: int, w_right: int
|
| 1158 |
+
) -> Tuple[int, int, int, int]:
|
| 1159 |
+
if self.position == PadIfNeeded.PositionType.TOP_LEFT:
|
| 1160 |
+
h_bottom += h_top
|
| 1161 |
+
w_right += w_left
|
| 1162 |
+
h_top = 0
|
| 1163 |
+
w_left = 0
|
| 1164 |
+
|
| 1165 |
+
elif self.position == PadIfNeeded.PositionType.TOP_RIGHT:
|
| 1166 |
+
h_bottom += h_top
|
| 1167 |
+
w_left += w_right
|
| 1168 |
+
h_top = 0
|
| 1169 |
+
w_right = 0
|
| 1170 |
+
|
| 1171 |
+
elif self.position == PadIfNeeded.PositionType.BOTTOM_LEFT:
|
| 1172 |
+
h_top += h_bottom
|
| 1173 |
+
w_right += w_left
|
| 1174 |
+
h_bottom = 0
|
| 1175 |
+
w_left = 0
|
| 1176 |
+
|
| 1177 |
+
elif self.position == PadIfNeeded.PositionType.BOTTOM_RIGHT:
|
| 1178 |
+
h_top += h_bottom
|
| 1179 |
+
w_left += w_right
|
| 1180 |
+
h_bottom = 0
|
| 1181 |
+
w_right = 0
|
| 1182 |
+
|
| 1183 |
+
elif self.position == PadIfNeeded.PositionType.RANDOM:
|
| 1184 |
+
h_pad = h_top + h_bottom
|
| 1185 |
+
w_pad = w_left + w_right
|
| 1186 |
+
h_top = random.randint(0, h_pad)
|
| 1187 |
+
h_bottom = h_pad - h_top
|
| 1188 |
+
w_left = random.randint(0, w_pad)
|
| 1189 |
+
w_right = w_pad - w_left
|
| 1190 |
+
|
| 1191 |
+
return h_top, h_bottom, w_left, w_right
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
class VerticalFlip(DualTransform):
|
| 1195 |
+
"""Flip the input vertically around the x-axis.
|
| 1196 |
+
|
| 1197 |
+
Args:
|
| 1198 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1199 |
+
|
| 1200 |
+
Targets:
|
| 1201 |
+
image, mask, bboxes, keypoints
|
| 1202 |
+
|
| 1203 |
+
Image types:
|
| 1204 |
+
uint8, float32
|
| 1205 |
+
"""
|
| 1206 |
+
|
| 1207 |
+
def apply(self, img: np.ndarray, **params) -> np.ndarray:
|
| 1208 |
+
return F.vflip(img)
|
| 1209 |
+
|
| 1210 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 1211 |
+
return F.bbox_vflip(bbox, **params)
|
| 1212 |
+
|
| 1213 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 1214 |
+
return F.keypoint_vflip(keypoint, **params)
|
| 1215 |
+
|
| 1216 |
+
def get_transform_init_args_names(self):
|
| 1217 |
+
return ()
|
| 1218 |
+
|
| 1219 |
+
|
| 1220 |
+
class HorizontalFlip(DualTransform):
|
| 1221 |
+
"""Flip the input horizontally around the y-axis.
|
| 1222 |
+
|
| 1223 |
+
Args:
|
| 1224 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1225 |
+
|
| 1226 |
+
Targets:
|
| 1227 |
+
image, mask, bboxes, keypoints
|
| 1228 |
+
|
| 1229 |
+
Image types:
|
| 1230 |
+
uint8, float32
|
| 1231 |
+
"""
|
| 1232 |
+
|
| 1233 |
+
def apply(self, img: np.ndarray, **params) -> np.ndarray:
|
| 1234 |
+
if img.ndim == 3 and img.shape[2] > 1 and img.dtype == np.uint8:
|
| 1235 |
+
# Opencv is faster than numpy only in case of
|
| 1236 |
+
# non-gray scale 8bits images
|
| 1237 |
+
return F.hflip_cv2(img)
|
| 1238 |
+
|
| 1239 |
+
return F.hflip(img)
|
| 1240 |
+
|
| 1241 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 1242 |
+
return F.bbox_hflip(bbox, **params)
|
| 1243 |
+
|
| 1244 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 1245 |
+
return F.keypoint_hflip(keypoint, **params)
|
| 1246 |
+
|
| 1247 |
+
def get_transform_init_args_names(self):
|
| 1248 |
+
return ()
|
| 1249 |
+
|
| 1250 |
+
|
| 1251 |
+
class Flip(DualTransform):
|
| 1252 |
+
"""Flip the input either horizontally, vertically or both horizontally and vertically.
|
| 1253 |
+
|
| 1254 |
+
Args:
|
| 1255 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1256 |
+
|
| 1257 |
+
Targets:
|
| 1258 |
+
image, mask, bboxes, keypoints
|
| 1259 |
+
|
| 1260 |
+
Image types:
|
| 1261 |
+
uint8, float32
|
| 1262 |
+
"""
|
| 1263 |
+
|
| 1264 |
+
def apply(self, img: np.ndarray, d: int = 0, **params) -> np.ndarray:
|
| 1265 |
+
"""Args:
|
| 1266 |
+
d (int): code that specifies how to flip the input. 0 for vertical flipping, 1 for horizontal flipping,
|
| 1267 |
+
-1 for both vertical and horizontal flipping (which is also could be seen as rotating the input by
|
| 1268 |
+
180 degrees).
|
| 1269 |
+
"""
|
| 1270 |
+
return F.random_flip(img, d)
|
| 1271 |
+
|
| 1272 |
+
def get_params(self):
|
| 1273 |
+
# Random int in the range [-1, 1]
|
| 1274 |
+
return {"d": random.randint(-1, 1)}
|
| 1275 |
+
|
| 1276 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 1277 |
+
return F.bbox_flip(bbox, **params)
|
| 1278 |
+
|
| 1279 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 1280 |
+
return F.keypoint_flip(keypoint, **params)
|
| 1281 |
+
|
| 1282 |
+
def get_transform_init_args_names(self):
|
| 1283 |
+
return ()
|
| 1284 |
+
|
| 1285 |
+
|
| 1286 |
+
class Transpose(DualTransform):
|
| 1287 |
+
"""Transpose the input by swapping rows and columns.
|
| 1288 |
+
|
| 1289 |
+
Args:
|
| 1290 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1291 |
+
|
| 1292 |
+
Targets:
|
| 1293 |
+
image, mask, bboxes, keypoints
|
| 1294 |
+
|
| 1295 |
+
Image types:
|
| 1296 |
+
uint8, float32
|
| 1297 |
+
"""
|
| 1298 |
+
|
| 1299 |
+
def apply(self, img: np.ndarray, **params) -> np.ndarray:
|
| 1300 |
+
return F.transpose(img)
|
| 1301 |
+
|
| 1302 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 1303 |
+
return F.bbox_transpose(bbox, 0, **params)
|
| 1304 |
+
|
| 1305 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 1306 |
+
return F.keypoint_transpose(keypoint)
|
| 1307 |
+
|
| 1308 |
+
def get_transform_init_args_names(self):
|
| 1309 |
+
return ()
|
| 1310 |
+
|
| 1311 |
+
|
| 1312 |
+
class OpticalDistortion(DualTransform):
|
| 1313 |
+
"""
|
| 1314 |
+
Args:
|
| 1315 |
+
distort_limit (float, (float, float)): If distort_limit is a single float, the range
|
| 1316 |
+
will be (-distort_limit, distort_limit). Default: (-0.05, 0.05).
|
| 1317 |
+
shift_limit (float, (float, float))): If shift_limit is a single float, the range
|
| 1318 |
+
will be (-shift_limit, shift_limit). Default: (-0.05, 0.05).
|
| 1319 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 1320 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 1321 |
+
Default: cv2.INTER_LINEAR.
|
| 1322 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 1323 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 1324 |
+
Default: cv2.BORDER_REFLECT_101
|
| 1325 |
+
value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 1326 |
+
mask_value (int, float,
|
| 1327 |
+
list of ints,
|
| 1328 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 1329 |
+
|
| 1330 |
+
Targets:
|
| 1331 |
+
image, mask, bbox
|
| 1332 |
+
|
| 1333 |
+
Image types:
|
| 1334 |
+
uint8, float32
|
| 1335 |
+
"""
|
| 1336 |
+
|
| 1337 |
+
def __init__(
|
| 1338 |
+
self,
|
| 1339 |
+
distort_limit: ScaleFloatType = 0.05,
|
| 1340 |
+
shift_limit: ScaleFloatType = 0.05,
|
| 1341 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 1342 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1343 |
+
value: Optional[ImageColorType] = None,
|
| 1344 |
+
mask_value: Optional[ImageColorType] = None,
|
| 1345 |
+
always_apply: bool = False,
|
| 1346 |
+
p: float = 0.5,
|
| 1347 |
+
):
|
| 1348 |
+
super(OpticalDistortion, self).__init__(always_apply, p)
|
| 1349 |
+
self.shift_limit = to_tuple(shift_limit)
|
| 1350 |
+
self.distort_limit = to_tuple(distort_limit)
|
| 1351 |
+
self.interpolation = interpolation
|
| 1352 |
+
self.border_mode = border_mode
|
| 1353 |
+
self.value = value
|
| 1354 |
+
self.mask_value = mask_value
|
| 1355 |
+
|
| 1356 |
+
def apply(
|
| 1357 |
+
self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, interpolation: int = cv2.INTER_LINEAR, **params
|
| 1358 |
+
) -> np.ndarray:
|
| 1359 |
+
return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value)
|
| 1360 |
+
|
| 1361 |
+
def apply_to_mask(self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, **params) -> np.ndarray:
|
| 1362 |
+
return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
|
| 1363 |
+
|
| 1364 |
+
def apply_to_bbox(self, bbox: BoxInternalType, k: int = 0, dx: int = 0, dy: int = 0, **params) -> BoxInternalType:
|
| 1365 |
+
rows, cols = params["rows"], params["cols"]
|
| 1366 |
+
mask = np.zeros((rows, cols), dtype=np.uint8)
|
| 1367 |
+
bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
|
| 1368 |
+
x_min, y_min, x_max, y_max = bbox_denorm[:4]
|
| 1369 |
+
x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
|
| 1370 |
+
mask[y_min:y_max, x_min:x_max] = 1
|
| 1371 |
+
mask = F.optical_distortion(mask, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
|
| 1372 |
+
bbox_returned = bbox_from_mask(mask)
|
| 1373 |
+
bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
|
| 1374 |
+
return bbox_returned
|
| 1375 |
+
|
| 1376 |
+
def get_params(self):
|
| 1377 |
+
return {
|
| 1378 |
+
"k": random.uniform(self.distort_limit[0], self.distort_limit[1]),
|
| 1379 |
+
"dx": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
|
| 1380 |
+
"dy": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
|
| 1381 |
+
}
|
| 1382 |
+
|
| 1383 |
+
def get_transform_init_args_names(self):
|
| 1384 |
+
return (
|
| 1385 |
+
"distort_limit",
|
| 1386 |
+
"shift_limit",
|
| 1387 |
+
"interpolation",
|
| 1388 |
+
"border_mode",
|
| 1389 |
+
"value",
|
| 1390 |
+
"mask_value",
|
| 1391 |
+
)
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
class GridDistortion(DualTransform):
|
| 1395 |
+
"""
|
| 1396 |
+
Args:
|
| 1397 |
+
num_steps (int): count of grid cells on each side.
|
| 1398 |
+
distort_limit (float, (float, float)): If distort_limit is a single float, the range
|
| 1399 |
+
will be (-distort_limit, distort_limit). Default: (-0.03, 0.03).
|
| 1400 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 1401 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 1402 |
+
Default: cv2.INTER_LINEAR.
|
| 1403 |
+
border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
|
| 1404 |
+
cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
|
| 1405 |
+
Default: cv2.BORDER_REFLECT_101
|
| 1406 |
+
value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
|
| 1407 |
+
mask_value (int, float,
|
| 1408 |
+
list of ints,
|
| 1409 |
+
list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
|
| 1410 |
+
normalized (bool): if true, distortion will be normalized to do not go outside the image. Default: False
|
| 1411 |
+
See for more information: https://github.com/albumentations-team/albumentations/pull/722
|
| 1412 |
+
|
| 1413 |
+
Targets:
|
| 1414 |
+
image, mask
|
| 1415 |
+
|
| 1416 |
+
Image types:
|
| 1417 |
+
uint8, float32
|
| 1418 |
+
"""
|
| 1419 |
+
|
| 1420 |
+
def __init__(
|
| 1421 |
+
self,
|
| 1422 |
+
num_steps: int = 5,
|
| 1423 |
+
distort_limit: ScaleFloatType = 0.3,
|
| 1424 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 1425 |
+
border_mode: int = cv2.BORDER_REFLECT_101,
|
| 1426 |
+
value: Optional[ImageColorType] = None,
|
| 1427 |
+
mask_value: Optional[ImageColorType] = None,
|
| 1428 |
+
normalized: bool = False,
|
| 1429 |
+
always_apply: bool = False,
|
| 1430 |
+
p: float = 0.5,
|
| 1431 |
+
):
|
| 1432 |
+
super(GridDistortion, self).__init__(always_apply, p)
|
| 1433 |
+
self.num_steps = num_steps
|
| 1434 |
+
self.distort_limit = to_tuple(distort_limit)
|
| 1435 |
+
self.interpolation = interpolation
|
| 1436 |
+
self.border_mode = border_mode
|
| 1437 |
+
self.value = value
|
| 1438 |
+
self.mask_value = mask_value
|
| 1439 |
+
self.normalized = normalized
|
| 1440 |
+
|
| 1441 |
+
def apply(
|
| 1442 |
+
self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), interpolation: int = cv2.INTER_LINEAR, **params
|
| 1443 |
+
) -> np.ndarray:
|
| 1444 |
+
return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value)
|
| 1445 |
+
|
| 1446 |
+
def apply_to_mask(self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> np.ndarray:
|
| 1447 |
+
return F.grid_distortion(
|
| 1448 |
+
img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value
|
| 1449 |
+
)
|
| 1450 |
+
|
| 1451 |
+
def apply_to_bbox(self, bbox: BoxInternalType, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> BoxInternalType:
|
| 1452 |
+
rows, cols = params["rows"], params["cols"]
|
| 1453 |
+
mask = np.zeros((rows, cols), dtype=np.uint8)
|
| 1454 |
+
bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
|
| 1455 |
+
x_min, y_min, x_max, y_max = bbox_denorm[:4]
|
| 1456 |
+
x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
|
| 1457 |
+
mask[y_min:y_max, x_min:x_max] = 1
|
| 1458 |
+
mask = F.grid_distortion(
|
| 1459 |
+
mask, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value
|
| 1460 |
+
)
|
| 1461 |
+
bbox_returned = bbox_from_mask(mask)
|
| 1462 |
+
bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
|
| 1463 |
+
return bbox_returned
|
| 1464 |
+
|
| 1465 |
+
def _normalize(self, h, w, xsteps, ysteps):
|
| 1466 |
+
# compensate for smaller last steps in source image.
|
| 1467 |
+
x_step = w // self.num_steps
|
| 1468 |
+
last_x_step = min(w, ((self.num_steps + 1) * x_step)) - (self.num_steps * x_step)
|
| 1469 |
+
xsteps[-1] *= last_x_step / x_step
|
| 1470 |
+
|
| 1471 |
+
y_step = h // self.num_steps
|
| 1472 |
+
last_y_step = min(h, ((self.num_steps + 1) * y_step)) - (self.num_steps * y_step)
|
| 1473 |
+
ysteps[-1] *= last_y_step / y_step
|
| 1474 |
+
|
| 1475 |
+
# now normalize such that distortion never leaves image bounds.
|
| 1476 |
+
tx = w / math.floor(w / self.num_steps)
|
| 1477 |
+
ty = h / math.floor(h / self.num_steps)
|
| 1478 |
+
xsteps = np.array(xsteps) * (tx / np.sum(xsteps))
|
| 1479 |
+
ysteps = np.array(ysteps) * (ty / np.sum(ysteps))
|
| 1480 |
+
|
| 1481 |
+
return {"stepsx": xsteps, "stepsy": ysteps}
|
| 1482 |
+
|
| 1483 |
+
@property
|
| 1484 |
+
def targets_as_params(self):
|
| 1485 |
+
return ["image"]
|
| 1486 |
+
|
| 1487 |
+
def get_params_dependent_on_targets(self, params):
|
| 1488 |
+
h, w = params["image"].shape[:2]
|
| 1489 |
+
|
| 1490 |
+
stepsx = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)]
|
| 1491 |
+
stepsy = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)]
|
| 1492 |
+
|
| 1493 |
+
if self.normalized:
|
| 1494 |
+
return self._normalize(h, w, stepsx, stepsy)
|
| 1495 |
+
|
| 1496 |
+
return {"stepsx": stepsx, "stepsy": stepsy}
|
| 1497 |
+
|
| 1498 |
+
def get_transform_init_args_names(self):
|
| 1499 |
+
return "num_steps", "distort_limit", "interpolation", "border_mode", "value", "mask_value", "normalized"
|
custom_albumentations/augmentations/transforms.py
ADDED
|
@@ -0,0 +1,2667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import, division
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import numbers
|
| 5 |
+
import random
|
| 6 |
+
import warnings
|
| 7 |
+
from enum import IntEnum
|
| 8 |
+
from types import LambdaType
|
| 9 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import numpy as np
|
| 13 |
+
from scipy import special
|
| 14 |
+
from scipy.ndimage import gaussian_filter
|
| 15 |
+
|
| 16 |
+
from custom_albumentations import random_utils
|
| 17 |
+
from custom_albumentations.augmentations.blur.functional import blur
|
| 18 |
+
from custom_albumentations.augmentations.utils import (
|
| 19 |
+
get_num_channels,
|
| 20 |
+
is_grayscale_image,
|
| 21 |
+
is_rgb_image,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
from ..core.transforms_interface import (
|
| 25 |
+
DualTransform,
|
| 26 |
+
ImageOnlyTransform,
|
| 27 |
+
NoOp,
|
| 28 |
+
ScaleFloatType,
|
| 29 |
+
to_tuple,
|
| 30 |
+
)
|
| 31 |
+
from ..core.utils import format_args
|
| 32 |
+
from . import functional as F
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"Normalize",
|
| 36 |
+
"RandomGamma",
|
| 37 |
+
"RandomGridShuffle",
|
| 38 |
+
"HueSaturationValue",
|
| 39 |
+
"RGBShift",
|
| 40 |
+
"RandomBrightness",
|
| 41 |
+
"RandomContrast",
|
| 42 |
+
"GaussNoise",
|
| 43 |
+
"CLAHE",
|
| 44 |
+
"ChannelShuffle",
|
| 45 |
+
"InvertImg",
|
| 46 |
+
"ToGray",
|
| 47 |
+
"ToRGB",
|
| 48 |
+
"ToSepia",
|
| 49 |
+
"JpegCompression",
|
| 50 |
+
"ImageCompression",
|
| 51 |
+
"ToFloat",
|
| 52 |
+
"FromFloat",
|
| 53 |
+
"RandomBrightnessContrast",
|
| 54 |
+
"RandomSnow",
|
| 55 |
+
"RandomGravel",
|
| 56 |
+
"RandomRain",
|
| 57 |
+
"RandomFog",
|
| 58 |
+
"RandomSunFlare",
|
| 59 |
+
"RandomShadow",
|
| 60 |
+
"RandomToneCurve",
|
| 61 |
+
"Lambda",
|
| 62 |
+
"ISONoise",
|
| 63 |
+
"Solarize",
|
| 64 |
+
"Equalize",
|
| 65 |
+
"Posterize",
|
| 66 |
+
"Downscale",
|
| 67 |
+
"MultiplicativeNoise",
|
| 68 |
+
"FancyPCA",
|
| 69 |
+
"ColorJitter",
|
| 70 |
+
"Sharpen",
|
| 71 |
+
"Emboss",
|
| 72 |
+
"Superpixels",
|
| 73 |
+
"TemplateTransform",
|
| 74 |
+
"RingingOvershoot",
|
| 75 |
+
"UnsharpMask",
|
| 76 |
+
"PixelDropout",
|
| 77 |
+
"Spatter",
|
| 78 |
+
]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class RandomGridShuffle(DualTransform):
|
| 82 |
+
"""
|
| 83 |
+
Random shuffle grid's cells on image.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
grid ((int, int)): size of grid for splitting image.
|
| 87 |
+
|
| 88 |
+
Targets:
|
| 89 |
+
image, mask, keypoints
|
| 90 |
+
|
| 91 |
+
Image types:
|
| 92 |
+
uint8, float32
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
def __init__(self, grid: Tuple[int, int] = (3, 3), always_apply: bool = False, p: float = 0.5):
|
| 96 |
+
super(RandomGridShuffle, self).__init__(always_apply, p)
|
| 97 |
+
self.grid = grid
|
| 98 |
+
|
| 99 |
+
def apply(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params):
|
| 100 |
+
return F.swap_tiles_on_image(img, tiles)
|
| 101 |
+
|
| 102 |
+
def apply_to_mask(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params):
|
| 103 |
+
return F.swap_tiles_on_image(img, tiles)
|
| 104 |
+
|
| 105 |
+
def apply_to_keypoint(
|
| 106 |
+
self, keypoint: Tuple[float, ...], tiles: np.ndarray = np.array(None), rows: int = 0, cols: int = 0, **params
|
| 107 |
+
):
|
| 108 |
+
for (
|
| 109 |
+
current_left_up_corner_row,
|
| 110 |
+
current_left_up_corner_col,
|
| 111 |
+
old_left_up_corner_row,
|
| 112 |
+
old_left_up_corner_col,
|
| 113 |
+
height_tile,
|
| 114 |
+
width_tile,
|
| 115 |
+
) in tiles:
|
| 116 |
+
x, y = keypoint[:2]
|
| 117 |
+
|
| 118 |
+
if (old_left_up_corner_row <= y < (old_left_up_corner_row + height_tile)) and (
|
| 119 |
+
old_left_up_corner_col <= x < (old_left_up_corner_col + width_tile)
|
| 120 |
+
):
|
| 121 |
+
x = x - old_left_up_corner_col + current_left_up_corner_col
|
| 122 |
+
y = y - old_left_up_corner_row + current_left_up_corner_row
|
| 123 |
+
keypoint = (x, y) + tuple(keypoint[2:])
|
| 124 |
+
break
|
| 125 |
+
|
| 126 |
+
return keypoint
|
| 127 |
+
|
| 128 |
+
def get_params_dependent_on_targets(self, params):
|
| 129 |
+
height, width = params["image"].shape[:2]
|
| 130 |
+
n, m = self.grid
|
| 131 |
+
|
| 132 |
+
if n <= 0 or m <= 0:
|
| 133 |
+
raise ValueError("Grid's values must be positive. Current grid [%s, %s]" % (n, m))
|
| 134 |
+
|
| 135 |
+
if n > height // 2 or m > width // 2:
|
| 136 |
+
raise ValueError("Incorrect size cell of grid. Just shuffle pixels of image")
|
| 137 |
+
|
| 138 |
+
height_split = np.linspace(0, height, n + 1, dtype=np.int32)
|
| 139 |
+
width_split = np.linspace(0, width, m + 1, dtype=np.int32)
|
| 140 |
+
|
| 141 |
+
height_matrix, width_matrix = np.meshgrid(height_split, width_split, indexing="ij")
|
| 142 |
+
|
| 143 |
+
index_height_matrix = height_matrix[:-1, :-1]
|
| 144 |
+
index_width_matrix = width_matrix[:-1, :-1]
|
| 145 |
+
|
| 146 |
+
shifted_index_height_matrix = height_matrix[1:, 1:]
|
| 147 |
+
shifted_index_width_matrix = width_matrix[1:, 1:]
|
| 148 |
+
|
| 149 |
+
height_tile_sizes = shifted_index_height_matrix - index_height_matrix
|
| 150 |
+
width_tile_sizes = shifted_index_width_matrix - index_width_matrix
|
| 151 |
+
|
| 152 |
+
tiles_sizes = np.stack((height_tile_sizes, width_tile_sizes), axis=2)
|
| 153 |
+
|
| 154 |
+
index_matrix = np.indices((n, m))
|
| 155 |
+
new_index_matrix = np.stack(index_matrix, axis=2)
|
| 156 |
+
|
| 157 |
+
for bbox_size in np.unique(tiles_sizes.reshape(-1, 2), axis=0):
|
| 158 |
+
eq_mat = np.all(tiles_sizes == bbox_size, axis=2)
|
| 159 |
+
new_index_matrix[eq_mat] = random_utils.permutation(new_index_matrix[eq_mat])
|
| 160 |
+
|
| 161 |
+
new_index_matrix = np.split(new_index_matrix, 2, axis=2)
|
| 162 |
+
|
| 163 |
+
old_x = index_height_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)
|
| 164 |
+
old_y = index_width_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)
|
| 165 |
+
|
| 166 |
+
shift_x = height_tile_sizes.reshape(-1)
|
| 167 |
+
shift_y = width_tile_sizes.reshape(-1)
|
| 168 |
+
|
| 169 |
+
curr_x = index_height_matrix.reshape(-1)
|
| 170 |
+
curr_y = index_width_matrix.reshape(-1)
|
| 171 |
+
|
| 172 |
+
tiles = np.stack([curr_x, curr_y, old_x, old_y, shift_x, shift_y], axis=1)
|
| 173 |
+
|
| 174 |
+
return {"tiles": tiles}
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def targets_as_params(self):
|
| 178 |
+
return ["image"]
|
| 179 |
+
|
| 180 |
+
def get_transform_init_args_names(self):
|
| 181 |
+
return ("grid",)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class Normalize(ImageOnlyTransform):
|
| 185 |
+
"""Normalization is applied by the formula: `img = (img - mean * max_pixel_value) / (std * max_pixel_value)`
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
mean (float, list of float): mean values
|
| 189 |
+
std (float, list of float): std values
|
| 190 |
+
max_pixel_value (float): maximum possible pixel value
|
| 191 |
+
|
| 192 |
+
Targets:
|
| 193 |
+
image
|
| 194 |
+
|
| 195 |
+
Image types:
|
| 196 |
+
uint8, float32
|
| 197 |
+
"""
|
| 198 |
+
|
| 199 |
+
def __init__(
|
| 200 |
+
self,
|
| 201 |
+
mean=(0.485, 0.456, 0.406),
|
| 202 |
+
std=(0.229, 0.224, 0.225),
|
| 203 |
+
max_pixel_value=255.0,
|
| 204 |
+
always_apply=False,
|
| 205 |
+
p=1.0,
|
| 206 |
+
):
|
| 207 |
+
super(Normalize, self).__init__(always_apply, p)
|
| 208 |
+
self.mean = mean
|
| 209 |
+
self.std = std
|
| 210 |
+
self.max_pixel_value = max_pixel_value
|
| 211 |
+
|
| 212 |
+
def apply(self, image, **params):
|
| 213 |
+
return F.normalize(image, self.mean, self.std, self.max_pixel_value)
|
| 214 |
+
|
| 215 |
+
def get_transform_init_args_names(self):
|
| 216 |
+
return ("mean", "std", "max_pixel_value")
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class ImageCompression(ImageOnlyTransform):
|
| 220 |
+
"""Decreases image quality by Jpeg, WebP compression of an image.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
quality_lower (float): lower bound on the image quality.
|
| 224 |
+
Should be in [0, 100] range for jpeg and [1, 100] for webp.
|
| 225 |
+
quality_upper (float): upper bound on the image quality.
|
| 226 |
+
Should be in [0, 100] range for jpeg and [1, 100] for webp.
|
| 227 |
+
compression_type (ImageCompressionType): should be ImageCompressionType.JPEG or ImageCompressionType.WEBP.
|
| 228 |
+
Default: ImageCompressionType.JPEG
|
| 229 |
+
|
| 230 |
+
Targets:
|
| 231 |
+
image
|
| 232 |
+
|
| 233 |
+
Image types:
|
| 234 |
+
uint8, float32
|
| 235 |
+
"""
|
| 236 |
+
|
| 237 |
+
class ImageCompressionType(IntEnum):
|
| 238 |
+
JPEG = 0
|
| 239 |
+
WEBP = 1
|
| 240 |
+
|
| 241 |
+
def __init__(
|
| 242 |
+
self,
|
| 243 |
+
quality_lower=99,
|
| 244 |
+
quality_upper=100,
|
| 245 |
+
compression_type=ImageCompressionType.JPEG,
|
| 246 |
+
always_apply=False,
|
| 247 |
+
p=0.5,
|
| 248 |
+
):
|
| 249 |
+
super(ImageCompression, self).__init__(always_apply, p)
|
| 250 |
+
|
| 251 |
+
self.compression_type = ImageCompression.ImageCompressionType(compression_type)
|
| 252 |
+
low_thresh_quality_assert = 0
|
| 253 |
+
|
| 254 |
+
if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
|
| 255 |
+
low_thresh_quality_assert = 1
|
| 256 |
+
|
| 257 |
+
if not low_thresh_quality_assert <= quality_lower <= 100:
|
| 258 |
+
raise ValueError("Invalid quality_lower. Got: {}".format(quality_lower))
|
| 259 |
+
if not low_thresh_quality_assert <= quality_upper <= 100:
|
| 260 |
+
raise ValueError("Invalid quality_upper. Got: {}".format(quality_upper))
|
| 261 |
+
|
| 262 |
+
self.quality_lower = quality_lower
|
| 263 |
+
self.quality_upper = quality_upper
|
| 264 |
+
|
| 265 |
+
def apply(self, image, quality=100, image_type=".jpg", **params):
|
| 266 |
+
if not image.ndim == 2 and image.shape[-1] not in (1, 3, 4):
|
| 267 |
+
raise TypeError("ImageCompression transformation expects 1, 3 or 4 channel images.")
|
| 268 |
+
return F.image_compression(image, quality, image_type)
|
| 269 |
+
|
| 270 |
+
def get_params(self):
|
| 271 |
+
image_type = ".jpg"
|
| 272 |
+
|
| 273 |
+
if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
|
| 274 |
+
image_type = ".webp"
|
| 275 |
+
|
| 276 |
+
return {
|
| 277 |
+
"quality": random.randint(self.quality_lower, self.quality_upper),
|
| 278 |
+
"image_type": image_type,
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
def get_transform_init_args(self):
|
| 282 |
+
return {
|
| 283 |
+
"quality_lower": self.quality_lower,
|
| 284 |
+
"quality_upper": self.quality_upper,
|
| 285 |
+
"compression_type": self.compression_type.value,
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
class JpegCompression(ImageCompression):
|
| 290 |
+
"""Decreases image quality by Jpeg compression of an image.
|
| 291 |
+
|
| 292 |
+
Args:
|
| 293 |
+
quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range
|
| 294 |
+
quality_upper (float): upper bound on the jpeg quality. Should be in [0, 100] range
|
| 295 |
+
|
| 296 |
+
Targets:
|
| 297 |
+
image
|
| 298 |
+
|
| 299 |
+
Image types:
|
| 300 |
+
uint8, float32
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
def __init__(self, quality_lower=99, quality_upper=100, always_apply=False, p=0.5):
|
| 304 |
+
super(JpegCompression, self).__init__(
|
| 305 |
+
quality_lower=quality_lower,
|
| 306 |
+
quality_upper=quality_upper,
|
| 307 |
+
compression_type=ImageCompression.ImageCompressionType.JPEG,
|
| 308 |
+
always_apply=always_apply,
|
| 309 |
+
p=p,
|
| 310 |
+
)
|
| 311 |
+
warnings.warn(
|
| 312 |
+
f"{self.__class__.__name__} has been deprecated. Please use ImageCompression",
|
| 313 |
+
FutureWarning,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
def get_transform_init_args(self):
|
| 317 |
+
return {
|
| 318 |
+
"quality_lower": self.quality_lower,
|
| 319 |
+
"quality_upper": self.quality_upper,
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
class RandomSnow(ImageOnlyTransform):
|
| 324 |
+
"""Bleach out some pixel values simulating snow.
|
| 325 |
+
|
| 326 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 327 |
+
|
| 328 |
+
Args:
|
| 329 |
+
snow_point_lower (float): lower_bond of the amount of snow. Should be in [0, 1] range
|
| 330 |
+
snow_point_upper (float): upper_bond of the amount of snow. Should be in [0, 1] range
|
| 331 |
+
brightness_coeff (float): larger number will lead to a more snow on the image. Should be >= 0
|
| 332 |
+
|
| 333 |
+
Targets:
|
| 334 |
+
image
|
| 335 |
+
|
| 336 |
+
Image types:
|
| 337 |
+
uint8, float32
|
| 338 |
+
"""
|
| 339 |
+
|
| 340 |
+
def __init__(
|
| 341 |
+
self,
|
| 342 |
+
snow_point_lower=0.1,
|
| 343 |
+
snow_point_upper=0.3,
|
| 344 |
+
brightness_coeff=2.5,
|
| 345 |
+
always_apply=False,
|
| 346 |
+
p=0.5,
|
| 347 |
+
):
|
| 348 |
+
super(RandomSnow, self).__init__(always_apply, p)
|
| 349 |
+
|
| 350 |
+
if not 0 <= snow_point_lower <= snow_point_upper <= 1:
|
| 351 |
+
raise ValueError(
|
| 352 |
+
"Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format(
|
| 353 |
+
(snow_point_lower, snow_point_upper)
|
| 354 |
+
)
|
| 355 |
+
)
|
| 356 |
+
if brightness_coeff < 0:
|
| 357 |
+
raise ValueError("brightness_coeff must be greater than 0. Got: {}".format(brightness_coeff))
|
| 358 |
+
|
| 359 |
+
self.snow_point_lower = snow_point_lower
|
| 360 |
+
self.snow_point_upper = snow_point_upper
|
| 361 |
+
self.brightness_coeff = brightness_coeff
|
| 362 |
+
|
| 363 |
+
def apply(self, image, snow_point=0.1, **params):
|
| 364 |
+
return F.add_snow(image, snow_point, self.brightness_coeff)
|
| 365 |
+
|
| 366 |
+
def get_params(self):
|
| 367 |
+
return {"snow_point": random.uniform(self.snow_point_lower, self.snow_point_upper)}
|
| 368 |
+
|
| 369 |
+
def get_transform_init_args_names(self):
|
| 370 |
+
return ("snow_point_lower", "snow_point_upper", "brightness_coeff")
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
class RandomGravel(ImageOnlyTransform):
|
| 374 |
+
"""Add gravels.
|
| 375 |
+
|
| 376 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 377 |
+
|
| 378 |
+
Args:
|
| 379 |
+
gravel_roi (float, float, float, float): (top-left x, top-left y,
|
| 380 |
+
bottom-right x, bottom right y). Should be in [0, 1] range
|
| 381 |
+
number_of_patches (int): no. of gravel patches required
|
| 382 |
+
|
| 383 |
+
Targets:
|
| 384 |
+
image
|
| 385 |
+
|
| 386 |
+
Image types:
|
| 387 |
+
uint8, float32
|
| 388 |
+
"""
|
| 389 |
+
|
| 390 |
+
def __init__(
|
| 391 |
+
self,
|
| 392 |
+
gravel_roi: tuple = (0.1, 0.4, 0.9, 0.9),
|
| 393 |
+
number_of_patches: int = 2,
|
| 394 |
+
always_apply: bool = False,
|
| 395 |
+
p: float = 0.5,
|
| 396 |
+
):
|
| 397 |
+
super(RandomGravel, self).__init__(always_apply, p)
|
| 398 |
+
|
| 399 |
+
(gravel_lower_x, gravel_lower_y, gravel_upper_x, gravel_upper_y) = gravel_roi
|
| 400 |
+
|
| 401 |
+
if not 0 <= gravel_lower_x < gravel_upper_x <= 1 or not 0 <= gravel_lower_y < gravel_upper_y <= 1:
|
| 402 |
+
raise ValueError("Invalid gravel_roi. Got: %s." % gravel_roi)
|
| 403 |
+
if number_of_patches < 1:
|
| 404 |
+
raise ValueError("Invalid gravel number_of_patches. Got: %s." % number_of_patches)
|
| 405 |
+
|
| 406 |
+
self.gravel_roi = gravel_roi
|
| 407 |
+
self.number_of_patches = number_of_patches
|
| 408 |
+
|
| 409 |
+
def generate_gravel_patch(self, rectangular_roi):
|
| 410 |
+
x1, y1, x2, y2 = rectangular_roi
|
| 411 |
+
gravels = []
|
| 412 |
+
area = abs((x2 - x1) * (y2 - y1))
|
| 413 |
+
count = area // 10
|
| 414 |
+
gravels = np.empty([count, 2], dtype=np.int64)
|
| 415 |
+
gravels[:, 0] = random_utils.randint(x1, x2, count)
|
| 416 |
+
gravels[:, 1] = random_utils.randint(y1, y2, count)
|
| 417 |
+
return gravels
|
| 418 |
+
|
| 419 |
+
def apply(self, image, gravels_infos=(), **params):
|
| 420 |
+
return F.add_gravel(image, gravels_infos)
|
| 421 |
+
|
| 422 |
+
@property
|
| 423 |
+
def targets_as_params(self):
|
| 424 |
+
return ["image"]
|
| 425 |
+
|
| 426 |
+
def get_params_dependent_on_targets(self, params):
|
| 427 |
+
img = params["image"]
|
| 428 |
+
height, width = img.shape[:2]
|
| 429 |
+
|
| 430 |
+
x_min, y_min, x_max, y_max = self.gravel_roi
|
| 431 |
+
x_min = int(x_min * width)
|
| 432 |
+
x_max = int(x_max * width)
|
| 433 |
+
y_min = int(y_min * height)
|
| 434 |
+
y_max = int(y_max * height)
|
| 435 |
+
|
| 436 |
+
max_height = 200
|
| 437 |
+
max_width = 30
|
| 438 |
+
|
| 439 |
+
rectangular_rois = np.zeros([self.number_of_patches, 4], dtype=np.int64)
|
| 440 |
+
xx1 = random_utils.randint(x_min + 1, x_max, self.number_of_patches) # xmax
|
| 441 |
+
xx2 = random_utils.randint(x_min, xx1) # xmin
|
| 442 |
+
yy1 = random_utils.randint(y_min + 1, y_max, self.number_of_patches) # ymax
|
| 443 |
+
yy2 = random_utils.randint(y_min, yy1) # ymin
|
| 444 |
+
|
| 445 |
+
rectangular_rois[:, 0] = xx2
|
| 446 |
+
rectangular_rois[:, 1] = yy2
|
| 447 |
+
rectangular_rois[:, 2] = [min(tup) for tup in zip(xx1, xx2 + max_height)]
|
| 448 |
+
rectangular_rois[:, 3] = [min(tup) for tup in zip(yy1, yy2 + max_width)]
|
| 449 |
+
|
| 450 |
+
minx = []
|
| 451 |
+
maxx = []
|
| 452 |
+
miny = []
|
| 453 |
+
maxy = []
|
| 454 |
+
val = []
|
| 455 |
+
for roi in rectangular_rois:
|
| 456 |
+
gravels = self.generate_gravel_patch(roi)
|
| 457 |
+
x = gravels[:, 0]
|
| 458 |
+
y = gravels[:, 1]
|
| 459 |
+
r = random_utils.randint(1, 4, len(gravels))
|
| 460 |
+
sat = random_utils.randint(0, 255, len(gravels))
|
| 461 |
+
miny.append(np.maximum(y - r, 0))
|
| 462 |
+
maxy.append(np.minimum(y + r, y))
|
| 463 |
+
minx.append(np.maximum(x - r, 0))
|
| 464 |
+
maxx.append(np.minimum(x + r, x))
|
| 465 |
+
val.append(sat)
|
| 466 |
+
|
| 467 |
+
return {
|
| 468 |
+
"gravels_infos": np.stack(
|
| 469 |
+
[
|
| 470 |
+
np.concatenate(miny),
|
| 471 |
+
np.concatenate(maxy),
|
| 472 |
+
np.concatenate(minx),
|
| 473 |
+
np.concatenate(maxx),
|
| 474 |
+
np.concatenate(val),
|
| 475 |
+
],
|
| 476 |
+
1,
|
| 477 |
+
)
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
def get_transform_init_args_names(self):
|
| 481 |
+
return {"gravel_roi": self.gravel_roi, "number_of_patches": self.number_of_patches}
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
class RandomRain(ImageOnlyTransform):
|
| 485 |
+
"""Adds rain effects.
|
| 486 |
+
|
| 487 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 488 |
+
|
| 489 |
+
Args:
|
| 490 |
+
slant_lower: should be in range [-20, 20].
|
| 491 |
+
slant_upper: should be in range [-20, 20].
|
| 492 |
+
drop_length: should be in range [0, 100].
|
| 493 |
+
drop_width: should be in range [1, 5].
|
| 494 |
+
drop_color (list of (r, g, b)): rain lines color.
|
| 495 |
+
blur_value (int): rainy view are blurry
|
| 496 |
+
brightness_coefficient (float): rainy days are usually shady. Should be in range [0, 1].
|
| 497 |
+
rain_type: One of [None, "drizzle", "heavy", "torrential"]
|
| 498 |
+
|
| 499 |
+
Targets:
|
| 500 |
+
image
|
| 501 |
+
|
| 502 |
+
Image types:
|
| 503 |
+
uint8, float32
|
| 504 |
+
"""
|
| 505 |
+
|
| 506 |
+
def __init__(
|
| 507 |
+
self,
|
| 508 |
+
slant_lower=-10,
|
| 509 |
+
slant_upper=10,
|
| 510 |
+
drop_length=20,
|
| 511 |
+
drop_width=1,
|
| 512 |
+
drop_color=(200, 200, 200),
|
| 513 |
+
blur_value=7,
|
| 514 |
+
brightness_coefficient=0.7,
|
| 515 |
+
rain_type=None,
|
| 516 |
+
always_apply=False,
|
| 517 |
+
p=0.5,
|
| 518 |
+
):
|
| 519 |
+
super(RandomRain, self).__init__(always_apply, p)
|
| 520 |
+
|
| 521 |
+
if rain_type not in ["drizzle", "heavy", "torrential", None]:
|
| 522 |
+
raise ValueError(
|
| 523 |
+
"raint_type must be one of ({}). Got: {}".format(["drizzle", "heavy", "torrential", None], rain_type)
|
| 524 |
+
)
|
| 525 |
+
if not -20 <= slant_lower <= slant_upper <= 20:
|
| 526 |
+
raise ValueError(
|
| 527 |
+
"Invalid combination of slant_lower and slant_upper. Got: {}".format((slant_lower, slant_upper))
|
| 528 |
+
)
|
| 529 |
+
if not 1 <= drop_width <= 5:
|
| 530 |
+
raise ValueError("drop_width must be in range [1, 5]. Got: {}".format(drop_width))
|
| 531 |
+
if not 0 <= drop_length <= 100:
|
| 532 |
+
raise ValueError("drop_length must be in range [0, 100]. Got: {}".format(drop_length))
|
| 533 |
+
if not 0 <= brightness_coefficient <= 1:
|
| 534 |
+
raise ValueError("brightness_coefficient must be in range [0, 1]. Got: {}".format(brightness_coefficient))
|
| 535 |
+
|
| 536 |
+
self.slant_lower = slant_lower
|
| 537 |
+
self.slant_upper = slant_upper
|
| 538 |
+
|
| 539 |
+
self.drop_length = drop_length
|
| 540 |
+
self.drop_width = drop_width
|
| 541 |
+
self.drop_color = drop_color
|
| 542 |
+
self.blur_value = blur_value
|
| 543 |
+
self.brightness_coefficient = brightness_coefficient
|
| 544 |
+
self.rain_type = rain_type
|
| 545 |
+
|
| 546 |
+
def apply(self, image, slant=10, drop_length=20, rain_drops=(), **params):
|
| 547 |
+
return F.add_rain(
|
| 548 |
+
image,
|
| 549 |
+
slant,
|
| 550 |
+
drop_length,
|
| 551 |
+
self.drop_width,
|
| 552 |
+
self.drop_color,
|
| 553 |
+
self.blur_value,
|
| 554 |
+
self.brightness_coefficient,
|
| 555 |
+
rain_drops,
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
@property
|
| 559 |
+
def targets_as_params(self):
|
| 560 |
+
return ["image"]
|
| 561 |
+
|
| 562 |
+
def get_params_dependent_on_targets(self, params):
|
| 563 |
+
img = params["image"]
|
| 564 |
+
slant = int(random.uniform(self.slant_lower, self.slant_upper))
|
| 565 |
+
|
| 566 |
+
height, width = img.shape[:2]
|
| 567 |
+
area = height * width
|
| 568 |
+
|
| 569 |
+
if self.rain_type == "drizzle":
|
| 570 |
+
num_drops = area // 770
|
| 571 |
+
drop_length = 10
|
| 572 |
+
elif self.rain_type == "heavy":
|
| 573 |
+
num_drops = width * height // 600
|
| 574 |
+
drop_length = 30
|
| 575 |
+
elif self.rain_type == "torrential":
|
| 576 |
+
num_drops = area // 500
|
| 577 |
+
drop_length = 60
|
| 578 |
+
else:
|
| 579 |
+
drop_length = self.drop_length
|
| 580 |
+
num_drops = area // 600
|
| 581 |
+
|
| 582 |
+
rain_drops = []
|
| 583 |
+
|
| 584 |
+
for _i in range(num_drops): # If You want heavy rain, try increasing this
|
| 585 |
+
if slant < 0:
|
| 586 |
+
x = random.randint(slant, width)
|
| 587 |
+
else:
|
| 588 |
+
x = random.randint(0, width - slant)
|
| 589 |
+
|
| 590 |
+
y = random.randint(0, height - drop_length)
|
| 591 |
+
|
| 592 |
+
rain_drops.append((x, y))
|
| 593 |
+
|
| 594 |
+
return {"drop_length": drop_length, "slant": slant, "rain_drops": rain_drops}
|
| 595 |
+
|
| 596 |
+
def get_transform_init_args_names(self):
|
| 597 |
+
return (
|
| 598 |
+
"slant_lower",
|
| 599 |
+
"slant_upper",
|
| 600 |
+
"drop_length",
|
| 601 |
+
"drop_width",
|
| 602 |
+
"drop_color",
|
| 603 |
+
"blur_value",
|
| 604 |
+
"brightness_coefficient",
|
| 605 |
+
"rain_type",
|
| 606 |
+
)
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
class RandomFog(ImageOnlyTransform):
|
| 610 |
+
"""Simulates fog for the image
|
| 611 |
+
|
| 612 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 613 |
+
|
| 614 |
+
Args:
|
| 615 |
+
fog_coef_lower (float): lower limit for fog intensity coefficient. Should be in [0, 1] range.
|
| 616 |
+
fog_coef_upper (float): upper limit for fog intensity coefficient. Should be in [0, 1] range.
|
| 617 |
+
alpha_coef (float): transparency of the fog circles. Should be in [0, 1] range.
|
| 618 |
+
|
| 619 |
+
Targets:
|
| 620 |
+
image
|
| 621 |
+
|
| 622 |
+
Image types:
|
| 623 |
+
uint8, float32
|
| 624 |
+
"""
|
| 625 |
+
|
| 626 |
+
def __init__(
|
| 627 |
+
self,
|
| 628 |
+
fog_coef_lower=0.3,
|
| 629 |
+
fog_coef_upper=1,
|
| 630 |
+
alpha_coef=0.08,
|
| 631 |
+
always_apply=False,
|
| 632 |
+
p=0.5,
|
| 633 |
+
):
|
| 634 |
+
super(RandomFog, self).__init__(always_apply, p)
|
| 635 |
+
|
| 636 |
+
if not 0 <= fog_coef_lower <= fog_coef_upper <= 1:
|
| 637 |
+
raise ValueError(
|
| 638 |
+
"Invalid combination if fog_coef_lower and fog_coef_upper. Got: {}".format(
|
| 639 |
+
(fog_coef_lower, fog_coef_upper)
|
| 640 |
+
)
|
| 641 |
+
)
|
| 642 |
+
if not 0 <= alpha_coef <= 1:
|
| 643 |
+
raise ValueError("alpha_coef must be in range [0, 1]. Got: {}".format(alpha_coef))
|
| 644 |
+
|
| 645 |
+
self.fog_coef_lower = fog_coef_lower
|
| 646 |
+
self.fog_coef_upper = fog_coef_upper
|
| 647 |
+
self.alpha_coef = alpha_coef
|
| 648 |
+
|
| 649 |
+
def apply(self, image, fog_coef=0.1, haze_list=(), **params):
|
| 650 |
+
return F.add_fog(image, fog_coef, self.alpha_coef, haze_list)
|
| 651 |
+
|
| 652 |
+
@property
|
| 653 |
+
def targets_as_params(self):
|
| 654 |
+
return ["image"]
|
| 655 |
+
|
| 656 |
+
def get_params_dependent_on_targets(self, params):
|
| 657 |
+
img = params["image"]
|
| 658 |
+
fog_coef = random.uniform(self.fog_coef_lower, self.fog_coef_upper)
|
| 659 |
+
|
| 660 |
+
height, width = imshape = img.shape[:2]
|
| 661 |
+
|
| 662 |
+
hw = max(1, int(width // 3 * fog_coef))
|
| 663 |
+
|
| 664 |
+
haze_list = []
|
| 665 |
+
midx = width // 2 - 2 * hw
|
| 666 |
+
midy = height // 2 - hw
|
| 667 |
+
index = 1
|
| 668 |
+
|
| 669 |
+
while midx > -hw or midy > -hw:
|
| 670 |
+
for _i in range(hw // 10 * index):
|
| 671 |
+
x = random.randint(midx, width - midx - hw)
|
| 672 |
+
y = random.randint(midy, height - midy - hw)
|
| 673 |
+
haze_list.append((x, y))
|
| 674 |
+
|
| 675 |
+
midx -= 3 * hw * width // sum(imshape)
|
| 676 |
+
midy -= 3 * hw * height // sum(imshape)
|
| 677 |
+
index += 1
|
| 678 |
+
|
| 679 |
+
return {"haze_list": haze_list, "fog_coef": fog_coef}
|
| 680 |
+
|
| 681 |
+
def get_transform_init_args_names(self):
|
| 682 |
+
return ("fog_coef_lower", "fog_coef_upper", "alpha_coef")
|
| 683 |
+
|
| 684 |
+
|
| 685 |
+
class RandomSunFlare(ImageOnlyTransform):
|
| 686 |
+
"""Simulates Sun Flare for the image
|
| 687 |
+
|
| 688 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 689 |
+
|
| 690 |
+
Args:
|
| 691 |
+
flare_roi (float, float, float, float): region of the image where flare will
|
| 692 |
+
appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1].
|
| 693 |
+
angle_lower (float): should be in range [0, `angle_upper`].
|
| 694 |
+
angle_upper (float): should be in range [`angle_lower`, 1].
|
| 695 |
+
num_flare_circles_lower (int): lower limit for the number of flare circles.
|
| 696 |
+
Should be in range [0, `num_flare_circles_upper`].
|
| 697 |
+
num_flare_circles_upper (int): upper limit for the number of flare circles.
|
| 698 |
+
Should be in range [`num_flare_circles_lower`, inf].
|
| 699 |
+
src_radius (int):
|
| 700 |
+
src_color ((int, int, int)): color of the flare
|
| 701 |
+
|
| 702 |
+
Targets:
|
| 703 |
+
image
|
| 704 |
+
|
| 705 |
+
Image types:
|
| 706 |
+
uint8, float32
|
| 707 |
+
"""
|
| 708 |
+
|
| 709 |
+
def __init__(
|
| 710 |
+
self,
|
| 711 |
+
flare_roi=(0, 0, 1, 0.5),
|
| 712 |
+
angle_lower=0,
|
| 713 |
+
angle_upper=1,
|
| 714 |
+
num_flare_circles_lower=6,
|
| 715 |
+
num_flare_circles_upper=10,
|
| 716 |
+
src_radius=400,
|
| 717 |
+
src_color=(255, 255, 255),
|
| 718 |
+
always_apply=False,
|
| 719 |
+
p=0.5,
|
| 720 |
+
):
|
| 721 |
+
super(RandomSunFlare, self).__init__(always_apply, p)
|
| 722 |
+
|
| 723 |
+
(
|
| 724 |
+
flare_center_lower_x,
|
| 725 |
+
flare_center_lower_y,
|
| 726 |
+
flare_center_upper_x,
|
| 727 |
+
flare_center_upper_y,
|
| 728 |
+
) = flare_roi
|
| 729 |
+
|
| 730 |
+
if (
|
| 731 |
+
not 0 <= flare_center_lower_x < flare_center_upper_x <= 1
|
| 732 |
+
or not 0 <= flare_center_lower_y < flare_center_upper_y <= 1
|
| 733 |
+
):
|
| 734 |
+
raise ValueError("Invalid flare_roi. Got: {}".format(flare_roi))
|
| 735 |
+
if not 0 <= angle_lower < angle_upper <= 1:
|
| 736 |
+
raise ValueError(
|
| 737 |
+
"Invalid combination of angle_lower nad angle_upper. Got: {}".format((angle_lower, angle_upper))
|
| 738 |
+
)
|
| 739 |
+
if not 0 <= num_flare_circles_lower < num_flare_circles_upper:
|
| 740 |
+
raise ValueError(
|
| 741 |
+
"Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format(
|
| 742 |
+
(num_flare_circles_lower, num_flare_circles_upper)
|
| 743 |
+
)
|
| 744 |
+
)
|
| 745 |
+
|
| 746 |
+
self.flare_center_lower_x = flare_center_lower_x
|
| 747 |
+
self.flare_center_upper_x = flare_center_upper_x
|
| 748 |
+
|
| 749 |
+
self.flare_center_lower_y = flare_center_lower_y
|
| 750 |
+
self.flare_center_upper_y = flare_center_upper_y
|
| 751 |
+
|
| 752 |
+
self.angle_lower = angle_lower
|
| 753 |
+
self.angle_upper = angle_upper
|
| 754 |
+
self.num_flare_circles_lower = num_flare_circles_lower
|
| 755 |
+
self.num_flare_circles_upper = num_flare_circles_upper
|
| 756 |
+
|
| 757 |
+
self.src_radius = src_radius
|
| 758 |
+
self.src_color = src_color
|
| 759 |
+
|
| 760 |
+
def apply(self, image, flare_center_x=0.5, flare_center_y=0.5, circles=(), **params):
|
| 761 |
+
return F.add_sun_flare(
|
| 762 |
+
image,
|
| 763 |
+
flare_center_x,
|
| 764 |
+
flare_center_y,
|
| 765 |
+
self.src_radius,
|
| 766 |
+
self.src_color,
|
| 767 |
+
circles,
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
@property
|
| 771 |
+
def targets_as_params(self):
|
| 772 |
+
return ["image"]
|
| 773 |
+
|
| 774 |
+
def get_params_dependent_on_targets(self, params):
|
| 775 |
+
img = params["image"]
|
| 776 |
+
height, width = img.shape[:2]
|
| 777 |
+
|
| 778 |
+
angle = 2 * math.pi * random.uniform(self.angle_lower, self.angle_upper)
|
| 779 |
+
|
| 780 |
+
flare_center_x = random.uniform(self.flare_center_lower_x, self.flare_center_upper_x)
|
| 781 |
+
flare_center_y = random.uniform(self.flare_center_lower_y, self.flare_center_upper_y)
|
| 782 |
+
|
| 783 |
+
flare_center_x = int(width * flare_center_x)
|
| 784 |
+
flare_center_y = int(height * flare_center_y)
|
| 785 |
+
|
| 786 |
+
num_circles = random.randint(self.num_flare_circles_lower, self.num_flare_circles_upper)
|
| 787 |
+
|
| 788 |
+
circles = []
|
| 789 |
+
|
| 790 |
+
x = []
|
| 791 |
+
y = []
|
| 792 |
+
|
| 793 |
+
def line(t):
|
| 794 |
+
return (flare_center_x + t * math.cos(angle), flare_center_y + t * math.sin(angle))
|
| 795 |
+
|
| 796 |
+
for t_val in range(-flare_center_x, width - flare_center_x, 10):
|
| 797 |
+
rand_x, rand_y = line(t_val)
|
| 798 |
+
x.append(rand_x)
|
| 799 |
+
y.append(rand_y)
|
| 800 |
+
|
| 801 |
+
for _i in range(num_circles):
|
| 802 |
+
alpha = random.uniform(0.05, 0.2)
|
| 803 |
+
r = random.randint(0, len(x) - 1)
|
| 804 |
+
rad = random.randint(1, max(height // 100 - 2, 2))
|
| 805 |
+
|
| 806 |
+
r_color = random.randint(max(self.src_color[0] - 50, 0), self.src_color[0])
|
| 807 |
+
g_color = random.randint(max(self.src_color[1] - 50, 0), self.src_color[1])
|
| 808 |
+
b_color = random.randint(max(self.src_color[2] - 50, 0), self.src_color[2])
|
| 809 |
+
|
| 810 |
+
circles += [
|
| 811 |
+
(
|
| 812 |
+
alpha,
|
| 813 |
+
(int(x[r]), int(y[r])),
|
| 814 |
+
pow(rad, 3),
|
| 815 |
+
(r_color, g_color, b_color),
|
| 816 |
+
)
|
| 817 |
+
]
|
| 818 |
+
|
| 819 |
+
return {
|
| 820 |
+
"circles": circles,
|
| 821 |
+
"flare_center_x": flare_center_x,
|
| 822 |
+
"flare_center_y": flare_center_y,
|
| 823 |
+
}
|
| 824 |
+
|
| 825 |
+
def get_transform_init_args(self):
|
| 826 |
+
return {
|
| 827 |
+
"flare_roi": (
|
| 828 |
+
self.flare_center_lower_x,
|
| 829 |
+
self.flare_center_lower_y,
|
| 830 |
+
self.flare_center_upper_x,
|
| 831 |
+
self.flare_center_upper_y,
|
| 832 |
+
),
|
| 833 |
+
"angle_lower": self.angle_lower,
|
| 834 |
+
"angle_upper": self.angle_upper,
|
| 835 |
+
"num_flare_circles_lower": self.num_flare_circles_lower,
|
| 836 |
+
"num_flare_circles_upper": self.num_flare_circles_upper,
|
| 837 |
+
"src_radius": self.src_radius,
|
| 838 |
+
"src_color": self.src_color,
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class RandomShadow(ImageOnlyTransform):
|
| 843 |
+
"""Simulates shadows for the image
|
| 844 |
+
|
| 845 |
+
From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
|
| 846 |
+
|
| 847 |
+
Args:
|
| 848 |
+
shadow_roi (float, float, float, float): region of the image where shadows
|
| 849 |
+
will appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1].
|
| 850 |
+
num_shadows_lower (int): Lower limit for the possible number of shadows.
|
| 851 |
+
Should be in range [0, `num_shadows_upper`].
|
| 852 |
+
num_shadows_upper (int): Lower limit for the possible number of shadows.
|
| 853 |
+
Should be in range [`num_shadows_lower`, inf].
|
| 854 |
+
shadow_dimension (int): number of edges in the shadow polygons
|
| 855 |
+
|
| 856 |
+
Targets:
|
| 857 |
+
image
|
| 858 |
+
|
| 859 |
+
Image types:
|
| 860 |
+
uint8, float32
|
| 861 |
+
"""
|
| 862 |
+
|
| 863 |
+
def __init__(
|
| 864 |
+
self,
|
| 865 |
+
shadow_roi=(0, 0.5, 1, 1),
|
| 866 |
+
num_shadows_lower=1,
|
| 867 |
+
num_shadows_upper=2,
|
| 868 |
+
shadow_dimension=5,
|
| 869 |
+
always_apply=False,
|
| 870 |
+
p=0.5,
|
| 871 |
+
):
|
| 872 |
+
super(RandomShadow, self).__init__(always_apply, p)
|
| 873 |
+
|
| 874 |
+
(shadow_lower_x, shadow_lower_y, shadow_upper_x, shadow_upper_y) = shadow_roi
|
| 875 |
+
|
| 876 |
+
if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
|
| 877 |
+
raise ValueError("Invalid shadow_roi. Got: {}".format(shadow_roi))
|
| 878 |
+
if not 0 <= num_shadows_lower <= num_shadows_upper:
|
| 879 |
+
raise ValueError(
|
| 880 |
+
"Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format(
|
| 881 |
+
(num_shadows_lower, num_shadows_upper)
|
| 882 |
+
)
|
| 883 |
+
)
|
| 884 |
+
|
| 885 |
+
self.shadow_roi = shadow_roi
|
| 886 |
+
|
| 887 |
+
self.num_shadows_lower = num_shadows_lower
|
| 888 |
+
self.num_shadows_upper = num_shadows_upper
|
| 889 |
+
|
| 890 |
+
self.shadow_dimension = shadow_dimension
|
| 891 |
+
|
| 892 |
+
def apply(self, image, vertices_list=(), **params):
|
| 893 |
+
return F.add_shadow(image, vertices_list)
|
| 894 |
+
|
| 895 |
+
@property
|
| 896 |
+
def targets_as_params(self):
|
| 897 |
+
return ["image"]
|
| 898 |
+
|
| 899 |
+
def get_params_dependent_on_targets(self, params):
|
| 900 |
+
img = params["image"]
|
| 901 |
+
height, width = img.shape[:2]
|
| 902 |
+
|
| 903 |
+
num_shadows = random.randint(self.num_shadows_lower, self.num_shadows_upper)
|
| 904 |
+
|
| 905 |
+
x_min, y_min, x_max, y_max = self.shadow_roi
|
| 906 |
+
|
| 907 |
+
x_min = int(x_min * width)
|
| 908 |
+
x_max = int(x_max * width)
|
| 909 |
+
y_min = int(y_min * height)
|
| 910 |
+
y_max = int(y_max * height)
|
| 911 |
+
|
| 912 |
+
vertices_list = []
|
| 913 |
+
|
| 914 |
+
for _index in range(num_shadows):
|
| 915 |
+
vertex = []
|
| 916 |
+
for _dimension in range(self.shadow_dimension):
|
| 917 |
+
vertex.append((random.randint(x_min, x_max), random.randint(y_min, y_max)))
|
| 918 |
+
|
| 919 |
+
vertices = np.array([vertex], dtype=np.int32)
|
| 920 |
+
vertices_list.append(vertices)
|
| 921 |
+
|
| 922 |
+
return {"vertices_list": vertices_list}
|
| 923 |
+
|
| 924 |
+
def get_transform_init_args_names(self):
|
| 925 |
+
return (
|
| 926 |
+
"shadow_roi",
|
| 927 |
+
"num_shadows_lower",
|
| 928 |
+
"num_shadows_upper",
|
| 929 |
+
"shadow_dimension",
|
| 930 |
+
)
|
| 931 |
+
|
| 932 |
+
|
| 933 |
+
class RandomToneCurve(ImageOnlyTransform):
|
| 934 |
+
"""Randomly change the relationship between bright and dark areas of the image by manipulating its tone curve.
|
| 935 |
+
|
| 936 |
+
Args:
|
| 937 |
+
scale (float): standard deviation of the normal distribution.
|
| 938 |
+
Used to sample random distances to move two control points that modify the image's curve.
|
| 939 |
+
Values should be in range [0, 1]. Default: 0.1
|
| 940 |
+
|
| 941 |
+
|
| 942 |
+
Targets:
|
| 943 |
+
image
|
| 944 |
+
|
| 945 |
+
Image types:
|
| 946 |
+
uint8
|
| 947 |
+
"""
|
| 948 |
+
|
| 949 |
+
def __init__(
|
| 950 |
+
self,
|
| 951 |
+
scale=0.1,
|
| 952 |
+
always_apply=False,
|
| 953 |
+
p=0.5,
|
| 954 |
+
):
|
| 955 |
+
super(RandomToneCurve, self).__init__(always_apply, p)
|
| 956 |
+
self.scale = scale
|
| 957 |
+
|
| 958 |
+
def apply(self, image, low_y, high_y, **params):
|
| 959 |
+
return F.move_tone_curve(image, low_y, high_y)
|
| 960 |
+
|
| 961 |
+
def get_params(self):
|
| 962 |
+
return {
|
| 963 |
+
"low_y": np.clip(random_utils.normal(loc=0.25, scale=self.scale), 0, 1),
|
| 964 |
+
"high_y": np.clip(random_utils.normal(loc=0.75, scale=self.scale), 0, 1),
|
| 965 |
+
}
|
| 966 |
+
|
| 967 |
+
def get_transform_init_args_names(self):
|
| 968 |
+
return ("scale",)
|
| 969 |
+
|
| 970 |
+
|
| 971 |
+
class HueSaturationValue(ImageOnlyTransform):
|
| 972 |
+
"""Randomly change hue, saturation and value of the input image.
|
| 973 |
+
|
| 974 |
+
Args:
|
| 975 |
+
hue_shift_limit ((int, int) or int): range for changing hue. If hue_shift_limit is a single int, the range
|
| 976 |
+
will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20).
|
| 977 |
+
sat_shift_limit ((int, int) or int): range for changing saturation. If sat_shift_limit is a single int,
|
| 978 |
+
the range will be (-sat_shift_limit, sat_shift_limit). Default: (-30, 30).
|
| 979 |
+
val_shift_limit ((int, int) or int): range for changing value. If val_shift_limit is a single int, the range
|
| 980 |
+
will be (-val_shift_limit, val_shift_limit). Default: (-20, 20).
|
| 981 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 982 |
+
|
| 983 |
+
Targets:
|
| 984 |
+
image
|
| 985 |
+
|
| 986 |
+
Image types:
|
| 987 |
+
uint8, float32
|
| 988 |
+
"""
|
| 989 |
+
|
| 990 |
+
def __init__(
|
| 991 |
+
self,
|
| 992 |
+
hue_shift_limit=20,
|
| 993 |
+
sat_shift_limit=30,
|
| 994 |
+
val_shift_limit=20,
|
| 995 |
+
always_apply=False,
|
| 996 |
+
p=0.5,
|
| 997 |
+
):
|
| 998 |
+
super(HueSaturationValue, self).__init__(always_apply, p)
|
| 999 |
+
self.hue_shift_limit = to_tuple(hue_shift_limit)
|
| 1000 |
+
self.sat_shift_limit = to_tuple(sat_shift_limit)
|
| 1001 |
+
self.val_shift_limit = to_tuple(val_shift_limit)
|
| 1002 |
+
|
| 1003 |
+
def apply(self, image, hue_shift=0, sat_shift=0, val_shift=0, **params):
|
| 1004 |
+
if not is_rgb_image(image) and not is_grayscale_image(image):
|
| 1005 |
+
raise TypeError("HueSaturationValue transformation expects 1-channel or 3-channel images.")
|
| 1006 |
+
return F.shift_hsv(image, hue_shift, sat_shift, val_shift)
|
| 1007 |
+
|
| 1008 |
+
def get_params(self):
|
| 1009 |
+
return {
|
| 1010 |
+
"hue_shift": random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1]),
|
| 1011 |
+
"sat_shift": random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]),
|
| 1012 |
+
"val_shift": random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]),
|
| 1013 |
+
}
|
| 1014 |
+
|
| 1015 |
+
def get_transform_init_args_names(self):
|
| 1016 |
+
return ("hue_shift_limit", "sat_shift_limit", "val_shift_limit")
|
| 1017 |
+
|
| 1018 |
+
|
| 1019 |
+
class Solarize(ImageOnlyTransform):
|
| 1020 |
+
"""Invert all pixel values above a threshold.
|
| 1021 |
+
|
| 1022 |
+
Args:
|
| 1023 |
+
threshold ((int, int) or int, or (float, float) or float): range for solarizing threshold.
|
| 1024 |
+
If threshold is a single value, the range will be [threshold, threshold]. Default: 128.
|
| 1025 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1026 |
+
|
| 1027 |
+
Targets:
|
| 1028 |
+
image
|
| 1029 |
+
|
| 1030 |
+
Image types:
|
| 1031 |
+
any
|
| 1032 |
+
"""
|
| 1033 |
+
|
| 1034 |
+
def __init__(self, threshold=128, always_apply=False, p=0.5):
|
| 1035 |
+
super(Solarize, self).__init__(always_apply, p)
|
| 1036 |
+
|
| 1037 |
+
if isinstance(threshold, (int, float)):
|
| 1038 |
+
self.threshold = to_tuple(threshold, low=threshold)
|
| 1039 |
+
else:
|
| 1040 |
+
self.threshold = to_tuple(threshold, low=0)
|
| 1041 |
+
|
| 1042 |
+
def apply(self, image, threshold=0, **params):
|
| 1043 |
+
return F.solarize(image, threshold)
|
| 1044 |
+
|
| 1045 |
+
def get_params(self):
|
| 1046 |
+
return {"threshold": random.uniform(self.threshold[0], self.threshold[1])}
|
| 1047 |
+
|
| 1048 |
+
def get_transform_init_args_names(self):
|
| 1049 |
+
return ("threshold",)
|
| 1050 |
+
|
| 1051 |
+
|
| 1052 |
+
class Posterize(ImageOnlyTransform):
|
| 1053 |
+
"""Reduce the number of bits for each color channel.
|
| 1054 |
+
|
| 1055 |
+
Args:
|
| 1056 |
+
num_bits ((int, int) or int,
|
| 1057 |
+
or list of ints [r, g, b],
|
| 1058 |
+
or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits.
|
| 1059 |
+
If num_bits is a single value, the range will be [num_bits, num_bits].
|
| 1060 |
+
Must be in range [0, 8]. Default: 4.
|
| 1061 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1062 |
+
|
| 1063 |
+
Targets:
|
| 1064 |
+
image
|
| 1065 |
+
|
| 1066 |
+
Image types:
|
| 1067 |
+
uint8
|
| 1068 |
+
"""
|
| 1069 |
+
|
| 1070 |
+
def __init__(self, num_bits=4, always_apply=False, p=0.5):
|
| 1071 |
+
super(Posterize, self).__init__(always_apply, p)
|
| 1072 |
+
|
| 1073 |
+
if isinstance(num_bits, (list, tuple)):
|
| 1074 |
+
if len(num_bits) == 3:
|
| 1075 |
+
self.num_bits = [to_tuple(i, 0) for i in num_bits]
|
| 1076 |
+
else:
|
| 1077 |
+
self.num_bits = to_tuple(num_bits, 0)
|
| 1078 |
+
else:
|
| 1079 |
+
self.num_bits = to_tuple(num_bits, num_bits)
|
| 1080 |
+
|
| 1081 |
+
def apply(self, image, num_bits=1, **params):
|
| 1082 |
+
return F.posterize(image, num_bits)
|
| 1083 |
+
|
| 1084 |
+
def get_params(self):
|
| 1085 |
+
if len(self.num_bits) == 3:
|
| 1086 |
+
return {"num_bits": [random.randint(i[0], i[1]) for i in self.num_bits]}
|
| 1087 |
+
return {"num_bits": random.randint(self.num_bits[0], self.num_bits[1])}
|
| 1088 |
+
|
| 1089 |
+
def get_transform_init_args_names(self):
|
| 1090 |
+
return ("num_bits",)
|
| 1091 |
+
|
| 1092 |
+
|
| 1093 |
+
class Equalize(ImageOnlyTransform):
|
| 1094 |
+
"""Equalize the image histogram.
|
| 1095 |
+
|
| 1096 |
+
Args:
|
| 1097 |
+
mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
|
| 1098 |
+
by_channels (bool): If True, use equalization by channels separately,
|
| 1099 |
+
else convert image to YCbCr representation and use equalization by `Y` channel.
|
| 1100 |
+
mask (np.ndarray, callable): If given, only the pixels selected by
|
| 1101 |
+
the mask are included in the analysis. Maybe 1 channel or 3 channel array or callable.
|
| 1102 |
+
Function signature must include `image` argument.
|
| 1103 |
+
mask_params (list of str): Params for mask function.
|
| 1104 |
+
|
| 1105 |
+
Targets:
|
| 1106 |
+
image
|
| 1107 |
+
|
| 1108 |
+
Image types:
|
| 1109 |
+
uint8
|
| 1110 |
+
"""
|
| 1111 |
+
|
| 1112 |
+
def __init__(
|
| 1113 |
+
self,
|
| 1114 |
+
mode="cv",
|
| 1115 |
+
by_channels=True,
|
| 1116 |
+
mask=None,
|
| 1117 |
+
mask_params=(),
|
| 1118 |
+
always_apply=False,
|
| 1119 |
+
p=0.5,
|
| 1120 |
+
):
|
| 1121 |
+
modes = ["cv", "pil"]
|
| 1122 |
+
if mode not in modes:
|
| 1123 |
+
raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode))
|
| 1124 |
+
|
| 1125 |
+
super(Equalize, self).__init__(always_apply, p)
|
| 1126 |
+
self.mode = mode
|
| 1127 |
+
self.by_channels = by_channels
|
| 1128 |
+
self.mask = mask
|
| 1129 |
+
self.mask_params = mask_params
|
| 1130 |
+
|
| 1131 |
+
def apply(self, image, mask=None, **params):
|
| 1132 |
+
return F.equalize(image, mode=self.mode, by_channels=self.by_channels, mask=mask)
|
| 1133 |
+
|
| 1134 |
+
def get_params_dependent_on_targets(self, params):
|
| 1135 |
+
if not callable(self.mask):
|
| 1136 |
+
return {"mask": self.mask}
|
| 1137 |
+
|
| 1138 |
+
return {"mask": self.mask(**params)}
|
| 1139 |
+
|
| 1140 |
+
@property
|
| 1141 |
+
def targets_as_params(self):
|
| 1142 |
+
return ["image"] + list(self.mask_params)
|
| 1143 |
+
|
| 1144 |
+
def get_transform_init_args_names(self):
|
| 1145 |
+
return ("mode", "by_channels")
|
| 1146 |
+
|
| 1147 |
+
|
| 1148 |
+
class RGBShift(ImageOnlyTransform):
|
| 1149 |
+
"""Randomly shift values for each channel of the input RGB image.
|
| 1150 |
+
|
| 1151 |
+
Args:
|
| 1152 |
+
r_shift_limit ((int, int) or int): range for changing values for the red channel. If r_shift_limit is a single
|
| 1153 |
+
int, the range will be (-r_shift_limit, r_shift_limit). Default: (-20, 20).
|
| 1154 |
+
g_shift_limit ((int, int) or int): range for changing values for the green channel. If g_shift_limit is a
|
| 1155 |
+
single int, the range will be (-g_shift_limit, g_shift_limit). Default: (-20, 20).
|
| 1156 |
+
b_shift_limit ((int, int) or int): range for changing values for the blue channel. If b_shift_limit is a single
|
| 1157 |
+
int, the range will be (-b_shift_limit, b_shift_limit). Default: (-20, 20).
|
| 1158 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1159 |
+
|
| 1160 |
+
Targets:
|
| 1161 |
+
image
|
| 1162 |
+
|
| 1163 |
+
Image types:
|
| 1164 |
+
uint8, float32
|
| 1165 |
+
"""
|
| 1166 |
+
|
| 1167 |
+
def __init__(
|
| 1168 |
+
self,
|
| 1169 |
+
r_shift_limit=20,
|
| 1170 |
+
g_shift_limit=20,
|
| 1171 |
+
b_shift_limit=20,
|
| 1172 |
+
always_apply=False,
|
| 1173 |
+
p=0.5,
|
| 1174 |
+
):
|
| 1175 |
+
super(RGBShift, self).__init__(always_apply, p)
|
| 1176 |
+
self.r_shift_limit = to_tuple(r_shift_limit)
|
| 1177 |
+
self.g_shift_limit = to_tuple(g_shift_limit)
|
| 1178 |
+
self.b_shift_limit = to_tuple(b_shift_limit)
|
| 1179 |
+
|
| 1180 |
+
def apply(self, image, r_shift=0, g_shift=0, b_shift=0, **params):
|
| 1181 |
+
if not is_rgb_image(image):
|
| 1182 |
+
raise TypeError("RGBShift transformation expects 3-channel images.")
|
| 1183 |
+
return F.shift_rgb(image, r_shift, g_shift, b_shift)
|
| 1184 |
+
|
| 1185 |
+
def get_params(self):
|
| 1186 |
+
return {
|
| 1187 |
+
"r_shift": random.uniform(self.r_shift_limit[0], self.r_shift_limit[1]),
|
| 1188 |
+
"g_shift": random.uniform(self.g_shift_limit[0], self.g_shift_limit[1]),
|
| 1189 |
+
"b_shift": random.uniform(self.b_shift_limit[0], self.b_shift_limit[1]),
|
| 1190 |
+
}
|
| 1191 |
+
|
| 1192 |
+
def get_transform_init_args_names(self):
|
| 1193 |
+
return ("r_shift_limit", "g_shift_limit", "b_shift_limit")
|
| 1194 |
+
|
| 1195 |
+
|
| 1196 |
+
class RandomBrightnessContrast(ImageOnlyTransform):
|
| 1197 |
+
"""Randomly change brightness and contrast of the input image.
|
| 1198 |
+
|
| 1199 |
+
Args:
|
| 1200 |
+
brightness_limit ((float, float) or float): factor range for changing brightness.
|
| 1201 |
+
If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
|
| 1202 |
+
contrast_limit ((float, float) or float): factor range for changing contrast.
|
| 1203 |
+
If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
|
| 1204 |
+
brightness_by_max (Boolean): If True adjust contrast by image dtype maximum,
|
| 1205 |
+
else adjust contrast by image mean.
|
| 1206 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1207 |
+
|
| 1208 |
+
Targets:
|
| 1209 |
+
image
|
| 1210 |
+
|
| 1211 |
+
Image types:
|
| 1212 |
+
uint8, float32
|
| 1213 |
+
"""
|
| 1214 |
+
|
| 1215 |
+
def __init__(
|
| 1216 |
+
self,
|
| 1217 |
+
brightness_limit=0.2,
|
| 1218 |
+
contrast_limit=0.2,
|
| 1219 |
+
brightness_by_max=True,
|
| 1220 |
+
always_apply=False,
|
| 1221 |
+
p=0.5,
|
| 1222 |
+
):
|
| 1223 |
+
super(RandomBrightnessContrast, self).__init__(always_apply, p)
|
| 1224 |
+
self.brightness_limit = to_tuple(brightness_limit)
|
| 1225 |
+
self.contrast_limit = to_tuple(contrast_limit)
|
| 1226 |
+
self.brightness_by_max = brightness_by_max
|
| 1227 |
+
|
| 1228 |
+
def apply(self, img, alpha=1.0, beta=0.0, **params):
|
| 1229 |
+
return F.brightness_contrast_adjust(img, alpha, beta, self.brightness_by_max)
|
| 1230 |
+
|
| 1231 |
+
def get_params(self):
|
| 1232 |
+
return {
|
| 1233 |
+
"alpha": 1.0 + random.uniform(self.contrast_limit[0], self.contrast_limit[1]),
|
| 1234 |
+
"beta": 0.0 + random.uniform(self.brightness_limit[0], self.brightness_limit[1]),
|
| 1235 |
+
}
|
| 1236 |
+
|
| 1237 |
+
def get_transform_init_args_names(self):
|
| 1238 |
+
return ("brightness_limit", "contrast_limit", "brightness_by_max")
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
class RandomBrightness(RandomBrightnessContrast):
|
| 1242 |
+
"""Randomly change brightness of the input image.
|
| 1243 |
+
|
| 1244 |
+
Args:
|
| 1245 |
+
limit ((float, float) or float): factor range for changing brightness.
|
| 1246 |
+
If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
|
| 1247 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1248 |
+
|
| 1249 |
+
Targets:
|
| 1250 |
+
image
|
| 1251 |
+
|
| 1252 |
+
Image types:
|
| 1253 |
+
uint8, float32
|
| 1254 |
+
"""
|
| 1255 |
+
|
| 1256 |
+
def __init__(self, limit=0.2, always_apply=False, p=0.5):
|
| 1257 |
+
super(RandomBrightness, self).__init__(brightness_limit=limit, contrast_limit=0, always_apply=always_apply, p=p)
|
| 1258 |
+
warnings.warn(
|
| 1259 |
+
"This class has been deprecated. Please use RandomBrightnessContrast",
|
| 1260 |
+
FutureWarning,
|
| 1261 |
+
)
|
| 1262 |
+
|
| 1263 |
+
def get_transform_init_args(self):
|
| 1264 |
+
return {"limit": self.brightness_limit}
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
class RandomContrast(RandomBrightnessContrast):
|
| 1268 |
+
"""Randomly change contrast of the input image.
|
| 1269 |
+
|
| 1270 |
+
Args:
|
| 1271 |
+
limit ((float, float) or float): factor range for changing contrast.
|
| 1272 |
+
If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
|
| 1273 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1274 |
+
|
| 1275 |
+
Targets:
|
| 1276 |
+
image
|
| 1277 |
+
|
| 1278 |
+
Image types:
|
| 1279 |
+
uint8, float32
|
| 1280 |
+
"""
|
| 1281 |
+
|
| 1282 |
+
def __init__(self, limit=0.2, always_apply=False, p=0.5):
|
| 1283 |
+
super(RandomContrast, self).__init__(brightness_limit=0, contrast_limit=limit, always_apply=always_apply, p=p)
|
| 1284 |
+
warnings.warn(
|
| 1285 |
+
f"{self.__class__.__name__} has been deprecated. Please use RandomBrightnessContrast",
|
| 1286 |
+
FutureWarning,
|
| 1287 |
+
)
|
| 1288 |
+
|
| 1289 |
+
def get_transform_init_args(self):
|
| 1290 |
+
return {"limit": self.contrast_limit}
|
| 1291 |
+
|
| 1292 |
+
|
| 1293 |
+
class GaussNoise(ImageOnlyTransform):
|
| 1294 |
+
"""Apply gaussian noise to the input image.
|
| 1295 |
+
|
| 1296 |
+
Args:
|
| 1297 |
+
var_limit ((float, float) or float): variance range for noise. If var_limit is a single float, the range
|
| 1298 |
+
will be (0, var_limit). Default: (10.0, 50.0).
|
| 1299 |
+
mean (float): mean of the noise. Default: 0
|
| 1300 |
+
per_channel (bool): if set to True, noise will be sampled for each channel independently.
|
| 1301 |
+
Otherwise, the noise will be sampled once for all channels. Default: True
|
| 1302 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1303 |
+
|
| 1304 |
+
Targets:
|
| 1305 |
+
image
|
| 1306 |
+
|
| 1307 |
+
Image types:
|
| 1308 |
+
uint8, float32
|
| 1309 |
+
"""
|
| 1310 |
+
|
| 1311 |
+
def __init__(self, var_limit=(10.0, 50.0), mean=0, per_channel=True, always_apply=False, p=0.5):
|
| 1312 |
+
super(GaussNoise, self).__init__(always_apply, p)
|
| 1313 |
+
if isinstance(var_limit, (tuple, list)):
|
| 1314 |
+
if var_limit[0] < 0:
|
| 1315 |
+
raise ValueError("Lower var_limit should be non negative.")
|
| 1316 |
+
if var_limit[1] < 0:
|
| 1317 |
+
raise ValueError("Upper var_limit should be non negative.")
|
| 1318 |
+
self.var_limit = var_limit
|
| 1319 |
+
elif isinstance(var_limit, (int, float)):
|
| 1320 |
+
if var_limit < 0:
|
| 1321 |
+
raise ValueError("var_limit should be non negative.")
|
| 1322 |
+
|
| 1323 |
+
self.var_limit = (0, var_limit)
|
| 1324 |
+
else:
|
| 1325 |
+
raise TypeError(
|
| 1326 |
+
"Expected var_limit type to be one of (int, float, tuple, list), got {}".format(type(var_limit))
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
self.mean = mean
|
| 1330 |
+
self.per_channel = per_channel
|
| 1331 |
+
|
| 1332 |
+
def apply(self, img, gauss=None, **params):
|
| 1333 |
+
return F.gauss_noise(img, gauss=gauss)
|
| 1334 |
+
|
| 1335 |
+
def get_params_dependent_on_targets(self, params):
|
| 1336 |
+
image = params["image"]
|
| 1337 |
+
var = random.uniform(self.var_limit[0], self.var_limit[1])
|
| 1338 |
+
sigma = var**0.5
|
| 1339 |
+
|
| 1340 |
+
if self.per_channel:
|
| 1341 |
+
gauss = random_utils.normal(self.mean, sigma, image.shape)
|
| 1342 |
+
else:
|
| 1343 |
+
gauss = random_utils.normal(self.mean, sigma, image.shape[:2])
|
| 1344 |
+
if len(image.shape) == 3:
|
| 1345 |
+
gauss = np.expand_dims(gauss, -1)
|
| 1346 |
+
|
| 1347 |
+
return {"gauss": gauss}
|
| 1348 |
+
|
| 1349 |
+
@property
|
| 1350 |
+
def targets_as_params(self):
|
| 1351 |
+
return ["image"]
|
| 1352 |
+
|
| 1353 |
+
def get_transform_init_args_names(self):
|
| 1354 |
+
return ("var_limit", "per_channel", "mean")
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
class ISONoise(ImageOnlyTransform):
|
| 1358 |
+
"""
|
| 1359 |
+
Apply camera sensor noise.
|
| 1360 |
+
|
| 1361 |
+
Args:
|
| 1362 |
+
color_shift (float, float): variance range for color hue change.
|
| 1363 |
+
Measured as a fraction of 360 degree Hue angle in HLS colorspace.
|
| 1364 |
+
intensity ((float, float): Multiplicative factor that control strength
|
| 1365 |
+
of color and luminace noise.
|
| 1366 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1367 |
+
|
| 1368 |
+
Targets:
|
| 1369 |
+
image
|
| 1370 |
+
|
| 1371 |
+
Image types:
|
| 1372 |
+
uint8
|
| 1373 |
+
"""
|
| 1374 |
+
|
| 1375 |
+
def __init__(self, color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=0.5):
|
| 1376 |
+
super(ISONoise, self).__init__(always_apply, p)
|
| 1377 |
+
self.intensity = intensity
|
| 1378 |
+
self.color_shift = color_shift
|
| 1379 |
+
|
| 1380 |
+
def apply(self, img, color_shift=0.05, intensity=1.0, random_state=None, **params):
|
| 1381 |
+
return F.iso_noise(img, color_shift, intensity, np.random.RandomState(random_state))
|
| 1382 |
+
|
| 1383 |
+
def get_params(self):
|
| 1384 |
+
return {
|
| 1385 |
+
"color_shift": random.uniform(self.color_shift[0], self.color_shift[1]),
|
| 1386 |
+
"intensity": random.uniform(self.intensity[0], self.intensity[1]),
|
| 1387 |
+
"random_state": random.randint(0, 65536),
|
| 1388 |
+
}
|
| 1389 |
+
|
| 1390 |
+
def get_transform_init_args_names(self):
|
| 1391 |
+
return ("intensity", "color_shift")
|
| 1392 |
+
|
| 1393 |
+
|
| 1394 |
+
class CLAHE(ImageOnlyTransform):
|
| 1395 |
+
"""Apply Contrast Limited Adaptive Histogram Equalization to the input image.
|
| 1396 |
+
|
| 1397 |
+
Args:
|
| 1398 |
+
clip_limit (float or (float, float)): upper threshold value for contrast limiting.
|
| 1399 |
+
If clip_limit is a single float value, the range will be (1, clip_limit). Default: (1, 4).
|
| 1400 |
+
tile_grid_size ((int, int)): size of grid for histogram equalization. Default: (8, 8).
|
| 1401 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1402 |
+
|
| 1403 |
+
Targets:
|
| 1404 |
+
image
|
| 1405 |
+
|
| 1406 |
+
Image types:
|
| 1407 |
+
uint8
|
| 1408 |
+
"""
|
| 1409 |
+
|
| 1410 |
+
def __init__(self, clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5):
|
| 1411 |
+
super(CLAHE, self).__init__(always_apply, p)
|
| 1412 |
+
self.clip_limit = to_tuple(clip_limit, 1)
|
| 1413 |
+
self.tile_grid_size = tuple(tile_grid_size)
|
| 1414 |
+
|
| 1415 |
+
def apply(self, img, clip_limit=2, **params):
|
| 1416 |
+
if not is_rgb_image(img) and not is_grayscale_image(img):
|
| 1417 |
+
raise TypeError("CLAHE transformation expects 1-channel or 3-channel images.")
|
| 1418 |
+
|
| 1419 |
+
return F.clahe(img, clip_limit, self.tile_grid_size)
|
| 1420 |
+
|
| 1421 |
+
def get_params(self):
|
| 1422 |
+
return {"clip_limit": random.uniform(self.clip_limit[0], self.clip_limit[1])}
|
| 1423 |
+
|
| 1424 |
+
def get_transform_init_args_names(self):
|
| 1425 |
+
return ("clip_limit", "tile_grid_size")
|
| 1426 |
+
|
| 1427 |
+
|
| 1428 |
+
class ChannelShuffle(ImageOnlyTransform):
|
| 1429 |
+
"""Randomly rearrange channels of the input RGB image.
|
| 1430 |
+
|
| 1431 |
+
Args:
|
| 1432 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1433 |
+
|
| 1434 |
+
Targets:
|
| 1435 |
+
image
|
| 1436 |
+
|
| 1437 |
+
Image types:
|
| 1438 |
+
uint8, float32
|
| 1439 |
+
"""
|
| 1440 |
+
|
| 1441 |
+
@property
|
| 1442 |
+
def targets_as_params(self):
|
| 1443 |
+
return ["image"]
|
| 1444 |
+
|
| 1445 |
+
def apply(self, img, channels_shuffled=(0, 1, 2), **params):
|
| 1446 |
+
return F.channel_shuffle(img, channels_shuffled)
|
| 1447 |
+
|
| 1448 |
+
def get_params_dependent_on_targets(self, params):
|
| 1449 |
+
img = params["image"]
|
| 1450 |
+
ch_arr = list(range(img.shape[2]))
|
| 1451 |
+
random.shuffle(ch_arr)
|
| 1452 |
+
return {"channels_shuffled": ch_arr}
|
| 1453 |
+
|
| 1454 |
+
def get_transform_init_args_names(self):
|
| 1455 |
+
return ()
|
| 1456 |
+
|
| 1457 |
+
|
| 1458 |
+
class InvertImg(ImageOnlyTransform):
|
| 1459 |
+
"""Invert the input image by subtracting pixel values from max values of the image types,
|
| 1460 |
+
i.e., 255 for uint8 and 1.0 for float32.
|
| 1461 |
+
|
| 1462 |
+
Args:
|
| 1463 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1464 |
+
|
| 1465 |
+
Targets:
|
| 1466 |
+
image
|
| 1467 |
+
|
| 1468 |
+
Image types:
|
| 1469 |
+
uint8, float32
|
| 1470 |
+
"""
|
| 1471 |
+
|
| 1472 |
+
def apply(self, img, **params):
|
| 1473 |
+
return F.invert(img)
|
| 1474 |
+
|
| 1475 |
+
def get_transform_init_args_names(self):
|
| 1476 |
+
return ()
|
| 1477 |
+
|
| 1478 |
+
|
| 1479 |
+
class RandomGamma(ImageOnlyTransform):
|
| 1480 |
+
"""
|
| 1481 |
+
Args:
|
| 1482 |
+
gamma_limit (float or (float, float)): If gamma_limit is a single float value,
|
| 1483 |
+
the range will be (-gamma_limit, gamma_limit). Default: (80, 120).
|
| 1484 |
+
eps: Deprecated.
|
| 1485 |
+
|
| 1486 |
+
Targets:
|
| 1487 |
+
image
|
| 1488 |
+
|
| 1489 |
+
Image types:
|
| 1490 |
+
uint8, float32
|
| 1491 |
+
"""
|
| 1492 |
+
|
| 1493 |
+
def __init__(self, gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5):
|
| 1494 |
+
super(RandomGamma, self).__init__(always_apply, p)
|
| 1495 |
+
self.gamma_limit = to_tuple(gamma_limit)
|
| 1496 |
+
self.eps = eps
|
| 1497 |
+
|
| 1498 |
+
def apply(self, img, gamma=1, **params):
|
| 1499 |
+
return F.gamma_transform(img, gamma=gamma)
|
| 1500 |
+
|
| 1501 |
+
def get_params(self):
|
| 1502 |
+
return {"gamma": random.uniform(self.gamma_limit[0], self.gamma_limit[1]) / 100.0}
|
| 1503 |
+
|
| 1504 |
+
def get_transform_init_args_names(self):
|
| 1505 |
+
return ("gamma_limit", "eps")
|
| 1506 |
+
|
| 1507 |
+
|
| 1508 |
+
class ToGray(ImageOnlyTransform):
|
| 1509 |
+
"""Convert the input RGB image to grayscale. If the mean pixel value for the resulting image is greater
|
| 1510 |
+
than 127, invert the resulting grayscale image.
|
| 1511 |
+
|
| 1512 |
+
Args:
|
| 1513 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1514 |
+
|
| 1515 |
+
Targets:
|
| 1516 |
+
image
|
| 1517 |
+
|
| 1518 |
+
Image types:
|
| 1519 |
+
uint8, float32
|
| 1520 |
+
"""
|
| 1521 |
+
|
| 1522 |
+
def apply(self, img, **params):
|
| 1523 |
+
if is_grayscale_image(img):
|
| 1524 |
+
warnings.warn("The image is already gray.")
|
| 1525 |
+
return img
|
| 1526 |
+
if not is_rgb_image(img):
|
| 1527 |
+
raise TypeError("ToGray transformation expects 3-channel images.")
|
| 1528 |
+
|
| 1529 |
+
return F.to_gray(img)
|
| 1530 |
+
|
| 1531 |
+
def get_transform_init_args_names(self):
|
| 1532 |
+
return ()
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
class ToRGB(ImageOnlyTransform):
|
| 1536 |
+
"""Convert the input grayscale image to RGB.
|
| 1537 |
+
|
| 1538 |
+
Args:
|
| 1539 |
+
p (float): probability of applying the transform. Default: 1.
|
| 1540 |
+
|
| 1541 |
+
Targets:
|
| 1542 |
+
image
|
| 1543 |
+
|
| 1544 |
+
Image types:
|
| 1545 |
+
uint8, float32
|
| 1546 |
+
"""
|
| 1547 |
+
|
| 1548 |
+
def __init__(self, always_apply=True, p=1.0):
|
| 1549 |
+
super(ToRGB, self).__init__(always_apply=always_apply, p=p)
|
| 1550 |
+
|
| 1551 |
+
def apply(self, img, **params):
|
| 1552 |
+
if is_rgb_image(img):
|
| 1553 |
+
warnings.warn("The image is already an RGB.")
|
| 1554 |
+
return img
|
| 1555 |
+
if not is_grayscale_image(img):
|
| 1556 |
+
raise TypeError("ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1.")
|
| 1557 |
+
|
| 1558 |
+
return F.gray_to_rgb(img)
|
| 1559 |
+
|
| 1560 |
+
def get_transform_init_args_names(self):
|
| 1561 |
+
return ()
|
| 1562 |
+
|
| 1563 |
+
|
| 1564 |
+
class ToSepia(ImageOnlyTransform):
|
| 1565 |
+
"""Applies sepia filter to the input RGB image
|
| 1566 |
+
|
| 1567 |
+
Args:
|
| 1568 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 1569 |
+
|
| 1570 |
+
Targets:
|
| 1571 |
+
image
|
| 1572 |
+
|
| 1573 |
+
Image types:
|
| 1574 |
+
uint8, float32
|
| 1575 |
+
"""
|
| 1576 |
+
|
| 1577 |
+
def __init__(self, always_apply=False, p=0.5):
|
| 1578 |
+
super(ToSepia, self).__init__(always_apply, p)
|
| 1579 |
+
self.sepia_transformation_matrix = np.array(
|
| 1580 |
+
[[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
|
| 1581 |
+
)
|
| 1582 |
+
|
| 1583 |
+
def apply(self, image, **params):
|
| 1584 |
+
if not is_rgb_image(image):
|
| 1585 |
+
raise TypeError("ToSepia transformation expects 3-channel images.")
|
| 1586 |
+
return F.linear_transformation_rgb(image, self.sepia_transformation_matrix)
|
| 1587 |
+
|
| 1588 |
+
def get_transform_init_args_names(self):
|
| 1589 |
+
return ()
|
| 1590 |
+
|
| 1591 |
+
|
| 1592 |
+
class ToFloat(ImageOnlyTransform):
|
| 1593 |
+
"""Divide pixel values by `max_value` to get a float32 output array where all values lie in the range [0, 1.0].
|
| 1594 |
+
If `max_value` is None the transform will try to infer the maximum value by inspecting the data type of the input
|
| 1595 |
+
image.
|
| 1596 |
+
|
| 1597 |
+
See Also:
|
| 1598 |
+
:class:`~albumentations.augmentations.transforms.FromFloat`
|
| 1599 |
+
|
| 1600 |
+
Args:
|
| 1601 |
+
max_value (float): maximum possible input value. Default: None.
|
| 1602 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 1603 |
+
|
| 1604 |
+
Targets:
|
| 1605 |
+
image
|
| 1606 |
+
|
| 1607 |
+
Image types:
|
| 1608 |
+
any type
|
| 1609 |
+
|
| 1610 |
+
"""
|
| 1611 |
+
|
| 1612 |
+
def __init__(self, max_value=None, always_apply=False, p=1.0):
|
| 1613 |
+
super(ToFloat, self).__init__(always_apply, p)
|
| 1614 |
+
self.max_value = max_value
|
| 1615 |
+
|
| 1616 |
+
def apply(self, img, **params):
|
| 1617 |
+
return F.to_float(img, self.max_value)
|
| 1618 |
+
|
| 1619 |
+
def get_transform_init_args_names(self):
|
| 1620 |
+
return ("max_value",)
|
| 1621 |
+
|
| 1622 |
+
|
| 1623 |
+
class FromFloat(ImageOnlyTransform):
|
| 1624 |
+
"""Take an input array where all values should lie in the range [0, 1.0], multiply them by `max_value` and then
|
| 1625 |
+
cast the resulted value to a type specified by `dtype`. If `max_value` is None the transform will try to infer
|
| 1626 |
+
the maximum value for the data type from the `dtype` argument.
|
| 1627 |
+
|
| 1628 |
+
This is the inverse transform for :class:`~albumentations.augmentations.transforms.ToFloat`.
|
| 1629 |
+
|
| 1630 |
+
Args:
|
| 1631 |
+
max_value (float): maximum possible input value. Default: None.
|
| 1632 |
+
dtype (string or numpy data type): data type of the output. See the `'Data types' page from the NumPy docs`_.
|
| 1633 |
+
Default: 'uint16'.
|
| 1634 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 1635 |
+
|
| 1636 |
+
Targets:
|
| 1637 |
+
image
|
| 1638 |
+
|
| 1639 |
+
Image types:
|
| 1640 |
+
float32
|
| 1641 |
+
|
| 1642 |
+
.. _'Data types' page from the NumPy docs:
|
| 1643 |
+
https://docs.scipy.org/doc/numpy/user/basics.types.html
|
| 1644 |
+
"""
|
| 1645 |
+
|
| 1646 |
+
def __init__(self, dtype="uint16", max_value=None, always_apply=False, p=1.0):
|
| 1647 |
+
super(FromFloat, self).__init__(always_apply, p)
|
| 1648 |
+
self.dtype = np.dtype(dtype)
|
| 1649 |
+
self.max_value = max_value
|
| 1650 |
+
|
| 1651 |
+
def apply(self, img, **params):
|
| 1652 |
+
return F.from_float(img, self.dtype, self.max_value)
|
| 1653 |
+
|
| 1654 |
+
def get_transform_init_args(self):
|
| 1655 |
+
return {"dtype": self.dtype.name, "max_value": self.max_value}
|
| 1656 |
+
|
| 1657 |
+
|
| 1658 |
+
class Downscale(ImageOnlyTransform):
|
| 1659 |
+
"""Decreases image quality by downscaling and upscaling back.
|
| 1660 |
+
|
| 1661 |
+
Args:
|
| 1662 |
+
scale_min (float): lower bound on the image scale. Should be < 1.
|
| 1663 |
+
scale_max (float): lower bound on the image scale. Should be .
|
| 1664 |
+
interpolation: cv2 interpolation method. Could be:
|
| 1665 |
+
- single cv2 interpolation flag - selected method will be used for downscale and upscale.
|
| 1666 |
+
- dict(downscale=flag, upscale=flag)
|
| 1667 |
+
- Downscale.Interpolation(downscale=flag, upscale=flag) -
|
| 1668 |
+
Default: Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
|
| 1669 |
+
|
| 1670 |
+
Targets:
|
| 1671 |
+
image
|
| 1672 |
+
|
| 1673 |
+
Image types:
|
| 1674 |
+
uint8, float32
|
| 1675 |
+
"""
|
| 1676 |
+
|
| 1677 |
+
class Interpolation:
|
| 1678 |
+
def __init__(self, *, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_NEAREST):
|
| 1679 |
+
self.downscale = downscale
|
| 1680 |
+
self.upscale = upscale
|
| 1681 |
+
|
| 1682 |
+
def __init__(
|
| 1683 |
+
self,
|
| 1684 |
+
scale_min: float = 0.25,
|
| 1685 |
+
scale_max: float = 0.25,
|
| 1686 |
+
interpolation: Optional[Union[int, Interpolation, Dict[str, int]]] = None,
|
| 1687 |
+
always_apply: bool = False,
|
| 1688 |
+
p: float = 0.5,
|
| 1689 |
+
):
|
| 1690 |
+
super(Downscale, self).__init__(always_apply, p)
|
| 1691 |
+
if interpolation is None:
|
| 1692 |
+
self.interpolation = self.Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
|
| 1693 |
+
warnings.warn(
|
| 1694 |
+
"Using default interpolation INTER_NEAREST, which is sub-optimal."
|
| 1695 |
+
"Please specify interpolation mode for downscale and upscale explicitly."
|
| 1696 |
+
"For additional information see this PR https://github.com/albumentations-team/albumentations/pull/584"
|
| 1697 |
+
)
|
| 1698 |
+
elif isinstance(interpolation, int):
|
| 1699 |
+
self.interpolation = self.Interpolation(downscale=interpolation, upscale=interpolation)
|
| 1700 |
+
elif isinstance(interpolation, self.Interpolation):
|
| 1701 |
+
self.interpolation = interpolation
|
| 1702 |
+
elif isinstance(interpolation, dict):
|
| 1703 |
+
self.interpolation = self.Interpolation(**interpolation)
|
| 1704 |
+
else:
|
| 1705 |
+
raise ValueError(
|
| 1706 |
+
"Wrong interpolation data type. Supported types: `Optional[Union[int, Interpolation, Dict[str, int]]]`."
|
| 1707 |
+
f" Got: {type(interpolation)}"
|
| 1708 |
+
)
|
| 1709 |
+
|
| 1710 |
+
if scale_min > scale_max:
|
| 1711 |
+
raise ValueError("Expected scale_min be less or equal scale_max, got {} {}".format(scale_min, scale_max))
|
| 1712 |
+
if scale_max >= 1:
|
| 1713 |
+
raise ValueError("Expected scale_max to be less than 1, got {}".format(scale_max))
|
| 1714 |
+
self.scale_min = scale_min
|
| 1715 |
+
self.scale_max = scale_max
|
| 1716 |
+
|
| 1717 |
+
def apply(self, img: np.ndarray, scale: Optional[float] = None, **params) -> np.ndarray:
|
| 1718 |
+
return F.downscale(
|
| 1719 |
+
img,
|
| 1720 |
+
scale=scale,
|
| 1721 |
+
down_interpolation=self.interpolation.downscale,
|
| 1722 |
+
up_interpolation=self.interpolation.upscale,
|
| 1723 |
+
)
|
| 1724 |
+
|
| 1725 |
+
def get_params(self) -> Dict[str, Any]:
|
| 1726 |
+
return {"scale": random.uniform(self.scale_min, self.scale_max)}
|
| 1727 |
+
|
| 1728 |
+
def get_transform_init_args_names(self) -> Tuple[str, str]:
|
| 1729 |
+
return "scale_min", "scale_max"
|
| 1730 |
+
|
| 1731 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 1732 |
+
result = super()._to_dict()
|
| 1733 |
+
result["interpolation"] = {"upscale": self.interpolation.upscale, "downscale": self.interpolation.downscale}
|
| 1734 |
+
return result
|
| 1735 |
+
|
| 1736 |
+
|
| 1737 |
+
class Lambda(NoOp):
|
| 1738 |
+
"""A flexible transformation class for using user-defined transformation functions per targets.
|
| 1739 |
+
Function signature must include **kwargs to accept optinal arguments like interpolation method, image size, etc:
|
| 1740 |
+
|
| 1741 |
+
Args:
|
| 1742 |
+
image (callable): Image transformation function.
|
| 1743 |
+
mask (callable): Mask transformation function.
|
| 1744 |
+
keypoint (callable): Keypoint transformation function.
|
| 1745 |
+
bbox (callable): BBox transformation function.
|
| 1746 |
+
always_apply (bool): Indicates whether this transformation should be always applied.
|
| 1747 |
+
p (float): probability of applying the transform. Default: 1.0.
|
| 1748 |
+
|
| 1749 |
+
Targets:
|
| 1750 |
+
image, mask, bboxes, keypoints
|
| 1751 |
+
|
| 1752 |
+
Image types:
|
| 1753 |
+
Any
|
| 1754 |
+
"""
|
| 1755 |
+
|
| 1756 |
+
def __init__(
|
| 1757 |
+
self,
|
| 1758 |
+
image=None,
|
| 1759 |
+
mask=None,
|
| 1760 |
+
keypoint=None,
|
| 1761 |
+
bbox=None,
|
| 1762 |
+
name=None,
|
| 1763 |
+
always_apply=False,
|
| 1764 |
+
p=1.0,
|
| 1765 |
+
):
|
| 1766 |
+
super(Lambda, self).__init__(always_apply, p)
|
| 1767 |
+
|
| 1768 |
+
self.name = name
|
| 1769 |
+
self.custom_apply_fns = {target_name: F.noop for target_name in ("image", "mask", "keypoint", "bbox")}
|
| 1770 |
+
for target_name, custom_apply_fn in {
|
| 1771 |
+
"image": image,
|
| 1772 |
+
"mask": mask,
|
| 1773 |
+
"keypoint": keypoint,
|
| 1774 |
+
"bbox": bbox,
|
| 1775 |
+
}.items():
|
| 1776 |
+
if custom_apply_fn is not None:
|
| 1777 |
+
if isinstance(custom_apply_fn, LambdaType) and custom_apply_fn.__name__ == "<lambda>":
|
| 1778 |
+
warnings.warn(
|
| 1779 |
+
"Using lambda is incompatible with multiprocessing. "
|
| 1780 |
+
"Consider using regular functions or partial()."
|
| 1781 |
+
)
|
| 1782 |
+
|
| 1783 |
+
self.custom_apply_fns[target_name] = custom_apply_fn
|
| 1784 |
+
|
| 1785 |
+
def apply(self, img, **params):
|
| 1786 |
+
fn = self.custom_apply_fns["image"]
|
| 1787 |
+
return fn(img, **params)
|
| 1788 |
+
|
| 1789 |
+
def apply_to_mask(self, mask, **params):
|
| 1790 |
+
fn = self.custom_apply_fns["mask"]
|
| 1791 |
+
return fn(mask, **params)
|
| 1792 |
+
|
| 1793 |
+
def apply_to_bbox(self, bbox, **params):
|
| 1794 |
+
fn = self.custom_apply_fns["bbox"]
|
| 1795 |
+
return fn(bbox, **params)
|
| 1796 |
+
|
| 1797 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 1798 |
+
fn = self.custom_apply_fns["keypoint"]
|
| 1799 |
+
return fn(keypoint, **params)
|
| 1800 |
+
|
| 1801 |
+
@classmethod
|
| 1802 |
+
def is_serializable(cls):
|
| 1803 |
+
return False
|
| 1804 |
+
|
| 1805 |
+
def _to_dict(self):
|
| 1806 |
+
if self.name is None:
|
| 1807 |
+
raise ValueError(
|
| 1808 |
+
"To make a Lambda transform serializable you should provide the `name` argument, "
|
| 1809 |
+
"e.g. `Lambda(name='my_transform', image=<some func>, ...)`."
|
| 1810 |
+
)
|
| 1811 |
+
return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
|
| 1812 |
+
|
| 1813 |
+
def __repr__(self):
|
| 1814 |
+
state = {"name": self.name}
|
| 1815 |
+
state.update(self.custom_apply_fns.items())
|
| 1816 |
+
state.update(self.get_base_init_args())
|
| 1817 |
+
return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state))
|
| 1818 |
+
|
| 1819 |
+
|
| 1820 |
+
class MultiplicativeNoise(ImageOnlyTransform):
|
| 1821 |
+
"""Multiply image to random number or array of numbers.
|
| 1822 |
+
|
| 1823 |
+
Args:
|
| 1824 |
+
multiplier (float or tuple of floats): If single float image will be multiplied to this number.
|
| 1825 |
+
If tuple of float multiplier will be in range `[multiplier[0], multiplier[1])`. Default: (0.9, 1.1).
|
| 1826 |
+
per_channel (bool): If `False`, same values for all channels will be used.
|
| 1827 |
+
If `True` use sample values for each channels. Default False.
|
| 1828 |
+
elementwise (bool): If `False` multiply multiply all pixels in an image with a random value sampled once.
|
| 1829 |
+
If `True` Multiply image pixels with values that are pixelwise randomly sampled. Defaule: False.
|
| 1830 |
+
|
| 1831 |
+
Targets:
|
| 1832 |
+
image
|
| 1833 |
+
|
| 1834 |
+
Image types:
|
| 1835 |
+
Any
|
| 1836 |
+
"""
|
| 1837 |
+
|
| 1838 |
+
def __init__(
|
| 1839 |
+
self,
|
| 1840 |
+
multiplier=(0.9, 1.1),
|
| 1841 |
+
per_channel=False,
|
| 1842 |
+
elementwise=False,
|
| 1843 |
+
always_apply=False,
|
| 1844 |
+
p=0.5,
|
| 1845 |
+
):
|
| 1846 |
+
super(MultiplicativeNoise, self).__init__(always_apply, p)
|
| 1847 |
+
self.multiplier = to_tuple(multiplier, multiplier)
|
| 1848 |
+
self.per_channel = per_channel
|
| 1849 |
+
self.elementwise = elementwise
|
| 1850 |
+
|
| 1851 |
+
def apply(self, img, multiplier=np.array([1]), **kwargs):
|
| 1852 |
+
return F.multiply(img, multiplier)
|
| 1853 |
+
|
| 1854 |
+
def get_params_dependent_on_targets(self, params):
|
| 1855 |
+
if self.multiplier[0] == self.multiplier[1]:
|
| 1856 |
+
return {"multiplier": np.array([self.multiplier[0]])}
|
| 1857 |
+
|
| 1858 |
+
img = params["image"]
|
| 1859 |
+
|
| 1860 |
+
h, w = img.shape[:2]
|
| 1861 |
+
|
| 1862 |
+
if self.per_channel:
|
| 1863 |
+
c = 1 if is_grayscale_image(img) else img.shape[-1]
|
| 1864 |
+
else:
|
| 1865 |
+
c = 1
|
| 1866 |
+
|
| 1867 |
+
if self.elementwise:
|
| 1868 |
+
shape = [h, w, c]
|
| 1869 |
+
else:
|
| 1870 |
+
shape = [c]
|
| 1871 |
+
|
| 1872 |
+
multiplier = random_utils.uniform(self.multiplier[0], self.multiplier[1], shape)
|
| 1873 |
+
if is_grayscale_image(img) and img.ndim == 2:
|
| 1874 |
+
multiplier = np.squeeze(multiplier)
|
| 1875 |
+
|
| 1876 |
+
return {"multiplier": multiplier}
|
| 1877 |
+
|
| 1878 |
+
@property
|
| 1879 |
+
def targets_as_params(self):
|
| 1880 |
+
return ["image"]
|
| 1881 |
+
|
| 1882 |
+
def get_transform_init_args_names(self):
|
| 1883 |
+
return "multiplier", "per_channel", "elementwise"
|
| 1884 |
+
|
| 1885 |
+
|
| 1886 |
+
class FancyPCA(ImageOnlyTransform):
|
| 1887 |
+
"""Augment RGB image using FancyPCA from Krizhevsky's paper
|
| 1888 |
+
"ImageNet Classification with Deep Convolutional Neural Networks"
|
| 1889 |
+
|
| 1890 |
+
Args:
|
| 1891 |
+
alpha (float): how much to perturb/scale the eigen vecs and vals.
|
| 1892 |
+
scale is samples from gaussian distribution (mu=0, sigma=alpha)
|
| 1893 |
+
|
| 1894 |
+
Targets:
|
| 1895 |
+
image
|
| 1896 |
+
|
| 1897 |
+
Image types:
|
| 1898 |
+
3-channel uint8 images only
|
| 1899 |
+
|
| 1900 |
+
Credit:
|
| 1901 |
+
http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
|
| 1902 |
+
https://deshanadesai.github.io/notes/Fancy-PCA-with-Scikit-Image
|
| 1903 |
+
https://pixelatedbrian.github.io/2018-04-29-fancy_pca/
|
| 1904 |
+
"""
|
| 1905 |
+
|
| 1906 |
+
def __init__(self, alpha=0.1, always_apply=False, p=0.5):
|
| 1907 |
+
super(FancyPCA, self).__init__(always_apply=always_apply, p=p)
|
| 1908 |
+
self.alpha = alpha
|
| 1909 |
+
|
| 1910 |
+
def apply(self, img, alpha=0.1, **params):
|
| 1911 |
+
img = F.fancy_pca(img, alpha)
|
| 1912 |
+
return img
|
| 1913 |
+
|
| 1914 |
+
def get_params(self):
|
| 1915 |
+
return {"alpha": random.gauss(0, self.alpha)}
|
| 1916 |
+
|
| 1917 |
+
def get_transform_init_args_names(self):
|
| 1918 |
+
return ("alpha",)
|
| 1919 |
+
|
| 1920 |
+
|
| 1921 |
+
class ColorJitter(ImageOnlyTransform):
|
| 1922 |
+
"""Randomly changes the brightness, contrast, and saturation of an image. Compared to ColorJitter from torchvision,
|
| 1923 |
+
this transform gives a little bit different results because Pillow (used in torchvision) and OpenCV (used in
|
| 1924 |
+
Albumentations) transform an image to HSV format by different formulas. Another difference - Pillow uses uint8
|
| 1925 |
+
overflow, but we use value saturation.
|
| 1926 |
+
|
| 1927 |
+
Args:
|
| 1928 |
+
brightness (float or tuple of float (min, max)): How much to jitter brightness.
|
| 1929 |
+
brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
|
| 1930 |
+
or the given [min, max]. Should be non negative numbers.
|
| 1931 |
+
contrast (float or tuple of float (min, max)): How much to jitter contrast.
|
| 1932 |
+
contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
|
| 1933 |
+
or the given [min, max]. Should be non negative numbers.
|
| 1934 |
+
saturation (float or tuple of float (min, max)): How much to jitter saturation.
|
| 1935 |
+
saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
|
| 1936 |
+
or the given [min, max]. Should be non negative numbers.
|
| 1937 |
+
hue (float or tuple of float (min, max)): How much to jitter hue.
|
| 1938 |
+
hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
|
| 1939 |
+
Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
|
| 1940 |
+
"""
|
| 1941 |
+
|
| 1942 |
+
def __init__(
|
| 1943 |
+
self,
|
| 1944 |
+
brightness=0.2,
|
| 1945 |
+
contrast=0.2,
|
| 1946 |
+
saturation=0.2,
|
| 1947 |
+
hue=0.2,
|
| 1948 |
+
always_apply=False,
|
| 1949 |
+
p=0.5,
|
| 1950 |
+
):
|
| 1951 |
+
super(ColorJitter, self).__init__(always_apply=always_apply, p=p)
|
| 1952 |
+
|
| 1953 |
+
self.brightness = self.__check_values(brightness, "brightness")
|
| 1954 |
+
self.contrast = self.__check_values(contrast, "contrast")
|
| 1955 |
+
self.saturation = self.__check_values(saturation, "saturation")
|
| 1956 |
+
self.hue = self.__check_values(hue, "hue", offset=0, bounds=[-0.5, 0.5], clip=False)
|
| 1957 |
+
|
| 1958 |
+
self.transforms = [
|
| 1959 |
+
F.adjust_brightness_torchvision,
|
| 1960 |
+
F.adjust_contrast_torchvision,
|
| 1961 |
+
F.adjust_saturation_torchvision,
|
| 1962 |
+
F.adjust_hue_torchvision,
|
| 1963 |
+
]
|
| 1964 |
+
|
| 1965 |
+
@staticmethod
|
| 1966 |
+
def __check_values(value, name, offset=1, bounds=(0, float("inf")), clip=True):
|
| 1967 |
+
if isinstance(value, numbers.Number):
|
| 1968 |
+
if value < 0:
|
| 1969 |
+
raise ValueError("If {} is a single number, it must be non negative.".format(name))
|
| 1970 |
+
value = [offset - value, offset + value]
|
| 1971 |
+
if clip:
|
| 1972 |
+
value[0] = max(value[0], 0)
|
| 1973 |
+
elif isinstance(value, (tuple, list)) and len(value) == 2:
|
| 1974 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 1975 |
+
raise ValueError("{} values should be between {}".format(name, bounds))
|
| 1976 |
+
else:
|
| 1977 |
+
raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name))
|
| 1978 |
+
|
| 1979 |
+
return value
|
| 1980 |
+
|
| 1981 |
+
def get_params(self):
|
| 1982 |
+
brightness = random.uniform(self.brightness[0], self.brightness[1])
|
| 1983 |
+
contrast = random.uniform(self.contrast[0], self.contrast[1])
|
| 1984 |
+
saturation = random.uniform(self.saturation[0], self.saturation[1])
|
| 1985 |
+
hue = random.uniform(self.hue[0], self.hue[1])
|
| 1986 |
+
|
| 1987 |
+
order = [0, 1, 2, 3]
|
| 1988 |
+
random.shuffle(order)
|
| 1989 |
+
|
| 1990 |
+
return {
|
| 1991 |
+
"brightness": brightness,
|
| 1992 |
+
"contrast": contrast,
|
| 1993 |
+
"saturation": saturation,
|
| 1994 |
+
"hue": hue,
|
| 1995 |
+
"order": order,
|
| 1996 |
+
}
|
| 1997 |
+
|
| 1998 |
+
def apply(self, img, brightness=1.0, contrast=1.0, saturation=1.0, hue=0, order=[0, 1, 2, 3], **params):
|
| 1999 |
+
if not is_rgb_image(img) and not is_grayscale_image(img):
|
| 2000 |
+
raise TypeError("ColorJitter transformation expects 1-channel or 3-channel images.")
|
| 2001 |
+
params = [brightness, contrast, saturation, hue]
|
| 2002 |
+
for i in order:
|
| 2003 |
+
img = self.transforms[i](img, params[i])
|
| 2004 |
+
return img
|
| 2005 |
+
|
| 2006 |
+
def get_transform_init_args_names(self):
|
| 2007 |
+
return ("brightness", "contrast", "saturation", "hue")
|
| 2008 |
+
|
| 2009 |
+
|
| 2010 |
+
class Sharpen(ImageOnlyTransform):
|
| 2011 |
+
"""Sharpen the input image and overlays the result with the original image.
|
| 2012 |
+
|
| 2013 |
+
Args:
|
| 2014 |
+
alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is
|
| 2015 |
+
visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5).
|
| 2016 |
+
lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0).
|
| 2017 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2018 |
+
|
| 2019 |
+
Targets:
|
| 2020 |
+
image
|
| 2021 |
+
"""
|
| 2022 |
+
|
| 2023 |
+
def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5):
|
| 2024 |
+
super(Sharpen, self).__init__(always_apply, p)
|
| 2025 |
+
self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
|
| 2026 |
+
self.lightness = self.__check_values(to_tuple(lightness, 0.0), name="lightness")
|
| 2027 |
+
|
| 2028 |
+
@staticmethod
|
| 2029 |
+
def __check_values(value, name, bounds=(0, float("inf"))):
|
| 2030 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 2031 |
+
raise ValueError("{} values should be between {}".format(name, bounds))
|
| 2032 |
+
return value
|
| 2033 |
+
|
| 2034 |
+
@staticmethod
|
| 2035 |
+
def __generate_sharpening_matrix(alpha_sample, lightness_sample):
|
| 2036 |
+
matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
|
| 2037 |
+
matrix_effect = np.array(
|
| 2038 |
+
[[-1, -1, -1], [-1, 8 + lightness_sample, -1], [-1, -1, -1]],
|
| 2039 |
+
dtype=np.float32,
|
| 2040 |
+
)
|
| 2041 |
+
|
| 2042 |
+
matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect
|
| 2043 |
+
return matrix
|
| 2044 |
+
|
| 2045 |
+
def get_params(self):
|
| 2046 |
+
alpha = random.uniform(*self.alpha)
|
| 2047 |
+
lightness = random.uniform(*self.lightness)
|
| 2048 |
+
sharpening_matrix = self.__generate_sharpening_matrix(alpha_sample=alpha, lightness_sample=lightness)
|
| 2049 |
+
return {"sharpening_matrix": sharpening_matrix}
|
| 2050 |
+
|
| 2051 |
+
def apply(self, img, sharpening_matrix=None, **params):
|
| 2052 |
+
return F.convolve(img, sharpening_matrix)
|
| 2053 |
+
|
| 2054 |
+
def get_transform_init_args_names(self):
|
| 2055 |
+
return ("alpha", "lightness")
|
| 2056 |
+
|
| 2057 |
+
|
| 2058 |
+
class Emboss(ImageOnlyTransform):
|
| 2059 |
+
"""Emboss the input image and overlays the result with the original image.
|
| 2060 |
+
|
| 2061 |
+
Args:
|
| 2062 |
+
alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is
|
| 2063 |
+
visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5).
|
| 2064 |
+
strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7).
|
| 2065 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2066 |
+
|
| 2067 |
+
Targets:
|
| 2068 |
+
image
|
| 2069 |
+
"""
|
| 2070 |
+
|
| 2071 |
+
def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5):
|
| 2072 |
+
super(Emboss, self).__init__(always_apply, p)
|
| 2073 |
+
self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
|
| 2074 |
+
self.strength = self.__check_values(to_tuple(strength, 0.0), name="strength")
|
| 2075 |
+
|
| 2076 |
+
@staticmethod
|
| 2077 |
+
def __check_values(value, name, bounds=(0, float("inf"))):
|
| 2078 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 2079 |
+
raise ValueError("{} values should be between {}".format(name, bounds))
|
| 2080 |
+
return value
|
| 2081 |
+
|
| 2082 |
+
@staticmethod
|
| 2083 |
+
def __generate_emboss_matrix(alpha_sample, strength_sample):
|
| 2084 |
+
matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
|
| 2085 |
+
matrix_effect = np.array(
|
| 2086 |
+
[
|
| 2087 |
+
[-1 - strength_sample, 0 - strength_sample, 0],
|
| 2088 |
+
[0 - strength_sample, 1, 0 + strength_sample],
|
| 2089 |
+
[0, 0 + strength_sample, 1 + strength_sample],
|
| 2090 |
+
],
|
| 2091 |
+
dtype=np.float32,
|
| 2092 |
+
)
|
| 2093 |
+
matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect
|
| 2094 |
+
return matrix
|
| 2095 |
+
|
| 2096 |
+
def get_params(self):
|
| 2097 |
+
alpha = random.uniform(*self.alpha)
|
| 2098 |
+
strength = random.uniform(*self.strength)
|
| 2099 |
+
emboss_matrix = self.__generate_emboss_matrix(alpha_sample=alpha, strength_sample=strength)
|
| 2100 |
+
return {"emboss_matrix": emboss_matrix}
|
| 2101 |
+
|
| 2102 |
+
def apply(self, img, emboss_matrix=None, **params):
|
| 2103 |
+
return F.convolve(img, emboss_matrix)
|
| 2104 |
+
|
| 2105 |
+
def get_transform_init_args_names(self):
|
| 2106 |
+
return ("alpha", "strength")
|
| 2107 |
+
|
| 2108 |
+
|
| 2109 |
+
class Superpixels(ImageOnlyTransform):
|
| 2110 |
+
"""Transform images partially/completely to their superpixel representation.
|
| 2111 |
+
This implementation uses skimage's version of the SLIC algorithm.
|
| 2112 |
+
|
| 2113 |
+
Args:
|
| 2114 |
+
p_replace (float or tuple of float): Defines for any segment the probability that the pixels within that
|
| 2115 |
+
segment are replaced by their average color (otherwise, the pixels are not changed).
|
| 2116 |
+
Examples:
|
| 2117 |
+
* A probability of ``0.0`` would mean, that the pixels in no
|
| 2118 |
+
segment are replaced by their average color (image is not
|
| 2119 |
+
changed at all).
|
| 2120 |
+
* A probability of ``0.5`` would mean, that around half of all
|
| 2121 |
+
segments are replaced by their average color.
|
| 2122 |
+
* A probability of ``1.0`` would mean, that all segments are
|
| 2123 |
+
replaced by their average color (resulting in a voronoi
|
| 2124 |
+
image).
|
| 2125 |
+
Behaviour based on chosen data types for this parameter:
|
| 2126 |
+
* If a ``float``, then that ``flat`` will always be used.
|
| 2127 |
+
* If ``tuple`` ``(a, b)``, then a random probability will be
|
| 2128 |
+
sampled from the interval ``[a, b]`` per image.
|
| 2129 |
+
n_segments (int, or tuple of int): Rough target number of how many superpixels to generate (the algorithm
|
| 2130 |
+
may deviate from this number). Lower value will lead to coarser superpixels.
|
| 2131 |
+
Higher values are computationally more intensive and will hence lead to a slowdown
|
| 2132 |
+
* If a single ``int``, then that value will always be used as the
|
| 2133 |
+
number of segments.
|
| 2134 |
+
* If a ``tuple`` ``(a, b)``, then a value from the discrete
|
| 2135 |
+
interval ``[a..b]`` will be sampled per image.
|
| 2136 |
+
max_size (int or None): Maximum image size at which the augmentation is performed.
|
| 2137 |
+
If the width or height of an image exceeds this value, it will be
|
| 2138 |
+
downscaled before the augmentation so that the longest side matches `max_size`.
|
| 2139 |
+
This is done to speed up the process. The final output image has the same size as the input image.
|
| 2140 |
+
Note that in case `p_replace` is below ``1.0``,
|
| 2141 |
+
the down-/upscaling will affect the not-replaced pixels too.
|
| 2142 |
+
Use ``None`` to apply no down-/upscaling.
|
| 2143 |
+
interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
|
| 2144 |
+
cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
|
| 2145 |
+
Default: cv2.INTER_LINEAR.
|
| 2146 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2147 |
+
|
| 2148 |
+
Targets:
|
| 2149 |
+
image
|
| 2150 |
+
"""
|
| 2151 |
+
|
| 2152 |
+
def __init__(
|
| 2153 |
+
self,
|
| 2154 |
+
p_replace: Union[float, Sequence[float]] = 0.1,
|
| 2155 |
+
n_segments: Union[int, Sequence[int]] = 100,
|
| 2156 |
+
max_size: Optional[int] = 128,
|
| 2157 |
+
interpolation: int = cv2.INTER_LINEAR,
|
| 2158 |
+
always_apply: bool = False,
|
| 2159 |
+
p: float = 0.5,
|
| 2160 |
+
):
|
| 2161 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 2162 |
+
self.p_replace = to_tuple(p_replace, p_replace)
|
| 2163 |
+
self.n_segments = to_tuple(n_segments, n_segments)
|
| 2164 |
+
self.max_size = max_size
|
| 2165 |
+
self.interpolation = interpolation
|
| 2166 |
+
|
| 2167 |
+
if min(self.n_segments) < 1:
|
| 2168 |
+
raise ValueError(f"n_segments must be >= 1. Got: {n_segments}")
|
| 2169 |
+
|
| 2170 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
|
| 2171 |
+
return ("p_replace", "n_segments", "max_size", "interpolation")
|
| 2172 |
+
|
| 2173 |
+
def get_params(self) -> dict:
|
| 2174 |
+
n_segments = random.randint(*self.n_segments)
|
| 2175 |
+
p = random.uniform(*self.p_replace)
|
| 2176 |
+
return {"replace_samples": random_utils.random(n_segments) < p, "n_segments": n_segments}
|
| 2177 |
+
|
| 2178 |
+
def apply(self, img: np.ndarray, replace_samples: Sequence[bool] = (False,), n_segments: int = 1, **kwargs):
|
| 2179 |
+
return F.superpixels(img, n_segments, replace_samples, self.max_size, self.interpolation)
|
| 2180 |
+
|
| 2181 |
+
|
| 2182 |
+
class TemplateTransform(ImageOnlyTransform):
|
| 2183 |
+
"""
|
| 2184 |
+
Apply blending of input image with specified templates
|
| 2185 |
+
Args:
|
| 2186 |
+
templates (numpy array or list of numpy arrays): Images as template for transform.
|
| 2187 |
+
img_weight ((float, float) or float): If single float will be used as weight for input image.
|
| 2188 |
+
If tuple of float img_weight will be in range `[img_weight[0], img_weight[1])`. Default: 0.5.
|
| 2189 |
+
template_weight ((float, float) or float): If single float will be used as weight for template.
|
| 2190 |
+
If tuple of float template_weight will be in range `[template_weight[0], template_weight[1])`.
|
| 2191 |
+
Default: 0.5.
|
| 2192 |
+
template_transform: transformation object which could be applied to template,
|
| 2193 |
+
must produce template the same size as input image.
|
| 2194 |
+
name (string): (Optional) Name of transform, used only for deserialization.
|
| 2195 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2196 |
+
Targets:
|
| 2197 |
+
image
|
| 2198 |
+
Image types:
|
| 2199 |
+
uint8, float32
|
| 2200 |
+
"""
|
| 2201 |
+
|
| 2202 |
+
def __init__(
|
| 2203 |
+
self,
|
| 2204 |
+
templates,
|
| 2205 |
+
img_weight=0.5,
|
| 2206 |
+
template_weight=0.5,
|
| 2207 |
+
template_transform=None,
|
| 2208 |
+
name=None,
|
| 2209 |
+
always_apply=False,
|
| 2210 |
+
p=0.5,
|
| 2211 |
+
):
|
| 2212 |
+
super().__init__(always_apply, p)
|
| 2213 |
+
|
| 2214 |
+
self.templates = templates if isinstance(templates, (list, tuple)) else [templates]
|
| 2215 |
+
self.img_weight = to_tuple(img_weight, img_weight)
|
| 2216 |
+
self.template_weight = to_tuple(template_weight, template_weight)
|
| 2217 |
+
self.template_transform = template_transform
|
| 2218 |
+
self.name = name
|
| 2219 |
+
|
| 2220 |
+
def apply(self, img, template=None, img_weight=0.5, template_weight=0.5, **params):
|
| 2221 |
+
return F.add_weighted(img, img_weight, template, template_weight)
|
| 2222 |
+
|
| 2223 |
+
def get_params(self):
|
| 2224 |
+
return {
|
| 2225 |
+
"img_weight": random.uniform(self.img_weight[0], self.img_weight[1]),
|
| 2226 |
+
"template_weight": random.uniform(self.template_weight[0], self.template_weight[1]),
|
| 2227 |
+
}
|
| 2228 |
+
|
| 2229 |
+
def get_params_dependent_on_targets(self, params):
|
| 2230 |
+
img = params["image"]
|
| 2231 |
+
template = random.choice(self.templates)
|
| 2232 |
+
|
| 2233 |
+
if self.template_transform is not None:
|
| 2234 |
+
template = self.template_transform(image=template)["image"]
|
| 2235 |
+
|
| 2236 |
+
if get_num_channels(template) not in [1, get_num_channels(img)]:
|
| 2237 |
+
raise ValueError(
|
| 2238 |
+
"Template must be a single channel or "
|
| 2239 |
+
"has the same number of channels as input image ({}), got {}".format(
|
| 2240 |
+
get_num_channels(img), get_num_channels(template)
|
| 2241 |
+
)
|
| 2242 |
+
)
|
| 2243 |
+
|
| 2244 |
+
if template.dtype != img.dtype:
|
| 2245 |
+
raise ValueError("Image and template must be the same image type")
|
| 2246 |
+
|
| 2247 |
+
if img.shape[:2] != template.shape[:2]:
|
| 2248 |
+
raise ValueError(
|
| 2249 |
+
"Image and template must be the same size, got {} and {}".format(img.shape[:2], template.shape[:2])
|
| 2250 |
+
)
|
| 2251 |
+
|
| 2252 |
+
if get_num_channels(template) == 1 and get_num_channels(img) > 1:
|
| 2253 |
+
template = np.stack((template,) * get_num_channels(img), axis=-1)
|
| 2254 |
+
|
| 2255 |
+
# in order to support grayscale image with dummy dim
|
| 2256 |
+
template = template.reshape(img.shape)
|
| 2257 |
+
|
| 2258 |
+
return {"template": template}
|
| 2259 |
+
|
| 2260 |
+
@classmethod
|
| 2261 |
+
def is_serializable(cls):
|
| 2262 |
+
return False
|
| 2263 |
+
|
| 2264 |
+
@property
|
| 2265 |
+
def targets_as_params(self):
|
| 2266 |
+
return ["image"]
|
| 2267 |
+
|
| 2268 |
+
def _to_dict(self):
|
| 2269 |
+
if self.name is None:
|
| 2270 |
+
raise ValueError(
|
| 2271 |
+
"To make a TemplateTransform serializable you should provide the `name` argument, "
|
| 2272 |
+
"e.g. `TemplateTransform(name='my_transform', ...)`."
|
| 2273 |
+
)
|
| 2274 |
+
return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
|
| 2275 |
+
|
| 2276 |
+
|
| 2277 |
+
class RingingOvershoot(ImageOnlyTransform):
|
| 2278 |
+
"""Create ringing or overshoot artefacts by conlvolving image with 2D sinc filter.
|
| 2279 |
+
|
| 2280 |
+
Args:
|
| 2281 |
+
blur_limit (int, (int, int)): maximum kernel size for sinc filter.
|
| 2282 |
+
Should be in range [3, inf). Default: (7, 15).
|
| 2283 |
+
cutoff (float, (float, float)): range to choose the cutoff frequency in radians.
|
| 2284 |
+
Should be in range (0, np.pi)
|
| 2285 |
+
Default: (np.pi / 4, np.pi / 2).
|
| 2286 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2287 |
+
|
| 2288 |
+
Reference:
|
| 2289 |
+
dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
| 2290 |
+
https://arxiv.org/abs/2107.10833
|
| 2291 |
+
|
| 2292 |
+
Targets:
|
| 2293 |
+
image
|
| 2294 |
+
"""
|
| 2295 |
+
|
| 2296 |
+
def __init__(
|
| 2297 |
+
self,
|
| 2298 |
+
blur_limit: Union[int, Sequence[int]] = (7, 15),
|
| 2299 |
+
cutoff: Union[float, Sequence[float]] = (np.pi / 4, np.pi / 2),
|
| 2300 |
+
always_apply=False,
|
| 2301 |
+
p=0.5,
|
| 2302 |
+
):
|
| 2303 |
+
super(RingingOvershoot, self).__init__(always_apply, p)
|
| 2304 |
+
self.blur_limit = to_tuple(blur_limit, 3)
|
| 2305 |
+
self.cutoff = self.__check_values(to_tuple(cutoff, np.pi / 2), name="cutoff", bounds=(0, np.pi))
|
| 2306 |
+
|
| 2307 |
+
@staticmethod
|
| 2308 |
+
def __check_values(value, name, bounds=(0, float("inf"))):
|
| 2309 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 2310 |
+
raise ValueError(f"{name} values should be between {bounds}")
|
| 2311 |
+
return value
|
| 2312 |
+
|
| 2313 |
+
def get_params(self):
|
| 2314 |
+
ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2)
|
| 2315 |
+
if ksize % 2 == 0:
|
| 2316 |
+
raise ValueError(f"Kernel size must be odd. Got: {ksize}")
|
| 2317 |
+
|
| 2318 |
+
cutoff = random.uniform(*self.cutoff)
|
| 2319 |
+
|
| 2320 |
+
# From dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
|
| 2321 |
+
with np.errstate(divide="ignore", invalid="ignore"):
|
| 2322 |
+
kernel = np.fromfunction(
|
| 2323 |
+
lambda x, y: cutoff
|
| 2324 |
+
* special.j1(cutoff * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2))
|
| 2325 |
+
/ (2 * np.pi * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2)),
|
| 2326 |
+
[ksize, ksize],
|
| 2327 |
+
)
|
| 2328 |
+
kernel[(ksize - 1) // 2, (ksize - 1) // 2] = cutoff**2 / (4 * np.pi)
|
| 2329 |
+
|
| 2330 |
+
# Normalize kernel
|
| 2331 |
+
kernel = kernel.astype(np.float32) / np.sum(kernel)
|
| 2332 |
+
|
| 2333 |
+
return {"kernel": kernel}
|
| 2334 |
+
|
| 2335 |
+
def apply(self, img, kernel=None, **params):
|
| 2336 |
+
return F.convolve(img, kernel)
|
| 2337 |
+
|
| 2338 |
+
def get_transform_init_args_names(self):
|
| 2339 |
+
return ("blur_limit", "cutoff")
|
| 2340 |
+
|
| 2341 |
+
|
| 2342 |
+
class UnsharpMask(ImageOnlyTransform):
|
| 2343 |
+
"""
|
| 2344 |
+
Sharpen the input image using Unsharp Masking processing and overlays the result with the original image.
|
| 2345 |
+
|
| 2346 |
+
Args:
|
| 2347 |
+
blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image.
|
| 2348 |
+
Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
|
| 2349 |
+
as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
|
| 2350 |
+
If set single value `blur_limit` will be in range (0, blur_limit).
|
| 2351 |
+
Default: (3, 7).
|
| 2352 |
+
sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf).
|
| 2353 |
+
If set single value `sigma_limit` will be in range (0, sigma_limit).
|
| 2354 |
+
If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
|
| 2355 |
+
alpha (float, (float, float)): range to choose the visibility of the sharpened image.
|
| 2356 |
+
At 0, only the original image is visible, at 1.0 only its sharpened version is visible.
|
| 2357 |
+
Default: (0.2, 0.5).
|
| 2358 |
+
threshold (int): Value to limit sharpening only for areas with high pixel difference between original image
|
| 2359 |
+
and it's smoothed version. Higher threshold means less sharpening on flat areas.
|
| 2360 |
+
Must be in range [0, 255]. Default: 10.
|
| 2361 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2362 |
+
|
| 2363 |
+
Reference:
|
| 2364 |
+
arxiv.org/pdf/2107.10833.pdf
|
| 2365 |
+
|
| 2366 |
+
Targets:
|
| 2367 |
+
image
|
| 2368 |
+
"""
|
| 2369 |
+
|
| 2370 |
+
def __init__(
|
| 2371 |
+
self,
|
| 2372 |
+
blur_limit: Union[int, Sequence[int]] = (3, 7),
|
| 2373 |
+
sigma_limit: Union[float, Sequence[float]] = 0.0,
|
| 2374 |
+
alpha: Union[float, Sequence[float]] = (0.2, 0.5),
|
| 2375 |
+
threshold: int = 10,
|
| 2376 |
+
always_apply=False,
|
| 2377 |
+
p=0.5,
|
| 2378 |
+
):
|
| 2379 |
+
super(UnsharpMask, self).__init__(always_apply, p)
|
| 2380 |
+
self.blur_limit = to_tuple(blur_limit, 3)
|
| 2381 |
+
self.sigma_limit = self.__check_values(to_tuple(sigma_limit, 0.0), name="sigma_limit")
|
| 2382 |
+
self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
|
| 2383 |
+
self.threshold = threshold
|
| 2384 |
+
|
| 2385 |
+
if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0:
|
| 2386 |
+
self.blur_limit = 3, max(3, self.blur_limit[1])
|
| 2387 |
+
raise ValueError("blur_limit and sigma_limit minimum value can not be both equal to 0.")
|
| 2388 |
+
|
| 2389 |
+
if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
|
| 2390 |
+
self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
|
| 2391 |
+
):
|
| 2392 |
+
raise ValueError("UnsharpMask supports only odd blur limits.")
|
| 2393 |
+
|
| 2394 |
+
@staticmethod
|
| 2395 |
+
def __check_values(value, name, bounds=(0, float("inf"))):
|
| 2396 |
+
if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
|
| 2397 |
+
raise ValueError(f"{name} values should be between {bounds}")
|
| 2398 |
+
return value
|
| 2399 |
+
|
| 2400 |
+
def get_params(self):
|
| 2401 |
+
return {
|
| 2402 |
+
"ksize": random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2),
|
| 2403 |
+
"sigma": random.uniform(*self.sigma_limit),
|
| 2404 |
+
"alpha": random.uniform(*self.alpha),
|
| 2405 |
+
}
|
| 2406 |
+
|
| 2407 |
+
def apply(self, img, ksize=3, sigma=0, alpha=0.2, **params):
|
| 2408 |
+
return F.unsharp_mask(img, ksize, sigma=sigma, alpha=alpha, threshold=self.threshold)
|
| 2409 |
+
|
| 2410 |
+
def get_transform_init_args_names(self):
|
| 2411 |
+
return ("blur_limit", "sigma_limit", "alpha", "threshold")
|
| 2412 |
+
|
| 2413 |
+
|
| 2414 |
+
class PixelDropout(DualTransform):
|
| 2415 |
+
"""Set pixels to 0 with some probability.
|
| 2416 |
+
|
| 2417 |
+
Args:
|
| 2418 |
+
dropout_prob (float): pixel drop probability. Default: 0.01
|
| 2419 |
+
per_channel (bool): if set to `True` drop mask will be sampled fo each channel,
|
| 2420 |
+
otherwise the same mask will be sampled for all channels. Default: False
|
| 2421 |
+
drop_value (number or sequence of numbers or None): Value that will be set in dropped place.
|
| 2422 |
+
If set to None value will be sampled randomly, default ranges will be used:
|
| 2423 |
+
- uint8 - [0, 255]
|
| 2424 |
+
- uint16 - [0, 65535]
|
| 2425 |
+
- uint32 - [0, 4294967295]
|
| 2426 |
+
- float, double - [0, 1]
|
| 2427 |
+
Default: 0
|
| 2428 |
+
mask_drop_value (number or sequence of numbers or None): Value that will be set in dropped place in masks.
|
| 2429 |
+
If set to None masks will be unchanged. Default: 0
|
| 2430 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2431 |
+
|
| 2432 |
+
Targets:
|
| 2433 |
+
image, mask
|
| 2434 |
+
Image types:
|
| 2435 |
+
any
|
| 2436 |
+
"""
|
| 2437 |
+
|
| 2438 |
+
def __init__(
|
| 2439 |
+
self,
|
| 2440 |
+
dropout_prob: float = 0.01,
|
| 2441 |
+
per_channel: bool = False,
|
| 2442 |
+
drop_value: Optional[Union[float, Sequence[float]]] = 0,
|
| 2443 |
+
mask_drop_value: Optional[Union[float, Sequence[float]]] = None,
|
| 2444 |
+
always_apply: bool = False,
|
| 2445 |
+
p: float = 0.5,
|
| 2446 |
+
):
|
| 2447 |
+
super().__init__(always_apply, p)
|
| 2448 |
+
self.dropout_prob = dropout_prob
|
| 2449 |
+
self.per_channel = per_channel
|
| 2450 |
+
self.drop_value = drop_value
|
| 2451 |
+
self.mask_drop_value = mask_drop_value
|
| 2452 |
+
|
| 2453 |
+
if self.mask_drop_value is not None and self.per_channel:
|
| 2454 |
+
raise ValueError("PixelDropout supports mask only with per_channel=False")
|
| 2455 |
+
|
| 2456 |
+
def apply(
|
| 2457 |
+
self,
|
| 2458 |
+
img: np.ndarray,
|
| 2459 |
+
drop_mask: np.ndarray = np.array(None),
|
| 2460 |
+
drop_value: Union[float, Sequence[float]] = (),
|
| 2461 |
+
**params
|
| 2462 |
+
) -> np.ndarray:
|
| 2463 |
+
return F.pixel_dropout(img, drop_mask, drop_value)
|
| 2464 |
+
|
| 2465 |
+
def apply_to_mask(self, img: np.ndarray, drop_mask: np.ndarray = np.array(None), **params) -> np.ndarray:
|
| 2466 |
+
if self.mask_drop_value is None:
|
| 2467 |
+
return img
|
| 2468 |
+
|
| 2469 |
+
if img.ndim == 2:
|
| 2470 |
+
drop_mask = np.squeeze(drop_mask)
|
| 2471 |
+
|
| 2472 |
+
return F.pixel_dropout(img, drop_mask, self.mask_drop_value)
|
| 2473 |
+
|
| 2474 |
+
def apply_to_bbox(self, bbox, **params):
|
| 2475 |
+
return bbox
|
| 2476 |
+
|
| 2477 |
+
def apply_to_keypoint(self, keypoint, **params):
|
| 2478 |
+
return keypoint
|
| 2479 |
+
|
| 2480 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 2481 |
+
img = params["image"]
|
| 2482 |
+
shape = img.shape if self.per_channel else img.shape[:2]
|
| 2483 |
+
|
| 2484 |
+
rnd = np.random.RandomState(random.randint(0, 1 << 31))
|
| 2485 |
+
# Use choice to create boolean matrix, if we will use binomial after that we will need type conversion
|
| 2486 |
+
drop_mask = rnd.choice([True, False], shape, p=[self.dropout_prob, 1 - self.dropout_prob])
|
| 2487 |
+
|
| 2488 |
+
drop_value: Union[float, Sequence[float], np.ndarray]
|
| 2489 |
+
if drop_mask.ndim != img.ndim:
|
| 2490 |
+
drop_mask = np.expand_dims(drop_mask, -1)
|
| 2491 |
+
if self.drop_value is None:
|
| 2492 |
+
drop_shape = 1 if is_grayscale_image(img) else int(img.shape[-1])
|
| 2493 |
+
|
| 2494 |
+
if img.dtype in (np.uint8, np.uint16, np.uint32):
|
| 2495 |
+
drop_value = rnd.randint(0, int(F.MAX_VALUES_BY_DTYPE[img.dtype]), drop_shape, img.dtype)
|
| 2496 |
+
elif img.dtype in [np.float32, np.double]:
|
| 2497 |
+
drop_value = rnd.uniform(0, 1, drop_shape).astype(img.dtype)
|
| 2498 |
+
else:
|
| 2499 |
+
raise ValueError(f"Unsupported dtype: {img.dtype}")
|
| 2500 |
+
else:
|
| 2501 |
+
drop_value = self.drop_value
|
| 2502 |
+
|
| 2503 |
+
return {"drop_mask": drop_mask, "drop_value": drop_value}
|
| 2504 |
+
|
| 2505 |
+
@property
|
| 2506 |
+
def targets_as_params(self) -> List[str]:
|
| 2507 |
+
return ["image"]
|
| 2508 |
+
|
| 2509 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
|
| 2510 |
+
return ("dropout_prob", "per_channel", "drop_value", "mask_drop_value")
|
| 2511 |
+
|
| 2512 |
+
|
| 2513 |
+
class Spatter(ImageOnlyTransform):
|
| 2514 |
+
"""
|
| 2515 |
+
Apply spatter transform. It simulates corruption which can occlude a lens in the form of rain or mud.
|
| 2516 |
+
|
| 2517 |
+
Args:
|
| 2518 |
+
mean (float, or tuple of floats): Mean value of normal distribution for generating liquid layer.
|
| 2519 |
+
If single float it will be used as mean.
|
| 2520 |
+
If tuple of float mean will be sampled from range `[mean[0], mean[1])`. Default: (0.65).
|
| 2521 |
+
std (float, or tuple of floats): Standard deviation value of normal distribution for generating liquid layer.
|
| 2522 |
+
If single float it will be used as std.
|
| 2523 |
+
If tuple of float std will be sampled from range `[std[0], std[1])`. Default: (0.3).
|
| 2524 |
+
gauss_sigma (float, or tuple of floats): Sigma value for gaussian filtering of liquid layer.
|
| 2525 |
+
If single float it will be used as gauss_sigma.
|
| 2526 |
+
If tuple of float gauss_sigma will be sampled from range `[sigma[0], sigma[1])`. Default: (2).
|
| 2527 |
+
cutout_threshold (float, or tuple of floats): Threshold for filtering liqued layer
|
| 2528 |
+
(determines number of drops). If single float it will used as cutout_threshold.
|
| 2529 |
+
If tuple of float cutout_threshold will be sampled from range `[cutout_threshold[0], cutout_threshold[1])`.
|
| 2530 |
+
Default: (0.68).
|
| 2531 |
+
intensity (float, or tuple of floats): Intensity of corruption.
|
| 2532 |
+
If single float it will be used as intensity.
|
| 2533 |
+
If tuple of float intensity will be sampled from range `[intensity[0], intensity[1])`. Default: (0.6).
|
| 2534 |
+
mode (string, or list of strings): Type of corruption. Currently, supported options are 'rain' and 'mud'.
|
| 2535 |
+
If list is provided type of corruption will be sampled list. Default: ("rain").
|
| 2536 |
+
color (list of (r, g, b) or dict or None): Corruption elements color.
|
| 2537 |
+
If list uses provided list as color for specified mode.
|
| 2538 |
+
If dict uses provided color for specified mode. Color for each specified mode should be provided in dict.
|
| 2539 |
+
If None uses default colors (rain: (238, 238, 175), mud: (20, 42, 63)).
|
| 2540 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 2541 |
+
|
| 2542 |
+
Targets:
|
| 2543 |
+
image
|
| 2544 |
+
|
| 2545 |
+
Image types:
|
| 2546 |
+
uint8, float32
|
| 2547 |
+
|
| 2548 |
+
Reference:
|
| 2549 |
+
| https://arxiv.org/pdf/1903.12261.pdf
|
| 2550 |
+
| https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py
|
| 2551 |
+
"""
|
| 2552 |
+
|
| 2553 |
+
def __init__(
|
| 2554 |
+
self,
|
| 2555 |
+
mean: ScaleFloatType = 0.65,
|
| 2556 |
+
std: ScaleFloatType = 0.3,
|
| 2557 |
+
gauss_sigma: ScaleFloatType = 2,
|
| 2558 |
+
cutout_threshold: ScaleFloatType = 0.68,
|
| 2559 |
+
intensity: ScaleFloatType = 0.6,
|
| 2560 |
+
mode: Union[str, Sequence[str]] = "rain",
|
| 2561 |
+
color: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None,
|
| 2562 |
+
always_apply: bool = False,
|
| 2563 |
+
p: float = 0.5,
|
| 2564 |
+
):
|
| 2565 |
+
super().__init__(always_apply=always_apply, p=p)
|
| 2566 |
+
|
| 2567 |
+
self.mean = to_tuple(mean, mean)
|
| 2568 |
+
self.std = to_tuple(std, std)
|
| 2569 |
+
self.gauss_sigma = to_tuple(gauss_sigma, gauss_sigma)
|
| 2570 |
+
self.intensity = to_tuple(intensity, intensity)
|
| 2571 |
+
self.cutout_threshold = to_tuple(cutout_threshold, cutout_threshold)
|
| 2572 |
+
self.color = (
|
| 2573 |
+
color
|
| 2574 |
+
if color is not None
|
| 2575 |
+
else {
|
| 2576 |
+
"rain": [238, 238, 175],
|
| 2577 |
+
"mud": [20, 42, 63],
|
| 2578 |
+
}
|
| 2579 |
+
)
|
| 2580 |
+
self.mode = mode if isinstance(mode, (list, tuple)) else [mode]
|
| 2581 |
+
|
| 2582 |
+
if len(set(self.mode)) > 1 and not isinstance(self.color, dict):
|
| 2583 |
+
raise ValueError(f"Unsupported color: {self.color}. Please specify color for each mode (use dict for it).")
|
| 2584 |
+
|
| 2585 |
+
for i in self.mode:
|
| 2586 |
+
if i not in ["rain", "mud"]:
|
| 2587 |
+
raise ValueError(f"Unsupported color mode: {mode}. Transform supports only `rain` and `mud` mods.")
|
| 2588 |
+
if isinstance(self.color, dict):
|
| 2589 |
+
if i not in self.color:
|
| 2590 |
+
raise ValueError(f"Wrong color definition: {self.color}. Color for mode: {i} not specified.")
|
| 2591 |
+
if len(self.color[i]) != 3:
|
| 2592 |
+
raise ValueError(
|
| 2593 |
+
f"Unsupported color: {self.color[i]} for mode {i}. Color should be presented in RGB format."
|
| 2594 |
+
)
|
| 2595 |
+
|
| 2596 |
+
if isinstance(self.color, (list, tuple)):
|
| 2597 |
+
if len(self.color) != 3:
|
| 2598 |
+
raise ValueError(f"Unsupported color: {self.color}. Color should be presented in RGB format.")
|
| 2599 |
+
self.color = {self.mode[0]: self.color}
|
| 2600 |
+
|
| 2601 |
+
def apply(
|
| 2602 |
+
self,
|
| 2603 |
+
img: np.ndarray,
|
| 2604 |
+
non_mud: Optional[np.ndarray] = None,
|
| 2605 |
+
mud: Optional[np.ndarray] = None,
|
| 2606 |
+
drops: Optional[np.ndarray] = None,
|
| 2607 |
+
mode: str = "",
|
| 2608 |
+
**params
|
| 2609 |
+
) -> np.ndarray:
|
| 2610 |
+
return F.spatter(img, non_mud, mud, drops, mode)
|
| 2611 |
+
|
| 2612 |
+
@property
|
| 2613 |
+
def targets_as_params(self) -> List[str]:
|
| 2614 |
+
return ["image"]
|
| 2615 |
+
|
| 2616 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 2617 |
+
h, w = params["image"].shape[:2]
|
| 2618 |
+
|
| 2619 |
+
mean = random.uniform(self.mean[0], self.mean[1])
|
| 2620 |
+
std = random.uniform(self.std[0], self.std[1])
|
| 2621 |
+
cutout_threshold = random.uniform(self.cutout_threshold[0], self.cutout_threshold[1])
|
| 2622 |
+
sigma = random.uniform(self.gauss_sigma[0], self.gauss_sigma[1])
|
| 2623 |
+
mode = random.choice(self.mode)
|
| 2624 |
+
intensity = random.uniform(self.intensity[0], self.intensity[1])
|
| 2625 |
+
color = np.array(self.color[mode]) / 255.0
|
| 2626 |
+
|
| 2627 |
+
liquid_layer = random_utils.normal(size=(h, w), loc=mean, scale=std)
|
| 2628 |
+
liquid_layer = gaussian_filter(liquid_layer, sigma=sigma, mode="nearest")
|
| 2629 |
+
liquid_layer[liquid_layer < cutout_threshold] = 0
|
| 2630 |
+
|
| 2631 |
+
if mode == "rain":
|
| 2632 |
+
liquid_layer = (liquid_layer * 255).astype(np.uint8)
|
| 2633 |
+
dist = 255 - cv2.Canny(liquid_layer, 50, 150)
|
| 2634 |
+
dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
|
| 2635 |
+
_, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
|
| 2636 |
+
dist = blur(dist, 3).astype(np.uint8)
|
| 2637 |
+
dist = F.equalize(dist)
|
| 2638 |
+
|
| 2639 |
+
ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
|
| 2640 |
+
dist = F.convolve(dist, ker)
|
| 2641 |
+
dist = blur(dist, 3).astype(np.float32)
|
| 2642 |
+
|
| 2643 |
+
m = liquid_layer * dist
|
| 2644 |
+
m *= 1 / np.max(m, axis=(0, 1))
|
| 2645 |
+
|
| 2646 |
+
drops = m[:, :, None] * color * intensity
|
| 2647 |
+
mud = None
|
| 2648 |
+
non_mud = None
|
| 2649 |
+
else:
|
| 2650 |
+
m = np.where(liquid_layer > cutout_threshold, 1, 0)
|
| 2651 |
+
m = gaussian_filter(m.astype(np.float32), sigma=sigma, mode="nearest")
|
| 2652 |
+
m[m < 1.2 * cutout_threshold] = 0
|
| 2653 |
+
m = m[..., np.newaxis]
|
| 2654 |
+
|
| 2655 |
+
mud = m * color
|
| 2656 |
+
non_mud = 1 - m
|
| 2657 |
+
drops = None
|
| 2658 |
+
|
| 2659 |
+
return {
|
| 2660 |
+
"non_mud": non_mud,
|
| 2661 |
+
"mud": mud,
|
| 2662 |
+
"drops": drops,
|
| 2663 |
+
"mode": mode,
|
| 2664 |
+
}
|
| 2665 |
+
|
| 2666 |
+
def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]:
|
| 2667 |
+
return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color"
|
custom_albumentations/augmentations/utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import wraps
|
| 2 |
+
from typing import Callable, Union
|
| 3 |
+
|
| 4 |
+
import cv2
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing_extensions import Concatenate, ParamSpec
|
| 7 |
+
|
| 8 |
+
from custom_albumentations.core.keypoints_utils import angle_to_2pi_range
|
| 9 |
+
from custom_albumentations.core.transforms_interface import KeypointInternalType
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"read_bgr_image",
|
| 13 |
+
"read_rgb_image",
|
| 14 |
+
"MAX_VALUES_BY_DTYPE",
|
| 15 |
+
"NPDTYPE_TO_OPENCV_DTYPE",
|
| 16 |
+
"clipped",
|
| 17 |
+
"get_opencv_dtype_from_numpy",
|
| 18 |
+
"angle_2pi_range",
|
| 19 |
+
"clip",
|
| 20 |
+
"preserve_shape",
|
| 21 |
+
"preserve_channel_dim",
|
| 22 |
+
"ensure_contiguous",
|
| 23 |
+
"is_rgb_image",
|
| 24 |
+
"is_grayscale_image",
|
| 25 |
+
"is_multispectral_image",
|
| 26 |
+
"get_num_channels",
|
| 27 |
+
"non_rgb_warning",
|
| 28 |
+
"_maybe_process_in_chunks",
|
| 29 |
+
]
|
| 30 |
+
|
| 31 |
+
P = ParamSpec("P")
|
| 32 |
+
|
| 33 |
+
MAX_VALUES_BY_DTYPE = {
|
| 34 |
+
np.dtype("uint8"): 255,
|
| 35 |
+
np.dtype("uint16"): 65535,
|
| 36 |
+
np.dtype("uint32"): 4294967295,
|
| 37 |
+
np.dtype("float32"): 1.0,
|
| 38 |
+
}
|
| 39 |
+
|
| 40 |
+
NPDTYPE_TO_OPENCV_DTYPE = {
|
| 41 |
+
np.uint8: cv2.CV_8U, # type: ignore[attr-defined]
|
| 42 |
+
np.uint16: cv2.CV_16U, # type: ignore[attr-defined]
|
| 43 |
+
np.int32: cv2.CV_32S, # type: ignore[attr-defined]
|
| 44 |
+
np.float32: cv2.CV_32F, # type: ignore[attr-defined]
|
| 45 |
+
np.float64: cv2.CV_64F, # type: ignore[attr-defined]
|
| 46 |
+
np.dtype("uint8"): cv2.CV_8U, # type: ignore[attr-defined]
|
| 47 |
+
np.dtype("uint16"): cv2.CV_16U, # type: ignore[attr-defined]
|
| 48 |
+
np.dtype("int32"): cv2.CV_32S, # type: ignore[attr-defined]
|
| 49 |
+
np.dtype("float32"): cv2.CV_32F, # type: ignore[attr-defined]
|
| 50 |
+
np.dtype("float64"): cv2.CV_64F, # type: ignore[attr-defined]
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def read_bgr_image(path):
|
| 55 |
+
return cv2.imread(path, cv2.IMREAD_COLOR)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def read_rgb_image(path):
|
| 59 |
+
image = cv2.imread(path, cv2.IMREAD_COLOR)
|
| 60 |
+
return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def clipped(func: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
|
| 64 |
+
@wraps(func)
|
| 65 |
+
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
|
| 66 |
+
dtype = img.dtype
|
| 67 |
+
maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0)
|
| 68 |
+
return clip(func(img, *args, **kwargs), dtype, maxval)
|
| 69 |
+
|
| 70 |
+
return wrapped_function
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def clip(img: np.ndarray, dtype: np.dtype, maxval: float) -> np.ndarray:
|
| 74 |
+
return np.clip(img, 0, maxval).astype(dtype)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def get_opencv_dtype_from_numpy(value: Union[np.ndarray, int, np.dtype, object]) -> int:
|
| 78 |
+
"""
|
| 79 |
+
Return a corresponding OpenCV dtype for a numpy's dtype
|
| 80 |
+
:param value: Input dtype of numpy array
|
| 81 |
+
:return: Corresponding dtype for OpenCV
|
| 82 |
+
"""
|
| 83 |
+
if isinstance(value, np.ndarray):
|
| 84 |
+
value = value.dtype
|
| 85 |
+
return NPDTYPE_TO_OPENCV_DTYPE[value]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def angle_2pi_range(
|
| 89 |
+
func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]
|
| 90 |
+
) -> Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]:
|
| 91 |
+
@wraps(func)
|
| 92 |
+
def wrapped_function(keypoint: KeypointInternalType, *args: P.args, **kwargs: P.kwargs) -> KeypointInternalType:
|
| 93 |
+
(x, y, a, s) = func(keypoint, *args, **kwargs)[:4]
|
| 94 |
+
return (x, y, angle_to_2pi_range(a), s)
|
| 95 |
+
|
| 96 |
+
return wrapped_function
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def preserve_shape(
|
| 100 |
+
func: Callable[Concatenate[np.ndarray, P], np.ndarray]
|
| 101 |
+
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
|
| 102 |
+
"""Preserve shape of the image"""
|
| 103 |
+
|
| 104 |
+
@wraps(func)
|
| 105 |
+
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
|
| 106 |
+
shape = img.shape
|
| 107 |
+
result = func(img, *args, **kwargs)
|
| 108 |
+
result = result.reshape(shape)
|
| 109 |
+
return result
|
| 110 |
+
|
| 111 |
+
return wrapped_function
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def preserve_channel_dim(
|
| 115 |
+
func: Callable[Concatenate[np.ndarray, P], np.ndarray]
|
| 116 |
+
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
|
| 117 |
+
"""Preserve dummy channel dim."""
|
| 118 |
+
|
| 119 |
+
@wraps(func)
|
| 120 |
+
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
|
| 121 |
+
shape = img.shape
|
| 122 |
+
result = func(img, *args, **kwargs)
|
| 123 |
+
if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
|
| 124 |
+
result = np.expand_dims(result, axis=-1)
|
| 125 |
+
return result
|
| 126 |
+
|
| 127 |
+
return wrapped_function
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def ensure_contiguous(
|
| 131 |
+
func: Callable[Concatenate[np.ndarray, P], np.ndarray]
|
| 132 |
+
) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
|
| 133 |
+
"""Ensure that input img is contiguous."""
|
| 134 |
+
|
| 135 |
+
@wraps(func)
|
| 136 |
+
def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
|
| 137 |
+
img = np.require(img, requirements=["C_CONTIGUOUS"])
|
| 138 |
+
result = func(img, *args, **kwargs)
|
| 139 |
+
return result
|
| 140 |
+
|
| 141 |
+
return wrapped_function
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def is_rgb_image(image: np.ndarray) -> bool:
|
| 145 |
+
return len(image.shape) == 3 and image.shape[-1] == 3
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def is_grayscale_image(image: np.ndarray) -> bool:
|
| 149 |
+
return (len(image.shape) == 2) or (len(image.shape) == 3 and image.shape[-1] == 1)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def is_multispectral_image(image: np.ndarray) -> bool:
|
| 153 |
+
return len(image.shape) == 3 and image.shape[-1] not in [1, 3]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def get_num_channels(image: np.ndarray) -> int:
|
| 157 |
+
return image.shape[2] if len(image.shape) == 3 else 1
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def non_rgb_warning(image: np.ndarray) -> None:
|
| 161 |
+
if not is_rgb_image(image):
|
| 162 |
+
message = "This transformation expects 3-channel images"
|
| 163 |
+
if is_grayscale_image(image):
|
| 164 |
+
message += "\nYou can convert your grayscale image to RGB using cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))"
|
| 165 |
+
if is_multispectral_image(image): # Any image with a number of channels other than 1 and 3
|
| 166 |
+
message += "\nThis transformation cannot be applied to multi-spectral images"
|
| 167 |
+
|
| 168 |
+
raise ValueError(message)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def _maybe_process_in_chunks(
|
| 172 |
+
process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs
|
| 173 |
+
) -> Callable[[np.ndarray], np.ndarray]:
|
| 174 |
+
"""
|
| 175 |
+
Wrap OpenCV function to enable processing images with more than 4 channels.
|
| 176 |
+
|
| 177 |
+
Limitations:
|
| 178 |
+
This wrapper requires image to be the first argument and rest must be sent via named arguments.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
process_fn: Transform function (e.g cv2.resize).
|
| 182 |
+
kwargs: Additional parameters.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
numpy.ndarray: Transformed image.
|
| 186 |
+
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
@wraps(process_fn)
|
| 190 |
+
def __process_fn(img: np.ndarray) -> np.ndarray:
|
| 191 |
+
num_channels = get_num_channels(img)
|
| 192 |
+
if num_channels > 4:
|
| 193 |
+
chunks = []
|
| 194 |
+
for index in range(0, num_channels, 4):
|
| 195 |
+
if num_channels - index == 2:
|
| 196 |
+
# Many OpenCV functions cannot work with 2-channel images
|
| 197 |
+
for i in range(2):
|
| 198 |
+
chunk = img[:, :, index + i : index + i + 1]
|
| 199 |
+
chunk = process_fn(chunk, **kwargs)
|
| 200 |
+
chunk = np.expand_dims(chunk, -1)
|
| 201 |
+
chunks.append(chunk)
|
| 202 |
+
else:
|
| 203 |
+
chunk = img[:, :, index : index + 4]
|
| 204 |
+
chunk = process_fn(chunk, **kwargs)
|
| 205 |
+
chunks.append(chunk)
|
| 206 |
+
img = np.dstack(chunks)
|
| 207 |
+
else:
|
| 208 |
+
img = process_fn(img, **kwargs)
|
| 209 |
+
return img
|
| 210 |
+
|
| 211 |
+
return __process_fn
|
custom_albumentations/core/__init__.py
ADDED
|
File without changes
|
custom_albumentations/core/bbox_utils.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, cast
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
from .transforms_interface import BoxInternalType, BoxType
|
| 8 |
+
from .utils import DataProcessor, Params
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"normalize_bbox",
|
| 12 |
+
"denormalize_bbox",
|
| 13 |
+
"normalize_bboxes",
|
| 14 |
+
"denormalize_bboxes",
|
| 15 |
+
"calculate_bbox_area",
|
| 16 |
+
"filter_bboxes_by_visibility",
|
| 17 |
+
"convert_bbox_to_albumentations",
|
| 18 |
+
"convert_bbox_from_albumentations",
|
| 19 |
+
"convert_bboxes_to_albumentations",
|
| 20 |
+
"convert_bboxes_from_albumentations",
|
| 21 |
+
"check_bbox",
|
| 22 |
+
"check_bboxes",
|
| 23 |
+
"filter_bboxes",
|
| 24 |
+
"union_of_bboxes",
|
| 25 |
+
"BboxProcessor",
|
| 26 |
+
"BboxParams",
|
| 27 |
+
]
|
| 28 |
+
|
| 29 |
+
TBox = TypeVar("TBox", BoxType, BoxInternalType)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class BboxParams(Params):
|
| 33 |
+
"""
|
| 34 |
+
Parameters of bounding boxes
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
format (str): format of bounding boxes. Should be 'coco', 'pascal_voc', 'albumentations' or 'yolo'.
|
| 38 |
+
|
| 39 |
+
The `coco` format
|
| 40 |
+
`[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
|
| 41 |
+
The `pascal_voc` format
|
| 42 |
+
`[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212].
|
| 43 |
+
The `albumentations` format
|
| 44 |
+
is like `pascal_voc`, but normalized,
|
| 45 |
+
in other words: `[x_min, y_min, x_max, y_max]`, e.g. [0.2, 0.3, 0.4, 0.5].
|
| 46 |
+
The `yolo` format
|
| 47 |
+
`[x, y, width, height]`, e.g. [0.1, 0.2, 0.3, 0.4];
|
| 48 |
+
`x`, `y` - normalized bbox center; `width`, `height` - normalized bbox width and height.
|
| 49 |
+
label_fields (list): list of fields that are joined with boxes, e.g labels.
|
| 50 |
+
Should be same type as boxes.
|
| 51 |
+
min_area (float): minimum area of a bounding box. All bounding boxes whose
|
| 52 |
+
visible area in pixels is less than this value will be removed. Default: 0.0.
|
| 53 |
+
min_visibility (float): minimum fraction of area for a bounding box
|
| 54 |
+
to remain this box in list. Default: 0.0.
|
| 55 |
+
min_width (float): Minimum width of a bounding box. All bounding boxes whose width is
|
| 56 |
+
less than this value will be removed. Default: 0.0.
|
| 57 |
+
min_height (float): Minimum height of a bounding box. All bounding boxes whose height is
|
| 58 |
+
less than this value will be removed. Default: 0.0.
|
| 59 |
+
check_each_transform (bool): if `True`, then bboxes will be checked after each dual transform.
|
| 60 |
+
Default: `True`
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(
|
| 64 |
+
self,
|
| 65 |
+
format: str,
|
| 66 |
+
label_fields: Optional[Sequence[str]] = None,
|
| 67 |
+
min_area: float = 0.0,
|
| 68 |
+
min_visibility: float = 0.0,
|
| 69 |
+
min_width: float = 0.0,
|
| 70 |
+
min_height: float = 0.0,
|
| 71 |
+
check_each_transform: bool = True,
|
| 72 |
+
):
|
| 73 |
+
super(BboxParams, self).__init__(format, label_fields)
|
| 74 |
+
self.min_area = min_area
|
| 75 |
+
self.min_visibility = min_visibility
|
| 76 |
+
self.min_width = min_width
|
| 77 |
+
self.min_height = min_height
|
| 78 |
+
self.check_each_transform = check_each_transform
|
| 79 |
+
|
| 80 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 81 |
+
data = super(BboxParams, self)._to_dict()
|
| 82 |
+
data.update(
|
| 83 |
+
{
|
| 84 |
+
"min_area": self.min_area,
|
| 85 |
+
"min_visibility": self.min_visibility,
|
| 86 |
+
"min_width": self.min_width,
|
| 87 |
+
"min_height": self.min_height,
|
| 88 |
+
"check_each_transform": self.check_each_transform,
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
return data
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def is_serializable(cls) -> bool:
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def get_class_fullname(cls) -> str:
|
| 99 |
+
return "BboxParams"
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class BboxProcessor(DataProcessor):
|
| 103 |
+
def __init__(self, params: BboxParams, additional_targets: Optional[Dict[str, str]] = None):
|
| 104 |
+
super().__init__(params, additional_targets)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def default_data_name(self) -> str:
|
| 108 |
+
return "bboxes"
|
| 109 |
+
|
| 110 |
+
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
|
| 111 |
+
for data_name in self.data_fields:
|
| 112 |
+
data_exists = data_name in data and len(data[data_name])
|
| 113 |
+
if data_exists and len(data[data_name][0]) < 5:
|
| 114 |
+
if self.params.label_fields is None:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
"Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox "
|
| 117 |
+
"because bboxes must have labels"
|
| 118 |
+
)
|
| 119 |
+
if self.params.label_fields:
|
| 120 |
+
if not all(i in data.keys() for i in self.params.label_fields):
|
| 121 |
+
raise ValueError("Your 'label_fields' are not valid - them must have same names as params in dict")
|
| 122 |
+
|
| 123 |
+
def filter(self, data: Sequence, rows: int, cols: int) -> List:
|
| 124 |
+
self.params: BboxParams
|
| 125 |
+
return filter_bboxes(
|
| 126 |
+
data,
|
| 127 |
+
rows,
|
| 128 |
+
cols,
|
| 129 |
+
min_area=self.params.min_area,
|
| 130 |
+
min_visibility=self.params.min_visibility,
|
| 131 |
+
min_width=self.params.min_width,
|
| 132 |
+
min_height=self.params.min_height,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
def check(self, data: Sequence, rows: int, cols: int) -> None:
|
| 136 |
+
check_bboxes(data)
|
| 137 |
+
|
| 138 |
+
def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> List[BoxType]:
|
| 139 |
+
return convert_bboxes_from_albumentations(data, self.params.format, rows, cols, check_validity=True)
|
| 140 |
+
|
| 141 |
+
def convert_to_albumentations(self, data: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
|
| 142 |
+
return convert_bboxes_to_albumentations(data, self.params.format, rows, cols, check_validity=True)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def normalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox:
|
| 146 |
+
"""Normalize coordinates of a bounding box. Divide x-coordinates by image width and y-coordinates
|
| 147 |
+
by image height.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
bbox: Denormalized bounding box `(x_min, y_min, x_max, y_max)`.
|
| 151 |
+
rows: Image height.
|
| 152 |
+
cols: Image width.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
Normalized bounding box `(x_min, y_min, x_max, y_max)`.
|
| 156 |
+
|
| 157 |
+
Raises:
|
| 158 |
+
ValueError: If rows or cols is less or equal zero
|
| 159 |
+
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
if rows <= 0:
|
| 163 |
+
raise ValueError("Argument rows must be positive integer")
|
| 164 |
+
if cols <= 0:
|
| 165 |
+
raise ValueError("Argument cols must be positive integer")
|
| 166 |
+
|
| 167 |
+
tail: Tuple[Any, ...]
|
| 168 |
+
(x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
|
| 169 |
+
|
| 170 |
+
x_min, x_max = x_min / cols, x_max / cols
|
| 171 |
+
y_min, y_max = y_min / rows, y_max / rows
|
| 172 |
+
|
| 173 |
+
return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def denormalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox:
|
| 177 |
+
"""Denormalize coordinates of a bounding box. Multiply x-coordinates by image width and y-coordinates
|
| 178 |
+
by image height. This is an inverse operation for :func:`~albumentations.augmentations.bbox.normalize_bbox`.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
bbox: Normalized bounding box `(x_min, y_min, x_max, y_max)`.
|
| 182 |
+
rows: Image height.
|
| 183 |
+
cols: Image width.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Denormalized bounding box `(x_min, y_min, x_max, y_max)`.
|
| 187 |
+
|
| 188 |
+
Raises:
|
| 189 |
+
ValueError: If rows or cols is less or equal zero
|
| 190 |
+
|
| 191 |
+
"""
|
| 192 |
+
tail: Tuple[Any, ...]
|
| 193 |
+
(x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
|
| 194 |
+
|
| 195 |
+
if rows <= 0:
|
| 196 |
+
raise ValueError("Argument rows must be positive integer")
|
| 197 |
+
if cols <= 0:
|
| 198 |
+
raise ValueError("Argument cols must be positive integer")
|
| 199 |
+
|
| 200 |
+
x_min, x_max = x_min * cols, x_max * cols
|
| 201 |
+
y_min, y_max = y_min * rows, y_max * rows
|
| 202 |
+
|
| 203 |
+
return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def normalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
|
| 207 |
+
"""Normalize a list of bounding boxes.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
bboxes: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 211 |
+
rows: Image height.
|
| 212 |
+
cols: Image width.
|
| 213 |
+
|
| 214 |
+
Returns:
|
| 215 |
+
Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 216 |
+
|
| 217 |
+
"""
|
| 218 |
+
return [normalize_bbox(bbox, rows, cols) for bbox in bboxes]
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def denormalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
|
| 222 |
+
"""Denormalize a list of bounding boxes.
|
| 223 |
+
|
| 224 |
+
Args:
|
| 225 |
+
bboxes: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 226 |
+
rows: Image height.
|
| 227 |
+
cols: Image width.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
List: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 231 |
+
|
| 232 |
+
"""
|
| 233 |
+
return [denormalize_bbox(bbox, rows, cols) for bbox in bboxes]
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def calculate_bbox_area(bbox: BoxType, rows: int, cols: int) -> float:
|
| 237 |
+
"""Calculate the area of a bounding box in (fractional) pixels.
|
| 238 |
+
|
| 239 |
+
Args:
|
| 240 |
+
bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 241 |
+
rows: Image height.
|
| 242 |
+
cols: Image width.
|
| 243 |
+
|
| 244 |
+
Return:
|
| 245 |
+
Area in (fractional) pixels of the (denormalized) bounding box.
|
| 246 |
+
|
| 247 |
+
"""
|
| 248 |
+
bbox = denormalize_bbox(bbox, rows, cols)
|
| 249 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 250 |
+
area = (x_max - x_min) * (y_max - y_min)
|
| 251 |
+
return area
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def filter_bboxes_by_visibility(
|
| 255 |
+
original_shape: Sequence[int],
|
| 256 |
+
bboxes: Sequence[BoxType],
|
| 257 |
+
transformed_shape: Sequence[int],
|
| 258 |
+
transformed_bboxes: Sequence[BoxType],
|
| 259 |
+
threshold: float = 0.0,
|
| 260 |
+
min_area: float = 0.0,
|
| 261 |
+
) -> List[BoxType]:
|
| 262 |
+
"""Filter bounding boxes and return only those boxes whose visibility after transformation is above
|
| 263 |
+
the threshold and minimal area of bounding box in pixels is more then min_area.
|
| 264 |
+
|
| 265 |
+
Args:
|
| 266 |
+
original_shape: Original image shape `(height, width, ...)`.
|
| 267 |
+
bboxes: Original bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 268 |
+
transformed_shape: Transformed image shape `(height, width)`.
|
| 269 |
+
transformed_bboxes: Transformed bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 270 |
+
threshold: visibility threshold. Should be a value in the range [0.0, 1.0].
|
| 271 |
+
min_area: Minimal area threshold.
|
| 272 |
+
|
| 273 |
+
Returns:
|
| 274 |
+
Filtered bounding boxes `[(x_min, y_min, x_max, y_max)]`.
|
| 275 |
+
|
| 276 |
+
"""
|
| 277 |
+
img_height, img_width = original_shape[:2]
|
| 278 |
+
transformed_img_height, transformed_img_width = transformed_shape[:2]
|
| 279 |
+
|
| 280 |
+
visible_bboxes = []
|
| 281 |
+
for bbox, transformed_bbox in zip(bboxes, transformed_bboxes):
|
| 282 |
+
if not all(0.0 <= value <= 1.0 for value in transformed_bbox[:4]):
|
| 283 |
+
continue
|
| 284 |
+
bbox_area = calculate_bbox_area(bbox, img_height, img_width)
|
| 285 |
+
transformed_bbox_area = calculate_bbox_area(transformed_bbox, transformed_img_height, transformed_img_width)
|
| 286 |
+
if transformed_bbox_area < min_area:
|
| 287 |
+
continue
|
| 288 |
+
visibility = transformed_bbox_area / bbox_area
|
| 289 |
+
if visibility >= threshold:
|
| 290 |
+
visible_bboxes.append(transformed_bbox)
|
| 291 |
+
return visible_bboxes
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def convert_bbox_to_albumentations(
|
| 295 |
+
bbox: BoxType, source_format: str, rows: int, cols: int, check_validity: bool = False
|
| 296 |
+
) -> BoxType:
|
| 297 |
+
"""Convert a bounding box from a format specified in `source_format` to the format used by albumentations:
|
| 298 |
+
normalized coordinates of top-left and bottom-right corners of the bounding box in a form of
|
| 299 |
+
`(x_min, y_min, x_max, y_max)` e.g. `(0.15, 0.27, 0.67, 0.5)`.
|
| 300 |
+
|
| 301 |
+
Args:
|
| 302 |
+
bbox: A bounding box tuple.
|
| 303 |
+
source_format: format of the bounding box. Should be 'coco', 'pascal_voc', or 'yolo'.
|
| 304 |
+
check_validity: Check if all boxes are valid boxes.
|
| 305 |
+
rows: Image height.
|
| 306 |
+
cols: Image width.
|
| 307 |
+
|
| 308 |
+
Returns:
|
| 309 |
+
tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 310 |
+
|
| 311 |
+
Note:
|
| 312 |
+
The `coco` format of a bounding box looks like `(x_min, y_min, width, height)`, e.g. (97, 12, 150, 200).
|
| 313 |
+
The `pascal_voc` format of a bounding box looks like `(x_min, y_min, x_max, y_max)`, e.g. (97, 12, 247, 212).
|
| 314 |
+
The `yolo` format of a bounding box looks like `(x, y, width, height)`, e.g. (0.3, 0.1, 0.05, 0.07);
|
| 315 |
+
where `x`, `y` coordinates of the center of the box, all values normalized to 1 by image height and width.
|
| 316 |
+
|
| 317 |
+
Raises:
|
| 318 |
+
ValueError: if `target_format` is not equal to `coco` or `pascal_voc`, or `yolo`.
|
| 319 |
+
ValueError: If in YOLO format all labels not in range (0, 1).
|
| 320 |
+
|
| 321 |
+
"""
|
| 322 |
+
if source_format not in {"coco", "pascal_voc", "yolo"}:
|
| 323 |
+
raise ValueError(
|
| 324 |
+
f"Unknown source_format {source_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'"
|
| 325 |
+
)
|
| 326 |
+
|
| 327 |
+
if source_format == "coco":
|
| 328 |
+
(x_min, y_min, width, height), tail = bbox[:4], bbox[4:]
|
| 329 |
+
x_max = x_min + width
|
| 330 |
+
y_max = y_min + height
|
| 331 |
+
elif source_format == "yolo":
|
| 332 |
+
# https://github.com/pjreddie/darknet/blob/f6d861736038da22c9eb0739dca84003c5a5e275/scripts/voc_label.py#L12
|
| 333 |
+
_bbox = np.array(bbox[:4])
|
| 334 |
+
if check_validity and np.any((_bbox <= 0) | (_bbox > 1)):
|
| 335 |
+
raise ValueError("In YOLO format all coordinates must be float and in range (0, 1]")
|
| 336 |
+
|
| 337 |
+
(x, y, w, h), tail = bbox[:4], bbox[4:]
|
| 338 |
+
|
| 339 |
+
w_half, h_half = w / 2, h / 2
|
| 340 |
+
x_min = x - w_half
|
| 341 |
+
y_min = y - h_half
|
| 342 |
+
x_max = x_min + w
|
| 343 |
+
y_max = y_min + h
|
| 344 |
+
else:
|
| 345 |
+
(x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:]
|
| 346 |
+
|
| 347 |
+
bbox = (x_min, y_min, x_max, y_max) + tuple(tail) # type: ignore
|
| 348 |
+
|
| 349 |
+
if source_format != "yolo":
|
| 350 |
+
bbox = normalize_bbox(bbox, rows, cols)
|
| 351 |
+
if check_validity:
|
| 352 |
+
check_bbox(bbox)
|
| 353 |
+
return bbox
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def convert_bbox_from_albumentations(
|
| 357 |
+
bbox: BoxType, target_format: str, rows: int, cols: int, check_validity: bool = False
|
| 358 |
+
) -> BoxType:
|
| 359 |
+
"""Convert a bounding box from the format used by albumentations to a format, specified in `target_format`.
|
| 360 |
+
|
| 361 |
+
Args:
|
| 362 |
+
bbox: An albumentations bounding box `(x_min, y_min, x_max, y_max)`.
|
| 363 |
+
target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'.
|
| 364 |
+
rows: Image height.
|
| 365 |
+
cols: Image width.
|
| 366 |
+
check_validity: Check if all boxes are valid boxes.
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
tuple: A bounding box.
|
| 370 |
+
|
| 371 |
+
Note:
|
| 372 |
+
The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
|
| 373 |
+
The `pascal_voc` format of a bounding box looks like `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212].
|
| 374 |
+
The `yolo` format of a bounding box looks like `[x, y, width, height]`, e.g. [0.3, 0.1, 0.05, 0.07].
|
| 375 |
+
|
| 376 |
+
Raises:
|
| 377 |
+
ValueError: if `target_format` is not equal to `coco`, `pascal_voc` or `yolo`.
|
| 378 |
+
|
| 379 |
+
"""
|
| 380 |
+
if target_format not in {"coco", "pascal_voc", "yolo"}:
|
| 381 |
+
raise ValueError(
|
| 382 |
+
f"Unknown target_format {target_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'"
|
| 383 |
+
)
|
| 384 |
+
if check_validity:
|
| 385 |
+
check_bbox(bbox)
|
| 386 |
+
|
| 387 |
+
if target_format != "yolo":
|
| 388 |
+
bbox = denormalize_bbox(bbox, rows, cols)
|
| 389 |
+
if target_format == "coco":
|
| 390 |
+
(x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
|
| 391 |
+
width = x_max - x_min
|
| 392 |
+
height = y_max - y_min
|
| 393 |
+
bbox = cast(BoxType, (x_min, y_min, width, height) + tail)
|
| 394 |
+
elif target_format == "yolo":
|
| 395 |
+
(x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:]
|
| 396 |
+
x = (x_min + x_max) / 2.0
|
| 397 |
+
y = (y_min + y_max) / 2.0
|
| 398 |
+
w = x_max - x_min
|
| 399 |
+
h = y_max - y_min
|
| 400 |
+
bbox = cast(BoxType, (x, y, w, h) + tail)
|
| 401 |
+
return bbox
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def convert_bboxes_to_albumentations(
|
| 405 |
+
bboxes: Sequence[BoxType], source_format, rows, cols, check_validity=False
|
| 406 |
+
) -> List[BoxType]:
|
| 407 |
+
"""Convert a list bounding boxes from a format specified in `source_format` to the format used by albumentations"""
|
| 408 |
+
return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes]
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
def convert_bboxes_from_albumentations(
|
| 412 |
+
bboxes: Sequence[BoxType], target_format: str, rows: int, cols: int, check_validity: bool = False
|
| 413 |
+
) -> List[BoxType]:
|
| 414 |
+
"""Convert a list of bounding boxes from the format used by albumentations to a format, specified
|
| 415 |
+
in `target_format`.
|
| 416 |
+
|
| 417 |
+
Args:
|
| 418 |
+
bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
|
| 419 |
+
target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'.
|
| 420 |
+
rows: Image height.
|
| 421 |
+
cols: Image width.
|
| 422 |
+
check_validity: Check if all boxes are valid boxes.
|
| 423 |
+
|
| 424 |
+
Returns:
|
| 425 |
+
List of bounding boxes.
|
| 426 |
+
|
| 427 |
+
"""
|
| 428 |
+
return [convert_bbox_from_albumentations(bbox, target_format, rows, cols, check_validity) for bbox in bboxes]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
def check_bbox(bbox: BoxType) -> None:
|
| 432 |
+
"""Check if bbox boundaries are in range 0, 1 and minimums are lesser then maximums"""
|
| 433 |
+
for name, value in zip(["x_min", "y_min", "x_max", "y_max"], bbox[:4]):
|
| 434 |
+
if not 0 <= value <= 1 and not np.isclose(value, 0) and not np.isclose(value, 1):
|
| 435 |
+
raise ValueError(f"Expected {name} for bbox {bbox} to be in the range [0.0, 1.0], got {value}.")
|
| 436 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 437 |
+
if x_max <= x_min:
|
| 438 |
+
raise ValueError(f"x_max is less than or equal to x_min for bbox {bbox}.")
|
| 439 |
+
if y_max <= y_min:
|
| 440 |
+
raise ValueError(f"y_max is less than or equal to y_min for bbox {bbox}.")
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def check_bboxes(bboxes: Sequence[BoxType]) -> None:
|
| 444 |
+
"""Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums"""
|
| 445 |
+
for bbox in bboxes:
|
| 446 |
+
check_bbox(bbox)
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def filter_bboxes(
|
| 450 |
+
bboxes: Sequence[BoxType],
|
| 451 |
+
rows: int,
|
| 452 |
+
cols: int,
|
| 453 |
+
min_area: float = 0.0,
|
| 454 |
+
min_visibility: float = 0.0,
|
| 455 |
+
min_width: float = 0.0,
|
| 456 |
+
min_height: float = 0.0,
|
| 457 |
+
) -> List[BoxType]:
|
| 458 |
+
"""Remove bounding boxes that either lie outside of the visible area by more then min_visibility
|
| 459 |
+
or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size.
|
| 460 |
+
|
| 461 |
+
Args:
|
| 462 |
+
bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
|
| 463 |
+
rows: Image height.
|
| 464 |
+
cols: Image width.
|
| 465 |
+
min_area: Minimum area of a bounding box. All bounding boxes whose visible area in pixels.
|
| 466 |
+
is less than this value will be removed. Default: 0.0.
|
| 467 |
+
min_visibility: Minimum fraction of area for a bounding box to remain this box in list. Default: 0.0.
|
| 468 |
+
min_width: Minimum width of a bounding box. All bounding boxes whose width is
|
| 469 |
+
less than this value will be removed. Default: 0.0.
|
| 470 |
+
min_height: Minimum height of a bounding box. All bounding boxes whose height is
|
| 471 |
+
less than this value will be removed. Default: 0.0.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
List of bounding boxes.
|
| 475 |
+
|
| 476 |
+
"""
|
| 477 |
+
resulting_boxes: List[BoxType] = []
|
| 478 |
+
for bbox in bboxes:
|
| 479 |
+
# Calculate areas of bounding box before and after clipping.
|
| 480 |
+
transformed_box_area = calculate_bbox_area(bbox, rows, cols)
|
| 481 |
+
bbox, tail = cast(BoxType, tuple(np.clip(bbox[:4], 0, 1.0))), tuple(bbox[4:])
|
| 482 |
+
clipped_box_area = calculate_bbox_area(bbox, rows, cols)
|
| 483 |
+
|
| 484 |
+
# Calculate width and height of the clipped bounding box.
|
| 485 |
+
x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
|
| 486 |
+
clipped_width, clipped_height = x_max - x_min, y_max - y_min
|
| 487 |
+
|
| 488 |
+
if (
|
| 489 |
+
clipped_box_area != 0 # to ensure transformed_box_area!=0 and to handle min_area=0 or min_visibility=0
|
| 490 |
+
and clipped_box_area >= min_area
|
| 491 |
+
and clipped_box_area / transformed_box_area >= min_visibility
|
| 492 |
+
and clipped_width >= min_width
|
| 493 |
+
and clipped_height >= min_height
|
| 494 |
+
):
|
| 495 |
+
resulting_boxes.append(cast(BoxType, bbox + tail))
|
| 496 |
+
return resulting_boxes
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def union_of_bboxes(height: int, width: int, bboxes: Sequence[BoxType], erosion_rate: float = 0.0) -> BoxType:
|
| 500 |
+
"""Calculate union of bounding boxes.
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
height (float): Height of image or space.
|
| 504 |
+
width (float): Width of image or space.
|
| 505 |
+
bboxes (List[tuple]): List like bounding boxes. Format is `[(x_min, y_min, x_max, y_max)]`.
|
| 506 |
+
erosion_rate (float): How much each bounding box can be shrinked, useful for erosive cropping.
|
| 507 |
+
Set this in range [0, 1]. 0 will not be erosive at all, 1.0 can make any bbox to lose its volume.
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
|
| 511 |
+
|
| 512 |
+
"""
|
| 513 |
+
x1, y1 = width, height
|
| 514 |
+
x2, y2 = 0, 0
|
| 515 |
+
for bbox in bboxes:
|
| 516 |
+
x_min, y_min, x_max, y_max = bbox[:4]
|
| 517 |
+
w, h = x_max - x_min, y_max - y_min
|
| 518 |
+
lim_x1, lim_y1 = x_min + erosion_rate * w, y_min + erosion_rate * h
|
| 519 |
+
lim_x2, lim_y2 = x_max - erosion_rate * w, y_max - erosion_rate * h
|
| 520 |
+
x1, y1 = np.min([x1, lim_x1]), np.min([y1, lim_y1])
|
| 521 |
+
x2, y2 = np.max([x2, lim_x2]), np.max([y2, lim_y2])
|
| 522 |
+
return x1, y1, x2, y2
|
custom_albumentations/core/composition.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import typing
|
| 5 |
+
import warnings
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from .. import random_utils
|
| 11 |
+
from .bbox_utils import BboxParams, BboxProcessor
|
| 12 |
+
from .keypoints_utils import KeypointParams, KeypointsProcessor
|
| 13 |
+
from .serialization import (
|
| 14 |
+
SERIALIZABLE_REGISTRY,
|
| 15 |
+
Serializable,
|
| 16 |
+
get_shortest_class_fullname,
|
| 17 |
+
instantiate_nonserializable,
|
| 18 |
+
)
|
| 19 |
+
from .transforms_interface import BasicTransform
|
| 20 |
+
from .utils import format_args, get_shape
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
"BaseCompose",
|
| 24 |
+
"Compose",
|
| 25 |
+
"SomeOf",
|
| 26 |
+
"OneOf",
|
| 27 |
+
"OneOrOther",
|
| 28 |
+
"BboxParams",
|
| 29 |
+
"KeypointParams",
|
| 30 |
+
"ReplayCompose",
|
| 31 |
+
"Sequential",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
REPR_INDENT_STEP = 2
|
| 36 |
+
TransformType = typing.Union[BasicTransform, "BaseCompose"]
|
| 37 |
+
TransformsSeqType = typing.Sequence[TransformType]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) -> TransformsSeqType:
|
| 41 |
+
new_transforms: typing.List[TransformType] = []
|
| 42 |
+
for transform in transforms: # type: ignore
|
| 43 |
+
if isinstance(transform, BaseCompose):
|
| 44 |
+
new_transforms.extend(get_always_apply(transform))
|
| 45 |
+
elif transform.always_apply:
|
| 46 |
+
new_transforms.append(transform)
|
| 47 |
+
return new_transforms
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BaseCompose(Serializable):
|
| 51 |
+
def __init__(self, transforms: TransformsSeqType, p: float):
|
| 52 |
+
if isinstance(transforms, (BaseCompose, BasicTransform)):
|
| 53 |
+
warnings.warn(
|
| 54 |
+
"transforms is single transform, but a sequence is expected! Transform will be wrapped into list."
|
| 55 |
+
)
|
| 56 |
+
transforms = [transforms]
|
| 57 |
+
|
| 58 |
+
self.transforms = transforms
|
| 59 |
+
self.p = p
|
| 60 |
+
|
| 61 |
+
self.replay_mode = False
|
| 62 |
+
self.applied_in_replay = False
|
| 63 |
+
|
| 64 |
+
def __len__(self) -> int:
|
| 65 |
+
return len(self.transforms)
|
| 66 |
+
|
| 67 |
+
def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]:
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
def __getitem__(self, item: int) -> TransformType: # type: ignore
|
| 71 |
+
return self.transforms[item]
|
| 72 |
+
|
| 73 |
+
def __repr__(self) -> str:
|
| 74 |
+
return self.indented_repr()
|
| 75 |
+
|
| 76 |
+
def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
|
| 77 |
+
args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")}
|
| 78 |
+
repr_string = self.__class__.__name__ + "(["
|
| 79 |
+
for t in self.transforms:
|
| 80 |
+
repr_string += "\n"
|
| 81 |
+
if hasattr(t, "indented_repr"):
|
| 82 |
+
t_repr = t.indented_repr(indent + REPR_INDENT_STEP) # type: ignore
|
| 83 |
+
else:
|
| 84 |
+
t_repr = repr(t)
|
| 85 |
+
repr_string += " " * indent + t_repr + ","
|
| 86 |
+
repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + "], {args})".format(args=format_args(args))
|
| 87 |
+
return repr_string
|
| 88 |
+
|
| 89 |
+
@classmethod
|
| 90 |
+
def get_class_fullname(cls) -> str:
|
| 91 |
+
return get_shortest_class_fullname(cls)
|
| 92 |
+
|
| 93 |
+
@classmethod
|
| 94 |
+
def is_serializable(cls) -> bool:
|
| 95 |
+
return True
|
| 96 |
+
|
| 97 |
+
def _to_dict(self) -> typing.Dict[str, typing.Any]:
|
| 98 |
+
return {
|
| 99 |
+
"__class_fullname__": self.get_class_fullname(),
|
| 100 |
+
"p": self.p,
|
| 101 |
+
"transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212
|
| 102 |
+
}
|
| 103 |
+
|
| 104 |
+
def get_dict_with_id(self) -> typing.Dict[str, typing.Any]:
|
| 105 |
+
return {
|
| 106 |
+
"__class_fullname__": self.get_class_fullname(),
|
| 107 |
+
"id": id(self),
|
| 108 |
+
"params": None,
|
| 109 |
+
"transforms": [t.get_dict_with_id() for t in self.transforms],
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def add_targets(self, additional_targets: typing.Optional[typing.Dict[str, str]]) -> None:
|
| 113 |
+
if additional_targets:
|
| 114 |
+
for t in self.transforms:
|
| 115 |
+
t.add_targets(additional_targets)
|
| 116 |
+
|
| 117 |
+
def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
|
| 118 |
+
for t in self.transforms:
|
| 119 |
+
t.set_deterministic(flag, save_key)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class Compose(BaseCompose):
|
| 123 |
+
"""Compose transforms and handle all transformations regarding bounding boxes
|
| 124 |
+
|
| 125 |
+
Args:
|
| 126 |
+
transforms (list): list of transformations to compose.
|
| 127 |
+
bbox_params (BboxParams): Parameters for bounding boxes transforms
|
| 128 |
+
keypoint_params (KeypointParams): Parameters for keypoints transforms
|
| 129 |
+
additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'}
|
| 130 |
+
p (float): probability of applying all list of transforms. Default: 1.0.
|
| 131 |
+
is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you
|
| 132 |
+
would like to disable this check - pass False (do it only if you are sure in your data consistency).
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
transforms: TransformsSeqType,
|
| 138 |
+
bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None,
|
| 139 |
+
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
|
| 140 |
+
additional_targets: typing.Optional[typing.Dict[str, str]] = None,
|
| 141 |
+
p: float = 1.0,
|
| 142 |
+
is_check_shapes: bool = True,
|
| 143 |
+
):
|
| 144 |
+
super(Compose, self).__init__(transforms, p)
|
| 145 |
+
|
| 146 |
+
self.processors: typing.Dict[str, typing.Union[BboxProcessor, KeypointsProcessor]] = {}
|
| 147 |
+
if bbox_params:
|
| 148 |
+
if isinstance(bbox_params, dict):
|
| 149 |
+
b_params = BboxParams(**bbox_params)
|
| 150 |
+
elif isinstance(bbox_params, BboxParams):
|
| 151 |
+
b_params = bbox_params
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`")
|
| 154 |
+
self.processors["bboxes"] = BboxProcessor(b_params, additional_targets)
|
| 155 |
+
|
| 156 |
+
if keypoint_params:
|
| 157 |
+
if isinstance(keypoint_params, dict):
|
| 158 |
+
k_params = KeypointParams(**keypoint_params)
|
| 159 |
+
elif isinstance(keypoint_params, KeypointParams):
|
| 160 |
+
k_params = keypoint_params
|
| 161 |
+
else:
|
| 162 |
+
raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`")
|
| 163 |
+
self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets)
|
| 164 |
+
|
| 165 |
+
if additional_targets is None:
|
| 166 |
+
additional_targets = {}
|
| 167 |
+
|
| 168 |
+
self.additional_targets = additional_targets
|
| 169 |
+
|
| 170 |
+
for proc in self.processors.values():
|
| 171 |
+
proc.ensure_transforms_valid(self.transforms)
|
| 172 |
+
|
| 173 |
+
self.add_targets(additional_targets)
|
| 174 |
+
|
| 175 |
+
self.is_check_args = True
|
| 176 |
+
self._disable_check_args_for_transforms(self.transforms)
|
| 177 |
+
|
| 178 |
+
self.is_check_shapes = is_check_shapes
|
| 179 |
+
|
| 180 |
+
@staticmethod
|
| 181 |
+
def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None:
|
| 182 |
+
for transform in transforms:
|
| 183 |
+
if isinstance(transform, BaseCompose):
|
| 184 |
+
Compose._disable_check_args_for_transforms(transform.transforms)
|
| 185 |
+
if isinstance(transform, Compose):
|
| 186 |
+
transform._disable_check_args()
|
| 187 |
+
|
| 188 |
+
def _disable_check_args(self) -> None:
|
| 189 |
+
self.is_check_args = False
|
| 190 |
+
|
| 191 |
+
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
|
| 192 |
+
if args:
|
| 193 |
+
raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
|
| 194 |
+
if self.is_check_args:
|
| 195 |
+
self._check_args(**data)
|
| 196 |
+
assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type"
|
| 197 |
+
need_to_run = force_apply or random.random() < self.p
|
| 198 |
+
for p in self.processors.values():
|
| 199 |
+
p.ensure_data_valid(data)
|
| 200 |
+
transforms = self.transforms if need_to_run else get_always_apply(self.transforms)
|
| 201 |
+
|
| 202 |
+
check_each_transform = any(
|
| 203 |
+
getattr(item.params, "check_each_transform", False) for item in self.processors.values()
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
for p in self.processors.values():
|
| 207 |
+
p.preprocess(data)
|
| 208 |
+
|
| 209 |
+
for idx, t in enumerate(transforms):
|
| 210 |
+
data = t(**data)
|
| 211 |
+
|
| 212 |
+
if check_each_transform:
|
| 213 |
+
data = self._check_data_post_transform(data)
|
| 214 |
+
data = Compose._make_targets_contiguous(data) # ensure output targets are contiguous
|
| 215 |
+
|
| 216 |
+
for p in self.processors.values():
|
| 217 |
+
p.postprocess(data)
|
| 218 |
+
|
| 219 |
+
return data
|
| 220 |
+
|
| 221 |
+
def _check_data_post_transform(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
| 222 |
+
rows, cols = get_shape(data["image"])
|
| 223 |
+
|
| 224 |
+
for p in self.processors.values():
|
| 225 |
+
if not getattr(p.params, "check_each_transform", False):
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
for data_name in p.data_fields:
|
| 229 |
+
data[data_name] = p.filter(data[data_name], rows, cols)
|
| 230 |
+
return data
|
| 231 |
+
|
| 232 |
+
def _to_dict(self) -> typing.Dict[str, typing.Any]:
|
| 233 |
+
dictionary = super(Compose, self)._to_dict()
|
| 234 |
+
bbox_processor = self.processors.get("bboxes")
|
| 235 |
+
keypoints_processor = self.processors.get("keypoints")
|
| 236 |
+
dictionary.update(
|
| 237 |
+
{
|
| 238 |
+
"bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212
|
| 239 |
+
"keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212
|
| 240 |
+
if keypoints_processor
|
| 241 |
+
else None,
|
| 242 |
+
"additional_targets": self.additional_targets,
|
| 243 |
+
"is_check_shapes": self.is_check_shapes,
|
| 244 |
+
}
|
| 245 |
+
)
|
| 246 |
+
return dictionary
|
| 247 |
+
|
| 248 |
+
def get_dict_with_id(self) -> typing.Dict[str, typing.Any]:
|
| 249 |
+
dictionary = super().get_dict_with_id()
|
| 250 |
+
bbox_processor = self.processors.get("bboxes")
|
| 251 |
+
keypoints_processor = self.processors.get("keypoints")
|
| 252 |
+
dictionary.update(
|
| 253 |
+
{
|
| 254 |
+
"bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212
|
| 255 |
+
"keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212
|
| 256 |
+
if keypoints_processor
|
| 257 |
+
else None,
|
| 258 |
+
"additional_targets": self.additional_targets,
|
| 259 |
+
"params": None,
|
| 260 |
+
"is_check_shapes": self.is_check_shapes,
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
return dictionary
|
| 264 |
+
|
| 265 |
+
def _check_args(self, **kwargs) -> None:
|
| 266 |
+
checked_single = ["image", "mask"]
|
| 267 |
+
checked_multi = ["masks"]
|
| 268 |
+
check_bbox_param = ["bboxes"]
|
| 269 |
+
# ["bboxes", "keypoints"] could be almost any type, no need to check them
|
| 270 |
+
shapes = []
|
| 271 |
+
for data_name, data in kwargs.items():
|
| 272 |
+
internal_data_name = self.additional_targets.get(data_name, data_name)
|
| 273 |
+
if internal_data_name in checked_single:
|
| 274 |
+
if not isinstance(data, np.ndarray):
|
| 275 |
+
raise TypeError("{} must be numpy array type".format(data_name))
|
| 276 |
+
shapes.append(data.shape[:2])
|
| 277 |
+
if internal_data_name in checked_multi:
|
| 278 |
+
if data is not None and len(data):
|
| 279 |
+
if not isinstance(data[0], np.ndarray):
|
| 280 |
+
raise TypeError("{} must be list of numpy arrays".format(data_name))
|
| 281 |
+
shapes.append(data[0].shape[:2])
|
| 282 |
+
if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None:
|
| 283 |
+
raise ValueError("bbox_params must be specified for bbox transformations")
|
| 284 |
+
|
| 285 |
+
if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
|
| 286 |
+
raise ValueError(
|
| 287 |
+
"Height and Width of image, mask or masks should be equal. You can disable shapes check "
|
| 288 |
+
"by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure "
|
| 289 |
+
"about your data consistency)."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
@staticmethod
|
| 293 |
+
def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
|
| 294 |
+
result = {}
|
| 295 |
+
for key, value in data.items():
|
| 296 |
+
if isinstance(value, np.ndarray):
|
| 297 |
+
value = np.ascontiguousarray(value)
|
| 298 |
+
result[key] = value
|
| 299 |
+
return result
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
class OneOf(BaseCompose):
|
| 303 |
+
"""Select one of transforms to apply. Selected transform will be called with `force_apply=True`.
|
| 304 |
+
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
transforms (list): list of transformations to compose.
|
| 308 |
+
p (float): probability of applying selected transform. Default: 0.5.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
|
| 312 |
+
super(OneOf, self).__init__(transforms, p)
|
| 313 |
+
transforms_ps = [t.p for t in self.transforms]
|
| 314 |
+
s = sum(transforms_ps)
|
| 315 |
+
self.transforms_ps = [t / s for t in transforms_ps]
|
| 316 |
+
|
| 317 |
+
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
|
| 318 |
+
if self.replay_mode:
|
| 319 |
+
for t in self.transforms:
|
| 320 |
+
data = t(**data)
|
| 321 |
+
return data
|
| 322 |
+
|
| 323 |
+
if self.transforms_ps and (force_apply or random.random() < self.p):
|
| 324 |
+
idx: int = random_utils.choice(len(self.transforms), p=self.transforms_ps)
|
| 325 |
+
t = self.transforms[idx]
|
| 326 |
+
data = t(force_apply=True, **data)
|
| 327 |
+
return data
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
class SomeOf(BaseCompose):
|
| 331 |
+
"""Select N transforms to apply. Selected transforms will be called with `force_apply=True`.
|
| 332 |
+
Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
|
| 333 |
+
|
| 334 |
+
Args:
|
| 335 |
+
transforms (list): list of transformations to compose.
|
| 336 |
+
n (int): number of transforms to apply.
|
| 337 |
+
replace (bool): Whether the sampled transforms are with or without replacement. Default: True.
|
| 338 |
+
p (float): probability of applying selected transform. Default: 1.
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1):
|
| 342 |
+
super(SomeOf, self).__init__(transforms, p)
|
| 343 |
+
self.n = n
|
| 344 |
+
self.replace = replace
|
| 345 |
+
transforms_ps = [t.p for t in self.transforms]
|
| 346 |
+
s = sum(transforms_ps)
|
| 347 |
+
self.transforms_ps = [t / s for t in transforms_ps]
|
| 348 |
+
|
| 349 |
+
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
|
| 350 |
+
if self.replay_mode:
|
| 351 |
+
for t in self.transforms:
|
| 352 |
+
data = t(**data)
|
| 353 |
+
return data
|
| 354 |
+
|
| 355 |
+
if self.transforms_ps and (force_apply or random.random() < self.p):
|
| 356 |
+
idx = random_utils.choice(len(self.transforms), size=self.n, replace=self.replace, p=self.transforms_ps)
|
| 357 |
+
for i in idx: # type: ignore
|
| 358 |
+
t = self.transforms[i]
|
| 359 |
+
data = t(force_apply=True, **data)
|
| 360 |
+
return data
|
| 361 |
+
|
| 362 |
+
def _to_dict(self) -> typing.Dict[str, typing.Any]:
|
| 363 |
+
dictionary = super(SomeOf, self)._to_dict()
|
| 364 |
+
dictionary.update({"n": self.n, "replace": self.replace})
|
| 365 |
+
return dictionary
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
class OneOrOther(BaseCompose):
|
| 369 |
+
"""Select one or another transform to apply. Selected transform will be called with `force_apply=True`."""
|
| 370 |
+
|
| 371 |
+
def __init__(
|
| 372 |
+
self,
|
| 373 |
+
first: typing.Optional[TransformType] = None,
|
| 374 |
+
second: typing.Optional[TransformType] = None,
|
| 375 |
+
transforms: typing.Optional[TransformsSeqType] = None,
|
| 376 |
+
p: float = 0.5,
|
| 377 |
+
):
|
| 378 |
+
if transforms is None:
|
| 379 |
+
if first is None or second is None:
|
| 380 |
+
raise ValueError("You must set both first and second or set transforms argument.")
|
| 381 |
+
transforms = [first, second]
|
| 382 |
+
super(OneOrOther, self).__init__(transforms, p)
|
| 383 |
+
if len(self.transforms) != 2:
|
| 384 |
+
warnings.warn("Length of transforms is not equal to 2.")
|
| 385 |
+
|
| 386 |
+
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
|
| 387 |
+
if self.replay_mode:
|
| 388 |
+
for t in self.transforms:
|
| 389 |
+
data = t(**data)
|
| 390 |
+
return data
|
| 391 |
+
|
| 392 |
+
if random.random() < self.p:
|
| 393 |
+
return self.transforms[0](force_apply=True, **data)
|
| 394 |
+
|
| 395 |
+
return self.transforms[-1](force_apply=True, **data)
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
class PerChannel(BaseCompose):
|
| 399 |
+
"""Apply transformations per-channel
|
| 400 |
+
|
| 401 |
+
Args:
|
| 402 |
+
transforms (list): list of transformations to compose.
|
| 403 |
+
channels (sequence): channels to apply the transform to. Pass None to apply to all.
|
| 404 |
+
Default: None (apply to all)
|
| 405 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 406 |
+
"""
|
| 407 |
+
|
| 408 |
+
def __init__(
|
| 409 |
+
self, transforms: TransformsSeqType, channels: typing.Optional[typing.Sequence[int]] = None, p: float = 0.5
|
| 410 |
+
):
|
| 411 |
+
super(PerChannel, self).__init__(transforms, p)
|
| 412 |
+
self.channels = channels
|
| 413 |
+
|
| 414 |
+
def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
|
| 415 |
+
if force_apply or random.random() < self.p:
|
| 416 |
+
image = data["image"]
|
| 417 |
+
|
| 418 |
+
# Expand mono images to have a single channel
|
| 419 |
+
if len(image.shape) == 2:
|
| 420 |
+
image = np.expand_dims(image, -1)
|
| 421 |
+
|
| 422 |
+
if self.channels is None:
|
| 423 |
+
self.channels = range(image.shape[2])
|
| 424 |
+
|
| 425 |
+
for c in self.channels:
|
| 426 |
+
for t in self.transforms:
|
| 427 |
+
image[:, :, c] = t(image=image[:, :, c])["image"]
|
| 428 |
+
|
| 429 |
+
data["image"] = image
|
| 430 |
+
|
| 431 |
+
return data
|
| 432 |
+
|
| 433 |
+
|
| 434 |
+
class ReplayCompose(Compose):
|
| 435 |
+
def __init__(
|
| 436 |
+
self,
|
| 437 |
+
transforms: TransformsSeqType,
|
| 438 |
+
bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None,
|
| 439 |
+
keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
|
| 440 |
+
additional_targets: typing.Optional[typing.Dict[str, str]] = None,
|
| 441 |
+
p: float = 1.0,
|
| 442 |
+
is_check_shapes: bool = True,
|
| 443 |
+
save_key: str = "replay",
|
| 444 |
+
):
|
| 445 |
+
super(ReplayCompose, self).__init__(
|
| 446 |
+
transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes
|
| 447 |
+
)
|
| 448 |
+
self.set_deterministic(True, save_key=save_key)
|
| 449 |
+
self.save_key = save_key
|
| 450 |
+
|
| 451 |
+
def __call__(self, *args, force_apply: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
|
| 452 |
+
kwargs[self.save_key] = defaultdict(dict)
|
| 453 |
+
result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs)
|
| 454 |
+
serialized = self.get_dict_with_id()
|
| 455 |
+
self.fill_with_params(serialized, result[self.save_key])
|
| 456 |
+
self.fill_applied(serialized)
|
| 457 |
+
result[self.save_key] = serialized
|
| 458 |
+
return result
|
| 459 |
+
|
| 460 |
+
@staticmethod
|
| 461 |
+
def replay(saved_augmentations: typing.Dict[str, typing.Any], **kwargs) -> typing.Dict[str, typing.Any]:
|
| 462 |
+
augs = ReplayCompose._restore_for_replay(saved_augmentations)
|
| 463 |
+
return augs(force_apply=True, **kwargs)
|
| 464 |
+
|
| 465 |
+
@staticmethod
|
| 466 |
+
def _restore_for_replay(
|
| 467 |
+
transform_dict: typing.Dict[str, typing.Any], lambda_transforms: typing.Optional[dict] = None
|
| 468 |
+
) -> TransformType:
|
| 469 |
+
"""
|
| 470 |
+
Args:
|
| 471 |
+
lambda_transforms (dict): A dictionary that contains lambda transforms, that
|
| 472 |
+
is instances of the Lambda class.
|
| 473 |
+
This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
|
| 474 |
+
in that dictionary should be named same as `name` arguments in respective lambda transforms from
|
| 475 |
+
a serialized pipeline.
|
| 476 |
+
"""
|
| 477 |
+
applied = transform_dict["applied"]
|
| 478 |
+
params = transform_dict["params"]
|
| 479 |
+
lmbd = instantiate_nonserializable(transform_dict, lambda_transforms)
|
| 480 |
+
if lmbd:
|
| 481 |
+
transform = lmbd
|
| 482 |
+
else:
|
| 483 |
+
name = transform_dict["__class_fullname__"]
|
| 484 |
+
args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]}
|
| 485 |
+
cls = SERIALIZABLE_REGISTRY[name]
|
| 486 |
+
if "transforms" in args:
|
| 487 |
+
args["transforms"] = [
|
| 488 |
+
ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
|
| 489 |
+
for t in args["transforms"]
|
| 490 |
+
]
|
| 491 |
+
transform = cls(**args)
|
| 492 |
+
|
| 493 |
+
transform = typing.cast(BasicTransform, transform)
|
| 494 |
+
if isinstance(transform, BasicTransform):
|
| 495 |
+
transform.params = params
|
| 496 |
+
transform.replay_mode = True
|
| 497 |
+
transform.applied_in_replay = applied
|
| 498 |
+
return transform
|
| 499 |
+
|
| 500 |
+
def fill_with_params(self, serialized: dict, all_params: dict) -> None:
|
| 501 |
+
params = all_params.get(serialized.get("id"))
|
| 502 |
+
serialized["params"] = params
|
| 503 |
+
del serialized["id"]
|
| 504 |
+
for transform in serialized.get("transforms", []):
|
| 505 |
+
self.fill_with_params(transform, all_params)
|
| 506 |
+
|
| 507 |
+
def fill_applied(self, serialized: typing.Dict[str, typing.Any]) -> bool:
|
| 508 |
+
if "transforms" in serialized:
|
| 509 |
+
applied = [self.fill_applied(t) for t in serialized["transforms"]]
|
| 510 |
+
serialized["applied"] = any(applied)
|
| 511 |
+
else:
|
| 512 |
+
serialized["applied"] = serialized.get("params") is not None
|
| 513 |
+
return serialized["applied"]
|
| 514 |
+
|
| 515 |
+
def _to_dict(self) -> typing.Dict[str, typing.Any]:
|
| 516 |
+
dictionary = super(ReplayCompose, self)._to_dict()
|
| 517 |
+
dictionary.update({"save_key": self.save_key})
|
| 518 |
+
return dictionary
|
| 519 |
+
|
| 520 |
+
|
| 521 |
+
class Sequential(BaseCompose):
|
| 522 |
+
"""Sequentially applies all transforms to targets.
|
| 523 |
+
|
| 524 |
+
Note:
|
| 525 |
+
This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose`
|
| 526 |
+
the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to
|
| 527 |
+
create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly
|
| 528 |
+
chose sequence to input data (see the `Example` section for an example definition of such pipeline).
|
| 529 |
+
|
| 530 |
+
Example:
|
| 531 |
+
>>> import custom_albumentations as albumentations as A
|
| 532 |
+
>>> transform = A.Compose([
|
| 533 |
+
>>> A.OneOf([
|
| 534 |
+
>>> A.Sequential([
|
| 535 |
+
>>> A.HorizontalFlip(p=0.5),
|
| 536 |
+
>>> A.ShiftScaleRotate(p=0.5),
|
| 537 |
+
>>> ]),
|
| 538 |
+
>>> A.Sequential([
|
| 539 |
+
>>> A.VerticalFlip(p=0.5),
|
| 540 |
+
>>> A.RandomBrightnessContrast(p=0.5),
|
| 541 |
+
>>> ]),
|
| 542 |
+
>>> ], p=1)
|
| 543 |
+
>>> ])
|
| 544 |
+
"""
|
| 545 |
+
|
| 546 |
+
def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
|
| 547 |
+
super().__init__(transforms, p)
|
| 548 |
+
|
| 549 |
+
def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]:
|
| 550 |
+
for t in self.transforms:
|
| 551 |
+
data = t(**data)
|
| 552 |
+
return data
|
custom_albumentations/core/keypoints_utils.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
import typing
|
| 5 |
+
import warnings
|
| 6 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
| 7 |
+
|
| 8 |
+
from .utils import DataProcessor, Params
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"angle_to_2pi_range",
|
| 12 |
+
"check_keypoints",
|
| 13 |
+
"convert_keypoints_from_albumentations",
|
| 14 |
+
"convert_keypoints_to_albumentations",
|
| 15 |
+
"filter_keypoints",
|
| 16 |
+
"KeypointsProcessor",
|
| 17 |
+
"KeypointParams",
|
| 18 |
+
]
|
| 19 |
+
|
| 20 |
+
keypoint_formats = {"xy", "yx", "xya", "xys", "xyas", "xysa"}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def angle_to_2pi_range(angle: float) -> float:
|
| 24 |
+
two_pi = 2 * math.pi
|
| 25 |
+
return angle % two_pi
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class KeypointParams(Params):
|
| 29 |
+
"""
|
| 30 |
+
Parameters of keypoints
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
format (str): format of keypoints. Should be 'xy', 'yx', 'xya', 'xys', 'xyas', 'xysa'.
|
| 34 |
+
|
| 35 |
+
x - X coordinate,
|
| 36 |
+
|
| 37 |
+
y - Y coordinate
|
| 38 |
+
|
| 39 |
+
s - Keypoint scale
|
| 40 |
+
|
| 41 |
+
a - Keypoint orientation in radians or degrees (depending on KeypointParams.angle_in_degrees)
|
| 42 |
+
label_fields (list): list of fields that are joined with keypoints, e.g labels.
|
| 43 |
+
Should be same type as keypoints.
|
| 44 |
+
remove_invisible (bool): to remove invisible points after transform or not
|
| 45 |
+
angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints
|
| 46 |
+
check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform.
|
| 47 |
+
Default: `True`
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
format: str, # skipcq: PYL-W0622
|
| 53 |
+
label_fields: Optional[Sequence[str]] = None,
|
| 54 |
+
remove_invisible: bool = True,
|
| 55 |
+
angle_in_degrees: bool = True,
|
| 56 |
+
check_each_transform: bool = True,
|
| 57 |
+
):
|
| 58 |
+
super(KeypointParams, self).__init__(format, label_fields)
|
| 59 |
+
self.remove_invisible = remove_invisible
|
| 60 |
+
self.angle_in_degrees = angle_in_degrees
|
| 61 |
+
self.check_each_transform = check_each_transform
|
| 62 |
+
|
| 63 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 64 |
+
data = super(KeypointParams, self)._to_dict()
|
| 65 |
+
data.update(
|
| 66 |
+
{
|
| 67 |
+
"remove_invisible": self.remove_invisible,
|
| 68 |
+
"angle_in_degrees": self.angle_in_degrees,
|
| 69 |
+
"check_each_transform": self.check_each_transform,
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
return data
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def is_serializable(cls) -> bool:
|
| 76 |
+
return True
|
| 77 |
+
|
| 78 |
+
@classmethod
|
| 79 |
+
def get_class_fullname(cls) -> str:
|
| 80 |
+
return "KeypointParams"
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class KeypointsProcessor(DataProcessor):
|
| 84 |
+
def __init__(self, params: KeypointParams, additional_targets: Optional[Dict[str, str]] = None):
|
| 85 |
+
super().__init__(params, additional_targets)
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def default_data_name(self) -> str:
|
| 89 |
+
return "keypoints"
|
| 90 |
+
|
| 91 |
+
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
|
| 92 |
+
if self.params.label_fields:
|
| 93 |
+
if not all(i in data.keys() for i in self.params.label_fields):
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"Your 'label_fields' are not valid - them must have same names as params in "
|
| 96 |
+
"'keypoint_params' dict"
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
|
| 100 |
+
# IAA-based augmentations supports only transformation of xy keypoints.
|
| 101 |
+
# If your keypoints formats is other than 'xy' we emit warning to let user
|
| 102 |
+
# be aware that angle and size will not be modified.
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
from custom_albumentations.imgaug.transforms import DualIAATransform
|
| 106 |
+
except ImportError:
|
| 107 |
+
# imgaug is not installed so we skip imgaug checks.
|
| 108 |
+
return
|
| 109 |
+
|
| 110 |
+
if self.params.format is not None and self.params.format != "xy":
|
| 111 |
+
for transform in transforms:
|
| 112 |
+
if isinstance(transform, DualIAATransform):
|
| 113 |
+
warnings.warn(
|
| 114 |
+
"{} transformation supports only 'xy' keypoints "
|
| 115 |
+
"augmentation. You have '{}' keypoints format. Scale "
|
| 116 |
+
"and angle WILL NOT BE transformed.".format(transform.__class__.__name__, self.params.format)
|
| 117 |
+
)
|
| 118 |
+
break
|
| 119 |
+
|
| 120 |
+
def filter(self, data: Sequence[Sequence], rows: int, cols: int) -> Sequence[Sequence]:
|
| 121 |
+
self.params: KeypointParams
|
| 122 |
+
return filter_keypoints(data, rows, cols, remove_invisible=self.params.remove_invisible)
|
| 123 |
+
|
| 124 |
+
def check(self, data: Sequence[Sequence], rows: int, cols: int) -> None:
|
| 125 |
+
check_keypoints(data, rows, cols)
|
| 126 |
+
|
| 127 |
+
def convert_from_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]:
|
| 128 |
+
params = self.params
|
| 129 |
+
return convert_keypoints_from_albumentations(
|
| 130 |
+
data,
|
| 131 |
+
params.format,
|
| 132 |
+
rows,
|
| 133 |
+
cols,
|
| 134 |
+
check_validity=params.remove_invisible,
|
| 135 |
+
angle_in_degrees=params.angle_in_degrees,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
def convert_to_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]:
|
| 139 |
+
params = self.params
|
| 140 |
+
return convert_keypoints_to_albumentations(
|
| 141 |
+
data,
|
| 142 |
+
params.format,
|
| 143 |
+
rows,
|
| 144 |
+
cols,
|
| 145 |
+
check_validity=params.remove_invisible,
|
| 146 |
+
angle_in_degrees=params.angle_in_degrees,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def check_keypoint(kp: Sequence, rows: int, cols: int) -> None:
|
| 151 |
+
"""Check if keypoint coordinates are less than image shapes"""
|
| 152 |
+
for name, value, size in zip(["x", "y"], kp[:2], [cols, rows]):
|
| 153 |
+
if not 0 <= value < size:
|
| 154 |
+
raise ValueError(
|
| 155 |
+
"Expected {name} for keypoint {kp} "
|
| 156 |
+
"to be in the range [0.0, {size}], got {value}.".format(kp=kp, name=name, value=value, size=size)
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
angle = kp[2]
|
| 160 |
+
if not (0 <= angle < 2 * math.pi):
|
| 161 |
+
raise ValueError("Keypoint angle must be in range [0, 2 * PI). Got: {angle}".format(angle=angle))
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def check_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int) -> None:
|
| 165 |
+
"""Check if keypoints boundaries are less than image shapes"""
|
| 166 |
+
for kp in keypoints:
|
| 167 |
+
check_keypoint(kp, rows, cols)
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def filter_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int, remove_invisible: bool) -> Sequence[Sequence]:
|
| 171 |
+
if not remove_invisible:
|
| 172 |
+
return keypoints
|
| 173 |
+
|
| 174 |
+
resulting_keypoints = []
|
| 175 |
+
for kp in keypoints:
|
| 176 |
+
x, y = kp[:2]
|
| 177 |
+
if x < 0 or x >= cols:
|
| 178 |
+
continue
|
| 179 |
+
if y < 0 or y >= rows:
|
| 180 |
+
continue
|
| 181 |
+
resulting_keypoints.append(kp)
|
| 182 |
+
return resulting_keypoints
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def convert_keypoint_to_albumentations(
|
| 186 |
+
keypoint: Sequence,
|
| 187 |
+
source_format: str,
|
| 188 |
+
rows: int,
|
| 189 |
+
cols: int,
|
| 190 |
+
check_validity: bool = False,
|
| 191 |
+
angle_in_degrees: bool = True,
|
| 192 |
+
) -> Tuple:
|
| 193 |
+
if source_format not in keypoint_formats:
|
| 194 |
+
raise ValueError("Unknown target_format {}. Supported formats are: {}".format(source_format, keypoint_formats))
|
| 195 |
+
|
| 196 |
+
if source_format == "xy":
|
| 197 |
+
(x, y), tail = keypoint[:2], tuple(keypoint[2:])
|
| 198 |
+
a, s = 0.0, 0.0
|
| 199 |
+
elif source_format == "yx":
|
| 200 |
+
(y, x), tail = keypoint[:2], tuple(keypoint[2:])
|
| 201 |
+
a, s = 0.0, 0.0
|
| 202 |
+
elif source_format == "xya":
|
| 203 |
+
(x, y, a), tail = keypoint[:3], tuple(keypoint[3:])
|
| 204 |
+
s = 0.0
|
| 205 |
+
elif source_format == "xys":
|
| 206 |
+
(x, y, s), tail = keypoint[:3], tuple(keypoint[3:])
|
| 207 |
+
a = 0.0
|
| 208 |
+
elif source_format == "xyas":
|
| 209 |
+
(x, y, a, s), tail = keypoint[:4], tuple(keypoint[4:])
|
| 210 |
+
elif source_format == "xysa":
|
| 211 |
+
(x, y, s, a), tail = keypoint[:4], tuple(keypoint[4:])
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError(f"Unsupported source format. Got {source_format}")
|
| 214 |
+
|
| 215 |
+
if angle_in_degrees:
|
| 216 |
+
a = math.radians(a)
|
| 217 |
+
|
| 218 |
+
keypoint = (x, y, angle_to_2pi_range(a), s) + tail
|
| 219 |
+
if check_validity:
|
| 220 |
+
check_keypoint(keypoint, rows, cols)
|
| 221 |
+
return keypoint
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
def convert_keypoint_from_albumentations(
|
| 225 |
+
keypoint: Sequence,
|
| 226 |
+
target_format: str,
|
| 227 |
+
rows: int,
|
| 228 |
+
cols: int,
|
| 229 |
+
check_validity: bool = False,
|
| 230 |
+
angle_in_degrees: bool = True,
|
| 231 |
+
) -> Tuple:
|
| 232 |
+
if target_format not in keypoint_formats:
|
| 233 |
+
raise ValueError("Unknown target_format {}. Supported formats are: {}".format(target_format, keypoint_formats))
|
| 234 |
+
|
| 235 |
+
(x, y, angle, scale), tail = keypoint[:4], tuple(keypoint[4:])
|
| 236 |
+
angle = angle_to_2pi_range(angle)
|
| 237 |
+
if check_validity:
|
| 238 |
+
check_keypoint((x, y, angle, scale), rows, cols)
|
| 239 |
+
if angle_in_degrees:
|
| 240 |
+
angle = math.degrees(angle)
|
| 241 |
+
|
| 242 |
+
kp: Tuple
|
| 243 |
+
if target_format == "xy":
|
| 244 |
+
kp = (x, y)
|
| 245 |
+
elif target_format == "yx":
|
| 246 |
+
kp = (y, x)
|
| 247 |
+
elif target_format == "xya":
|
| 248 |
+
kp = (x, y, angle)
|
| 249 |
+
elif target_format == "xys":
|
| 250 |
+
kp = (x, y, scale)
|
| 251 |
+
elif target_format == "xyas":
|
| 252 |
+
kp = (x, y, angle, scale)
|
| 253 |
+
elif target_format == "xysa":
|
| 254 |
+
kp = (x, y, scale, angle)
|
| 255 |
+
else:
|
| 256 |
+
raise ValueError(f"Invalid target format. Got: {target_format}")
|
| 257 |
+
|
| 258 |
+
return kp + tail
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def convert_keypoints_to_albumentations(
|
| 262 |
+
keypoints: Sequence[Sequence],
|
| 263 |
+
source_format: str,
|
| 264 |
+
rows: int,
|
| 265 |
+
cols: int,
|
| 266 |
+
check_validity: bool = False,
|
| 267 |
+
angle_in_degrees: bool = True,
|
| 268 |
+
) -> List[Tuple]:
|
| 269 |
+
return [
|
| 270 |
+
convert_keypoint_to_albumentations(kp, source_format, rows, cols, check_validity, angle_in_degrees)
|
| 271 |
+
for kp in keypoints
|
| 272 |
+
]
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def convert_keypoints_from_albumentations(
|
| 276 |
+
keypoints: Sequence[Sequence],
|
| 277 |
+
target_format: str,
|
| 278 |
+
rows: int,
|
| 279 |
+
cols: int,
|
| 280 |
+
check_validity: bool = False,
|
| 281 |
+
angle_in_degrees: bool = True,
|
| 282 |
+
) -> List[Tuple]:
|
| 283 |
+
return [
|
| 284 |
+
convert_keypoint_from_albumentations(kp, target_format, rows, cols, check_validity, angle_in_degrees)
|
| 285 |
+
for kp in keypoints
|
| 286 |
+
]
|
custom_albumentations/core/serialization.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import typing
|
| 5 |
+
import warnings
|
| 6 |
+
from abc import ABC, ABCMeta, abstractmethod
|
| 7 |
+
from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import yaml
|
| 11 |
+
|
| 12 |
+
yaml_available = True
|
| 13 |
+
except ImportError:
|
| 14 |
+
yaml_available = False
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from custom_albumentations import __version__
|
| 18 |
+
|
| 19 |
+
__all__ = ["to_dict", "from_dict", "save", "load"]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
|
| 23 |
+
NON_SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def shorten_class_name(class_fullname: str) -> str:
|
| 27 |
+
splitted = class_fullname.split(".")
|
| 28 |
+
if len(splitted) == 1:
|
| 29 |
+
return class_fullname
|
| 30 |
+
top_module, *_, class_name = splitted
|
| 31 |
+
if top_module == "albumentations":
|
| 32 |
+
return class_name
|
| 33 |
+
return class_fullname
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def get_shortest_class_fullname(cls: Type) -> str:
|
| 37 |
+
class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls)
|
| 38 |
+
return shorten_class_name(class_fullname)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class SerializableMeta(ABCMeta):
|
| 42 |
+
"""
|
| 43 |
+
A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
|
| 44 |
+
so they can be found later while deserializing transformation pipeline using classes full names.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def __new__(mcs, name: str, bases: Tuple[type, ...], *args, **kwargs) -> "SerializableMeta":
|
| 48 |
+
cls_obj = super().__new__(mcs, name, bases, *args, **kwargs)
|
| 49 |
+
if name != "Serializable" and ABC not in bases:
|
| 50 |
+
if cls_obj.is_serializable():
|
| 51 |
+
SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
|
| 52 |
+
else:
|
| 53 |
+
NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
|
| 54 |
+
return cls_obj
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def is_serializable(mcs) -> bool:
|
| 58 |
+
return False
|
| 59 |
+
|
| 60 |
+
@classmethod
|
| 61 |
+
def get_class_fullname(mcs) -> str:
|
| 62 |
+
return get_shortest_class_fullname(mcs)
|
| 63 |
+
|
| 64 |
+
@classmethod
|
| 65 |
+
def _to_dict(mcs) -> Dict[str, Any]:
|
| 66 |
+
return {}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Serializable(metaclass=SerializableMeta):
|
| 70 |
+
@classmethod
|
| 71 |
+
@abstractmethod
|
| 72 |
+
def is_serializable(cls) -> bool:
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
@abstractmethod
|
| 77 |
+
def get_class_fullname(cls) -> str:
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
|
| 80 |
+
@abstractmethod
|
| 81 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 82 |
+
raise NotImplementedError
|
| 83 |
+
|
| 84 |
+
def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
|
| 85 |
+
"""
|
| 86 |
+
Take a transform pipeline and convert it to a serializable representation that uses only standard
|
| 87 |
+
python data types: dictionaries, lists, strings, integers, and floats.
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
self: A transform that should be serialized. If the transform doesn't implement the `to_dict`
|
| 91 |
+
method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
|
| 92 |
+
If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
|
| 93 |
+
but no transform parameters will be serialized.
|
| 94 |
+
on_not_implemented_error (str): `raise` or `warn`.
|
| 95 |
+
"""
|
| 96 |
+
if on_not_implemented_error not in {"raise", "warn"}:
|
| 97 |
+
raise ValueError(
|
| 98 |
+
"Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format(
|
| 99 |
+
on_not_implemented_error
|
| 100 |
+
)
|
| 101 |
+
)
|
| 102 |
+
try:
|
| 103 |
+
transform_dict = self._to_dict()
|
| 104 |
+
except NotImplementedError as e:
|
| 105 |
+
if on_not_implemented_error == "raise":
|
| 106 |
+
raise e
|
| 107 |
+
|
| 108 |
+
transform_dict = {}
|
| 109 |
+
warnings.warn(
|
| 110 |
+
"Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
|
| 111 |
+
"Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
|
| 112 |
+
"method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
|
| 113 |
+
)
|
| 114 |
+
return {"__version__": __version__, "transform": transform_dict}
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
|
| 118 |
+
"""
|
| 119 |
+
Take a transform pipeline and convert it to a serializable representation that uses only standard
|
| 120 |
+
python data types: dictionaries, lists, strings, integers, and floats.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
transform: A transform that should be serialized. If the transform doesn't implement the `to_dict`
|
| 124 |
+
method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
|
| 125 |
+
If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
|
| 126 |
+
but no transform parameters will be serialized.
|
| 127 |
+
on_not_implemented_error (str): `raise` or `warn`.
|
| 128 |
+
"""
|
| 129 |
+
return transform.to_dict(on_not_implemented_error)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def instantiate_nonserializable(
|
| 133 |
+
transform: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None
|
| 134 |
+
) -> Optional[Serializable]:
|
| 135 |
+
if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
|
| 136 |
+
name = transform["__name__"]
|
| 137 |
+
if nonserializable is None:
|
| 138 |
+
raise ValueError(
|
| 139 |
+
"To deserialize a non-serializable transform with name {name} you need to pass a dict with"
|
| 140 |
+
"this transform as the `lambda_transforms` argument".format(name=name)
|
| 141 |
+
)
|
| 142 |
+
result_transform = nonserializable.get(name)
|
| 143 |
+
if transform is None:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
"Non-serializable transform with {name} was not found in `nonserializable`".format(name=name)
|
| 146 |
+
)
|
| 147 |
+
return result_transform
|
| 148 |
+
return None
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def from_dict(
|
| 152 |
+
transform_dict: Dict[str, Any],
|
| 153 |
+
nonserializable: Optional[Dict[str, Any]] = None,
|
| 154 |
+
lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
|
| 155 |
+
) -> Optional[Serializable]:
|
| 156 |
+
"""
|
| 157 |
+
Args:
|
| 158 |
+
transform_dict (dict): A dictionary with serialized transform pipeline.
|
| 159 |
+
nonserializable (dict): A dictionary that contains non-serializable transforms.
|
| 160 |
+
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
|
| 161 |
+
Keys in that dictionary should be named same as `name` arguments in respective transforms from
|
| 162 |
+
a serialized pipeline.
|
| 163 |
+
lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
|
| 164 |
+
"""
|
| 165 |
+
if lambda_transforms != "deprecated":
|
| 166 |
+
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
|
| 167 |
+
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
|
| 168 |
+
|
| 169 |
+
register_additional_transforms()
|
| 170 |
+
transform = transform_dict["transform"]
|
| 171 |
+
lmbd = instantiate_nonserializable(transform, nonserializable)
|
| 172 |
+
if lmbd:
|
| 173 |
+
return lmbd
|
| 174 |
+
name = transform["__class_fullname__"]
|
| 175 |
+
args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
|
| 176 |
+
cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
|
| 177 |
+
if "transforms" in args:
|
| 178 |
+
args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
|
| 179 |
+
return cls(**args)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def check_data_format(data_format: str) -> None:
|
| 183 |
+
if data_format not in {"json", "yaml"}:
|
| 184 |
+
raise ValueError("Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format(data_format))
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def save(
|
| 188 |
+
transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise"
|
| 189 |
+
) -> None:
|
| 190 |
+
"""
|
| 191 |
+
Take a transform pipeline, serialize it and save a serialized version to a file
|
| 192 |
+
using either json or yaml format.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
transform (obj): Transform to serialize.
|
| 196 |
+
filepath (str): Filepath to write to.
|
| 197 |
+
data_format (str): Serialization format. Should be either `json` or 'yaml'.
|
| 198 |
+
on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement
|
| 199 |
+
the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be
|
| 200 |
+
ignored and no transform arguments will be saved.
|
| 201 |
+
"""
|
| 202 |
+
check_data_format(data_format)
|
| 203 |
+
transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
|
| 204 |
+
dump_fn = json.dump if data_format == "json" else yaml.safe_dump
|
| 205 |
+
with open(filepath, "w") as f:
|
| 206 |
+
dump_fn(transform_dict, f) # type: ignore
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def load(
|
| 210 |
+
filepath: str,
|
| 211 |
+
data_format: str = "json",
|
| 212 |
+
nonserializable: Optional[Dict[str, Any]] = None,
|
| 213 |
+
lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
|
| 214 |
+
) -> object:
|
| 215 |
+
"""
|
| 216 |
+
Load a serialized pipeline from a json or yaml file and construct a transform pipeline.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
filepath (str): Filepath to read from.
|
| 220 |
+
data_format (str): Serialization format. Should be either `json` or 'yaml'.
|
| 221 |
+
nonserializable (dict): A dictionary that contains non-serializable transforms.
|
| 222 |
+
This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
|
| 223 |
+
Keys in that dictionary should be named same as `name` arguments in respective transforms from
|
| 224 |
+
a serialized pipeline.
|
| 225 |
+
lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
|
| 226 |
+
"""
|
| 227 |
+
if lambda_transforms != "deprecated":
|
| 228 |
+
warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
|
| 229 |
+
nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
|
| 230 |
+
|
| 231 |
+
check_data_format(data_format)
|
| 232 |
+
load_fn = json.load if data_format == "json" else yaml.safe_load
|
| 233 |
+
with open(filepath) as f:
|
| 234 |
+
transform_dict = load_fn(f) # type: ignore
|
| 235 |
+
|
| 236 |
+
return from_dict(transform_dict, nonserializable=nonserializable)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def register_additional_transforms() -> None:
|
| 240 |
+
"""
|
| 241 |
+
Register transforms that are not imported directly into the `albumentations` module.
|
| 242 |
+
"""
|
| 243 |
+
try:
|
| 244 |
+
# This import will result in ImportError if `torch` is not installed
|
| 245 |
+
import custom_albumentations.pytorch
|
| 246 |
+
except ImportError:
|
| 247 |
+
pass
|
custom_albumentations/core/transforms_interface.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
from copy import deepcopy
|
| 5 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
|
| 6 |
+
from warnings import warn
|
| 7 |
+
|
| 8 |
+
import cv2
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
from .serialization import Serializable, get_shortest_class_fullname
|
| 12 |
+
from .utils import format_args
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"to_tuple",
|
| 16 |
+
"BasicTransform",
|
| 17 |
+
"DualTransform",
|
| 18 |
+
"ImageOnlyTransform",
|
| 19 |
+
"NoOp",
|
| 20 |
+
"BoxType",
|
| 21 |
+
"KeypointType",
|
| 22 |
+
"ImageColorType",
|
| 23 |
+
"ScaleFloatType",
|
| 24 |
+
"ScaleIntType",
|
| 25 |
+
"ImageColorType",
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
NumType = Union[int, float, np.ndarray]
|
| 29 |
+
BoxInternalType = Tuple[float, float, float, float]
|
| 30 |
+
BoxType = Union[BoxInternalType, Tuple[float, float, float, float, Any]]
|
| 31 |
+
KeypointInternalType = Tuple[float, float, float, float]
|
| 32 |
+
KeypointType = Union[KeypointInternalType, Tuple[float, float, float, float, Any]]
|
| 33 |
+
ImageColorType = Union[float, Sequence[float]]
|
| 34 |
+
|
| 35 |
+
ScaleFloatType = Union[float, Tuple[float, float]]
|
| 36 |
+
ScaleIntType = Union[int, Tuple[int, int]]
|
| 37 |
+
|
| 38 |
+
FillValueType = Optional[Union[int, float, Sequence[int], Sequence[float]]]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def to_tuple(param, low=None, bias=None):
|
| 42 |
+
"""Convert input argument to min-max tuple
|
| 43 |
+
Args:
|
| 44 |
+
param (scalar, tuple or list of 2+ elements): Input value.
|
| 45 |
+
If value is scalar, return value would be (offset - value, offset + value).
|
| 46 |
+
If value is tuple, return value would be value + offset (broadcasted).
|
| 47 |
+
low: Second element of tuple can be passed as optional argument
|
| 48 |
+
bias: An offset factor added to each element
|
| 49 |
+
"""
|
| 50 |
+
if low is not None and bias is not None:
|
| 51 |
+
raise ValueError("Arguments low and bias are mutually exclusive")
|
| 52 |
+
|
| 53 |
+
if param is None:
|
| 54 |
+
return param
|
| 55 |
+
|
| 56 |
+
if isinstance(param, (int, float)):
|
| 57 |
+
if low is None:
|
| 58 |
+
param = -param, +param
|
| 59 |
+
else:
|
| 60 |
+
param = (low, param) if low < param else (param, low)
|
| 61 |
+
elif isinstance(param, Sequence):
|
| 62 |
+
if len(param) != 2:
|
| 63 |
+
raise ValueError("to_tuple expects 1 or 2 values")
|
| 64 |
+
param = tuple(param)
|
| 65 |
+
else:
|
| 66 |
+
raise ValueError("Argument param must be either scalar (int, float) or tuple")
|
| 67 |
+
|
| 68 |
+
if bias is not None:
|
| 69 |
+
return tuple(bias + x for x in param)
|
| 70 |
+
|
| 71 |
+
return tuple(param)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class BasicTransform(Serializable):
|
| 75 |
+
call_backup = None
|
| 76 |
+
interpolation: Any
|
| 77 |
+
fill_value: Any
|
| 78 |
+
mask_fill_value: Any
|
| 79 |
+
|
| 80 |
+
def __init__(self, always_apply: bool = False, p: float = 0.5):
|
| 81 |
+
self.p = p
|
| 82 |
+
self.always_apply = always_apply
|
| 83 |
+
self._additional_targets: Dict[str, str] = {}
|
| 84 |
+
|
| 85 |
+
# replay mode params
|
| 86 |
+
self.deterministic = False
|
| 87 |
+
self.save_key = "replay"
|
| 88 |
+
self.params: Dict[Any, Any] = {}
|
| 89 |
+
self.replay_mode = False
|
| 90 |
+
self.applied_in_replay = False
|
| 91 |
+
|
| 92 |
+
def __call__(self, *args, force_apply: bool = False, **kwargs) -> Dict[str, Any]:
|
| 93 |
+
if args:
|
| 94 |
+
raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
|
| 95 |
+
if self.replay_mode:
|
| 96 |
+
if self.applied_in_replay:
|
| 97 |
+
return self.apply_with_params(self.params, **kwargs)
|
| 98 |
+
|
| 99 |
+
return kwargs
|
| 100 |
+
|
| 101 |
+
if (random.random() < self.p) or self.always_apply or force_apply:
|
| 102 |
+
params = self.get_params()
|
| 103 |
+
|
| 104 |
+
if self.targets_as_params:
|
| 105 |
+
assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format(
|
| 106 |
+
self.__class__.__name__, self.targets_as_params
|
| 107 |
+
)
|
| 108 |
+
targets_as_params = {k: kwargs[k] for k in self.targets_as_params}
|
| 109 |
+
params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
|
| 110 |
+
params.update(params_dependent_on_targets)
|
| 111 |
+
if self.deterministic:
|
| 112 |
+
if self.targets_as_params:
|
| 113 |
+
warn(
|
| 114 |
+
self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data"
|
| 115 |
+
" because its' params depend on targets."
|
| 116 |
+
)
|
| 117 |
+
kwargs[self.save_key][id(self)] = deepcopy(params)
|
| 118 |
+
return self.apply_with_params(params, **kwargs)
|
| 119 |
+
|
| 120 |
+
return kwargs
|
| 121 |
+
|
| 122 |
+
def apply_with_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: # skipcq: PYL-W0613
|
| 123 |
+
if params is None:
|
| 124 |
+
return kwargs
|
| 125 |
+
params = self.update_params(params, **kwargs)
|
| 126 |
+
res = {}
|
| 127 |
+
for key, arg in kwargs.items():
|
| 128 |
+
if arg is not None:
|
| 129 |
+
target_function = self._get_target_function(key)
|
| 130 |
+
target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
|
| 131 |
+
res[key] = target_function(arg, **dict(params, **target_dependencies))
|
| 132 |
+
else:
|
| 133 |
+
res[key] = None
|
| 134 |
+
return res
|
| 135 |
+
|
| 136 |
+
def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform":
|
| 137 |
+
assert save_key != "params", "params save_key is reserved"
|
| 138 |
+
self.deterministic = flag
|
| 139 |
+
self.save_key = save_key
|
| 140 |
+
return self
|
| 141 |
+
|
| 142 |
+
def __repr__(self) -> str:
|
| 143 |
+
state = self.get_base_init_args()
|
| 144 |
+
state.update(self.get_transform_init_args())
|
| 145 |
+
return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state))
|
| 146 |
+
|
| 147 |
+
def _get_target_function(self, key: str) -> Callable:
|
| 148 |
+
transform_key = key
|
| 149 |
+
if key in self._additional_targets:
|
| 150 |
+
transform_key = self._additional_targets.get(key, key)
|
| 151 |
+
|
| 152 |
+
target_function = self.targets.get(transform_key, lambda x, **p: x)
|
| 153 |
+
return target_function
|
| 154 |
+
|
| 155 |
+
def apply(self, img: np.ndarray, **params) -> np.ndarray:
|
| 156 |
+
raise NotImplementedError
|
| 157 |
+
|
| 158 |
+
def get_params(self) -> Dict:
|
| 159 |
+
return {}
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def targets(self) -> Dict[str, Callable]:
|
| 163 |
+
# you must specify targets in subclass
|
| 164 |
+
# for example: ('image', 'mask')
|
| 165 |
+
# ('image', 'boxes')
|
| 166 |
+
raise NotImplementedError
|
| 167 |
+
|
| 168 |
+
def update_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
| 169 |
+
if hasattr(self, "interpolation"):
|
| 170 |
+
params["interpolation"] = self.interpolation
|
| 171 |
+
if hasattr(self, "fill_value"):
|
| 172 |
+
params["fill_value"] = self.fill_value
|
| 173 |
+
if hasattr(self, "mask_fill_value"):
|
| 174 |
+
params["mask_fill_value"] = self.mask_fill_value
|
| 175 |
+
params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]})
|
| 176 |
+
return params
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def target_dependence(self) -> Dict:
|
| 180 |
+
return {}
|
| 181 |
+
|
| 182 |
+
def add_targets(self, additional_targets: Dict[str, str]):
|
| 183 |
+
"""Add targets to transform them the same way as one of existing targets
|
| 184 |
+
ex: {'target_image': 'image'}
|
| 185 |
+
ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'}
|
| 186 |
+
by the way you must have at least one object with key 'image'
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
|
| 190 |
+
"""
|
| 191 |
+
self._additional_targets = additional_targets
|
| 192 |
+
|
| 193 |
+
@property
|
| 194 |
+
def targets_as_params(self) -> List[str]:
|
| 195 |
+
return []
|
| 196 |
+
|
| 197 |
+
def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
| 198 |
+
raise NotImplementedError(
|
| 199 |
+
"Method get_params_dependent_on_targets is not implemented in class " + self.__class__.__name__
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@classmethod
|
| 203 |
+
def get_class_fullname(cls) -> str:
|
| 204 |
+
return get_shortest_class_fullname(cls)
|
| 205 |
+
|
| 206 |
+
@classmethod
|
| 207 |
+
def is_serializable(cls):
|
| 208 |
+
return True
|
| 209 |
+
|
| 210 |
+
def get_transform_init_args_names(self) -> Tuple[str, ...]:
|
| 211 |
+
raise NotImplementedError(
|
| 212 |
+
"Class {name} is not serializable because the `get_transform_init_args_names` method is not "
|
| 213 |
+
"implemented".format(name=self.get_class_fullname())
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
def get_base_init_args(self) -> Dict[str, Any]:
|
| 217 |
+
return {"always_apply": self.always_apply, "p": self.p}
|
| 218 |
+
|
| 219 |
+
def get_transform_init_args(self) -> Dict[str, Any]:
|
| 220 |
+
return {k: getattr(self, k) for k in self.get_transform_init_args_names()}
|
| 221 |
+
|
| 222 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 223 |
+
state = {"__class_fullname__": self.get_class_fullname()}
|
| 224 |
+
state.update(self.get_base_init_args())
|
| 225 |
+
state.update(self.get_transform_init_args())
|
| 226 |
+
return state
|
| 227 |
+
|
| 228 |
+
def get_dict_with_id(self) -> Dict[str, Any]:
|
| 229 |
+
d = self._to_dict()
|
| 230 |
+
d["id"] = id(self)
|
| 231 |
+
return d
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class DualTransform(BasicTransform):
|
| 235 |
+
"""Transform for segmentation task."""
|
| 236 |
+
|
| 237 |
+
@property
|
| 238 |
+
def targets(self) -> Dict[str, Callable]:
|
| 239 |
+
return {
|
| 240 |
+
"image": self.apply,
|
| 241 |
+
"mask": self.apply_to_mask,
|
| 242 |
+
"masks": self.apply_to_masks,
|
| 243 |
+
"bboxes": self.apply_to_bboxes,
|
| 244 |
+
"keypoints": self.apply_to_keypoints,
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 248 |
+
raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__)
|
| 249 |
+
|
| 250 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 251 |
+
raise NotImplementedError("Method apply_to_keypoint is not implemented in class " + self.__class__.__name__)
|
| 252 |
+
|
| 253 |
+
def apply_to_bboxes(self, bboxes: Sequence[BoxType], **params) -> List[BoxType]:
|
| 254 |
+
return [self.apply_to_bbox(tuple(bbox[:4]), **params) + tuple(bbox[4:]) for bbox in bboxes] # type: ignore
|
| 255 |
+
|
| 256 |
+
def apply_to_keypoints(self, keypoints: Sequence[KeypointType], **params) -> List[KeypointType]:
|
| 257 |
+
return [ # type: ignore
|
| 258 |
+
self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) # type: ignore
|
| 259 |
+
for keypoint in keypoints
|
| 260 |
+
]
|
| 261 |
+
|
| 262 |
+
def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
|
| 263 |
+
return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()})
|
| 264 |
+
|
| 265 |
+
def apply_to_masks(self, masks: Sequence[np.ndarray], **params) -> List[np.ndarray]:
|
| 266 |
+
return [self.apply_to_mask(mask, **params) for mask in masks]
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
class ImageOnlyTransform(BasicTransform):
|
| 270 |
+
"""Transform applied to image only."""
|
| 271 |
+
|
| 272 |
+
@property
|
| 273 |
+
def targets(self) -> Dict[str, Callable]:
|
| 274 |
+
return {"image": self.apply}
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class NoOp(DualTransform):
|
| 278 |
+
"""Does nothing"""
|
| 279 |
+
|
| 280 |
+
def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
|
| 281 |
+
return keypoint
|
| 282 |
+
|
| 283 |
+
def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
|
| 284 |
+
return bbox
|
| 285 |
+
|
| 286 |
+
def apply(self, img: np.ndarray, **params) -> np.ndarray:
|
| 287 |
+
return img
|
| 288 |
+
|
| 289 |
+
def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
|
| 290 |
+
return img
|
| 291 |
+
|
| 292 |
+
def get_transform_init_args_names(self) -> Tuple:
|
| 293 |
+
return ()
|
custom_albumentations/core/utils.py
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from abc import ABC, abstractmethod
|
| 4 |
+
from typing import Any, Dict, Optional, Sequence, Tuple
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from .serialization import Serializable
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_shape(img: Any) -> Tuple[int, int]:
|
| 12 |
+
if isinstance(img, np.ndarray):
|
| 13 |
+
rows, cols = img.shape[:2]
|
| 14 |
+
return rows, cols
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
import torch
|
| 18 |
+
|
| 19 |
+
if torch.is_tensor(img):
|
| 20 |
+
rows, cols = img.shape[-2:]
|
| 21 |
+
return rows, cols
|
| 22 |
+
except ImportError:
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
raise RuntimeError(
|
| 26 |
+
f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def format_args(args_dict: Dict):
|
| 31 |
+
formatted_args = []
|
| 32 |
+
for k, v in args_dict.items():
|
| 33 |
+
if isinstance(v, str):
|
| 34 |
+
v = f"'{v}'"
|
| 35 |
+
formatted_args.append(f"{k}={v}")
|
| 36 |
+
return ", ".join(formatted_args)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Params(Serializable, ABC):
|
| 40 |
+
def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None):
|
| 41 |
+
self.format = format
|
| 42 |
+
self.label_fields = label_fields
|
| 43 |
+
|
| 44 |
+
def _to_dict(self) -> Dict[str, Any]:
|
| 45 |
+
return {"format": self.format, "label_fields": self.label_fields}
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DataProcessor(ABC):
|
| 49 |
+
def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None):
|
| 50 |
+
self.params = params
|
| 51 |
+
self.data_fields = [self.default_data_name]
|
| 52 |
+
if additional_targets is not None:
|
| 53 |
+
for k, v in additional_targets.items():
|
| 54 |
+
if v == self.default_data_name:
|
| 55 |
+
self.data_fields.append(k)
|
| 56 |
+
|
| 57 |
+
@property
|
| 58 |
+
@abstractmethod
|
| 59 |
+
def default_data_name(self) -> str:
|
| 60 |
+
raise NotImplementedError
|
| 61 |
+
|
| 62 |
+
def ensure_data_valid(self, data: Dict[str, Any]) -> None:
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 69 |
+
rows, cols = get_shape(data["image"])
|
| 70 |
+
|
| 71 |
+
for data_name in self.data_fields:
|
| 72 |
+
data[data_name] = self.filter(data[data_name], rows, cols)
|
| 73 |
+
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from")
|
| 74 |
+
|
| 75 |
+
data = self.remove_label_fields_from_data(data)
|
| 76 |
+
return data
|
| 77 |
+
|
| 78 |
+
def preprocess(self, data: Dict[str, Any]) -> None:
|
| 79 |
+
data = self.add_label_fields_to_data(data)
|
| 80 |
+
|
| 81 |
+
rows, cols = data["image"].shape[:2]
|
| 82 |
+
for data_name in self.data_fields:
|
| 83 |
+
data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")
|
| 84 |
+
|
| 85 |
+
def check_and_convert(self, data: Sequence, rows: int, cols: int, direction: str = "to") -> Sequence:
|
| 86 |
+
if self.params.format == "albumentations":
|
| 87 |
+
self.check(data, rows, cols)
|
| 88 |
+
return data
|
| 89 |
+
|
| 90 |
+
if direction == "to":
|
| 91 |
+
return self.convert_to_albumentations(data, rows, cols)
|
| 92 |
+
elif direction == "from":
|
| 93 |
+
return self.convert_from_albumentations(data, rows, cols)
|
| 94 |
+
else:
|
| 95 |
+
raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`")
|
| 96 |
+
|
| 97 |
+
@abstractmethod
|
| 98 |
+
def filter(self, data: Sequence, rows: int, cols: int) -> Sequence:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
@abstractmethod
|
| 102 |
+
def check(self, data: Sequence, rows: int, cols: int) -> None:
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
@abstractmethod
|
| 106 |
+
def convert_to_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
|
| 107 |
+
pass
|
| 108 |
+
|
| 109 |
+
@abstractmethod
|
| 110 |
+
def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
|
| 111 |
+
pass
|
| 112 |
+
|
| 113 |
+
def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 114 |
+
if self.params.label_fields is None:
|
| 115 |
+
return data
|
| 116 |
+
for data_name in self.data_fields:
|
| 117 |
+
for field in self.params.label_fields:
|
| 118 |
+
assert len(data[data_name]) == len(data[field])
|
| 119 |
+
data_with_added_field = []
|
| 120 |
+
for d, field_value in zip(data[data_name], data[field]):
|
| 121 |
+
data_with_added_field.append(list(d) + [field_value])
|
| 122 |
+
data[data_name] = data_with_added_field
|
| 123 |
+
return data
|
| 124 |
+
|
| 125 |
+
def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
| 126 |
+
if self.params.label_fields is None:
|
| 127 |
+
return data
|
| 128 |
+
for data_name in self.data_fields:
|
| 129 |
+
label_fields_len = len(self.params.label_fields)
|
| 130 |
+
for idx, field in enumerate(self.params.label_fields):
|
| 131 |
+
field_values = []
|
| 132 |
+
for bbox in data[data_name]:
|
| 133 |
+
field_values.append(bbox[-label_fields_len + idx])
|
| 134 |
+
data[field] = field_values
|
| 135 |
+
if label_fields_len:
|
| 136 |
+
data[data_name] = [d[:-label_fields_len] for d in data[data_name]]
|
| 137 |
+
return data
|
custom_albumentations/imgaug/__init__.py
ADDED
|
File without changes
|
custom_albumentations/imgaug/stubs.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__all__ = [
|
| 2 |
+
"IAAEmboss",
|
| 3 |
+
"IAASuperpixels",
|
| 4 |
+
"IAASharpen",
|
| 5 |
+
"IAAAdditiveGaussianNoise",
|
| 6 |
+
"IAACropAndPad",
|
| 7 |
+
"IAAFliplr",
|
| 8 |
+
"IAAFlipud",
|
| 9 |
+
"IAAAffine",
|
| 10 |
+
"IAAPiecewiseAffine",
|
| 11 |
+
"IAAPerspective",
|
| 12 |
+
]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class IAAStub:
|
| 16 |
+
def __init__(self, *args, **kwargs):
|
| 17 |
+
cls_name = self.__class__.__name__
|
| 18 |
+
doc_link = "https://albumentations.ai/docs/api_reference/augmentations" + self.doc_link
|
| 19 |
+
raise RuntimeError(
|
| 20 |
+
f"You are trying to use a deprecated augmentation '{cls_name}' which depends on the imgaug library, "
|
| 21 |
+
f"but imgaug is not installed.\n\n"
|
| 22 |
+
"There are two options to fix this error:\n"
|
| 23 |
+
"1. [Recommended]. Switch to the Albumentations' implementation of the augmentation with the same API: "
|
| 24 |
+
f"{self.alternative} - {doc_link}\n"
|
| 25 |
+
"2. Install a version of Albumentations that contains imgaug by running "
|
| 26 |
+
"'pip install -U albumentations[imgaug]'."
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IAACropAndPad(IAAStub):
|
| 31 |
+
alternative = "CropAndPad"
|
| 32 |
+
doc_link = "/crops/transforms/#albumentations.augmentations.crops.transforms.CropAndPad"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class IAAFliplr(IAAStub):
|
| 36 |
+
alternative = "HorizontalFlip"
|
| 37 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.HorizontalFlip"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class IAAFlipud(IAAStub):
|
| 41 |
+
alternative = "VerticalFlip"
|
| 42 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.VerticalFlip"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class IAAEmboss(IAAStub):
|
| 46 |
+
alternative = "Emboss"
|
| 47 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.Emboss"
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class IAASuperpixels(IAAStub):
|
| 51 |
+
alternative = "Superpixels"
|
| 52 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.Superpixels"
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class IAASharpen(IAAStub):
|
| 56 |
+
alternative = "Sharpen"
|
| 57 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.Sharpen"
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class IAAAdditiveGaussianNoise(IAAStub):
|
| 61 |
+
alternative = "GaussNoise"
|
| 62 |
+
doc_link = "/transforms/#albumentations.augmentations.transforms.GaussNoise"
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class IAAPiecewiseAffine(IAAStub):
|
| 66 |
+
alternative = "PiecewiseAffine"
|
| 67 |
+
doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.PiecewiseAffine"
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class IAAAffine(IAAStub):
|
| 71 |
+
alternative = "Affine"
|
| 72 |
+
doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Affine"
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class IAAPerspective(IAAStub):
|
| 76 |
+
alternative = "Perspective"
|
| 77 |
+
doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Perspective"
|
custom_albumentations/imgaug/transforms.py
ADDED
|
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
try:
|
| 2 |
+
import imgaug as ia
|
| 3 |
+
except ImportError as e:
|
| 4 |
+
raise ImportError(
|
| 5 |
+
"You are trying to import an augmentation that depends on the imgaug library, but imgaug is not installed. To "
|
| 6 |
+
"install a version of Albumentations that contains imgaug please run 'pip install -U albumentations[imgaug]'"
|
| 7 |
+
) from e
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from imgaug import augmenters as iaa
|
| 11 |
+
except ImportError:
|
| 12 |
+
import imgaug.imgaug.augmenters as iaa
|
| 13 |
+
|
| 14 |
+
import warnings
|
| 15 |
+
|
| 16 |
+
from custom_albumentations.core.bbox_utils import (
|
| 17 |
+
convert_bboxes_from_albumentations,
|
| 18 |
+
convert_bboxes_to_albumentations,
|
| 19 |
+
)
|
| 20 |
+
from custom_albumentations.core.keypoints_utils import (
|
| 21 |
+
convert_keypoints_from_albumentations,
|
| 22 |
+
convert_keypoints_to_albumentations,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
from ..augmentations import Perspective
|
| 26 |
+
from ..core.transforms_interface import (
|
| 27 |
+
BasicTransform,
|
| 28 |
+
DualTransform,
|
| 29 |
+
ImageOnlyTransform,
|
| 30 |
+
to_tuple,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
__all__ = [
|
| 34 |
+
"BasicIAATransform",
|
| 35 |
+
"DualIAATransform",
|
| 36 |
+
"ImageOnlyIAATransform",
|
| 37 |
+
"IAAEmboss",
|
| 38 |
+
"IAASuperpixels",
|
| 39 |
+
"IAASharpen",
|
| 40 |
+
"IAAAdditiveGaussianNoise",
|
| 41 |
+
"IAACropAndPad",
|
| 42 |
+
"IAAFliplr",
|
| 43 |
+
"IAAFlipud",
|
| 44 |
+
"IAAAffine",
|
| 45 |
+
"IAAPiecewiseAffine",
|
| 46 |
+
"IAAPerspective",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class BasicIAATransform(BasicTransform):
|
| 51 |
+
def __init__(self, always_apply=False, p=0.5):
|
| 52 |
+
super(BasicIAATransform, self).__init__(always_apply, p)
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def processor(self):
|
| 56 |
+
return iaa.Noop()
|
| 57 |
+
|
| 58 |
+
def update_params(self, params, **kwargs):
|
| 59 |
+
params = super(BasicIAATransform, self).update_params(params, **kwargs)
|
| 60 |
+
params["deterministic_processor"] = self.processor.to_deterministic()
|
| 61 |
+
return params
|
| 62 |
+
|
| 63 |
+
def apply(self, img, deterministic_processor=None, **params):
|
| 64 |
+
return deterministic_processor.augment_image(img)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class DualIAATransform(DualTransform, BasicIAATransform):
|
| 68 |
+
def apply_to_bboxes(self, bboxes, deterministic_processor=None, rows=0, cols=0, **params):
|
| 69 |
+
if len(bboxes) > 0:
|
| 70 |
+
bboxes = convert_bboxes_from_albumentations(bboxes, "pascal_voc", rows=rows, cols=cols)
|
| 71 |
+
|
| 72 |
+
bboxes_t = ia.BoundingBoxesOnImage([ia.BoundingBox(*bbox[:4]) for bbox in bboxes], (rows, cols))
|
| 73 |
+
bboxes_t = deterministic_processor.augment_bounding_boxes([bboxes_t])[0].bounding_boxes
|
| 74 |
+
bboxes_t = [
|
| 75 |
+
[bbox.x1, bbox.y1, bbox.x2, bbox.y2] + list(bbox_orig[4:])
|
| 76 |
+
for (bbox, bbox_orig) in zip(bboxes_t, bboxes)
|
| 77 |
+
]
|
| 78 |
+
|
| 79 |
+
bboxes = convert_bboxes_to_albumentations(bboxes_t, "pascal_voc", rows=rows, cols=cols)
|
| 80 |
+
return bboxes
|
| 81 |
+
|
| 82 |
+
"""Applies transformation to keypoints.
|
| 83 |
+
Notes:
|
| 84 |
+
Since IAA supports only xy keypoints, scale and orientation will remain unchanged.
|
| 85 |
+
TODO:
|
| 86 |
+
Emit a warning message if child classes of DualIAATransform are instantiated
|
| 87 |
+
inside Compose with keypoints format other than 'xy'.
|
| 88 |
+
"""
|
| 89 |
+
|
| 90 |
+
def apply_to_keypoints(self, keypoints, deterministic_processor=None, rows=0, cols=0, **params):
|
| 91 |
+
if len(keypoints) > 0:
|
| 92 |
+
keypoints = convert_keypoints_from_albumentations(keypoints, "xy", rows=rows, cols=cols)
|
| 93 |
+
keypoints_t = ia.KeypointsOnImage([ia.Keypoint(*kp[:2]) for kp in keypoints], (rows, cols))
|
| 94 |
+
keypoints_t = deterministic_processor.augment_keypoints([keypoints_t])[0].keypoints
|
| 95 |
+
|
| 96 |
+
bboxes_t = [[kp.x, kp.y] + list(kp_orig[2:]) for (kp, kp_orig) in zip(keypoints_t, keypoints)]
|
| 97 |
+
|
| 98 |
+
keypoints = convert_keypoints_to_albumentations(bboxes_t, "xy", rows=rows, cols=cols)
|
| 99 |
+
return keypoints
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
class ImageOnlyIAATransform(ImageOnlyTransform, BasicIAATransform):
|
| 103 |
+
pass
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class IAACropAndPad(DualIAATransform):
|
| 107 |
+
"""This augmentation is deprecated. Please use CropAndPad instead."""
|
| 108 |
+
|
| 109 |
+
def __init__(self, px=None, percent=None, pad_mode="constant", pad_cval=0, keep_size=True, always_apply=False, p=1):
|
| 110 |
+
super(IAACropAndPad, self).__init__(always_apply, p)
|
| 111 |
+
self.px = px
|
| 112 |
+
self.percent = percent
|
| 113 |
+
self.pad_mode = pad_mode
|
| 114 |
+
self.pad_cval = pad_cval
|
| 115 |
+
self.keep_size = keep_size
|
| 116 |
+
warnings.warn("IAACropAndPad is deprecated. Please use CropAndPad instead", FutureWarning)
|
| 117 |
+
|
| 118 |
+
@property
|
| 119 |
+
def processor(self):
|
| 120 |
+
return iaa.CropAndPad(self.px, self.percent, self.pad_mode, self.pad_cval, self.keep_size)
|
| 121 |
+
|
| 122 |
+
def get_transform_init_args_names(self):
|
| 123 |
+
return ("px", "percent", "pad_mode", "pad_cval", "keep_size")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class IAAFliplr(DualIAATransform):
|
| 127 |
+
"""This augmentation is deprecated. Please use HorizontalFlip instead."""
|
| 128 |
+
|
| 129 |
+
def __init__(self, always_apply=False, p=0.5):
|
| 130 |
+
super().__init__(always_apply, p)
|
| 131 |
+
warnings.warn("IAAFliplr is deprecated. Please use HorizontalFlip instead.", FutureWarning)
|
| 132 |
+
|
| 133 |
+
@property
|
| 134 |
+
def processor(self):
|
| 135 |
+
return iaa.Fliplr(1)
|
| 136 |
+
|
| 137 |
+
def get_transform_init_args_names(self):
|
| 138 |
+
return ()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class IAAFlipud(DualIAATransform):
|
| 142 |
+
"""This augmentation is deprecated. Please use VerticalFlip instead."""
|
| 143 |
+
|
| 144 |
+
def __init__(self, always_apply=False, p=0.5):
|
| 145 |
+
super().__init__(always_apply, p)
|
| 146 |
+
warnings.warn("IAAFlipud is deprecated. Please use VerticalFlip instead.", FutureWarning)
|
| 147 |
+
|
| 148 |
+
@property
|
| 149 |
+
def processor(self):
|
| 150 |
+
return iaa.Flipud(1)
|
| 151 |
+
|
| 152 |
+
def get_transform_init_args_names(self):
|
| 153 |
+
return ()
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class IAAEmboss(ImageOnlyIAATransform):
|
| 157 |
+
"""Emboss the input image and overlays the result with the original image.
|
| 158 |
+
This augmentation is deprecated. Please use Emboss instead.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is
|
| 162 |
+
visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5).
|
| 163 |
+
strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7).
|
| 164 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 165 |
+
|
| 166 |
+
Targets:
|
| 167 |
+
image
|
| 168 |
+
"""
|
| 169 |
+
|
| 170 |
+
def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5):
|
| 171 |
+
super(IAAEmboss, self).__init__(always_apply, p)
|
| 172 |
+
self.alpha = to_tuple(alpha, 0.0)
|
| 173 |
+
self.strength = to_tuple(strength, 0.0)
|
| 174 |
+
warnings.warn("This augmentation is deprecated. Please use Emboss instead", FutureWarning)
|
| 175 |
+
|
| 176 |
+
@property
|
| 177 |
+
def processor(self):
|
| 178 |
+
return iaa.Emboss(self.alpha, self.strength)
|
| 179 |
+
|
| 180 |
+
def get_transform_init_args_names(self):
|
| 181 |
+
return ("alpha", "strength")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class IAASuperpixels(ImageOnlyIAATransform):
|
| 185 |
+
"""Completely or partially transform the input image to its superpixel representation. Uses skimage's version
|
| 186 |
+
of the SLIC algorithm. May be slow.
|
| 187 |
+
|
| 188 |
+
This augmentation is deprecated. Please use Superpixels instead.
|
| 189 |
+
|
| 190 |
+
Args:
|
| 191 |
+
p_replace (float): defines the probability of any superpixel area being replaced by the superpixel, i.e. by
|
| 192 |
+
the average pixel color within its area. Default: 0.1.
|
| 193 |
+
n_segments (int): target number of superpixels to generate. Default: 100.
|
| 194 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 195 |
+
|
| 196 |
+
Targets:
|
| 197 |
+
image
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, p_replace=0.1, n_segments=100, always_apply=False, p=0.5):
|
| 201 |
+
super(IAASuperpixels, self).__init__(always_apply, p)
|
| 202 |
+
self.p_replace = p_replace
|
| 203 |
+
self.n_segments = n_segments
|
| 204 |
+
warnings.warn("IAASuperpixels is deprecated. Please use Superpixels instead.", FutureWarning)
|
| 205 |
+
|
| 206 |
+
@property
|
| 207 |
+
def processor(self):
|
| 208 |
+
return iaa.Superpixels(p_replace=self.p_replace, n_segments=self.n_segments)
|
| 209 |
+
|
| 210 |
+
def get_transform_init_args_names(self):
|
| 211 |
+
return ("p_replace", "n_segments")
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class IAASharpen(ImageOnlyIAATransform):
|
| 215 |
+
"""Sharpen the input image and overlays the result with the original image.
|
| 216 |
+
This augmentation is deprecated. Please use Sharpen instead
|
| 217 |
+
Args:
|
| 218 |
+
alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is
|
| 219 |
+
visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5).
|
| 220 |
+
lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0).
|
| 221 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 222 |
+
|
| 223 |
+
Targets:
|
| 224 |
+
image
|
| 225 |
+
"""
|
| 226 |
+
|
| 227 |
+
def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5):
|
| 228 |
+
super(IAASharpen, self).__init__(always_apply, p)
|
| 229 |
+
self.alpha = to_tuple(alpha, 0)
|
| 230 |
+
self.lightness = to_tuple(lightness, 0)
|
| 231 |
+
warnings.warn("IAASharpen is deprecated. Please use Sharpen instead", FutureWarning)
|
| 232 |
+
|
| 233 |
+
@property
|
| 234 |
+
def processor(self):
|
| 235 |
+
return iaa.Sharpen(self.alpha, self.lightness)
|
| 236 |
+
|
| 237 |
+
def get_transform_init_args_names(self):
|
| 238 |
+
return ("alpha", "lightness")
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class IAAAdditiveGaussianNoise(ImageOnlyIAATransform):
|
| 242 |
+
"""Add gaussian noise to the input image.
|
| 243 |
+
|
| 244 |
+
This augmentation is deprecated. Please use GaussNoise instead.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
loc (int): mean of the normal distribution that generates the noise. Default: 0.
|
| 248 |
+
scale ((float, float)): standard deviation of the normal distribution that generates the noise.
|
| 249 |
+
Default: (0.01 * 255, 0.05 * 255).
|
| 250 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 251 |
+
|
| 252 |
+
Targets:
|
| 253 |
+
image
|
| 254 |
+
"""
|
| 255 |
+
|
| 256 |
+
def __init__(self, loc=0, scale=(0.01 * 255, 0.05 * 255), per_channel=False, always_apply=False, p=0.5):
|
| 257 |
+
super(IAAAdditiveGaussianNoise, self).__init__(always_apply, p)
|
| 258 |
+
self.loc = loc
|
| 259 |
+
self.scale = to_tuple(scale, 0.0)
|
| 260 |
+
self.per_channel = per_channel
|
| 261 |
+
warnings.warn("IAAAdditiveGaussianNoise is deprecated. Please use GaussNoise instead", FutureWarning)
|
| 262 |
+
|
| 263 |
+
@property
|
| 264 |
+
def processor(self):
|
| 265 |
+
return iaa.AdditiveGaussianNoise(self.loc, self.scale, self.per_channel)
|
| 266 |
+
|
| 267 |
+
def get_transform_init_args_names(self):
|
| 268 |
+
return ("loc", "scale", "per_channel")
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
class IAAPiecewiseAffine(DualIAATransform):
|
| 272 |
+
"""Place a regular grid of points on the input and randomly move the neighbourhood of these point around
|
| 273 |
+
via affine transformations.
|
| 274 |
+
|
| 275 |
+
This augmentation is deprecated. Please use PiecewiseAffine instead.
|
| 276 |
+
|
| 277 |
+
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
|
| 278 |
+
|
| 279 |
+
Args:
|
| 280 |
+
scale ((float, float): factor range that determines how far each point is moved. Default: (0.03, 0.05).
|
| 281 |
+
nb_rows (int): number of rows of points that the regular grid should have. Default: 4.
|
| 282 |
+
nb_cols (int): number of columns of points that the regular grid should have. Default: 4.
|
| 283 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 284 |
+
|
| 285 |
+
Targets:
|
| 286 |
+
image, mask
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
def __init__(
|
| 290 |
+
self, scale=(0.03, 0.05), nb_rows=4, nb_cols=4, order=1, cval=0, mode="constant", always_apply=False, p=0.5
|
| 291 |
+
):
|
| 292 |
+
super(IAAPiecewiseAffine, self).__init__(always_apply, p)
|
| 293 |
+
self.scale = to_tuple(scale, 0.0)
|
| 294 |
+
self.nb_rows = nb_rows
|
| 295 |
+
self.nb_cols = nb_cols
|
| 296 |
+
self.order = order
|
| 297 |
+
self.cval = cval
|
| 298 |
+
self.mode = mode
|
| 299 |
+
warnings.warn("This IAAPiecewiseAffine is deprecated. Please use PiecewiseAffine instead", FutureWarning)
|
| 300 |
+
|
| 301 |
+
@property
|
| 302 |
+
def processor(self):
|
| 303 |
+
return iaa.PiecewiseAffine(self.scale, self.nb_rows, self.nb_cols, self.order, self.cval, self.mode)
|
| 304 |
+
|
| 305 |
+
def get_transform_init_args_names(self):
|
| 306 |
+
return ("scale", "nb_rows", "nb_cols", "order", "cval", "mode")
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class IAAAffine(DualIAATransform):
|
| 310 |
+
"""Place a regular grid of points on the input and randomly move the neighbourhood of these point around
|
| 311 |
+
via affine transformations.
|
| 312 |
+
|
| 313 |
+
This augmentation is deprecated. Please use Affine instead.
|
| 314 |
+
|
| 315 |
+
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
|
| 316 |
+
|
| 317 |
+
Args:
|
| 318 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 319 |
+
|
| 320 |
+
Targets:
|
| 321 |
+
image, mask
|
| 322 |
+
"""
|
| 323 |
+
|
| 324 |
+
def __init__(
|
| 325 |
+
self,
|
| 326 |
+
scale=1.0,
|
| 327 |
+
translate_percent=None,
|
| 328 |
+
translate_px=None,
|
| 329 |
+
rotate=0.0,
|
| 330 |
+
shear=0.0,
|
| 331 |
+
order=1,
|
| 332 |
+
cval=0,
|
| 333 |
+
mode="reflect",
|
| 334 |
+
always_apply=False,
|
| 335 |
+
p=0.5,
|
| 336 |
+
):
|
| 337 |
+
super(IAAAffine, self).__init__(always_apply, p)
|
| 338 |
+
self.scale = to_tuple(scale, 1.0)
|
| 339 |
+
self.translate_percent = to_tuple(translate_percent, 0)
|
| 340 |
+
self.translate_px = to_tuple(translate_px, 0)
|
| 341 |
+
self.rotate = to_tuple(rotate)
|
| 342 |
+
self.shear = to_tuple(shear)
|
| 343 |
+
self.order = order
|
| 344 |
+
self.cval = cval
|
| 345 |
+
self.mode = mode
|
| 346 |
+
warnings.warn("This IAAAffine is deprecated. Please use Affine instead", FutureWarning)
|
| 347 |
+
|
| 348 |
+
@property
|
| 349 |
+
def processor(self):
|
| 350 |
+
return iaa.Affine(
|
| 351 |
+
self.scale,
|
| 352 |
+
self.translate_percent,
|
| 353 |
+
self.translate_px,
|
| 354 |
+
self.rotate,
|
| 355 |
+
self.shear,
|
| 356 |
+
self.order,
|
| 357 |
+
self.cval,
|
| 358 |
+
self.mode,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
def get_transform_init_args_names(self):
|
| 362 |
+
return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
class IAAPerspective(Perspective):
|
| 366 |
+
"""Perform a random four point perspective transform of the input.
|
| 367 |
+
This augmentation is deprecated. Please use Perspective instead.
|
| 368 |
+
|
| 369 |
+
Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
|
| 370 |
+
|
| 371 |
+
Args:
|
| 372 |
+
scale ((float, float): standard deviation of the normal distributions. These are used to sample
|
| 373 |
+
the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
|
| 374 |
+
p (float): probability of applying the transform. Default: 0.5.
|
| 375 |
+
|
| 376 |
+
Targets:
|
| 377 |
+
image, mask
|
| 378 |
+
"""
|
| 379 |
+
|
| 380 |
+
def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5):
|
| 381 |
+
super(IAAPerspective, self).__init__(always_apply, p)
|
| 382 |
+
self.scale = to_tuple(scale, 1.0)
|
| 383 |
+
self.keep_size = keep_size
|
| 384 |
+
warnings.warn("This IAAPerspective is deprecated. Please use Perspective instead", FutureWarning)
|
| 385 |
+
|
| 386 |
+
@property
|
| 387 |
+
def processor(self):
|
| 388 |
+
return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size)
|
| 389 |
+
|
| 390 |
+
def get_transform_init_args_names(self):
|
| 391 |
+
return ("scale", "keep_size")
|
custom_albumentations/pytorch/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
from .transforms import *
|
custom_albumentations/pytorch/functional.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import division
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
import torchvision.transforms.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def img_to_tensor(im, normalize=None):
|
| 9 |
+
tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
|
| 10 |
+
if normalize is not None:
|
| 11 |
+
return F.normalize(tensor, **normalize)
|
| 12 |
+
return tensor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def mask_to_tensor(mask, num_classes, sigmoid):
|
| 16 |
+
if num_classes > 1:
|
| 17 |
+
if not sigmoid:
|
| 18 |
+
# softmax
|
| 19 |
+
long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
|
| 20 |
+
if len(mask.shape) == 3:
|
| 21 |
+
for c in range(mask.shape[2]):
|
| 22 |
+
long_mask[mask[..., c] > 0] = c
|
| 23 |
+
else:
|
| 24 |
+
long_mask[mask > 127] = 1
|
| 25 |
+
long_mask[mask == 0] = 0
|
| 26 |
+
mask = long_mask
|
| 27 |
+
else:
|
| 28 |
+
mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
|
| 29 |
+
else:
|
| 30 |
+
mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
|
| 31 |
+
return torch.from_numpy(mask)
|
custom_albumentations/pytorch/transforms.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import absolute_import
|
| 2 |
+
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision.transforms import functional as F
|
| 8 |
+
|
| 9 |
+
from ..core.transforms_interface import BasicTransform
|
| 10 |
+
|
| 11 |
+
__all__ = ["ToTensorV2"]
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def img_to_tensor(im, normalize=None):
|
| 15 |
+
tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
|
| 16 |
+
if normalize is not None:
|
| 17 |
+
return F.normalize(tensor, **normalize)
|
| 18 |
+
return tensor
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def mask_to_tensor(mask, num_classes, sigmoid):
|
| 22 |
+
if num_classes > 1:
|
| 23 |
+
if not sigmoid:
|
| 24 |
+
# softmax
|
| 25 |
+
long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
|
| 26 |
+
if len(mask.shape) == 3:
|
| 27 |
+
for c in range(mask.shape[2]):
|
| 28 |
+
long_mask[mask[..., c] > 0] = c
|
| 29 |
+
else:
|
| 30 |
+
long_mask[mask > 127] = 1
|
| 31 |
+
long_mask[mask == 0] = 0
|
| 32 |
+
mask = long_mask
|
| 33 |
+
else:
|
| 34 |
+
mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
|
| 35 |
+
else:
|
| 36 |
+
mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
|
| 37 |
+
return torch.from_numpy(mask)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ToTensor(BasicTransform):
|
| 41 |
+
"""Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type.
|
| 42 |
+
This transform is now removed from custom_albumentations. If you need it downgrade the library to version 0.5.2.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
num_classes (int): only for segmentation
|
| 46 |
+
sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not.
|
| 47 |
+
normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize
|
| 48 |
+
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, num_classes=1, sigmoid=True, normalize=None):
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"`ToTensor` is obsolete and it was removed from custom_albumentations. Please use `ToTensorV2` instead - "
|
| 54 |
+
"https://albumentations.ai/docs/api_reference/pytorch/transforms/"
|
| 55 |
+
"#albumentations.pytorch.transforms.ToTensorV2. "
|
| 56 |
+
"\n\nIf you need `ToTensor` downgrade Albumentations to version 0.5.2."
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class ToTensorV2(BasicTransform):
|
| 61 |
+
"""Convert image and mask to `torch.Tensor`. The numpy `HWC` image is converted to pytorch `CHW` tensor.
|
| 62 |
+
If the image is in `HW` format (grayscale image), it will be converted to pytorch `HW` tensor.
|
| 63 |
+
This is a simplified and improved version of the old `ToTensor`
|
| 64 |
+
transform (`ToTensor` was deprecated, and now it is not present in Albumentations. You should use `ToTensorV2`
|
| 65 |
+
instead).
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
transpose_mask (bool): If True and an input mask has three dimensions, this transform will transpose dimensions
|
| 69 |
+
so the shape `[height, width, num_channels]` becomes `[num_channels, height, width]`. The latter format is a
|
| 70 |
+
standard format for PyTorch Tensors. Default: False.
|
| 71 |
+
always_apply (bool): Indicates whether this transformation should be always applied. Default: True.
|
| 72 |
+
p (float): Probability of applying the transform. Default: 1.0.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(self, transpose_mask=False, always_apply=True, p=1.0):
|
| 76 |
+
super(ToTensorV2, self).__init__(always_apply=always_apply, p=p)
|
| 77 |
+
self.transpose_mask = transpose_mask
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def targets(self):
|
| 81 |
+
return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks}
|
| 82 |
+
|
| 83 |
+
def apply(self, img, **params): # skipcq: PYL-W0613
|
| 84 |
+
if len(img.shape) not in [2, 3]:
|
| 85 |
+
raise ValueError("Albumentations only supports images in HW or HWC format")
|
| 86 |
+
|
| 87 |
+
if len(img.shape) == 2:
|
| 88 |
+
img = np.expand_dims(img, 2)
|
| 89 |
+
|
| 90 |
+
return torch.from_numpy(img.transpose(2, 0, 1))
|
| 91 |
+
|
| 92 |
+
def apply_to_mask(self, mask, **params): # skipcq: PYL-W0613
|
| 93 |
+
if self.transpose_mask and mask.ndim == 3:
|
| 94 |
+
mask = mask.transpose(2, 0, 1)
|
| 95 |
+
return torch.from_numpy(mask)
|
| 96 |
+
|
| 97 |
+
def apply_to_masks(self, masks, **params):
|
| 98 |
+
return [self.apply_to_mask(mask, **params) for mask in masks]
|
| 99 |
+
|
| 100 |
+
def get_transform_init_args_names(self):
|
| 101 |
+
return ("transpose_mask",)
|
| 102 |
+
|
| 103 |
+
def get_params_dependent_on_targets(self, params):
|
| 104 |
+
return {}
|
custom_albumentations/random_utils.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use `Any` as the return type to avoid mypy problems with Union data types,
|
| 2 |
+
# because numpy can return single number and ndarray
|
| 3 |
+
|
| 4 |
+
import random as py_random
|
| 5 |
+
from typing import Any, Optional, Sequence, Type, Union
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from .core.transforms_interface import NumType
|
| 10 |
+
|
| 11 |
+
IntNumType = Union[int, np.ndarray]
|
| 12 |
+
Size = Union[int, Sequence[int]]
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def get_random_state() -> np.random.RandomState:
|
| 16 |
+
return np.random.RandomState(py_random.randint(0, (1 << 32) - 1))
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def uniform(
|
| 20 |
+
low: NumType = 0.0,
|
| 21 |
+
high: NumType = 1.0,
|
| 22 |
+
size: Optional[Size] = None,
|
| 23 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 24 |
+
) -> Any:
|
| 25 |
+
if random_state is None:
|
| 26 |
+
random_state = get_random_state()
|
| 27 |
+
return random_state.uniform(low, high, size)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def rand(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any:
|
| 31 |
+
if random_state is None:
|
| 32 |
+
random_state = get_random_state()
|
| 33 |
+
return random_state.rand(d0, d1, *more, **kwargs) # type: ignore
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def randn(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any:
|
| 37 |
+
if random_state is None:
|
| 38 |
+
random_state = get_random_state()
|
| 39 |
+
return random_state.randn(d0, d1, *more, **kwargs) # type: ignore
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def normal(
|
| 43 |
+
loc: NumType = 0.0,
|
| 44 |
+
scale: NumType = 1.0,
|
| 45 |
+
size: Optional[Size] = None,
|
| 46 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 47 |
+
) -> Any:
|
| 48 |
+
if random_state is None:
|
| 49 |
+
random_state = get_random_state()
|
| 50 |
+
return random_state.normal(loc, scale, size)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def poisson(
|
| 54 |
+
lam: NumType = 1.0, size: Optional[Size] = None, random_state: Optional[np.random.RandomState] = None
|
| 55 |
+
) -> Any:
|
| 56 |
+
if random_state is None:
|
| 57 |
+
random_state = get_random_state()
|
| 58 |
+
return random_state.poisson(lam, size)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def permutation(
|
| 62 |
+
x: Union[int, Sequence[float], np.ndarray], random_state: Optional[np.random.RandomState] = None
|
| 63 |
+
) -> Any:
|
| 64 |
+
if random_state is None:
|
| 65 |
+
random_state = get_random_state()
|
| 66 |
+
return random_state.permutation(x)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def randint(
|
| 70 |
+
low: IntNumType,
|
| 71 |
+
high: Optional[IntNumType] = None,
|
| 72 |
+
size: Optional[Size] = None,
|
| 73 |
+
dtype: Type = np.int32,
|
| 74 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 75 |
+
) -> Any:
|
| 76 |
+
if random_state is None:
|
| 77 |
+
random_state = get_random_state()
|
| 78 |
+
return random_state.randint(low, high, size, dtype)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def random(size: Optional[NumType] = None, random_state: Optional[np.random.RandomState] = None) -> Any:
|
| 82 |
+
if random_state is None:
|
| 83 |
+
random_state = get_random_state()
|
| 84 |
+
return random_state.random(size) # type: ignore
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def choice(
|
| 88 |
+
a: NumType,
|
| 89 |
+
size: Optional[Size] = None,
|
| 90 |
+
replace: bool = True,
|
| 91 |
+
p: Optional[Union[Sequence[float], np.ndarray]] = None,
|
| 92 |
+
random_state: Optional[np.random.RandomState] = None,
|
| 93 |
+
) -> Any:
|
| 94 |
+
if random_state is None:
|
| 95 |
+
random_state = get_random_state()
|
| 96 |
+
return random_state.choice(a, size, replace, p) # type: ignore
|
custom_controlnet_aux/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
#Dummy file ensuring this package will be recognized
|
custom_controlnet_aux/anime_face_segment/__init__.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .network import UNet
|
| 2 |
+
from .util import seg2img
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
import cv2
|
| 6 |
+
from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME
|
| 7 |
+
from huggingface_hub import hf_hub_download
|
| 8 |
+
from PIL import Image
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
from .anime_segmentation import AnimeSegmentation
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
class AnimeFaceSegmentor:
|
| 14 |
+
def __init__(self, model, seg_model):
|
| 15 |
+
self.model = model
|
| 16 |
+
self.seg_model = seg_model
|
| 17 |
+
self.device = "cpu"
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"):
|
| 21 |
+
model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators")
|
| 22 |
+
seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename)
|
| 23 |
+
|
| 24 |
+
model = UNet()
|
| 25 |
+
ckpt = torch.load(model_path, map_location="cpu")
|
| 26 |
+
model.load_state_dict(ckpt)
|
| 27 |
+
model.eval()
|
| 28 |
+
|
| 29 |
+
seg_model = AnimeSegmentation(seg_model_path)
|
| 30 |
+
seg_model.net.eval()
|
| 31 |
+
return cls(model, seg_model)
|
| 32 |
+
|
| 33 |
+
def to(self, device):
|
| 34 |
+
self.model.to(device)
|
| 35 |
+
self.seg_model.net.to(device)
|
| 36 |
+
self.device = device
|
| 37 |
+
return self
|
| 38 |
+
|
| 39 |
+
def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs):
|
| 40 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 41 |
+
input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 42 |
+
|
| 43 |
+
with torch.no_grad():
|
| 44 |
+
if remove_background:
|
| 45 |
+
print(input_image.shape)
|
| 46 |
+
mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized
|
| 47 |
+
image_feed = torch.from_numpy(input_image).float().to(self.device)
|
| 48 |
+
image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
|
| 49 |
+
image_feed = image_feed / 255
|
| 50 |
+
seg = self.model(image_feed).squeeze(dim=0)
|
| 51 |
+
result = seg2img(seg.cpu().detach().numpy())
|
| 52 |
+
|
| 53 |
+
detected_map = HWC3(result)
|
| 54 |
+
detected_map = remove_pad(detected_map)
|
| 55 |
+
if remove_background:
|
| 56 |
+
mask = remove_pad(mask)
|
| 57 |
+
H, W, C = detected_map.shape
|
| 58 |
+
tmp = np.zeros([H, W, C + 1])
|
| 59 |
+
tmp[:,:,:C] = detected_map
|
| 60 |
+
tmp[:,:,3:] = mask
|
| 61 |
+
detected_map = tmp
|
| 62 |
+
|
| 63 |
+
if output_type == "pil":
|
| 64 |
+
detected_map = Image.fromarray(detected_map[..., :3])
|
| 65 |
+
|
| 66 |
+
return detected_map
|
custom_controlnet_aux/anime_face_segment/anime_segmentation.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#https://github.com/SkyTNT/anime-segmentation/tree/main
|
| 2 |
+
#Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch
|
| 5 |
+
from .isnet import ISNetDIS
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
from comfy.model_management import get_torch_device
|
| 9 |
+
DEVICE = get_torch_device()
|
| 10 |
+
|
| 11 |
+
class AnimeSegmentation:
|
| 12 |
+
def __init__(self, ckpt_path):
|
| 13 |
+
super(AnimeSegmentation).__init__()
|
| 14 |
+
sd = torch.load(ckpt_path, map_location="cpu")
|
| 15 |
+
self.net = ISNetDIS()
|
| 16 |
+
#gt_encoder isn't used during inference
|
| 17 |
+
self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")})
|
| 18 |
+
self.net = self.net.to(DEVICE)
|
| 19 |
+
self.net.eval()
|
| 20 |
+
|
| 21 |
+
def get_mask(self, input_img, s=640):
|
| 22 |
+
input_img = (input_img / 255).astype(np.float32)
|
| 23 |
+
if s == 0:
|
| 24 |
+
img_input = np.transpose(input_img, (2, 0, 1))
|
| 25 |
+
img_input = img_input[np.newaxis, :]
|
| 26 |
+
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
|
| 27 |
+
with torch.no_grad():
|
| 28 |
+
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
|
| 29 |
+
pred = pred.cpu().numpy()[0]
|
| 30 |
+
pred = np.transpose(pred, (1, 2, 0))
|
| 31 |
+
#pred = pred[:, :, np.newaxis]
|
| 32 |
+
return pred
|
| 33 |
+
|
| 34 |
+
h, w = h0, w0 = input_img.shape[:-1]
|
| 35 |
+
h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
|
| 36 |
+
ph, pw = s - h, s - w
|
| 37 |
+
img_input = np.zeros([s, s, 3], dtype=np.float32)
|
| 38 |
+
img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h))
|
| 39 |
+
img_input = np.transpose(img_input, (2, 0, 1))
|
| 40 |
+
img_input = img_input[np.newaxis, :]
|
| 41 |
+
tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
|
| 44 |
+
pred = pred.cpu().numpy()[0]
|
| 45 |
+
pred = np.transpose(pred, (1, 2, 0))
|
| 46 |
+
pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
|
| 47 |
+
#pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]
|
| 48 |
+
pred = cv2.resize(pred, (w0, h0))
|
| 49 |
+
return pred
|
| 50 |
+
|
| 51 |
+
def __call__(self, np_img, img_size):
|
| 52 |
+
mask = self.get_mask(np_img, int(img_size))
|
| 53 |
+
np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8)
|
| 54 |
+
mask = (mask * 255).astype(np.uint8)
|
| 55 |
+
#np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8)
|
| 56 |
+
#mask = mask.repeat(3, axis=2)
|
| 57 |
+
return mask, np_img
|
| 58 |
+
|
custom_controlnet_aux/anime_face_segment/isnet.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Codes are borrowed from
|
| 2 |
+
# https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from torchvision import models
|
| 8 |
+
|
| 9 |
+
bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def muti_loss_fusion(preds, target):
|
| 13 |
+
loss0 = 0.0
|
| 14 |
+
loss = 0.0
|
| 15 |
+
|
| 16 |
+
for i in range(0, len(preds)):
|
| 17 |
+
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
| 18 |
+
tmp_target = F.interpolate(
|
| 19 |
+
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
|
| 20 |
+
)
|
| 21 |
+
loss = loss + bce_loss(preds[i], tmp_target)
|
| 22 |
+
else:
|
| 23 |
+
loss = loss + bce_loss(preds[i], target)
|
| 24 |
+
if i == 0:
|
| 25 |
+
loss0 = loss
|
| 26 |
+
return loss0, loss
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
fea_loss = nn.MSELoss(reduction="mean")
|
| 30 |
+
kl_loss = nn.KLDivLoss(reduction="mean")
|
| 31 |
+
l1_loss = nn.L1Loss(reduction="mean")
|
| 32 |
+
smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
|
| 36 |
+
loss0 = 0.0
|
| 37 |
+
loss = 0.0
|
| 38 |
+
|
| 39 |
+
for i in range(0, len(preds)):
|
| 40 |
+
if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
|
| 41 |
+
tmp_target = F.interpolate(
|
| 42 |
+
target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
|
| 43 |
+
)
|
| 44 |
+
loss = loss + bce_loss(preds[i], tmp_target)
|
| 45 |
+
else:
|
| 46 |
+
loss = loss + bce_loss(preds[i], target)
|
| 47 |
+
if i == 0:
|
| 48 |
+
loss0 = loss
|
| 49 |
+
|
| 50 |
+
for i in range(0, len(dfs)):
|
| 51 |
+
df = dfs[i]
|
| 52 |
+
fs_i = fs[i]
|
| 53 |
+
if mode == "MSE":
|
| 54 |
+
loss = loss + fea_loss(
|
| 55 |
+
df, fs_i
|
| 56 |
+
) ### add the mse loss of features as additional constraints
|
| 57 |
+
elif mode == "KL":
|
| 58 |
+
loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
|
| 59 |
+
elif mode == "MAE":
|
| 60 |
+
loss = loss + l1_loss(df, fs_i)
|
| 61 |
+
elif mode == "SmoothL1":
|
| 62 |
+
loss = loss + smooth_l1_loss(df, fs_i)
|
| 63 |
+
|
| 64 |
+
return loss0, loss
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class REBNCONV(nn.Module):
|
| 68 |
+
def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
|
| 69 |
+
super(REBNCONV, self).__init__()
|
| 70 |
+
|
| 71 |
+
self.conv_s1 = nn.Conv2d(
|
| 72 |
+
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
|
| 73 |
+
)
|
| 74 |
+
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
| 75 |
+
self.relu_s1 = nn.ReLU(inplace=True)
|
| 76 |
+
|
| 77 |
+
def forward(self, x):
|
| 78 |
+
hx = x
|
| 79 |
+
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
| 80 |
+
|
| 81 |
+
return xout
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
## upsample tensor 'src' to have the same spatial size with tensor 'tar'
|
| 85 |
+
def _upsample_like(src, tar):
|
| 86 |
+
src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
|
| 87 |
+
|
| 88 |
+
return src
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
### RSU-7 ###
|
| 92 |
+
class RSU7(nn.Module):
|
| 93 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
|
| 94 |
+
super(RSU7, self).__init__()
|
| 95 |
+
|
| 96 |
+
self.in_ch = in_ch
|
| 97 |
+
self.mid_ch = mid_ch
|
| 98 |
+
self.out_ch = out_ch
|
| 99 |
+
|
| 100 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
|
| 101 |
+
|
| 102 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 103 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 104 |
+
|
| 105 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 106 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 107 |
+
|
| 108 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 109 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 110 |
+
|
| 111 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 112 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 113 |
+
|
| 114 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 115 |
+
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 116 |
+
|
| 117 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 118 |
+
|
| 119 |
+
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 120 |
+
|
| 121 |
+
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 122 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 123 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 124 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 125 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 126 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 127 |
+
|
| 128 |
+
def forward(self, x):
|
| 129 |
+
b, c, h, w = x.shape
|
| 130 |
+
|
| 131 |
+
hx = x
|
| 132 |
+
hxin = self.rebnconvin(hx)
|
| 133 |
+
|
| 134 |
+
hx1 = self.rebnconv1(hxin)
|
| 135 |
+
hx = self.pool1(hx1)
|
| 136 |
+
|
| 137 |
+
hx2 = self.rebnconv2(hx)
|
| 138 |
+
hx = self.pool2(hx2)
|
| 139 |
+
|
| 140 |
+
hx3 = self.rebnconv3(hx)
|
| 141 |
+
hx = self.pool3(hx3)
|
| 142 |
+
|
| 143 |
+
hx4 = self.rebnconv4(hx)
|
| 144 |
+
hx = self.pool4(hx4)
|
| 145 |
+
|
| 146 |
+
hx5 = self.rebnconv5(hx)
|
| 147 |
+
hx = self.pool5(hx5)
|
| 148 |
+
|
| 149 |
+
hx6 = self.rebnconv6(hx)
|
| 150 |
+
|
| 151 |
+
hx7 = self.rebnconv7(hx6)
|
| 152 |
+
|
| 153 |
+
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
| 154 |
+
hx6dup = _upsample_like(hx6d, hx5)
|
| 155 |
+
|
| 156 |
+
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
| 157 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 158 |
+
|
| 159 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 160 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 161 |
+
|
| 162 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 163 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 164 |
+
|
| 165 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 166 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 167 |
+
|
| 168 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 169 |
+
|
| 170 |
+
return hx1d + hxin
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
### RSU-6 ###
|
| 174 |
+
class RSU6(nn.Module):
|
| 175 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 176 |
+
super(RSU6, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 179 |
+
|
| 180 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 181 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 182 |
+
|
| 183 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 184 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 185 |
+
|
| 186 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 187 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 188 |
+
|
| 189 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 190 |
+
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 191 |
+
|
| 192 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 193 |
+
|
| 194 |
+
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 195 |
+
|
| 196 |
+
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 197 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 198 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 199 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 200 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 201 |
+
|
| 202 |
+
def forward(self, x):
|
| 203 |
+
hx = x
|
| 204 |
+
|
| 205 |
+
hxin = self.rebnconvin(hx)
|
| 206 |
+
|
| 207 |
+
hx1 = self.rebnconv1(hxin)
|
| 208 |
+
hx = self.pool1(hx1)
|
| 209 |
+
|
| 210 |
+
hx2 = self.rebnconv2(hx)
|
| 211 |
+
hx = self.pool2(hx2)
|
| 212 |
+
|
| 213 |
+
hx3 = self.rebnconv3(hx)
|
| 214 |
+
hx = self.pool3(hx3)
|
| 215 |
+
|
| 216 |
+
hx4 = self.rebnconv4(hx)
|
| 217 |
+
hx = self.pool4(hx4)
|
| 218 |
+
|
| 219 |
+
hx5 = self.rebnconv5(hx)
|
| 220 |
+
|
| 221 |
+
hx6 = self.rebnconv6(hx5)
|
| 222 |
+
|
| 223 |
+
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
| 224 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 225 |
+
|
| 226 |
+
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
| 227 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 228 |
+
|
| 229 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 230 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 231 |
+
|
| 232 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 233 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 234 |
+
|
| 235 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 236 |
+
|
| 237 |
+
return hx1d + hxin
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
### RSU-5 ###
|
| 241 |
+
class RSU5(nn.Module):
|
| 242 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 243 |
+
super(RSU5, self).__init__()
|
| 244 |
+
|
| 245 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 246 |
+
|
| 247 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 248 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 249 |
+
|
| 250 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 251 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 252 |
+
|
| 253 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 254 |
+
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 255 |
+
|
| 256 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 257 |
+
|
| 258 |
+
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 259 |
+
|
| 260 |
+
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 261 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 262 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 263 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 264 |
+
|
| 265 |
+
def forward(self, x):
|
| 266 |
+
hx = x
|
| 267 |
+
|
| 268 |
+
hxin = self.rebnconvin(hx)
|
| 269 |
+
|
| 270 |
+
hx1 = self.rebnconv1(hxin)
|
| 271 |
+
hx = self.pool1(hx1)
|
| 272 |
+
|
| 273 |
+
hx2 = self.rebnconv2(hx)
|
| 274 |
+
hx = self.pool2(hx2)
|
| 275 |
+
|
| 276 |
+
hx3 = self.rebnconv3(hx)
|
| 277 |
+
hx = self.pool3(hx3)
|
| 278 |
+
|
| 279 |
+
hx4 = self.rebnconv4(hx)
|
| 280 |
+
|
| 281 |
+
hx5 = self.rebnconv5(hx4)
|
| 282 |
+
|
| 283 |
+
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
| 284 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 285 |
+
|
| 286 |
+
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
| 287 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 288 |
+
|
| 289 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 290 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 291 |
+
|
| 292 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 293 |
+
|
| 294 |
+
return hx1d + hxin
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
### RSU-4 ###
|
| 298 |
+
class RSU4(nn.Module):
|
| 299 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 300 |
+
super(RSU4, self).__init__()
|
| 301 |
+
|
| 302 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 303 |
+
|
| 304 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 305 |
+
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 306 |
+
|
| 307 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 308 |
+
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 309 |
+
|
| 310 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
| 311 |
+
|
| 312 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 313 |
+
|
| 314 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 315 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
| 316 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 317 |
+
|
| 318 |
+
def forward(self, x):
|
| 319 |
+
hx = x
|
| 320 |
+
|
| 321 |
+
hxin = self.rebnconvin(hx)
|
| 322 |
+
|
| 323 |
+
hx1 = self.rebnconv1(hxin)
|
| 324 |
+
hx = self.pool1(hx1)
|
| 325 |
+
|
| 326 |
+
hx2 = self.rebnconv2(hx)
|
| 327 |
+
hx = self.pool2(hx2)
|
| 328 |
+
|
| 329 |
+
hx3 = self.rebnconv3(hx)
|
| 330 |
+
|
| 331 |
+
hx4 = self.rebnconv4(hx3)
|
| 332 |
+
|
| 333 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 334 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 335 |
+
|
| 336 |
+
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
| 337 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 338 |
+
|
| 339 |
+
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
| 340 |
+
|
| 341 |
+
return hx1d + hxin
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
### RSU-4F ###
|
| 345 |
+
class RSU4F(nn.Module):
|
| 346 |
+
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
| 347 |
+
super(RSU4F, self).__init__()
|
| 348 |
+
|
| 349 |
+
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
| 350 |
+
|
| 351 |
+
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
| 352 |
+
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
| 353 |
+
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
| 354 |
+
|
| 355 |
+
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
| 356 |
+
|
| 357 |
+
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
| 358 |
+
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
| 359 |
+
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
| 360 |
+
|
| 361 |
+
def forward(self, x):
|
| 362 |
+
hx = x
|
| 363 |
+
|
| 364 |
+
hxin = self.rebnconvin(hx)
|
| 365 |
+
|
| 366 |
+
hx1 = self.rebnconv1(hxin)
|
| 367 |
+
hx2 = self.rebnconv2(hx1)
|
| 368 |
+
hx3 = self.rebnconv3(hx2)
|
| 369 |
+
|
| 370 |
+
hx4 = self.rebnconv4(hx3)
|
| 371 |
+
|
| 372 |
+
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
| 373 |
+
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
| 374 |
+
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
| 375 |
+
|
| 376 |
+
return hx1d + hxin
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class myrebnconv(nn.Module):
|
| 380 |
+
def __init__(
|
| 381 |
+
self,
|
| 382 |
+
in_ch=3,
|
| 383 |
+
out_ch=1,
|
| 384 |
+
kernel_size=3,
|
| 385 |
+
stride=1,
|
| 386 |
+
padding=1,
|
| 387 |
+
dilation=1,
|
| 388 |
+
groups=1,
|
| 389 |
+
):
|
| 390 |
+
super(myrebnconv, self).__init__()
|
| 391 |
+
|
| 392 |
+
self.conv = nn.Conv2d(
|
| 393 |
+
in_ch,
|
| 394 |
+
out_ch,
|
| 395 |
+
kernel_size=kernel_size,
|
| 396 |
+
stride=stride,
|
| 397 |
+
padding=padding,
|
| 398 |
+
dilation=dilation,
|
| 399 |
+
groups=groups,
|
| 400 |
+
)
|
| 401 |
+
self.bn = nn.BatchNorm2d(out_ch)
|
| 402 |
+
self.rl = nn.ReLU(inplace=True)
|
| 403 |
+
|
| 404 |
+
def forward(self, x):
|
| 405 |
+
return self.rl(self.bn(self.conv(x)))
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
class ISNetGTEncoder(nn.Module):
|
| 409 |
+
def __init__(self, in_ch=1, out_ch=1):
|
| 410 |
+
super(ISNetGTEncoder, self).__init__()
|
| 411 |
+
|
| 412 |
+
self.conv_in = myrebnconv(
|
| 413 |
+
in_ch, 16, 3, stride=2, padding=1
|
| 414 |
+
) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
|
| 415 |
+
|
| 416 |
+
self.stage1 = RSU7(16, 16, 64)
|
| 417 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 418 |
+
|
| 419 |
+
self.stage2 = RSU6(64, 16, 64)
|
| 420 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 421 |
+
|
| 422 |
+
self.stage3 = RSU5(64, 32, 128)
|
| 423 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 424 |
+
|
| 425 |
+
self.stage4 = RSU4(128, 32, 256)
|
| 426 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 427 |
+
|
| 428 |
+
self.stage5 = RSU4F(256, 64, 512)
|
| 429 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 430 |
+
|
| 431 |
+
self.stage6 = RSU4F(512, 64, 512)
|
| 432 |
+
|
| 433 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 434 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 435 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 436 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 437 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 438 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 439 |
+
|
| 440 |
+
@staticmethod
|
| 441 |
+
def compute_loss(args):
|
| 442 |
+
preds, targets = args
|
| 443 |
+
return muti_loss_fusion(preds, targets)
|
| 444 |
+
|
| 445 |
+
def forward(self, x):
|
| 446 |
+
hx = x
|
| 447 |
+
|
| 448 |
+
hxin = self.conv_in(hx)
|
| 449 |
+
# hx = self.pool_in(hxin)
|
| 450 |
+
|
| 451 |
+
# stage 1
|
| 452 |
+
hx1 = self.stage1(hxin)
|
| 453 |
+
hx = self.pool12(hx1)
|
| 454 |
+
|
| 455 |
+
# stage 2
|
| 456 |
+
hx2 = self.stage2(hx)
|
| 457 |
+
hx = self.pool23(hx2)
|
| 458 |
+
|
| 459 |
+
# stage 3
|
| 460 |
+
hx3 = self.stage3(hx)
|
| 461 |
+
hx = self.pool34(hx3)
|
| 462 |
+
|
| 463 |
+
# stage 4
|
| 464 |
+
hx4 = self.stage4(hx)
|
| 465 |
+
hx = self.pool45(hx4)
|
| 466 |
+
|
| 467 |
+
# stage 5
|
| 468 |
+
hx5 = self.stage5(hx)
|
| 469 |
+
hx = self.pool56(hx5)
|
| 470 |
+
|
| 471 |
+
# stage 6
|
| 472 |
+
hx6 = self.stage6(hx)
|
| 473 |
+
|
| 474 |
+
# side output
|
| 475 |
+
d1 = self.side1(hx1)
|
| 476 |
+
d1 = _upsample_like(d1, x)
|
| 477 |
+
|
| 478 |
+
d2 = self.side2(hx2)
|
| 479 |
+
d2 = _upsample_like(d2, x)
|
| 480 |
+
|
| 481 |
+
d3 = self.side3(hx3)
|
| 482 |
+
d3 = _upsample_like(d3, x)
|
| 483 |
+
|
| 484 |
+
d4 = self.side4(hx4)
|
| 485 |
+
d4 = _upsample_like(d4, x)
|
| 486 |
+
|
| 487 |
+
d5 = self.side5(hx5)
|
| 488 |
+
d5 = _upsample_like(d5, x)
|
| 489 |
+
|
| 490 |
+
d6 = self.side6(hx6)
|
| 491 |
+
d6 = _upsample_like(d6, x)
|
| 492 |
+
|
| 493 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
| 494 |
+
|
| 495 |
+
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
|
| 496 |
+
return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
class ISNetDIS(nn.Module):
|
| 500 |
+
def __init__(self, in_ch=3, out_ch=1):
|
| 501 |
+
super(ISNetDIS, self).__init__()
|
| 502 |
+
|
| 503 |
+
self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
|
| 504 |
+
self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 505 |
+
|
| 506 |
+
self.stage1 = RSU7(64, 32, 64)
|
| 507 |
+
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 508 |
+
|
| 509 |
+
self.stage2 = RSU6(64, 32, 128)
|
| 510 |
+
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 511 |
+
|
| 512 |
+
self.stage3 = RSU5(128, 64, 256)
|
| 513 |
+
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 514 |
+
|
| 515 |
+
self.stage4 = RSU4(256, 128, 512)
|
| 516 |
+
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 517 |
+
|
| 518 |
+
self.stage5 = RSU4F(512, 256, 512)
|
| 519 |
+
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
| 520 |
+
|
| 521 |
+
self.stage6 = RSU4F(512, 256, 512)
|
| 522 |
+
|
| 523 |
+
# decoder
|
| 524 |
+
self.stage5d = RSU4F(1024, 256, 512)
|
| 525 |
+
self.stage4d = RSU4(1024, 128, 256)
|
| 526 |
+
self.stage3d = RSU5(512, 64, 128)
|
| 527 |
+
self.stage2d = RSU6(256, 32, 64)
|
| 528 |
+
self.stage1d = RSU7(128, 16, 64)
|
| 529 |
+
|
| 530 |
+
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 531 |
+
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
| 532 |
+
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
| 533 |
+
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
| 534 |
+
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 535 |
+
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
| 536 |
+
|
| 537 |
+
# self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
|
| 538 |
+
|
| 539 |
+
@staticmethod
|
| 540 |
+
def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
|
| 541 |
+
return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
|
| 542 |
+
|
| 543 |
+
@staticmethod
|
| 544 |
+
def compute_loss(args):
|
| 545 |
+
if len(args) == 3:
|
| 546 |
+
ds, dfs, labels = args
|
| 547 |
+
return muti_loss_fusion(ds, labels)
|
| 548 |
+
else:
|
| 549 |
+
ds, dfs, labels, fs = args
|
| 550 |
+
return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
|
| 551 |
+
|
| 552 |
+
def forward(self, x):
|
| 553 |
+
hx = x
|
| 554 |
+
|
| 555 |
+
hxin = self.conv_in(hx)
|
| 556 |
+
hx = self.pool_in(hxin)
|
| 557 |
+
|
| 558 |
+
# stage 1
|
| 559 |
+
hx1 = self.stage1(hxin)
|
| 560 |
+
hx = self.pool12(hx1)
|
| 561 |
+
|
| 562 |
+
# stage 2
|
| 563 |
+
hx2 = self.stage2(hx)
|
| 564 |
+
hx = self.pool23(hx2)
|
| 565 |
+
|
| 566 |
+
# stage 3
|
| 567 |
+
hx3 = self.stage3(hx)
|
| 568 |
+
hx = self.pool34(hx3)
|
| 569 |
+
|
| 570 |
+
# stage 4
|
| 571 |
+
hx4 = self.stage4(hx)
|
| 572 |
+
hx = self.pool45(hx4)
|
| 573 |
+
|
| 574 |
+
# stage 5
|
| 575 |
+
hx5 = self.stage5(hx)
|
| 576 |
+
hx = self.pool56(hx5)
|
| 577 |
+
|
| 578 |
+
# stage 6
|
| 579 |
+
hx6 = self.stage6(hx)
|
| 580 |
+
hx6up = _upsample_like(hx6, hx5)
|
| 581 |
+
|
| 582 |
+
# -------------------- decoder --------------------
|
| 583 |
+
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
| 584 |
+
hx5dup = _upsample_like(hx5d, hx4)
|
| 585 |
+
|
| 586 |
+
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
| 587 |
+
hx4dup = _upsample_like(hx4d, hx3)
|
| 588 |
+
|
| 589 |
+
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
| 590 |
+
hx3dup = _upsample_like(hx3d, hx2)
|
| 591 |
+
|
| 592 |
+
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
| 593 |
+
hx2dup = _upsample_like(hx2d, hx1)
|
| 594 |
+
|
| 595 |
+
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
| 596 |
+
|
| 597 |
+
# side output
|
| 598 |
+
d1 = self.side1(hx1d)
|
| 599 |
+
d1 = _upsample_like(d1, x)
|
| 600 |
+
|
| 601 |
+
d2 = self.side2(hx2d)
|
| 602 |
+
d2 = _upsample_like(d2, x)
|
| 603 |
+
|
| 604 |
+
d3 = self.side3(hx3d)
|
| 605 |
+
d3 = _upsample_like(d3, x)
|
| 606 |
+
|
| 607 |
+
d4 = self.side4(hx4d)
|
| 608 |
+
d4 = _upsample_like(d4, x)
|
| 609 |
+
|
| 610 |
+
d5 = self.side5(hx5d)
|
| 611 |
+
d5 = _upsample_like(d5, x)
|
| 612 |
+
|
| 613 |
+
d6 = self.side6(hx6)
|
| 614 |
+
d6 = _upsample_like(d6, x)
|
| 615 |
+
|
| 616 |
+
# d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
|
| 617 |
+
|
| 618 |
+
# return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
| 619 |
+
return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
|
custom_controlnet_aux/anime_face_segment/network.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import torchvision
|
| 6 |
+
|
| 7 |
+
from custom_controlnet_aux.util import custom_torch_download
|
| 8 |
+
|
| 9 |
+
class UNet(nn.Module):
|
| 10 |
+
def __init__(self):
|
| 11 |
+
super(UNet, self).__init__()
|
| 12 |
+
self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
|
| 13 |
+
|
| 14 |
+
mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
|
| 15 |
+
mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
|
| 16 |
+
mob_blocks = mobilenet_v2.features
|
| 17 |
+
|
| 18 |
+
# Encoder
|
| 19 |
+
self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
|
| 20 |
+
mob_blocks[0],
|
| 21 |
+
mob_blocks[1]
|
| 22 |
+
)
|
| 23 |
+
self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
|
| 24 |
+
mob_blocks[2],
|
| 25 |
+
mob_blocks[3],
|
| 26 |
+
)
|
| 27 |
+
self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
|
| 28 |
+
mob_blocks[4],
|
| 29 |
+
mob_blocks[5],
|
| 30 |
+
mob_blocks[6],
|
| 31 |
+
)
|
| 32 |
+
self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
|
| 33 |
+
mob_blocks[7],
|
| 34 |
+
mob_blocks[8],
|
| 35 |
+
mob_blocks[9],
|
| 36 |
+
mob_blocks[10],
|
| 37 |
+
mob_blocks[11],
|
| 38 |
+
mob_blocks[12],
|
| 39 |
+
mob_blocks[13],
|
| 40 |
+
)
|
| 41 |
+
self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
|
| 42 |
+
mob_blocks[14],
|
| 43 |
+
mob_blocks[15],
|
| 44 |
+
mob_blocks[16],
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
# Decoder
|
| 48 |
+
self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
|
| 49 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
| 50 |
+
nn.Conv2d(160, 96, kernel_size=3, padding=1),
|
| 51 |
+
nn.InstanceNorm2d(96),
|
| 52 |
+
nn.LeakyReLU(0.1),
|
| 53 |
+
nn.Dropout(p=0.2)
|
| 54 |
+
)
|
| 55 |
+
self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
|
| 56 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
| 57 |
+
nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
|
| 58 |
+
nn.InstanceNorm2d(32),
|
| 59 |
+
nn.LeakyReLU(0.1),
|
| 60 |
+
nn.Dropout(p=0.2)
|
| 61 |
+
)
|
| 62 |
+
self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
|
| 63 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
| 64 |
+
nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
|
| 65 |
+
nn.InstanceNorm2d(24),
|
| 66 |
+
nn.LeakyReLU(0.1),
|
| 67 |
+
nn.Dropout(p=0.2)
|
| 68 |
+
)
|
| 69 |
+
self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
|
| 70 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
| 71 |
+
nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
|
| 72 |
+
nn.InstanceNorm2d(16),
|
| 73 |
+
nn.LeakyReLU(0.1),
|
| 74 |
+
nn.Dropout(p=0.2)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
|
| 78 |
+
nn.UpsamplingNearest2d(scale_factor=2),
|
| 79 |
+
nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
|
| 80 |
+
nn.Softmax2d()
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
e0 = self.en_block0(x)
|
| 85 |
+
e1 = self.en_block1(e0)
|
| 86 |
+
e2 = self.en_block2(e1)
|
| 87 |
+
e3 = self.en_block3(e2)
|
| 88 |
+
e4 = self.en_block4(e3)
|
| 89 |
+
|
| 90 |
+
d4 = self.de_block4(e4)
|
| 91 |
+
c4 = torch.cat((d4,e3),1)
|
| 92 |
+
d3 = self.de_block3(c4)
|
| 93 |
+
c3 = torch.cat((d3,e2),1)
|
| 94 |
+
d2 = self.de_block2(c3)
|
| 95 |
+
c2 =torch.cat((d2,e1),1)
|
| 96 |
+
d1 = self.de_block1(c2)
|
| 97 |
+
c1 = torch.cat((d1,e0),1)
|
| 98 |
+
y = self.de_block0(c1)
|
| 99 |
+
|
| 100 |
+
return y
|
custom_controlnet_aux/anime_face_segment/util.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py
|
| 2 |
+
#The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py
|
| 3 |
+
import cv2 as cv
|
| 4 |
+
import glob
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
"""
|
| 9 |
+
COLOR_BACKGROUND = (0,255,255)
|
| 10 |
+
COLOR_HAIR = (255,0,0)
|
| 11 |
+
COLOR_EYE = (0,0,255)
|
| 12 |
+
COLOR_MOUTH = (255,255,255)
|
| 13 |
+
COLOR_FACE = (0,255,0)
|
| 14 |
+
COLOR_SKIN = (255,255,0)
|
| 15 |
+
COLOR_CLOTHES = (255,0,255)
|
| 16 |
+
"""
|
| 17 |
+
COLOR_BACKGROUND = (255,255,0)
|
| 18 |
+
COLOR_HAIR = (0,0,255)
|
| 19 |
+
COLOR_EYE = (255,0,0)
|
| 20 |
+
COLOR_MOUTH = (255,255,255)
|
| 21 |
+
COLOR_FACE = (0,255,0)
|
| 22 |
+
COLOR_SKIN = (0,255,255)
|
| 23 |
+
COLOR_CLOTHES = (255,0,255)
|
| 24 |
+
PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
|
| 25 |
+
|
| 26 |
+
def img2seg(path):
|
| 27 |
+
src = cv.imread(path)
|
| 28 |
+
src = src.reshape(-1, 3)
|
| 29 |
+
seg_list = []
|
| 30 |
+
for color in PALETTE:
|
| 31 |
+
seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0))
|
| 32 |
+
dst = np.stack(seg_list,axis=1).reshape(512,512,7)
|
| 33 |
+
|
| 34 |
+
return dst.astype(np.float32)
|
| 35 |
+
|
| 36 |
+
def seg2img(src):
|
| 37 |
+
src = np.moveaxis(src,0,2)
|
| 38 |
+
dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src]
|
| 39 |
+
|
| 40 |
+
return np.array(dst).astype(np.uint8)
|
custom_controlnet_aux/binary/__init__.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from custom_controlnet_aux.util import HWC3, resize_image_with_pad
|
| 6 |
+
|
| 7 |
+
class BinaryDetector:
|
| 8 |
+
def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 9 |
+
if "img" in kwargs:
|
| 10 |
+
warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
|
| 11 |
+
input_image = kwargs.pop("img")
|
| 12 |
+
|
| 13 |
+
if input_image is None:
|
| 14 |
+
raise ValueError("input_image must be defined.")
|
| 15 |
+
|
| 16 |
+
if not isinstance(input_image, np.ndarray):
|
| 17 |
+
input_image = np.array(input_image, dtype=np.uint8)
|
| 18 |
+
output_type = output_type or "pil"
|
| 19 |
+
else:
|
| 20 |
+
output_type = output_type or "np"
|
| 21 |
+
|
| 22 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 23 |
+
|
| 24 |
+
img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
|
| 25 |
+
if bin_threshold == 0 or bin_threshold == 255:
|
| 26 |
+
# Otsu's threshold
|
| 27 |
+
otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
| 28 |
+
print("Otsu threshold:", otsu_threshold)
|
| 29 |
+
else:
|
| 30 |
+
_, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
|
| 31 |
+
|
| 32 |
+
detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
|
| 33 |
+
detected_map = HWC3(remove_pad(255 - detected_map))
|
| 34 |
+
|
| 35 |
+
if output_type == "pil":
|
| 36 |
+
detected_map = Image.fromarray(detected_map)
|
| 37 |
+
|
| 38 |
+
return detected_map
|
custom_controlnet_aux/canny/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
import cv2
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
|
| 6 |
+
|
| 7 |
+
class CannyDetector:
|
| 8 |
+
def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
|
| 9 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 10 |
+
detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
|
| 11 |
+
detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
|
| 12 |
+
detected_map = HWC3(remove_pad(detected_map))
|
| 13 |
+
|
| 14 |
+
if output_type == "pil":
|
| 15 |
+
detected_map = Image.fromarray(detected_map)
|
| 16 |
+
|
| 17 |
+
return detected_map
|
custom_controlnet_aux/color/__init__.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import warnings
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate
|
| 7 |
+
|
| 8 |
+
def cv2_resize_shortest_edge(image, size):
|
| 9 |
+
h, w = image.shape[:2]
|
| 10 |
+
if h < w:
|
| 11 |
+
new_h = size
|
| 12 |
+
new_w = int(round(w / h * size))
|
| 13 |
+
else:
|
| 14 |
+
new_w = size
|
| 15 |
+
new_h = int(round(h / w * size))
|
| 16 |
+
resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
|
| 17 |
+
return resized_image
|
| 18 |
+
|
| 19 |
+
def apply_color(img, res=512):
|
| 20 |
+
img = cv2_resize_shortest_edge(img, res)
|
| 21 |
+
h, w = img.shape[:2]
|
| 22 |
+
|
| 23 |
+
input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
|
| 24 |
+
input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
|
| 25 |
+
return input_img_color
|
| 26 |
+
|
| 27 |
+
#Color T2I like multiples-of-64, upscale methods are fixed.
|
| 28 |
+
class ColorDetector:
|
| 29 |
+
def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
|
| 30 |
+
input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
|
| 31 |
+
input_image = HWC3(input_image)
|
| 32 |
+
detected_map = HWC3(apply_color(input_image, detect_resolution))
|
| 33 |
+
|
| 34 |
+
if output_type == "pil":
|
| 35 |
+
detected_map = Image.fromarray(detected_map)
|
| 36 |
+
|
| 37 |
+
return detected_map
|