JasonSmithSO commited on
Commit
0034848
·
verified ·
1 Parent(s): 74f4a06

Upload 777 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. __init__.py +1 -214
  3. custom_albumentations/LICENSE +21 -0
  4. custom_albumentations/__init__.py +15 -0
  5. custom_albumentations/augmentations/__init__.py +21 -0
  6. custom_albumentations/augmentations/blur/__init__.py +2 -0
  7. custom_albumentations/augmentations/blur/functional.py +106 -0
  8. custom_albumentations/augmentations/blur/transforms.py +486 -0
  9. custom_albumentations/augmentations/crops/__init__.py +2 -0
  10. custom_albumentations/augmentations/crops/functional.py +317 -0
  11. custom_albumentations/augmentations/crops/transforms.py +943 -0
  12. custom_albumentations/augmentations/domain_adaptation.py +337 -0
  13. custom_albumentations/augmentations/dropout/__init__.py +5 -0
  14. custom_albumentations/augmentations/dropout/channel_dropout.py +72 -0
  15. custom_albumentations/augmentations/dropout/coarse_dropout.py +187 -0
  16. custom_albumentations/augmentations/dropout/cutout.py +79 -0
  17. custom_albumentations/augmentations/dropout/functional.py +29 -0
  18. custom_albumentations/augmentations/dropout/grid_dropout.py +155 -0
  19. custom_albumentations/augmentations/dropout/mask_dropout.py +99 -0
  20. custom_albumentations/augmentations/functional.py +1380 -0
  21. custom_albumentations/augmentations/geometric/__init__.py +4 -0
  22. custom_albumentations/augmentations/geometric/functional.py +1300 -0
  23. custom_albumentations/augmentations/geometric/resize.py +198 -0
  24. custom_albumentations/augmentations/geometric/rotate.py +294 -0
  25. custom_albumentations/augmentations/geometric/transforms.py +1499 -0
  26. custom_albumentations/augmentations/transforms.py +2667 -0
  27. custom_albumentations/augmentations/utils.py +211 -0
  28. custom_albumentations/core/__init__.py +0 -0
  29. custom_albumentations/core/bbox_utils.py +522 -0
  30. custom_albumentations/core/composition.py +552 -0
  31. custom_albumentations/core/keypoints_utils.py +286 -0
  32. custom_albumentations/core/serialization.py +247 -0
  33. custom_albumentations/core/transforms_interface.py +293 -0
  34. custom_albumentations/core/utils.py +137 -0
  35. custom_albumentations/imgaug/__init__.py +0 -0
  36. custom_albumentations/imgaug/stubs.py +77 -0
  37. custom_albumentations/imgaug/transforms.py +391 -0
  38. custom_albumentations/pytorch/__init__.py +3 -0
  39. custom_albumentations/pytorch/functional.py +31 -0
  40. custom_albumentations/pytorch/transforms.py +104 -0
  41. custom_albumentations/random_utils.py +96 -0
  42. custom_controlnet_aux/__init__.py +1 -0
  43. custom_controlnet_aux/anime_face_segment/__init__.py +66 -0
  44. custom_controlnet_aux/anime_face_segment/anime_segmentation.py +58 -0
  45. custom_controlnet_aux/anime_face_segment/isnet.py +619 -0
  46. custom_controlnet_aux/anime_face_segment/network.py +100 -0
  47. custom_controlnet_aux/anime_face_segment/util.py +40 -0
  48. custom_controlnet_aux/binary/__init__.py +38 -0
  49. custom_controlnet_aux/canny/__init__.py +17 -0
  50. custom_controlnet_aux/color/__init__.py +37 -0
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
37
  NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  comfyui_screenshot.png filter=lfs diff=lfs merge=lfs -text
37
  NotoSans-Regular.ttf filter=lfs diff=lfs merge=lfs -text
38
+ custom_controlnet_aux/mesh_graphormer/hand_landmarker.task filter=lfs diff=lfs merge=lfs -text
39
+ custom_controlnet_aux/tests/test_image.png filter=lfs diff=lfs merge=lfs -text
__init__.py CHANGED
@@ -1,214 +1 @@
1
- import sys, os
2
- from .utils import here, define_preprocessor_inputs, INPUT
3
- from pathlib import Path
4
- import traceback
5
- import importlib
6
- from .log import log, blue_text, cyan_text, get_summary, get_label
7
- from .hint_image_enchance import NODE_CLASS_MAPPINGS as HIE_NODE_CLASS_MAPPINGS
8
- from .hint_image_enchance import NODE_DISPLAY_NAME_MAPPINGS as HIE_NODE_DISPLAY_NAME_MAPPINGS
9
- #Ref: https://github.com/comfyanonymous/ComfyUI/blob/76d53c4622fc06372975ed2a43ad345935b8a551/nodes.py#L17
10
- sys.path.insert(0, str(Path(here, "src").resolve()))
11
- for pkg_name in ["custom_controlnet_aux", "custom_mmpkg"]:
12
- sys.path.append(str(Path(here, "src", pkg_name).resolve()))
13
-
14
- #Enable CPU fallback for ops not being supported by MPS like upsample_bicubic2d.out
15
- #https://github.com/pytorch/pytorch/issues/77764
16
- #https://github.com/Fannovel16/comfyui_controlnet_aux/issues/2#issuecomment-1763579485
17
- os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = os.getenv("PYTORCH_ENABLE_MPS_FALLBACK", '1')
18
-
19
-
20
- def load_nodes():
21
- shorted_errors = []
22
- full_error_messages = []
23
- node_class_mappings = {}
24
- node_display_name_mappings = {}
25
-
26
- for filename in (here / "node_wrappers").iterdir():
27
- module_name = filename.stem
28
- if module_name.startswith('.'): continue #Skip hidden files created by the OS (e.g. [.DS_Store](https://en.wikipedia.org/wiki/.DS_Store))
29
- try:
30
- module = importlib.import_module(
31
- f".node_wrappers.{module_name}", package=__package__
32
- )
33
- node_class_mappings.update(getattr(module, "NODE_CLASS_MAPPINGS"))
34
- if hasattr(module, "NODE_DISPLAY_NAME_MAPPINGS"):
35
- node_display_name_mappings.update(getattr(module, "NODE_DISPLAY_NAME_MAPPINGS"))
36
-
37
- log.debug(f"Imported {module_name} nodes")
38
-
39
- except AttributeError:
40
- pass # wip nodes
41
- except Exception:
42
- error_message = traceback.format_exc()
43
- full_error_messages.append(error_message)
44
- error_message = error_message.splitlines()[-1]
45
- shorted_errors.append(
46
- f"Failed to import module {module_name} because {error_message}"
47
- )
48
-
49
- if len(shorted_errors) > 0:
50
- full_err_log = '\n\n'.join(full_error_messages)
51
- print(f"\n\nFull error log from comfyui_controlnet_aux: \n{full_err_log}\n\n")
52
- log.info(
53
- f"Some nodes failed to load:\n\t"
54
- + "\n\t".join(shorted_errors)
55
- + "\n\n"
56
- + "Check that you properly installed the dependencies.\n"
57
- + "If you think this is a bug, please report it on the github page (https://github.com/Fannovel16/comfyui_controlnet_aux/issues)"
58
- )
59
- return node_class_mappings, node_display_name_mappings
60
-
61
- AUX_NODE_MAPPINGS, AUX_DISPLAY_NAME_MAPPINGS = load_nodes()
62
-
63
- #For nodes not mapping image to image or has special requirements
64
- AIO_NOT_SUPPORTED = ["InpaintPreprocessor", "MeshGraphormer+ImpactDetector-DepthMapPreprocessor", "DiffusionEdge_Preprocessor"]
65
- AIO_NOT_SUPPORTED += ["SavePoseKpsAsJsonFile", "FacialPartColoringFromPoseKps", "UpperBodyTrackingFromPoseKps", "RenderPeopleKps", "RenderAnimalKps"]
66
- AIO_NOT_SUPPORTED += ["Unimatch_OptFlowPreprocessor", "MaskOptFlow"]
67
-
68
- def preprocessor_options():
69
- auxs = list(AUX_NODE_MAPPINGS.keys())
70
- auxs.insert(0, "none")
71
- for name in AIO_NOT_SUPPORTED:
72
- if name in auxs:
73
- auxs.remove(name)
74
- return auxs
75
-
76
-
77
- PREPROCESSOR_OPTIONS = preprocessor_options()
78
-
79
- class AIO_Preprocessor:
80
- @classmethod
81
- def INPUT_TYPES(s):
82
- return define_preprocessor_inputs(
83
- preprocessor=INPUT.COMBO(PREPROCESSOR_OPTIONS, default="none"),
84
- resolution=INPUT.RESOLUTION()
85
- )
86
-
87
- RETURN_TYPES = ("IMAGE",)
88
- FUNCTION = "execute"
89
-
90
- CATEGORY = "ControlNet Preprocessors"
91
-
92
- def execute(self, preprocessor, image, resolution=512):
93
- if preprocessor == "none":
94
- return (image, )
95
- else:
96
- aux_class = AUX_NODE_MAPPINGS[preprocessor]
97
- input_types = aux_class.INPUT_TYPES()
98
- input_types = {
99
- **input_types["required"],
100
- **(input_types["optional"] if "optional" in input_types else {})
101
- }
102
- params = {}
103
- for name, input_type in input_types.items():
104
- if name == "image":
105
- params[name] = image
106
- continue
107
-
108
- if name == "resolution":
109
- params[name] = resolution
110
- continue
111
-
112
- if len(input_type) == 2 and ("default" in input_type[1]):
113
- params[name] = input_type[1]["default"]
114
- continue
115
-
116
- default_values = { "INT": 0, "FLOAT": 0.0 }
117
- if input_type[0] in default_values:
118
- params[name] = default_values[input_type[0]]
119
-
120
- return getattr(aux_class(), aux_class.FUNCTION)(**params)
121
-
122
- class ControlNetAuxSimpleAddText:
123
- @classmethod
124
- def INPUT_TYPES(s):
125
- return dict(
126
- required=dict(image=INPUT.IMAGE(), text=INPUT.STRING())
127
- )
128
-
129
- RETURN_TYPES = ("IMAGE",)
130
- FUNCTION = "execute"
131
- CATEGORY = "ControlNet Preprocessors"
132
- def execute(self, image, text):
133
- from PIL import Image, ImageDraw, ImageFont
134
- import numpy as np
135
- import torch
136
-
137
- font = ImageFont.truetype(str((here / "NotoSans-Regular.ttf").resolve()), 40)
138
- img = Image.fromarray(image[0].cpu().numpy().__mul__(255.).astype(np.uint8))
139
- ImageDraw.Draw(img).text((0,0), text, fill=(0,255,0), font=font)
140
- return (torch.from_numpy(np.array(img)).unsqueeze(0) / 255.,)
141
-
142
- class ExecuteAllControlNetPreprocessors:
143
- @classmethod
144
- def INPUT_TYPES(s):
145
- return define_preprocessor_inputs(resolution=INPUT.RESOLUTION())
146
- RETURN_TYPES = ("IMAGE",)
147
- FUNCTION = "execute"
148
-
149
- CATEGORY = "ControlNet Preprocessors"
150
-
151
- def execute(self, image, resolution=512):
152
- try:
153
- from comfy_execution.graph_utils import GraphBuilder
154
- except:
155
- raise RuntimeError("ExecuteAllControlNetPreprocessor requries [Execution Model Inversion](https://github.com/comfyanonymous/ComfyUI/commit/5cfe38). Update ComfyUI/SwarmUI to get this feature")
156
-
157
- graph = GraphBuilder()
158
- curr_outputs = []
159
- for preprocc in PREPROCESSOR_OPTIONS:
160
- preprocc_node = graph.node("AIO_Preprocessor", preprocessor=preprocc, image=image, resolution=resolution)
161
- hint_img = preprocc_node.out(0)
162
- add_text_node = graph.node("ControlNetAuxSimpleAddText", image=hint_img, text=preprocc)
163
- curr_outputs.append(add_text_node.out(0))
164
-
165
- while len(curr_outputs) > 1:
166
- _outputs = []
167
- for i in range(0, len(curr_outputs), 2):
168
- if i+1 < len(curr_outputs):
169
- image_batch = graph.node("ImageBatch", image1=curr_outputs[i], image2=curr_outputs[i+1])
170
- _outputs.append(image_batch.out(0))
171
- else:
172
- _outputs.append(curr_outputs[i])
173
- curr_outputs = _outputs
174
-
175
- return {
176
- "result": (curr_outputs[0],),
177
- "expand": graph.finalize(),
178
- }
179
-
180
- class ControlNetPreprocessorSelector:
181
- @classmethod
182
- def INPUT_TYPES(s):
183
- return {
184
- "required": {
185
- "preprocessor": (PREPROCESSOR_OPTIONS,),
186
- }
187
- }
188
-
189
- RETURN_TYPES = (PREPROCESSOR_OPTIONS,)
190
- RETURN_NAMES = ("preprocessor",)
191
- FUNCTION = "get_preprocessor"
192
-
193
- CATEGORY = "ControlNet Preprocessors"
194
-
195
- def get_preprocessor(self, preprocessor: str):
196
- return (preprocessor,)
197
-
198
-
199
- NODE_CLASS_MAPPINGS = {
200
- **AUX_NODE_MAPPINGS,
201
- "AIO_Preprocessor": AIO_Preprocessor,
202
- "ControlNetPreprocessorSelector": ControlNetPreprocessorSelector,
203
- **HIE_NODE_CLASS_MAPPINGS,
204
- "ExecuteAllControlNetPreprocessors": ExecuteAllControlNetPreprocessors,
205
- "ControlNetAuxSimpleAddText": ControlNetAuxSimpleAddText
206
- }
207
-
208
- NODE_DISPLAY_NAME_MAPPINGS = {
209
- **AUX_DISPLAY_NAME_MAPPINGS,
210
- "AIO_Preprocessor": "AIO Aux Preprocessor",
211
- "ControlNetPreprocessorSelector": "Preprocessor Selector",
212
- **HIE_NODE_DISPLAY_NAME_MAPPINGS,
213
- "ExecuteAllControlNetPreprocessors": "Execute All ControlNet Preprocessors"
214
- }
 
1
+ #Dummy file ensuring this package will be recognized
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
custom_albumentations/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2017 Buslaev Alexander, Alexander Parinov, Vladimir Iglovikov
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
custom_albumentations/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ __version__ = "1.3.1"
4
+
5
+ from .augmentations import *
6
+ from .core.composition import *
7
+ from .core.serialization import *
8
+ from .core.transforms_interface import *
9
+
10
+ try:
11
+ from .imgaug.transforms import * # type: ignore
12
+ except ImportError:
13
+ # imgaug is not installed by default, so we import stubs.
14
+ # Run `pip install -U albumentations[imgaug] if you need augmentations from imgaug.`
15
+ from .imgaug.stubs import * # type: ignore
custom_albumentations/augmentations/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common classes
2
+ from .blur.functional import *
3
+ from .blur.transforms import *
4
+ from .crops.functional import *
5
+ from .crops.transforms import *
6
+
7
+ # New transformations goes to individual files listed below
8
+ from .domain_adaptation import *
9
+ from .dropout.channel_dropout import *
10
+ from .dropout.coarse_dropout import *
11
+ from .dropout.cutout import *
12
+ from .dropout.functional import *
13
+ from .dropout.grid_dropout import *
14
+ from .dropout.mask_dropout import *
15
+ from .functional import *
16
+ from .geometric.functional import *
17
+ from .geometric.resize import *
18
+ from .geometric.rotate import *
19
+ from .geometric.transforms import *
20
+ from .transforms import *
21
+ from .utils import *
custom_albumentations/augmentations/blur/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functional import *
2
+ from .transforms import *
custom_albumentations/augmentations/blur/functional.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+ from math import ceil
3
+ from typing import Sequence, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from custom_albumentations.augmentations.functional import convolve
9
+ from custom_albumentations.augmentations.geometric.functional import scale
10
+ from custom_albumentations.augmentations.utils import (
11
+ _maybe_process_in_chunks,
12
+ clipped,
13
+ preserve_shape,
14
+ )
15
+
16
+ __all__ = ["blur", "median_blur", "gaussian_blur", "glass_blur"]
17
+
18
+
19
+ @preserve_shape
20
+ def blur(img: np.ndarray, ksize: int) -> np.ndarray:
21
+ blur_fn = _maybe_process_in_chunks(cv2.blur, ksize=(ksize, ksize))
22
+ return blur_fn(img)
23
+
24
+
25
+ @preserve_shape
26
+ def median_blur(img: np.ndarray, ksize: int) -> np.ndarray:
27
+ if img.dtype == np.float32 and ksize not in {3, 5}:
28
+ raise ValueError(f"Invalid ksize value {ksize}. For a float32 image the only valid ksize values are 3 and 5")
29
+
30
+ blur_fn = _maybe_process_in_chunks(cv2.medianBlur, ksize=ksize)
31
+ return blur_fn(img)
32
+
33
+
34
+ @preserve_shape
35
+ def gaussian_blur(img: np.ndarray, ksize: int, sigma: float = 0) -> np.ndarray:
36
+ # When sigma=0, it is computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`
37
+ blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma)
38
+ return blur_fn(img)
39
+
40
+
41
+ @preserve_shape
42
+ def glass_blur(
43
+ img: np.ndarray, sigma: float, max_delta: int, iterations: int, dxy: np.ndarray, mode: str
44
+ ) -> np.ndarray:
45
+ x = cv2.GaussianBlur(np.array(img), sigmaX=sigma, ksize=(0, 0))
46
+
47
+ if mode == "fast":
48
+ hs = np.arange(img.shape[0] - max_delta, max_delta, -1)
49
+ ws = np.arange(img.shape[1] - max_delta, max_delta, -1)
50
+ h: Union[int, np.ndarray] = np.tile(hs, ws.shape[0])
51
+ w: Union[int, np.ndarray] = np.repeat(ws, hs.shape[0])
52
+
53
+ for i in range(iterations):
54
+ dy = dxy[:, i, 0]
55
+ dx = dxy[:, i, 1]
56
+ x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w]
57
+
58
+ elif mode == "exact":
59
+ for ind, (i, h, w) in enumerate(
60
+ product(
61
+ range(iterations),
62
+ range(img.shape[0] - max_delta, max_delta, -1),
63
+ range(img.shape[1] - max_delta, max_delta, -1),
64
+ )
65
+ ):
66
+ ind = ind if ind < len(dxy) else ind % len(dxy)
67
+ dy = dxy[ind, i, 0]
68
+ dx = dxy[ind, i, 1]
69
+ x[h, w], x[h + dy, w + dx] = x[h + dy, w + dx], x[h, w]
70
+ else:
71
+ ValueError(f"Unsupported mode `{mode}`. Supports only `fast` and `exact`.")
72
+
73
+ return cv2.GaussianBlur(x, sigmaX=sigma, ksize=(0, 0))
74
+
75
+
76
+ def defocus(img: np.ndarray, radius: int, alias_blur: float) -> np.ndarray:
77
+ length = np.arange(-max(8, radius), max(8, radius) + 1)
78
+ ksize = 3 if radius <= 8 else 5
79
+
80
+ x, y = np.meshgrid(length, length)
81
+ aliased_disk = np.array((x**2 + y**2) <= radius**2, dtype=np.float32)
82
+ aliased_disk /= np.sum(aliased_disk)
83
+
84
+ kernel = gaussian_blur(aliased_disk, ksize, sigma=alias_blur)
85
+ return convolve(img, kernel=kernel)
86
+
87
+
88
+ def central_zoom(img: np.ndarray, zoom_factor: int) -> np.ndarray:
89
+ h, w = img.shape[:2]
90
+ h_ch, w_ch = ceil(h / zoom_factor), ceil(w / zoom_factor)
91
+ h_top, w_top = (h - h_ch) // 2, (w - w_ch) // 2
92
+
93
+ img = scale(img[h_top : h_top + h_ch, w_top : w_top + w_ch], zoom_factor, cv2.INTER_LINEAR)
94
+ h_trim_top, w_trim_top = (img.shape[0] - h) // 2, (img.shape[1] - w) // 2
95
+ return img[h_trim_top : h_trim_top + h, w_trim_top : w_trim_top + w]
96
+
97
+
98
+ @clipped
99
+ def zoom_blur(img: np.ndarray, zoom_factors: Union[np.ndarray, Sequence[int]]) -> np.ndarray:
100
+ out = np.zeros_like(img, dtype=np.float32)
101
+ for zoom_factor in zoom_factors:
102
+ out += central_zoom(img, zoom_factor)
103
+
104
+ img = ((img + out) / (len(zoom_factors) + 1)).astype(img.dtype)
105
+
106
+ return img
custom_albumentations/augmentations/blur/transforms.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ from typing import Any, Dict, List, Sequence, Tuple
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from custom_albumentations import random_utils
9
+ from custom_albumentations.augmentations import functional as FMain
10
+ from custom_albumentations.augmentations.blur import functional as F
11
+ from custom_albumentations.core.transforms_interface import (
12
+ ImageOnlyTransform,
13
+ ScaleFloatType,
14
+ ScaleIntType,
15
+ to_tuple,
16
+ )
17
+
18
+ __all__ = ["Blur", "MotionBlur", "GaussianBlur", "GlassBlur", "AdvancedBlur", "MedianBlur", "Defocus", "ZoomBlur"]
19
+
20
+
21
+ class Blur(ImageOnlyTransform):
22
+ """Blur the input image using a random-sized kernel.
23
+
24
+ Args:
25
+ blur_limit (int, (int, int)): maximum kernel size for blurring the input image.
26
+ Should be in range [3, inf). Default: (3, 7).
27
+ p (float): probability of applying the transform. Default: 0.5.
28
+
29
+ Targets:
30
+ image
31
+
32
+ Image types:
33
+ uint8, float32
34
+ """
35
+
36
+ def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5):
37
+ super().__init__(always_apply, p)
38
+ self.blur_limit = to_tuple(blur_limit, 3)
39
+
40
+ def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray:
41
+ return F.blur(img, ksize)
42
+
43
+ def get_params(self) -> Dict[str, Any]:
44
+ return {"ksize": int(random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2))))}
45
+
46
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
47
+ return ("blur_limit",)
48
+
49
+
50
+ class MotionBlur(Blur):
51
+ """Apply motion blur to the input image using a random-sized kernel.
52
+
53
+ Args:
54
+ blur_limit (int): maximum kernel size for blurring the input image.
55
+ Should be in range [3, inf). Default: (3, 7).
56
+ allow_shifted (bool): if set to true creates non shifted kernels only,
57
+ otherwise creates randomly shifted kernels. Default: True.
58
+ p (float): probability of applying the transform. Default: 0.5.
59
+
60
+ Targets:
61
+ image
62
+
63
+ Image types:
64
+ uint8, float32
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ blur_limit: ScaleIntType = 7,
70
+ allow_shifted: bool = True,
71
+ always_apply: bool = False,
72
+ p: float = 0.5,
73
+ ):
74
+ super().__init__(blur_limit=blur_limit, always_apply=always_apply, p=p)
75
+ self.allow_shifted = allow_shifted
76
+
77
+ if not allow_shifted and self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1:
78
+ raise ValueError(f"Blur limit must be odd when centered=True. Got: {self.blur_limit}")
79
+
80
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
81
+ return super().get_transform_init_args_names() + ("allow_shifted",)
82
+
83
+ def apply(self, img: np.ndarray, kernel: np.ndarray = None, **params) -> np.ndarray: # type: ignore
84
+ return FMain.convolve(img, kernel=kernel)
85
+
86
+ def get_params(self) -> Dict[str, Any]:
87
+ ksize = random.choice(list(range(self.blur_limit[0], self.blur_limit[1] + 1, 2)))
88
+ if ksize <= 2:
89
+ raise ValueError("ksize must be > 2. Got: {}".format(ksize))
90
+ kernel = np.zeros((ksize, ksize), dtype=np.uint8)
91
+ x1, x2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
92
+ if x1 == x2:
93
+ y1, y2 = random.sample(range(ksize), 2)
94
+ else:
95
+ y1, y2 = random.randint(0, ksize - 1), random.randint(0, ksize - 1)
96
+
97
+ def make_odd_val(v1, v2):
98
+ len_v = abs(v1 - v2) + 1
99
+ if len_v % 2 != 1:
100
+ if v2 > v1:
101
+ v2 -= 1
102
+ else:
103
+ v1 -= 1
104
+ return v1, v2
105
+
106
+ if not self.allow_shifted:
107
+ x1, x2 = make_odd_val(x1, x2)
108
+ y1, y2 = make_odd_val(y1, y2)
109
+
110
+ xc = (x1 + x2) / 2
111
+ yc = (y1 + y2) / 2
112
+
113
+ center = ksize / 2 - 0.5
114
+ dx = xc - center
115
+ dy = yc - center
116
+ x1, x2 = [int(i - dx) for i in [x1, x2]]
117
+ y1, y2 = [int(i - dy) for i in [y1, y2]]
118
+
119
+ cv2.line(kernel, (x1, y1), (x2, y2), 1, thickness=1)
120
+
121
+ # Normalize kernel
122
+ return {"kernel": kernel.astype(np.float32) / np.sum(kernel)}
123
+
124
+
125
+ class MedianBlur(Blur):
126
+ """Blur the input image using a median filter with a random aperture linear size.
127
+
128
+ Args:
129
+ blur_limit (int): maximum aperture linear size for blurring the input image.
130
+ Must be odd and in range [3, inf). Default: (3, 7).
131
+ p (float): probability of applying the transform. Default: 0.5.
132
+
133
+ Targets:
134
+ image
135
+
136
+ Image types:
137
+ uint8, float32
138
+ """
139
+
140
+ def __init__(self, blur_limit: ScaleIntType = 7, always_apply: bool = False, p: float = 0.5):
141
+ super().__init__(blur_limit, always_apply, p)
142
+
143
+ if self.blur_limit[0] % 2 != 1 or self.blur_limit[1] % 2 != 1:
144
+ raise ValueError("MedianBlur supports only odd blur limits.")
145
+
146
+ def apply(self, img: np.ndarray, ksize: int = 3, **params) -> np.ndarray:
147
+ return F.median_blur(img, ksize)
148
+
149
+
150
+ class GaussianBlur(ImageOnlyTransform):
151
+ """Blur the input image using a Gaussian filter with a random kernel size.
152
+
153
+ Args:
154
+ blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image.
155
+ Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
156
+ as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
157
+ If set single value `blur_limit` will be in range (0, blur_limit).
158
+ Default: (3, 7).
159
+ sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf).
160
+ If set single value `sigma_limit` will be in range (0, sigma_limit).
161
+ If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
162
+ p (float): probability of applying the transform. Default: 0.5.
163
+
164
+ Targets:
165
+ image
166
+
167
+ Image types:
168
+ uint8, float32
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ blur_limit: ScaleIntType = (3, 7),
174
+ sigma_limit: ScaleFloatType = 0,
175
+ always_apply: bool = False,
176
+ p: float = 0.5,
177
+ ):
178
+ super().__init__(always_apply, p)
179
+ self.blur_limit = to_tuple(blur_limit, 0)
180
+ self.sigma_limit = to_tuple(sigma_limit if sigma_limit is not None else 0, 0)
181
+
182
+ if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0:
183
+ self.blur_limit = 3, max(3, self.blur_limit[1])
184
+ warnings.warn(
185
+ "blur_limit and sigma_limit minimum value can not be both equal to 0. "
186
+ "blur_limit minimum value changed to 3."
187
+ )
188
+
189
+ if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
190
+ self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
191
+ ):
192
+ raise ValueError("GaussianBlur supports only odd blur limits.")
193
+
194
+ def apply(self, img: np.ndarray, ksize: int = 3, sigma: float = 0, **params) -> np.ndarray:
195
+ return F.gaussian_blur(img, ksize, sigma=sigma)
196
+
197
+ def get_params(self) -> Dict[str, float]:
198
+ ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1)
199
+ if ksize != 0 and ksize % 2 != 1:
200
+ ksize = (ksize + 1) % (self.blur_limit[1] + 1)
201
+
202
+ return {"ksize": ksize, "sigma": random.uniform(*self.sigma_limit)}
203
+
204
+ def get_transform_init_args_names(self) -> Tuple[str, str]:
205
+ return ("blur_limit", "sigma_limit")
206
+
207
+
208
+ class GlassBlur(Blur):
209
+ """Apply glass noise to the input image.
210
+
211
+ Args:
212
+ sigma (float): standard deviation for Gaussian kernel.
213
+ max_delta (int): max distance between pixels which are swapped.
214
+ iterations (int): number of repeats.
215
+ Should be in range [1, inf). Default: (2).
216
+ mode (str): mode of computation: fast or exact. Default: "fast".
217
+ p (float): probability of applying the transform. Default: 0.5.
218
+
219
+ Targets:
220
+ image
221
+
222
+ Image types:
223
+ uint8, float32
224
+
225
+ Reference:
226
+ | https://arxiv.org/abs/1903.12261
227
+ | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py
228
+ """
229
+
230
+ def __init__(
231
+ self,
232
+ sigma: float = 0.7,
233
+ max_delta: int = 4,
234
+ iterations: int = 2,
235
+ always_apply: bool = False,
236
+ mode: str = "fast",
237
+ p: float = 0.5,
238
+ ):
239
+ super().__init__(always_apply=always_apply, p=p)
240
+ if iterations < 1:
241
+ raise ValueError(f"Iterations should be more or equal to 1, but we got {iterations}")
242
+
243
+ if mode not in ["fast", "exact"]:
244
+ raise ValueError(f"Mode should be 'fast' or 'exact', but we got {mode}")
245
+
246
+ self.sigma = sigma
247
+ self.max_delta = max_delta
248
+ self.iterations = iterations
249
+ self.mode = mode
250
+
251
+ def apply(self, img: np.ndarray, dxy: np.ndarray = None, **params) -> np.ndarray: # type: ignore
252
+ assert dxy is not None
253
+ return F.glass_blur(img, self.sigma, self.max_delta, self.iterations, dxy, self.mode)
254
+
255
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, np.ndarray]:
256
+ img = params["image"]
257
+
258
+ # generate array containing all necessary values for transformations
259
+ width_pixels = img.shape[0] - self.max_delta * 2
260
+ height_pixels = img.shape[1] - self.max_delta * 2
261
+ total_pixels = width_pixels * height_pixels
262
+ dxy = random_utils.randint(-self.max_delta, self.max_delta, size=(total_pixels, self.iterations, 2))
263
+
264
+ return {"dxy": dxy}
265
+
266
+ def get_transform_init_args_names(self) -> Tuple[str, str, str]:
267
+ return ("sigma", "max_delta", "iterations")
268
+
269
+ @property
270
+ def targets_as_params(self) -> List[str]:
271
+ return ["image"]
272
+
273
+
274
+ class AdvancedBlur(ImageOnlyTransform):
275
+ """Blur the input image using a Generalized Normal filter with a randomly selected parameters.
276
+ This transform also adds multiplicative noise to generated kernel before convolution.
277
+
278
+ Args:
279
+ blur_limit: maximum Gaussian kernel size for blurring the input image.
280
+ Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
281
+ as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
282
+ If set single value `blur_limit` will be in range (0, blur_limit).
283
+ Default: (3, 7).
284
+ sigmaX_limit: Gaussian kernel standard deviation. Must be in range [0, inf).
285
+ If set single value `sigmaX_limit` will be in range (0, sigma_limit).
286
+ If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
287
+ sigmaY_limit: Same as `sigmaY_limit` for another dimension.
288
+ rotate_limit: Range from which a random angle used to rotate Gaussian kernel is picked.
289
+ If limit is a single int an angle is picked from (-rotate_limit, rotate_limit). Default: (-90, 90).
290
+ beta_limit: Distribution shape parameter, 1 is the normal distribution. Values below 1.0 make distribution
291
+ tails heavier than normal, values above 1.0 make it lighter than normal. Default: (0.5, 8.0).
292
+ noise_limit: Multiplicative factor that control strength of kernel noise. Must be positive and preferably
293
+ centered around 1.0. If set single value `noise_limit` will be in range (0, noise_limit).
294
+ Default: (0.75, 1.25).
295
+ p (float): probability of applying the transform. Default: 0.5.
296
+
297
+ Reference:
298
+ https://arxiv.org/abs/2107.10833
299
+
300
+ Targets:
301
+ image
302
+ Image types:
303
+ uint8, float32
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ blur_limit: ScaleIntType = (3, 7),
309
+ sigmaX_limit: ScaleFloatType = (0.2, 1.0),
310
+ sigmaY_limit: ScaleFloatType = (0.2, 1.0),
311
+ rotate_limit: ScaleIntType = 90,
312
+ beta_limit: ScaleFloatType = (0.5, 8.0),
313
+ noise_limit: ScaleFloatType = (0.9, 1.1),
314
+ always_apply: bool = False,
315
+ p: float = 0.5,
316
+ ):
317
+ super().__init__(always_apply, p)
318
+ self.blur_limit = to_tuple(blur_limit, 3)
319
+ self.sigmaX_limit = self.__check_values(to_tuple(sigmaX_limit, 0.0), name="sigmaX_limit")
320
+ self.sigmaY_limit = self.__check_values(to_tuple(sigmaY_limit, 0.0), name="sigmaY_limit")
321
+ self.rotate_limit = to_tuple(rotate_limit)
322
+ self.beta_limit = to_tuple(beta_limit, low=0.0)
323
+ self.noise_limit = self.__check_values(to_tuple(noise_limit, 0.0), name="noise_limit")
324
+
325
+ if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
326
+ self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
327
+ ):
328
+ raise ValueError("AdvancedBlur supports only odd blur limits.")
329
+
330
+ if self.sigmaX_limit[0] == 0 and self.sigmaY_limit[0] == 0:
331
+ raise ValueError("sigmaX_limit and sigmaY_limit minimum value can not be both equal to 0.")
332
+
333
+ if not (self.beta_limit[0] < 1.0 < self.beta_limit[1]):
334
+ raise ValueError("Beta limit is expected to include 1.0")
335
+
336
+ @staticmethod
337
+ def __check_values(
338
+ value: Sequence[float], name: str, bounds: Tuple[float, float] = (0, float("inf"))
339
+ ) -> Sequence[float]:
340
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
341
+ raise ValueError(f"{name} values should be between {bounds}")
342
+ return value
343
+
344
+ def apply(self, img: np.ndarray, kernel: np.ndarray = np.array(None), **params) -> np.ndarray:
345
+ return FMain.convolve(img, kernel=kernel)
346
+
347
+ def get_params(self) -> Dict[str, np.ndarray]:
348
+ ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2)
349
+ sigmaX = random.uniform(*self.sigmaX_limit)
350
+ sigmaY = random.uniform(*self.sigmaY_limit)
351
+ angle = np.deg2rad(random.uniform(*self.rotate_limit))
352
+
353
+ # Split into 2 cases to avoid selection of narrow kernels (beta > 1) too often.
354
+ if random.random() < 0.5:
355
+ beta = random.uniform(self.beta_limit[0], 1)
356
+ else:
357
+ beta = random.uniform(1, self.beta_limit[1])
358
+
359
+ noise_matrix = random_utils.uniform(self.noise_limit[0], self.noise_limit[1], size=[ksize, ksize])
360
+
361
+ # Generate mesh grid centered at zero.
362
+ ax = np.arange(-ksize // 2 + 1.0, ksize // 2 + 1.0)
363
+ # Shape (ksize, ksize, 2)
364
+ grid = np.stack(np.meshgrid(ax, ax), axis=-1)
365
+
366
+ # Calculate rotated sigma matrix
367
+ d_matrix = np.array([[sigmaX**2, 0], [0, sigmaY**2]])
368
+ u_matrix = np.array([[np.cos(angle), -np.sin(angle)], [np.sin(angle), np.cos(angle)]])
369
+ sigma_matrix = np.dot(u_matrix, np.dot(d_matrix, u_matrix.T))
370
+
371
+ inverse_sigma = np.linalg.inv(sigma_matrix)
372
+ # Described in "Parameter Estimation For Multivariate Generalized Gaussian Distributions"
373
+ kernel = np.exp(-0.5 * np.power(np.sum(np.dot(grid, inverse_sigma) * grid, 2), beta))
374
+ # Add noise
375
+ kernel = kernel * noise_matrix
376
+
377
+ # Normalize kernel
378
+ kernel = kernel.astype(np.float32) / np.sum(kernel)
379
+ return {"kernel": kernel}
380
+
381
+ def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str]:
382
+ return (
383
+ "blur_limit",
384
+ "sigmaX_limit",
385
+ "sigmaY_limit",
386
+ "rotate_limit",
387
+ "beta_limit",
388
+ "noise_limit",
389
+ )
390
+
391
+
392
+ class Defocus(ImageOnlyTransform):
393
+ """
394
+ Apply defocus transform. See https://arxiv.org/abs/1903.12261.
395
+
396
+ Args:
397
+ radius ((int, int) or int): range for radius of defocusing.
398
+ If limit is a single int, the range will be [1, limit]. Default: (3, 10).
399
+ alias_blur ((float, float) or float): range for alias_blur of defocusing (sigma of gaussian blur).
400
+ If limit is a single float, the range will be (0, limit). Default: (0.1, 0.5).
401
+ p (float): probability of applying the transform. Default: 0.5.
402
+
403
+ Targets:
404
+ image
405
+
406
+ Image types:
407
+ Any
408
+ """
409
+
410
+ def __init__(
411
+ self,
412
+ radius: ScaleIntType = (3, 10),
413
+ alias_blur: ScaleFloatType = (0.1, 0.5),
414
+ always_apply: bool = False,
415
+ p: float = 0.5,
416
+ ):
417
+ super().__init__(always_apply, p)
418
+ self.radius = to_tuple(radius, low=1)
419
+ self.alias_blur = to_tuple(alias_blur, low=0)
420
+
421
+ if self.radius[0] <= 0:
422
+ raise ValueError("Parameter radius must be positive")
423
+
424
+ if self.alias_blur[0] < 0:
425
+ raise ValueError("Parameter alias_blur must be non-negative")
426
+
427
+ def apply(self, img: np.ndarray, radius: int = 3, alias_blur: float = 0.5, **params) -> np.ndarray:
428
+ return F.defocus(img, radius, alias_blur)
429
+
430
+ def get_params(self) -> Dict[str, Any]:
431
+ return {
432
+ "radius": random_utils.randint(self.radius[0], self.radius[1] + 1),
433
+ "alias_blur": random_utils.uniform(self.alias_blur[0], self.alias_blur[1]),
434
+ }
435
+
436
+ def get_transform_init_args_names(self) -> Tuple[str, str]:
437
+ return ("radius", "alias_blur")
438
+
439
+
440
+ class ZoomBlur(ImageOnlyTransform):
441
+ """
442
+ Apply zoom blur transform. See https://arxiv.org/abs/1903.12261.
443
+
444
+ Args:
445
+ max_factor ((float, float) or float): range for max factor for blurring.
446
+ If max_factor is a single float, the range will be (1, limit). Default: (1, 1.31).
447
+ All max_factor values should be larger than 1.
448
+ step_factor ((float, float) or float): If single float will be used as step parameter for np.arange.
449
+ If tuple of float step_factor will be in range `[step_factor[0], step_factor[1])`. Default: (0.01, 0.03).
450
+ All step_factor values should be positive.
451
+ p (float): probability of applying the transform. Default: 0.5.
452
+
453
+ Targets:
454
+ image
455
+
456
+ Image types:
457
+ Any
458
+ """
459
+
460
+ def __init__(
461
+ self,
462
+ max_factor: ScaleFloatType = 1.31,
463
+ step_factor: ScaleFloatType = (0.01, 0.03),
464
+ always_apply: bool = False,
465
+ p: float = 0.5,
466
+ ):
467
+ super().__init__(always_apply, p)
468
+ self.max_factor = to_tuple(max_factor, low=1.0)
469
+ self.step_factor = to_tuple(step_factor, step_factor)
470
+
471
+ if self.max_factor[0] < 1:
472
+ raise ValueError("Max factor must be larger or equal 1")
473
+ if self.step_factor[0] <= 0:
474
+ raise ValueError("Step factor must be positive")
475
+
476
+ def apply(self, img: np.ndarray, zoom_factors: np.ndarray = np.array(None), **params) -> np.ndarray:
477
+ assert zoom_factors is not None
478
+ return F.zoom_blur(img, zoom_factors)
479
+
480
+ def get_params(self) -> Dict[str, Any]:
481
+ max_factor = random.uniform(self.max_factor[0], self.max_factor[1])
482
+ step_factor = random.uniform(self.step_factor[0], self.step_factor[1])
483
+ return {"zoom_factors": np.arange(1.0, max_factor, step_factor)}
484
+
485
+ def get_transform_init_args_names(self) -> Tuple[str, str]:
486
+ return ("max_factor", "step_factor")
custom_albumentations/augmentations/crops/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .functional import *
2
+ from .transforms import *
custom_albumentations/augmentations/crops/functional.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Sequence, Tuple
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from custom_albumentations.augmentations.utils import (
7
+ _maybe_process_in_chunks,
8
+ preserve_channel_dim,
9
+ )
10
+
11
+ from ...core.bbox_utils import denormalize_bbox, normalize_bbox
12
+ from ...core.transforms_interface import BoxInternalType, KeypointInternalType
13
+ from ..geometric import functional as FGeometric
14
+
15
+ __all__ = [
16
+ "get_random_crop_coords",
17
+ "random_crop",
18
+ "crop_bbox_by_coords",
19
+ "bbox_random_crop",
20
+ "crop_keypoint_by_coords",
21
+ "keypoint_random_crop",
22
+ "get_center_crop_coords",
23
+ "center_crop",
24
+ "bbox_center_crop",
25
+ "keypoint_center_crop",
26
+ "crop",
27
+ "bbox_crop",
28
+ "clamping_crop",
29
+ "crop_and_pad",
30
+ "crop_and_pad_bbox",
31
+ "crop_and_pad_keypoint",
32
+ ]
33
+
34
+
35
+ def get_random_crop_coords(height: int, width: int, crop_height: int, crop_width: int, h_start: float, w_start: float):
36
+ # h_start is [0, 1) and should map to [0, (height - crop_height)] (note inclusive)
37
+ # This is conceptually equivalent to mapping onto `range(0, (height - crop_height + 1))`
38
+ # See: https://github.com/albumentations-team/albumentations/pull/1080
39
+ y1 = int((height - crop_height + 1) * h_start)
40
+ y2 = y1 + crop_height
41
+ x1 = int((width - crop_width + 1) * w_start)
42
+ x2 = x1 + crop_width
43
+ return x1, y1, x2, y2
44
+
45
+
46
+ def random_crop(img: np.ndarray, crop_height: int, crop_width: int, h_start: float, w_start: float):
47
+ height, width = img.shape[:2]
48
+ if height < crop_height or width < crop_width:
49
+ raise ValueError(
50
+ "Requested crop size ({crop_height}, {crop_width}) is "
51
+ "larger than the image size ({height}, {width})".format(
52
+ crop_height=crop_height, crop_width=crop_width, height=height, width=width
53
+ )
54
+ )
55
+ x1, y1, x2, y2 = get_random_crop_coords(height, width, crop_height, crop_width, h_start, w_start)
56
+ img = img[y1:y2, x1:x2]
57
+ return img
58
+
59
+
60
+ def crop_bbox_by_coords(
61
+ bbox: BoxInternalType,
62
+ crop_coords: Tuple[int, int, int, int],
63
+ crop_height: int,
64
+ crop_width: int,
65
+ rows: int,
66
+ cols: int,
67
+ ):
68
+ """Crop a bounding box using the provided coordinates of bottom-left and top-right corners in pixels and the
69
+ required height and width of the crop.
70
+
71
+ Args:
72
+ bbox (tuple): A cropped box `(x_min, y_min, x_max, y_max)`.
73
+ crop_coords (tuple): Crop coordinates `(x1, y1, x2, y2)`.
74
+ crop_height (int):
75
+ crop_width (int):
76
+ rows (int): Image rows.
77
+ cols (int): Image cols.
78
+
79
+ Returns:
80
+ tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`.
81
+
82
+ """
83
+ bbox = denormalize_bbox(bbox, rows, cols)
84
+ x_min, y_min, x_max, y_max = bbox[:4]
85
+ x1, y1, _, _ = crop_coords
86
+ cropped_bbox = x_min - x1, y_min - y1, x_max - x1, y_max - y1
87
+ return normalize_bbox(cropped_bbox, crop_height, crop_width)
88
+
89
+
90
+ def bbox_random_crop(
91
+ bbox: BoxInternalType, crop_height: int, crop_width: int, h_start: float, w_start: float, rows: int, cols: int
92
+ ):
93
+ crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start)
94
+ return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
95
+
96
+
97
+ def crop_keypoint_by_coords(
98
+ keypoint: KeypointInternalType, crop_coords: Tuple[int, int, int, int]
99
+ ): # skipcq: PYL-W0613
100
+ """Crop a keypoint using the provided coordinates of bottom-left and top-right corners in pixels and the
101
+ required height and width of the crop.
102
+
103
+ Args:
104
+ keypoint (tuple): A keypoint `(x, y, angle, scale)`.
105
+ crop_coords (tuple): Crop box coords `(x1, x2, y1, y2)`.
106
+
107
+ Returns:
108
+ A keypoint `(x, y, angle, scale)`.
109
+
110
+ """
111
+ x, y, angle, scale = keypoint[:4]
112
+ x1, y1, _, _ = crop_coords
113
+ return x - x1, y - y1, angle, scale
114
+
115
+
116
+ def keypoint_random_crop(
117
+ keypoint: KeypointInternalType,
118
+ crop_height: int,
119
+ crop_width: int,
120
+ h_start: float,
121
+ w_start: float,
122
+ rows: int,
123
+ cols: int,
124
+ ):
125
+ """Keypoint random crop.
126
+
127
+ Args:
128
+ keypoint: (tuple): A keypoint `(x, y, angle, scale)`.
129
+ crop_height (int): Crop height.
130
+ crop_width (int): Crop width.
131
+ h_start (int): Crop height start.
132
+ w_start (int): Crop width start.
133
+ rows (int): Image height.
134
+ cols (int): Image width.
135
+
136
+ Returns:
137
+ A keypoint `(x, y, angle, scale)`.
138
+
139
+ """
140
+ crop_coords = get_random_crop_coords(rows, cols, crop_height, crop_width, h_start, w_start)
141
+ return crop_keypoint_by_coords(keypoint, crop_coords)
142
+
143
+
144
+ def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
145
+ y1 = (height - crop_height) // 2
146
+ y2 = y1 + crop_height
147
+ x1 = (width - crop_width) // 2
148
+ x2 = x1 + crop_width
149
+ return x1, y1, x2, y2
150
+
151
+
152
+ def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
153
+ height, width = img.shape[:2]
154
+ if height < crop_height or width < crop_width:
155
+ raise ValueError(
156
+ "Requested crop size ({crop_height}, {crop_width}) is "
157
+ "larger than the image size ({height}, {width})".format(
158
+ crop_height=crop_height, crop_width=crop_width, height=height, width=width
159
+ )
160
+ )
161
+ x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
162
+ img = img[y1:y2, x1:x2]
163
+ return img
164
+
165
+
166
+ def bbox_center_crop(bbox: BoxInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
167
+ crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
168
+ return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
169
+
170
+
171
+ def keypoint_center_crop(keypoint: KeypointInternalType, crop_height: int, crop_width: int, rows: int, cols: int):
172
+ """Keypoint center crop.
173
+
174
+ Args:
175
+ keypoint (tuple): A keypoint `(x, y, angle, scale)`.
176
+ crop_height (int): Crop height.
177
+ crop_width (int): Crop width.
178
+ rows (int): Image height.
179
+ cols (int): Image width.
180
+
181
+ Returns:
182
+ tuple: A keypoint `(x, y, angle, scale)`.
183
+
184
+ """
185
+ crop_coords = get_center_crop_coords(rows, cols, crop_height, crop_width)
186
+ return crop_keypoint_by_coords(keypoint, crop_coords)
187
+
188
+
189
+ def crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int):
190
+ height, width = img.shape[:2]
191
+ if x_max <= x_min or y_max <= y_min:
192
+ raise ValueError(
193
+ "We should have x_min < x_max and y_min < y_max. But we got"
194
+ " (x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max})".format(
195
+ x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max
196
+ )
197
+ )
198
+
199
+ if x_min < 0 or x_max > width or y_min < 0 or y_max > height:
200
+ raise ValueError(
201
+ "Values for crop should be non negative and equal or smaller than image sizes"
202
+ "(x_min = {x_min}, y_min = {y_min}, x_max = {x_max}, y_max = {y_max}, "
203
+ "height = {height}, width = {width})".format(
204
+ x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, height=height, width=width
205
+ )
206
+ )
207
+
208
+ return img[y_min:y_max, x_min:x_max]
209
+
210
+
211
+ def bbox_crop(bbox: BoxInternalType, x_min: int, y_min: int, x_max: int, y_max: int, rows: int, cols: int):
212
+ """Crop a bounding box.
213
+
214
+ Args:
215
+ bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`.
216
+ x_min (int):
217
+ y_min (int):
218
+ x_max (int):
219
+ y_max (int):
220
+ rows (int): Image rows.
221
+ cols (int): Image cols.
222
+
223
+ Returns:
224
+ tuple: A cropped bounding box `(x_min, y_min, x_max, y_max)`.
225
+
226
+ """
227
+ crop_coords = x_min, y_min, x_max, y_max
228
+ crop_height = y_max - y_min
229
+ crop_width = x_max - x_min
230
+ return crop_bbox_by_coords(bbox, crop_coords, crop_height, crop_width, rows, cols)
231
+
232
+
233
+ def clamping_crop(img: np.ndarray, x_min: int, y_min: int, x_max: int, y_max: int):
234
+ h, w = img.shape[:2]
235
+ if x_min < 0:
236
+ x_min = 0
237
+ if y_min < 0:
238
+ y_min = 0
239
+ if y_max >= h:
240
+ y_max = h - 1
241
+ if x_max >= w:
242
+ x_max = w - 1
243
+ return img[int(y_min) : int(y_max), int(x_min) : int(x_max)]
244
+
245
+
246
+ @preserve_channel_dim
247
+ def crop_and_pad(
248
+ img: np.ndarray,
249
+ crop_params: Optional[Sequence[int]],
250
+ pad_params: Optional[Sequence[int]],
251
+ pad_value: Optional[float],
252
+ rows: int,
253
+ cols: int,
254
+ interpolation: int,
255
+ pad_mode: int,
256
+ keep_size: bool,
257
+ ) -> np.ndarray:
258
+ if crop_params is not None and any(i != 0 for i in crop_params):
259
+ img = crop(img, *crop_params)
260
+ if pad_params is not None and any(i != 0 for i in pad_params):
261
+ img = FGeometric.pad_with_params(
262
+ img, pad_params[0], pad_params[1], pad_params[2], pad_params[3], border_mode=pad_mode, value=pad_value
263
+ )
264
+
265
+ if keep_size:
266
+ resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(cols, rows), interpolation=interpolation)
267
+ img = resize_fn(img)
268
+
269
+ return img
270
+
271
+
272
+ def crop_and_pad_bbox(
273
+ bbox: BoxInternalType,
274
+ crop_params: Optional[Sequence[int]],
275
+ pad_params: Optional[Sequence[int]],
276
+ rows,
277
+ cols,
278
+ result_rows,
279
+ result_cols,
280
+ ) -> BoxInternalType:
281
+ x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4]
282
+
283
+ if crop_params is not None:
284
+ crop_x, crop_y = crop_params[:2]
285
+ x1, y1, x2, y2 = x1 - crop_x, y1 - crop_y, x2 - crop_x, y2 - crop_y
286
+ if pad_params is not None:
287
+ top, bottom, left, right = pad_params
288
+ x1, y1, x2, y2 = x1 + left, y1 + top, x2 + left, y2 + top
289
+
290
+ return normalize_bbox((x1, y1, x2, y2), result_rows, result_cols)
291
+
292
+
293
+ def crop_and_pad_keypoint(
294
+ keypoint: KeypointInternalType,
295
+ crop_params: Optional[Sequence[int]],
296
+ pad_params: Optional[Sequence[int]],
297
+ rows: int,
298
+ cols: int,
299
+ result_rows: int,
300
+ result_cols: int,
301
+ keep_size: bool,
302
+ ) -> KeypointInternalType:
303
+ x, y, angle, scale = keypoint[:4]
304
+
305
+ if crop_params is not None:
306
+ crop_x1, crop_y1, crop_x2, crop_y2 = crop_params
307
+ x, y = x - crop_x1, y - crop_y1
308
+ if pad_params is not None:
309
+ top, bottom, left, right = pad_params
310
+ x, y = x + left, y + top
311
+
312
+ if keep_size and (result_cols != cols or result_rows != rows):
313
+ scale_x = cols / result_cols
314
+ scale_y = rows / result_rows
315
+ return FGeometric.keypoint_scale((x, y, angle, scale), scale_x, scale_y)
316
+
317
+ return x, y, angle, scale
custom_albumentations/augmentations/crops/transforms.py ADDED
@@ -0,0 +1,943 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from custom_albumentations.core.bbox_utils import union_of_bboxes
9
+
10
+ from ...core.transforms_interface import (
11
+ BoxInternalType,
12
+ DualTransform,
13
+ KeypointInternalType,
14
+ to_tuple,
15
+ )
16
+ from ..geometric import functional as FGeometric
17
+ from . import functional as F
18
+
19
+ __all__ = [
20
+ "RandomCrop",
21
+ "CenterCrop",
22
+ "Crop",
23
+ "CropNonEmptyMaskIfExists",
24
+ "RandomSizedCrop",
25
+ "RandomResizedCrop",
26
+ "RandomCropNearBBox",
27
+ "RandomSizedBBoxSafeCrop",
28
+ "CropAndPad",
29
+ "RandomCropFromBorders",
30
+ "BBoxSafeRandomCrop",
31
+ ]
32
+
33
+
34
+ class RandomCrop(DualTransform):
35
+ """Crop a random part of the input.
36
+
37
+ Args:
38
+ height (int): height of the crop.
39
+ width (int): width of the crop.
40
+ p (float): probability of applying the transform. Default: 1.
41
+
42
+ Targets:
43
+ image, mask, bboxes, keypoints
44
+
45
+ Image types:
46
+ uint8, float32
47
+ """
48
+
49
+ def __init__(self, height, width, always_apply=False, p=1.0):
50
+ super().__init__(always_apply, p)
51
+ self.height = height
52
+ self.width = width
53
+
54
+ def apply(self, img, h_start=0, w_start=0, **params):
55
+ return F.random_crop(img, self.height, self.width, h_start, w_start)
56
+
57
+ def get_params(self):
58
+ return {"h_start": random.random(), "w_start": random.random()}
59
+
60
+ def apply_to_bbox(self, bbox, **params):
61
+ return F.bbox_random_crop(bbox, self.height, self.width, **params)
62
+
63
+ def apply_to_keypoint(self, keypoint, **params):
64
+ return F.keypoint_random_crop(keypoint, self.height, self.width, **params)
65
+
66
+ def get_transform_init_args_names(self):
67
+ return ("height", "width")
68
+
69
+
70
+ class CenterCrop(DualTransform):
71
+ """Crop the central part of the input.
72
+
73
+ Args:
74
+ height (int): height of the crop.
75
+ width (int): width of the crop.
76
+ p (float): probability of applying the transform. Default: 1.
77
+
78
+ Targets:
79
+ image, mask, bboxes, keypoints
80
+
81
+ Image types:
82
+ uint8, float32
83
+
84
+ Note:
85
+ It is recommended to use uint8 images as input.
86
+ Otherwise the operation will require internal conversion
87
+ float32 -> uint8 -> float32 that causes worse performance.
88
+ """
89
+
90
+ def __init__(self, height, width, always_apply=False, p=1.0):
91
+ super(CenterCrop, self).__init__(always_apply, p)
92
+ self.height = height
93
+ self.width = width
94
+
95
+ def apply(self, img, **params):
96
+ return F.center_crop(img, self.height, self.width)
97
+
98
+ def apply_to_bbox(self, bbox, **params):
99
+ return F.bbox_center_crop(bbox, self.height, self.width, **params)
100
+
101
+ def apply_to_keypoint(self, keypoint, **params):
102
+ return F.keypoint_center_crop(keypoint, self.height, self.width, **params)
103
+
104
+ def get_transform_init_args_names(self):
105
+ return ("height", "width")
106
+
107
+
108
+ class Crop(DualTransform):
109
+ """Crop region from image.
110
+
111
+ Args:
112
+ x_min (int): Minimum upper left x coordinate.
113
+ y_min (int): Minimum upper left y coordinate.
114
+ x_max (int): Maximum lower right x coordinate.
115
+ y_max (int): Maximum lower right y coordinate.
116
+
117
+ Targets:
118
+ image, mask, bboxes, keypoints
119
+
120
+ Image types:
121
+ uint8, float32
122
+ """
123
+
124
+ def __init__(self, x_min=0, y_min=0, x_max=1024, y_max=1024, always_apply=False, p=1.0):
125
+ super(Crop, self).__init__(always_apply, p)
126
+ self.x_min = x_min
127
+ self.y_min = y_min
128
+ self.x_max = x_max
129
+ self.y_max = y_max
130
+
131
+ def apply(self, img, **params):
132
+ return F.crop(img, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max)
133
+
134
+ def apply_to_bbox(self, bbox, **params):
135
+ return F.bbox_crop(bbox, x_min=self.x_min, y_min=self.y_min, x_max=self.x_max, y_max=self.y_max, **params)
136
+
137
+ def apply_to_keypoint(self, keypoint, **params):
138
+ return F.crop_keypoint_by_coords(keypoint, crop_coords=(self.x_min, self.y_min, self.x_max, self.y_max))
139
+
140
+ def get_transform_init_args_names(self):
141
+ return ("x_min", "y_min", "x_max", "y_max")
142
+
143
+
144
+ class CropNonEmptyMaskIfExists(DualTransform):
145
+ """Crop area with mask if mask is non-empty, else make random crop.
146
+
147
+ Args:
148
+ height (int): vertical size of crop in pixels
149
+ width (int): horizontal size of crop in pixels
150
+ ignore_values (list of int): values to ignore in mask, `0` values are always ignored
151
+ (e.g. if background value is 5 set `ignore_values=[5]` to ignore)
152
+ ignore_channels (list of int): channels to ignore in mask
153
+ (e.g. if background is a first channel set `ignore_channels=[0]` to ignore)
154
+ p (float): probability of applying the transform. Default: 1.0.
155
+
156
+ Targets:
157
+ image, mask, bboxes, keypoints
158
+
159
+ Image types:
160
+ uint8, float32
161
+ """
162
+
163
+ def __init__(self, height, width, ignore_values=None, ignore_channels=None, always_apply=False, p=1.0):
164
+ super(CropNonEmptyMaskIfExists, self).__init__(always_apply, p)
165
+
166
+ if ignore_values is not None and not isinstance(ignore_values, list):
167
+ raise ValueError("Expected `ignore_values` of type `list`, got `{}`".format(type(ignore_values)))
168
+ if ignore_channels is not None and not isinstance(ignore_channels, list):
169
+ raise ValueError("Expected `ignore_channels` of type `list`, got `{}`".format(type(ignore_channels)))
170
+
171
+ self.height = height
172
+ self.width = width
173
+ self.ignore_values = ignore_values
174
+ self.ignore_channels = ignore_channels
175
+
176
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
177
+ return F.crop(img, x_min, y_min, x_max, y_max)
178
+
179
+ def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params):
180
+ return F.bbox_crop(
181
+ bbox, x_min=x_min, x_max=x_max, y_min=y_min, y_max=y_max, rows=params["rows"], cols=params["cols"]
182
+ )
183
+
184
+ def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params):
185
+ return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
186
+
187
+ def _preprocess_mask(self, mask):
188
+ mask_height, mask_width = mask.shape[:2]
189
+
190
+ if self.ignore_values is not None:
191
+ ignore_values_np = np.array(self.ignore_values)
192
+ mask = np.where(np.isin(mask, ignore_values_np), 0, mask)
193
+
194
+ if mask.ndim == 3 and self.ignore_channels is not None:
195
+ target_channels = np.array([ch for ch in range(mask.shape[-1]) if ch not in self.ignore_channels])
196
+ mask = np.take(mask, target_channels, axis=-1)
197
+
198
+ if self.height > mask_height or self.width > mask_width:
199
+ raise ValueError(
200
+ "Crop size ({},{}) is larger than image ({},{})".format(
201
+ self.height, self.width, mask_height, mask_width
202
+ )
203
+ )
204
+
205
+ return mask
206
+
207
+ def update_params(self, params, **kwargs):
208
+ super().update_params(params, **kwargs)
209
+ if "mask" in kwargs:
210
+ mask = self._preprocess_mask(kwargs["mask"])
211
+ elif "masks" in kwargs and len(kwargs["masks"]):
212
+ masks = kwargs["masks"]
213
+ mask = self._preprocess_mask(np.copy(masks[0])) # need copy as we perform in-place mod afterwards
214
+ for m in masks[1:]:
215
+ mask |= self._preprocess_mask(m)
216
+ else:
217
+ raise RuntimeError("Can not find mask for CropNonEmptyMaskIfExists")
218
+
219
+ mask_height, mask_width = mask.shape[:2]
220
+
221
+ if mask.any():
222
+ mask = mask.sum(axis=-1) if mask.ndim == 3 else mask
223
+ non_zero_yx = np.argwhere(mask)
224
+ y, x = random.choice(non_zero_yx)
225
+ x_min = x - random.randint(0, self.width - 1)
226
+ y_min = y - random.randint(0, self.height - 1)
227
+ x_min = np.clip(x_min, 0, mask_width - self.width)
228
+ y_min = np.clip(y_min, 0, mask_height - self.height)
229
+ else:
230
+ x_min = random.randint(0, mask_width - self.width)
231
+ y_min = random.randint(0, mask_height - self.height)
232
+
233
+ x_max = x_min + self.width
234
+ y_max = y_min + self.height
235
+
236
+ params.update({"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max})
237
+ return params
238
+
239
+ def get_transform_init_args_names(self):
240
+ return ("height", "width", "ignore_values", "ignore_channels")
241
+
242
+
243
+ class _BaseRandomSizedCrop(DualTransform):
244
+ # Base class for RandomSizedCrop and RandomResizedCrop
245
+
246
+ def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
247
+ super(_BaseRandomSizedCrop, self).__init__(always_apply, p)
248
+ self.height = height
249
+ self.width = width
250
+ self.interpolation = interpolation
251
+
252
+ def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
253
+ crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
254
+ return FGeometric.resize(crop, self.height, self.width, interpolation)
255
+
256
+ def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
257
+ return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)
258
+
259
+ def apply_to_keypoint(self, keypoint, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
260
+ keypoint = F.keypoint_random_crop(keypoint, crop_height, crop_width, h_start, w_start, rows, cols)
261
+ scale_x = self.width / crop_width
262
+ scale_y = self.height / crop_height
263
+ keypoint = FGeometric.keypoint_scale(keypoint, scale_x, scale_y)
264
+ return keypoint
265
+
266
+
267
+ class RandomSizedCrop(_BaseRandomSizedCrop):
268
+ """Crop a random part of the input and rescale it to some size.
269
+
270
+ Args:
271
+ min_max_height ((int, int)): crop size limits.
272
+ height (int): height after crop and resize.
273
+ width (int): width after crop and resize.
274
+ w2h_ratio (float): aspect ratio of crop.
275
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
276
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
277
+ Default: cv2.INTER_LINEAR.
278
+ p (float): probability of applying the transform. Default: 1.
279
+
280
+ Targets:
281
+ image, mask, bboxes, keypoints
282
+
283
+ Image types:
284
+ uint8, float32
285
+ """
286
+
287
+ def __init__(
288
+ self, min_max_height, height, width, w2h_ratio=1.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0
289
+ ):
290
+ super(RandomSizedCrop, self).__init__(
291
+ height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p
292
+ )
293
+ self.min_max_height = min_max_height
294
+ self.w2h_ratio = w2h_ratio
295
+
296
+ def get_params(self):
297
+ crop_height = random.randint(self.min_max_height[0], self.min_max_height[1])
298
+ return {
299
+ "h_start": random.random(),
300
+ "w_start": random.random(),
301
+ "crop_height": crop_height,
302
+ "crop_width": int(crop_height * self.w2h_ratio),
303
+ }
304
+
305
+ def get_transform_init_args_names(self):
306
+ return "min_max_height", "height", "width", "w2h_ratio", "interpolation"
307
+
308
+
309
+ class RandomResizedCrop(_BaseRandomSizedCrop):
310
+ """Torchvision's variant of crop a random part of the input and rescale it to some size.
311
+
312
+ Args:
313
+ height (int): height after crop and resize.
314
+ width (int): width after crop and resize.
315
+ scale ((float, float)): range of size of the origin size cropped
316
+ ratio ((float, float)): range of aspect ratio of the origin aspect ratio cropped
317
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
318
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
319
+ Default: cv2.INTER_LINEAR.
320
+ p (float): probability of applying the transform. Default: 1.
321
+
322
+ Targets:
323
+ image, mask, bboxes, keypoints
324
+
325
+ Image types:
326
+ uint8, float32
327
+ """
328
+
329
+ def __init__(
330
+ self,
331
+ height,
332
+ width,
333
+ scale=(0.08, 1.0),
334
+ ratio=(0.75, 1.3333333333333333),
335
+ interpolation=cv2.INTER_LINEAR,
336
+ always_apply=False,
337
+ p=1.0,
338
+ ):
339
+ super(RandomResizedCrop, self).__init__(
340
+ height=height, width=width, interpolation=interpolation, always_apply=always_apply, p=p
341
+ )
342
+ self.scale = scale
343
+ self.ratio = ratio
344
+
345
+ def get_params_dependent_on_targets(self, params):
346
+ img = params["image"]
347
+ area = img.shape[0] * img.shape[1]
348
+
349
+ for _attempt in range(10):
350
+ target_area = random.uniform(*self.scale) * area
351
+ log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1]))
352
+ aspect_ratio = math.exp(random.uniform(*log_ratio))
353
+
354
+ w = int(round(math.sqrt(target_area * aspect_ratio))) # skipcq: PTC-W0028
355
+ h = int(round(math.sqrt(target_area / aspect_ratio))) # skipcq: PTC-W0028
356
+
357
+ if 0 < w <= img.shape[1] and 0 < h <= img.shape[0]:
358
+ i = random.randint(0, img.shape[0] - h)
359
+ j = random.randint(0, img.shape[1] - w)
360
+ return {
361
+ "crop_height": h,
362
+ "crop_width": w,
363
+ "h_start": i * 1.0 / (img.shape[0] - h + 1e-10),
364
+ "w_start": j * 1.0 / (img.shape[1] - w + 1e-10),
365
+ }
366
+
367
+ # Fallback to central crop
368
+ in_ratio = img.shape[1] / img.shape[0]
369
+ if in_ratio < min(self.ratio):
370
+ w = img.shape[1]
371
+ h = int(round(w / min(self.ratio)))
372
+ elif in_ratio > max(self.ratio):
373
+ h = img.shape[0]
374
+ w = int(round(h * max(self.ratio)))
375
+ else: # whole image
376
+ w = img.shape[1]
377
+ h = img.shape[0]
378
+ i = (img.shape[0] - h) // 2
379
+ j = (img.shape[1] - w) // 2
380
+ return {
381
+ "crop_height": h,
382
+ "crop_width": w,
383
+ "h_start": i * 1.0 / (img.shape[0] - h + 1e-10),
384
+ "w_start": j * 1.0 / (img.shape[1] - w + 1e-10),
385
+ }
386
+
387
+ def get_params(self):
388
+ return {}
389
+
390
+ @property
391
+ def targets_as_params(self):
392
+ return ["image"]
393
+
394
+ def get_transform_init_args_names(self):
395
+ return "height", "width", "scale", "ratio", "interpolation"
396
+
397
+
398
+ class RandomCropNearBBox(DualTransform):
399
+ """Crop bbox from image with random shift by x,y coordinates
400
+
401
+ Args:
402
+ max_part_shift (float, (float, float)): Max shift in `height` and `width` dimensions relative
403
+ to `cropping_bbox` dimension.
404
+ If max_part_shift is a single float, the range will be (max_part_shift, max_part_shift).
405
+ Default (0.3, 0.3).
406
+ cropping_box_key (str): Additional target key for cropping box. Default `cropping_bbox`
407
+ p (float): probability of applying the transform. Default: 1.
408
+
409
+ Targets:
410
+ image, mask, bboxes, keypoints
411
+
412
+ Image types:
413
+ uint8, float32
414
+
415
+ Examples:
416
+ >>> aug = Compose([RandomCropNearBBox(max_part_shift=(0.1, 0.5), cropping_box_key='test_box')],
417
+ >>> bbox_params=BboxParams("pascal_voc"))
418
+ >>> result = aug(image=image, bboxes=bboxes, test_box=[0, 5, 10, 20])
419
+
420
+ """
421
+
422
+ def __init__(
423
+ self,
424
+ max_part_shift: Union[float, Tuple[float, float]] = (0.3, 0.3),
425
+ cropping_box_key: str = "cropping_bbox",
426
+ always_apply: bool = False,
427
+ p: float = 1.0,
428
+ ):
429
+ super(RandomCropNearBBox, self).__init__(always_apply, p)
430
+ self.max_part_shift = to_tuple(max_part_shift, low=max_part_shift)
431
+ self.cropping_bbox_key = cropping_box_key
432
+
433
+ if min(self.max_part_shift) < 0 or max(self.max_part_shift) > 1:
434
+ raise ValueError("Invalid max_part_shift. Got: {}".format(max_part_shift))
435
+
436
+ def apply(
437
+ self, img: np.ndarray, x_min: int = 0, x_max: int = 0, y_min: int = 0, y_max: int = 0, **params
438
+ ) -> np.ndarray:
439
+ return F.clamping_crop(img, x_min, y_min, x_max, y_max)
440
+
441
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, int]:
442
+ bbox = params[self.cropping_bbox_key]
443
+ h_max_shift = round((bbox[3] - bbox[1]) * self.max_part_shift[0])
444
+ w_max_shift = round((bbox[2] - bbox[0]) * self.max_part_shift[1])
445
+
446
+ x_min = bbox[0] - random.randint(-w_max_shift, w_max_shift)
447
+ x_max = bbox[2] + random.randint(-w_max_shift, w_max_shift)
448
+
449
+ y_min = bbox[1] - random.randint(-h_max_shift, h_max_shift)
450
+ y_max = bbox[3] + random.randint(-h_max_shift, h_max_shift)
451
+
452
+ x_min = max(0, x_min)
453
+ y_min = max(0, y_min)
454
+
455
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
456
+
457
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
458
+ return F.bbox_crop(bbox, **params)
459
+
460
+ def apply_to_keypoint(
461
+ self,
462
+ keypoint: Tuple[float, float, float, float],
463
+ x_min: int = 0,
464
+ x_max: int = 0,
465
+ y_min: int = 0,
466
+ y_max: int = 0,
467
+ **params
468
+ ) -> Tuple[float, float, float, float]:
469
+ return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
470
+
471
+ @property
472
+ def targets_as_params(self) -> List[str]:
473
+ return [self.cropping_bbox_key]
474
+
475
+ def get_transform_init_args_names(self) -> Tuple[str]:
476
+ return ("max_part_shift",)
477
+
478
+
479
+ class BBoxSafeRandomCrop(DualTransform):
480
+ """Crop a random part of the input without loss of bboxes.
481
+ Args:
482
+ erosion_rate (float): erosion rate applied on input image height before crop.
483
+ p (float): probability of applying the transform. Default: 1.
484
+ Targets:
485
+ image, mask, bboxes
486
+ Image types:
487
+ uint8, float32
488
+ """
489
+
490
+ def __init__(self, erosion_rate=0.0, always_apply=False, p=1.0):
491
+ super(BBoxSafeRandomCrop, self).__init__(always_apply, p)
492
+ self.erosion_rate = erosion_rate
493
+
494
+ def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, **params):
495
+ return F.random_crop(img, crop_height, crop_width, h_start, w_start)
496
+
497
+ def get_params_dependent_on_targets(self, params):
498
+ img_h, img_w = params["image"].shape[:2]
499
+ if len(params["bboxes"]) == 0: # less likely, this class is for use with bboxes.
500
+ erosive_h = int(img_h * (1.0 - self.erosion_rate))
501
+ crop_height = img_h if erosive_h >= img_h else random.randint(erosive_h, img_h)
502
+ return {
503
+ "h_start": random.random(),
504
+ "w_start": random.random(),
505
+ "crop_height": crop_height,
506
+ "crop_width": int(crop_height * img_w / img_h),
507
+ }
508
+ # get union of all bboxes
509
+ x, y, x2, y2 = union_of_bboxes(
510
+ width=img_w, height=img_h, bboxes=params["bboxes"], erosion_rate=self.erosion_rate
511
+ )
512
+ # find bigger region
513
+ bx, by = x * random.random(), y * random.random()
514
+ bx2, by2 = x2 + (1 - x2) * random.random(), y2 + (1 - y2) * random.random()
515
+ bw, bh = bx2 - bx, by2 - by
516
+ crop_height = img_h if bh >= 1.0 else int(img_h * bh)
517
+ crop_width = img_w if bw >= 1.0 else int(img_w * bw)
518
+ h_start = np.clip(0.0 if bh >= 1.0 else by / (1.0 - bh), 0.0, 1.0)
519
+ w_start = np.clip(0.0 if bw >= 1.0 else bx / (1.0 - bw), 0.0, 1.0)
520
+ return {"h_start": h_start, "w_start": w_start, "crop_height": crop_height, "crop_width": crop_width}
521
+
522
+ def apply_to_bbox(self, bbox, crop_height=0, crop_width=0, h_start=0, w_start=0, rows=0, cols=0, **params):
523
+ return F.bbox_random_crop(bbox, crop_height, crop_width, h_start, w_start, rows, cols)
524
+
525
+ @property
526
+ def targets_as_params(self):
527
+ return ["image", "bboxes"]
528
+
529
+ def get_transform_init_args_names(self):
530
+ return ("erosion_rate",)
531
+
532
+
533
+ class RandomSizedBBoxSafeCrop(BBoxSafeRandomCrop):
534
+ """Crop a random part of the input and rescale it to some size without loss of bboxes.
535
+ Args:
536
+ height (int): height after crop and resize.
537
+ width (int): width after crop and resize.
538
+ erosion_rate (float): erosion rate applied on input image height before crop.
539
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
540
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
541
+ Default: cv2.INTER_LINEAR.
542
+ p (float): probability of applying the transform. Default: 1.
543
+ Targets:
544
+ image, mask, bboxes
545
+ Image types:
546
+ uint8, float32
547
+ """
548
+
549
+ def __init__(self, height, width, erosion_rate=0.0, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1.0):
550
+ super(RandomSizedBBoxSafeCrop, self).__init__(erosion_rate, always_apply, p)
551
+ self.height = height
552
+ self.width = width
553
+ self.interpolation = interpolation
554
+
555
+ def apply(self, img, crop_height=0, crop_width=0, h_start=0, w_start=0, interpolation=cv2.INTER_LINEAR, **params):
556
+ crop = F.random_crop(img, crop_height, crop_width, h_start, w_start)
557
+ return FGeometric.resize(crop, self.height, self.width, interpolation)
558
+
559
+ def get_transform_init_args_names(self):
560
+ return super().get_transform_init_args_names() + ("height", "width", "interpolation")
561
+
562
+
563
+ class CropAndPad(DualTransform):
564
+ """Crop and pad images by pixel amounts or fractions of image sizes.
565
+ Cropping removes pixels at the sides (i.e. extracts a subimage from a given full image).
566
+ Padding adds pixels to the sides (e.g. black pixels).
567
+ This transformation will never crop images below a height or width of ``1``.
568
+
569
+ Note:
570
+ This transformation automatically resizes images back to their original size. To deactivate this, add the
571
+ parameter ``keep_size=False``.
572
+
573
+ Args:
574
+ px (int or tuple):
575
+ The number of pixels to crop (negative values) or pad (positive values)
576
+ on each side of the image. Either this or the parameter `percent` may
577
+ be set, not both at the same time.
578
+ * If ``None``, then pixel-based cropping/padding will not be used.
579
+ * If ``int``, then that exact number of pixels will always be cropped/padded.
580
+ * If a ``tuple`` of two ``int`` s with values ``a`` and ``b``,
581
+ then each side will be cropped/padded by a random amount sampled
582
+ uniformly per image and side from the interval ``[a, b]``. If
583
+ however `sample_independently` is set to ``False``, only one
584
+ value will be sampled per image and used for all sides.
585
+ * If a ``tuple`` of four entries, then the entries represent top,
586
+ right, bottom, left. Each entry may be a single ``int`` (always
587
+ crop/pad by exactly that value), a ``tuple`` of two ``int`` s
588
+ ``a`` and ``b`` (crop/pad by an amount within ``[a, b]``), a
589
+ ``list`` of ``int`` s (crop/pad by a random value that is
590
+ contained in the ``list``).
591
+ percent (float or tuple):
592
+ The number of pixels to crop (negative values) or pad (positive values)
593
+ on each side of the image given as a *fraction* of the image
594
+ height/width. E.g. if this is set to ``-0.1``, the transformation will
595
+ always crop away ``10%`` of the image's height at both the top and the
596
+ bottom (both ``10%`` each), as well as ``10%`` of the width at the
597
+ right and left.
598
+ Expected value range is ``(-1.0, inf)``.
599
+ Either this or the parameter `px` may be set, not both
600
+ at the same time.
601
+ * If ``None``, then fraction-based cropping/padding will not be
602
+ used.
603
+ * If ``float``, then that fraction will always be cropped/padded.
604
+ * If a ``tuple`` of two ``float`` s with values ``a`` and ``b``,
605
+ then each side will be cropped/padded by a random fraction
606
+ sampled uniformly per image and side from the interval
607
+ ``[a, b]``. If however `sample_independently` is set to
608
+ ``False``, only one value will be sampled per image and used for
609
+ all sides.
610
+ * If a ``tuple`` of four entries, then the entries represent top,
611
+ right, bottom, left. Each entry may be a single ``float``
612
+ (always crop/pad by exactly that percent value), a ``tuple`` of
613
+ two ``float`` s ``a`` and ``b`` (crop/pad by a fraction from
614
+ ``[a, b]``), a ``list`` of ``float`` s (crop/pad by a random
615
+ value that is contained in the list).
616
+ pad_mode (int): OpenCV border mode.
617
+ pad_cval (number, Sequence[number]):
618
+ The constant value to use if the pad mode is ``BORDER_CONSTANT``.
619
+ * If ``number``, then that value will be used.
620
+ * If a ``tuple`` of two ``number`` s and at least one of them is
621
+ a ``float``, then a random number will be uniformly sampled per
622
+ image from the continuous interval ``[a, b]`` and used as the
623
+ value. If both ``number`` s are ``int`` s, the interval is
624
+ discrete.
625
+ * If a ``list`` of ``number``, then a random value will be chosen
626
+ from the elements of the ``list`` and used as the value.
627
+ pad_cval_mask (number, Sequence[number]): Same as pad_cval but only for masks.
628
+ keep_size (bool):
629
+ After cropping and padding, the result image will usually have a
630
+ different height/width compared to the original input image. If this
631
+ parameter is set to ``True``, then the cropped/padded image will be
632
+ resized to the input image's size, i.e. the output shape is always identical to the input shape.
633
+ sample_independently (bool):
634
+ If ``False`` *and* the values for `px`/`percent` result in exactly
635
+ *one* probability distribution for all image sides, only one single
636
+ value will be sampled from that probability distribution and used for
637
+ all sides. I.e. the crop/pad amount then is the same for all sides.
638
+ If ``True``, four values will be sampled independently, one per side.
639
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
640
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
641
+ Default: cv2.INTER_LINEAR.
642
+
643
+ Targets:
644
+ image, mask, bboxes, keypoints
645
+
646
+ Image types:
647
+ any
648
+ """
649
+
650
+ def __init__(
651
+ self,
652
+ px: Optional[Union[int, Sequence[float], Sequence[Tuple]]] = None,
653
+ percent: Optional[Union[float, Sequence[float], Sequence[Tuple]]] = None,
654
+ pad_mode: int = cv2.BORDER_CONSTANT,
655
+ pad_cval: Union[float, Sequence[float]] = 0,
656
+ pad_cval_mask: Union[float, Sequence[float]] = 0,
657
+ keep_size: bool = True,
658
+ sample_independently: bool = True,
659
+ interpolation: int = cv2.INTER_LINEAR,
660
+ always_apply: bool = False,
661
+ p: float = 1.0,
662
+ ):
663
+ super().__init__(always_apply, p)
664
+
665
+ if px is None and percent is None:
666
+ raise ValueError("px and percent are empty!")
667
+ if px is not None and percent is not None:
668
+ raise ValueError("Only px or percent may be set!")
669
+
670
+ self.px = px
671
+ self.percent = percent
672
+
673
+ self.pad_mode = pad_mode
674
+ self.pad_cval = pad_cval
675
+ self.pad_cval_mask = pad_cval_mask
676
+
677
+ self.keep_size = keep_size
678
+ self.sample_independently = sample_independently
679
+
680
+ self.interpolation = interpolation
681
+
682
+ def apply(
683
+ self,
684
+ img: np.ndarray,
685
+ crop_params: Sequence[int] = (),
686
+ pad_params: Sequence[int] = (),
687
+ pad_value: Union[int, float] = 0,
688
+ rows: int = 0,
689
+ cols: int = 0,
690
+ interpolation: int = cv2.INTER_LINEAR,
691
+ **params
692
+ ) -> np.ndarray:
693
+ return F.crop_and_pad(
694
+ img, crop_params, pad_params, pad_value, rows, cols, interpolation, self.pad_mode, self.keep_size
695
+ )
696
+
697
+ def apply_to_mask(
698
+ self,
699
+ img: np.ndarray,
700
+ crop_params: Optional[Sequence[int]] = None,
701
+ pad_params: Optional[Sequence[int]] = None,
702
+ pad_value_mask: Optional[float] = None,
703
+ rows: int = 0,
704
+ cols: int = 0,
705
+ interpolation: int = cv2.INTER_NEAREST,
706
+ **params
707
+ ) -> np.ndarray:
708
+ return F.crop_and_pad(
709
+ img, crop_params, pad_params, pad_value_mask, rows, cols, interpolation, self.pad_mode, self.keep_size
710
+ )
711
+
712
+ def apply_to_bbox(
713
+ self,
714
+ bbox: BoxInternalType,
715
+ crop_params: Optional[Sequence[int]] = None,
716
+ pad_params: Optional[Sequence[int]] = None,
717
+ rows: int = 0,
718
+ cols: int = 0,
719
+ result_rows: int = 0,
720
+ result_cols: int = 0,
721
+ **params
722
+ ) -> BoxInternalType:
723
+ return F.crop_and_pad_bbox(bbox, crop_params, pad_params, rows, cols, result_rows, result_cols)
724
+
725
+ def apply_to_keypoint(
726
+ self,
727
+ keypoint: KeypointInternalType,
728
+ crop_params: Optional[Sequence[int]] = None,
729
+ pad_params: Optional[Sequence[int]] = None,
730
+ rows: int = 0,
731
+ cols: int = 0,
732
+ result_rows: int = 0,
733
+ result_cols: int = 0,
734
+ **params
735
+ ) -> KeypointInternalType:
736
+ return F.crop_and_pad_keypoint(
737
+ keypoint, crop_params, pad_params, rows, cols, result_rows, result_cols, self.keep_size
738
+ )
739
+
740
+ @property
741
+ def targets_as_params(self) -> List[str]:
742
+ return ["image"]
743
+
744
+ @staticmethod
745
+ def __prevent_zero(val1: int, val2: int, max_val: int) -> Tuple[int, int]:
746
+ regain = abs(max_val) + 1
747
+ regain1 = regain // 2
748
+ regain2 = regain // 2
749
+ if regain1 + regain2 < regain:
750
+ regain1 += 1
751
+
752
+ if regain1 > val1:
753
+ diff = regain1 - val1
754
+ regain1 = val1
755
+ regain2 += diff
756
+ elif regain2 > val2:
757
+ diff = regain2 - val2
758
+ regain2 = val2
759
+ regain1 += diff
760
+
761
+ val1 = val1 - regain1
762
+ val2 = val2 - regain2
763
+
764
+ return val1, val2
765
+
766
+ @staticmethod
767
+ def _prevent_zero(crop_params: List[int], height: int, width: int) -> Sequence[int]:
768
+ top, right, bottom, left = crop_params
769
+
770
+ remaining_height = height - (top + bottom)
771
+ remaining_width = width - (left + right)
772
+
773
+ if remaining_height < 1:
774
+ top, bottom = CropAndPad.__prevent_zero(top, bottom, height)
775
+ if remaining_width < 1:
776
+ left, right = CropAndPad.__prevent_zero(left, right, width)
777
+
778
+ return [max(top, 0), max(right, 0), max(bottom, 0), max(left, 0)]
779
+
780
+ def get_params_dependent_on_targets(self, params) -> dict:
781
+ height, width = params["image"].shape[:2]
782
+
783
+ if self.px is not None:
784
+ params = self._get_px_params()
785
+ else:
786
+ params = self._get_percent_params()
787
+ params[0] = int(params[0] * height)
788
+ params[1] = int(params[1] * width)
789
+ params[2] = int(params[2] * height)
790
+ params[3] = int(params[3] * width)
791
+
792
+ pad_params = [max(i, 0) for i in params]
793
+
794
+ crop_params = self._prevent_zero([-min(i, 0) for i in params], height, width)
795
+
796
+ top, right, bottom, left = crop_params
797
+ crop_params = [left, top, width - right, height - bottom]
798
+ result_rows = crop_params[3] - crop_params[1]
799
+ result_cols = crop_params[2] - crop_params[0]
800
+ if result_cols == width and result_rows == height:
801
+ crop_params = []
802
+
803
+ top, right, bottom, left = pad_params
804
+ pad_params = [top, bottom, left, right]
805
+ if any(pad_params):
806
+ result_rows += top + bottom
807
+ result_cols += left + right
808
+ else:
809
+ pad_params = []
810
+
811
+ return {
812
+ "crop_params": crop_params or None,
813
+ "pad_params": pad_params or None,
814
+ "pad_value": None if pad_params is None else self._get_pad_value(self.pad_cval),
815
+ "pad_value_mask": None if pad_params is None else self._get_pad_value(self.pad_cval_mask),
816
+ "result_rows": result_rows,
817
+ "result_cols": result_cols,
818
+ }
819
+
820
+ def _get_px_params(self) -> List[int]:
821
+ if self.px is None:
822
+ raise ValueError("px is not set")
823
+
824
+ if isinstance(self.px, int):
825
+ params = [self.px] * 4
826
+ elif len(self.px) == 2:
827
+ if self.sample_independently:
828
+ params = [random.randrange(*self.px) for _ in range(4)]
829
+ else:
830
+ px = random.randrange(*self.px)
831
+ params = [px] * 4
832
+ else:
833
+ params = [i if isinstance(i, int) else random.randrange(*i) for i in self.px] # type: ignore
834
+
835
+ return params # [top, right, bottom, left]
836
+
837
+ def _get_percent_params(self) -> List[float]:
838
+ if self.percent is None:
839
+ raise ValueError("percent is not set")
840
+
841
+ if isinstance(self.percent, float):
842
+ params = [self.percent] * 4
843
+ elif len(self.percent) == 2:
844
+ if self.sample_independently:
845
+ params = [random.uniform(*self.percent) for _ in range(4)]
846
+ else:
847
+ px = random.uniform(*self.percent)
848
+ params = [px] * 4
849
+ else:
850
+ params = [i if isinstance(i, (int, float)) else random.uniform(*i) for i in self.percent]
851
+
852
+ return params # params = [top, right, bottom, left]
853
+
854
+ @staticmethod
855
+ def _get_pad_value(pad_value: Union[float, Sequence[float]]) -> Union[int, float]:
856
+ if isinstance(pad_value, (int, float)):
857
+ return pad_value
858
+
859
+ if len(pad_value) == 2:
860
+ a, b = pad_value
861
+ if isinstance(a, int) and isinstance(b, int):
862
+ return random.randint(a, b)
863
+
864
+ return random.uniform(a, b)
865
+
866
+ return random.choice(pad_value)
867
+
868
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
869
+ return (
870
+ "px",
871
+ "percent",
872
+ "pad_mode",
873
+ "pad_cval",
874
+ "pad_cval_mask",
875
+ "keep_size",
876
+ "sample_independently",
877
+ "interpolation",
878
+ )
879
+
880
+
881
+ class RandomCropFromBorders(DualTransform):
882
+ """Crop bbox from image randomly cut parts from borders without resize at the end
883
+
884
+ Args:
885
+ crop_left (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
886
+ from left side in range [0, crop_left * width)
887
+ crop_right (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
888
+ from right side in range [(1 - crop_right) * width, width)
889
+ crop_top (float): singlefloat value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
890
+ from top side in range [0, crop_top * height)
891
+ crop_bottom (float): single float value in (0.0, 1.0) range. Default 0.1. Image will be randomly cut
892
+ from bottom side in range [(1 - crop_bottom) * height, height)
893
+ p (float): probability of applying the transform. Default: 1.
894
+
895
+ Targets:
896
+ image, mask, bboxes, keypoints
897
+
898
+ Image types:
899
+ uint8, float32
900
+ """
901
+
902
+ def __init__(
903
+ self,
904
+ crop_left=0.1,
905
+ crop_right=0.1,
906
+ crop_top=0.1,
907
+ crop_bottom=0.1,
908
+ always_apply=False,
909
+ p=1.0,
910
+ ):
911
+ super(RandomCropFromBorders, self).__init__(always_apply, p)
912
+ self.crop_left = crop_left
913
+ self.crop_right = crop_right
914
+ self.crop_top = crop_top
915
+ self.crop_bottom = crop_bottom
916
+
917
+ def get_params_dependent_on_targets(self, params):
918
+ img = params["image"]
919
+ x_min = random.randint(0, int(self.crop_left * img.shape[1]))
920
+ x_max = random.randint(max(x_min + 1, int((1 - self.crop_right) * img.shape[1])), img.shape[1])
921
+ y_min = random.randint(0, int(self.crop_top * img.shape[0]))
922
+ y_max = random.randint(max(y_min + 1, int((1 - self.crop_bottom) * img.shape[0])), img.shape[0])
923
+ return {"x_min": x_min, "x_max": x_max, "y_min": y_min, "y_max": y_max}
924
+
925
+ def apply(self, img, x_min=0, x_max=0, y_min=0, y_max=0, **params):
926
+ return F.clamping_crop(img, x_min, y_min, x_max, y_max)
927
+
928
+ def apply_to_mask(self, mask, x_min=0, x_max=0, y_min=0, y_max=0, **params):
929
+ return F.clamping_crop(mask, x_min, y_min, x_max, y_max)
930
+
931
+ def apply_to_bbox(self, bbox, x_min=0, x_max=0, y_min=0, y_max=0, **params):
932
+ rows, cols = params["rows"], params["cols"]
933
+ return F.bbox_crop(bbox, x_min, y_min, x_max, y_max, rows, cols)
934
+
935
+ def apply_to_keypoint(self, keypoint, x_min=0, x_max=0, y_min=0, y_max=0, **params):
936
+ return F.crop_keypoint_by_coords(keypoint, crop_coords=(x_min, y_min, x_max, y_max))
937
+
938
+ @property
939
+ def targets_as_params(self):
940
+ return ["image"]
941
+
942
+ def get_transform_init_args_names(self):
943
+ return "crop_left", "crop_right", "crop_top", "crop_bottom"
custom_albumentations/augmentations/domain_adaptation.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Callable, Literal, Sequence, Tuple
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from custom_qudida import DomainAdapter
7
+ from skimage.exposure import match_histograms
8
+ from sklearn.decomposition import PCA
9
+ from sklearn.preprocessing import MinMaxScaler, StandardScaler
10
+
11
+ from custom_albumentations.augmentations.utils import (
12
+ clipped,
13
+ get_opencv_dtype_from_numpy,
14
+ is_grayscale_image,
15
+ is_multispectral_image,
16
+ preserve_shape,
17
+ read_rgb_image,
18
+ )
19
+
20
+ from ..core.transforms_interface import ImageOnlyTransform, ScaleFloatType, to_tuple
21
+
22
+ __all__ = [
23
+ "HistogramMatching",
24
+ "FDA",
25
+ "PixelDistributionAdaptation",
26
+ "fourier_domain_adaptation",
27
+ "apply_histogram",
28
+ "adapt_pixel_distribution",
29
+ ]
30
+
31
+
32
+ @clipped
33
+ @preserve_shape
34
+ def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray:
35
+ """
36
+ Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA
37
+
38
+ Args:
39
+ img: source image
40
+ target_img: target image for domain adaptation
41
+ beta: coefficient from source paper
42
+
43
+ Returns:
44
+ transformed image
45
+
46
+ """
47
+
48
+ img = np.squeeze(img)
49
+ target_img = np.squeeze(target_img)
50
+
51
+ if target_img.shape != img.shape:
52
+ raise ValueError(
53
+ "The source and target images must have the same shape,"
54
+ " but got {} and {} respectively.".format(img.shape, target_img.shape)
55
+ )
56
+
57
+ # get fft of both source and target
58
+ fft_src = np.fft.fft2(img.astype(np.float32), axes=(0, 1))
59
+ fft_trg = np.fft.fft2(target_img.astype(np.float32), axes=(0, 1))
60
+
61
+ # extract amplitude and phase of both fft-s
62
+ amplitude_src, phase_src = np.abs(fft_src), np.angle(fft_src)
63
+ amplitude_trg = np.abs(fft_trg)
64
+
65
+ # mutate the amplitude part of source with target
66
+ amplitude_src = np.fft.fftshift(amplitude_src, axes=(0, 1))
67
+ amplitude_trg = np.fft.fftshift(amplitude_trg, axes=(0, 1))
68
+ height, width = amplitude_src.shape[:2]
69
+ border = np.floor(min(height, width) * beta).astype(int)
70
+ center_y, center_x = np.floor([height / 2.0, width / 2.0]).astype(int)
71
+
72
+ y1, y2 = center_y - border, center_y + border + 1
73
+ x1, x2 = center_x - border, center_x + border + 1
74
+
75
+ amplitude_src[y1:y2, x1:x2] = amplitude_trg[y1:y2, x1:x2]
76
+ amplitude_src = np.fft.ifftshift(amplitude_src, axes=(0, 1))
77
+
78
+ # get mutated image
79
+ src_image_transformed = np.fft.ifft2(amplitude_src * np.exp(1j * phase_src), axes=(0, 1))
80
+ src_image_transformed = np.real(src_image_transformed)
81
+
82
+ return src_image_transformed
83
+
84
+
85
+ @preserve_shape
86
+ def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray:
87
+ if img.dtype != reference_image.dtype:
88
+ raise RuntimeError(
89
+ f"Dtype of image and reference image must be the same. Got {img.dtype} and {reference_image.dtype}"
90
+ )
91
+ if img.shape[:2] != reference_image.shape[:2]:
92
+ reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0]))
93
+
94
+ img, reference_image = np.squeeze(img), np.squeeze(reference_image)
95
+
96
+ try:
97
+ matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == 3 else None)
98
+ except TypeError:
99
+ matched = match_histograms(img, reference_image, multichannel=True) # case for scikit-image<0.19.1
100
+ img = cv2.addWeighted(
101
+ matched,
102
+ blend_ratio,
103
+ img,
104
+ 1 - blend_ratio,
105
+ 0,
106
+ dtype=get_opencv_dtype_from_numpy(img.dtype),
107
+ )
108
+ return img
109
+
110
+
111
+ @preserve_shape
112
+ def adapt_pixel_distribution(
113
+ img: np.ndarray, ref: np.ndarray, transform_type: str = "pca", weight: float = 0.5
114
+ ) -> np.ndarray:
115
+ initial_type = img.dtype
116
+ transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]()
117
+ adapter = DomainAdapter(transformer=transformer, ref_img=ref)
118
+ result = adapter(img).astype("float32")
119
+ blended = (img.astype("float32") * (1 - weight) + result * weight).astype(initial_type)
120
+ return blended
121
+
122
+
123
+ class HistogramMatching(ImageOnlyTransform):
124
+ """
125
+ Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches
126
+ the histogram of the reference image. If the images have multiple channels, the matching is done independently
127
+ for each channel, as long as the number of channels is equal in the input image and the reference.
128
+
129
+ Histogram matching can be used as a lightweight normalisation for image processing,
130
+ such as feature matching, especially in circumstances where the images have been taken from different
131
+ sources or in different conditions (i.e. lighting).
132
+
133
+ See:
134
+ https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_histogram_matching.html
135
+
136
+ Args:
137
+ reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
138
+ it expects a sequence of paths to images.
139
+ blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original
140
+ with random blend factor for increased diversity of generated images.
141
+ read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
142
+ and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
143
+ p (float): probability of applying the transform. Default: 1.0.
144
+
145
+ Targets:
146
+ image
147
+
148
+ Image types:
149
+ uint8, uint16, float32
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ reference_images: Sequence[Any],
155
+ blend_ratio: Tuple[float, float] = (0.5, 1.0),
156
+ read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
157
+ always_apply: bool = False,
158
+ p: float = 0.5,
159
+ ):
160
+ super().__init__(always_apply=always_apply, p=p)
161
+ self.reference_images = reference_images
162
+ self.read_fn = read_fn
163
+ self.blend_ratio = blend_ratio
164
+
165
+ def apply(self, img, reference_image=None, blend_ratio=0.5, **params):
166
+ return apply_histogram(img, reference_image, blend_ratio)
167
+
168
+ def get_params(self):
169
+ return {
170
+ "reference_image": self.read_fn(random.choice(self.reference_images)),
171
+ "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]),
172
+ }
173
+
174
+ def get_transform_init_args_names(self):
175
+ return ("reference_images", "blend_ratio", "read_fn")
176
+
177
+ def _to_dict(self):
178
+ raise NotImplementedError("HistogramMatching can not be serialized.")
179
+
180
+
181
+ class FDA(ImageOnlyTransform):
182
+ """
183
+ Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA
184
+ Simple "style transfer".
185
+
186
+ Args:
187
+ reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
188
+ it expects a sequence of paths to images.
189
+ beta_limit (float or tuple of float): coefficient beta from paper. Recommended less 0.3.
190
+ read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
191
+ and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
192
+
193
+ Targets:
194
+ image
195
+
196
+ Image types:
197
+ uint8, float32
198
+
199
+ Reference:
200
+ https://github.com/YanchaoYang/FDA
201
+ https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf
202
+
203
+ Example:
204
+ >>> import numpy as np
205
+ >>> import custom_albumentations as albumentations as A
206
+ >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
207
+ >>> target_image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8)
208
+ >>> aug = A.Compose([A.FDA([target_image], p=1, read_fn=lambda x: x)])
209
+ >>> result = aug(image=image)
210
+
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ reference_images: Sequence[Any],
216
+ beta_limit: ScaleFloatType = 0.1,
217
+ read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
218
+ always_apply: bool = False,
219
+ p: float = 0.5,
220
+ ):
221
+ super(FDA, self).__init__(always_apply=always_apply, p=p)
222
+ self.reference_images = reference_images
223
+ self.read_fn = read_fn
224
+ self.beta_limit = to_tuple(beta_limit, low=0)
225
+
226
+ def apply(self, img, target_image=None, beta=0.1, **params):
227
+ return fourier_domain_adaptation(img=img, target_img=target_image, beta=beta)
228
+
229
+ def get_params_dependent_on_targets(self, params):
230
+ img = params["image"]
231
+ target_img = self.read_fn(random.choice(self.reference_images))
232
+ target_img = cv2.resize(target_img, dsize=(img.shape[1], img.shape[0]))
233
+
234
+ return {"target_image": target_img}
235
+
236
+ def get_params(self):
237
+ return {"beta": random.uniform(self.beta_limit[0], self.beta_limit[1])}
238
+
239
+ @property
240
+ def targets_as_params(self):
241
+ return ["image"]
242
+
243
+ def get_transform_init_args_names(self):
244
+ return ("reference_images", "beta_limit", "read_fn")
245
+
246
+ def _to_dict(self):
247
+ raise NotImplementedError("FDA can not be serialized.")
248
+
249
+
250
+ class PixelDistributionAdaptation(ImageOnlyTransform):
251
+ """
252
+ Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler
253
+ or MinMaxScaler) on both original and reference image, transforms original image with transform trained on this
254
+ image and then performs inverse transformation using transform fitted on reference image.
255
+
256
+ Args:
257
+ reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default,
258
+ it expects a sequence of paths to images.
259
+ blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original
260
+ with random blend factor for increased diversity of generated images.
261
+ read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images`
262
+ and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array.
263
+ transform_type (str): type of transform; "pca", "standard", "minmax" are allowed.
264
+ p (float): probability of applying the transform. Default: 1.0.
265
+
266
+ Targets:
267
+ image
268
+
269
+ Image types:
270
+ uint8, float32
271
+
272
+ See also: https://github.com/arsenyinfo/qudida
273
+ """
274
+
275
+ def __init__(
276
+ self,
277
+ reference_images: Sequence[Any],
278
+ blend_ratio: Tuple[float, float] = (0.25, 1.0),
279
+ read_fn: Callable[[Any], np.ndarray] = read_rgb_image,
280
+ transform_type: Literal["pca", "standard", "minmax"] = "pca",
281
+ always_apply: bool = False,
282
+ p: float = 0.5,
283
+ ):
284
+ super().__init__(always_apply=always_apply, p=p)
285
+ self.reference_images = reference_images
286
+ self.read_fn = read_fn
287
+ self.blend_ratio = blend_ratio
288
+ expected_transformers = ("pca", "standard", "minmax")
289
+ if transform_type not in expected_transformers:
290
+ raise ValueError(f"Got unexpected transform_type {transform_type}. Expected one of {expected_transformers}")
291
+ self.transform_type = transform_type
292
+
293
+ @staticmethod
294
+ def _validate_shape(img: np.ndarray):
295
+ if is_grayscale_image(img) or is_multispectral_image(img):
296
+ raise ValueError(
297
+ f"Unexpected image shape: expected 3 dimensions, got {len(img.shape)}."
298
+ f"Is it a grayscale or multispectral image? It's not supported for now."
299
+ )
300
+
301
+ def ensure_uint8(self, img: np.ndarray) -> Tuple[np.ndarray, bool]:
302
+ if img.dtype == np.float32:
303
+ if img.min() < 0 or img.max() > 1:
304
+ message = (
305
+ "PixelDistributionAdaptation uses uint8 under the hood, so float32 should be converted,"
306
+ "Can not do it automatically when the image is out of [0..1] range."
307
+ )
308
+ raise TypeError(message)
309
+ return (img * 255).astype("uint8"), True
310
+ return img, False
311
+
312
+ def apply(self, img, reference_image, blend_ratio, **params):
313
+ self._validate_shape(img)
314
+ reference_image, _ = self.ensure_uint8(reference_image)
315
+ img, needs_reconvert = self.ensure_uint8(img)
316
+
317
+ adapted = adapt_pixel_distribution(
318
+ img=img,
319
+ ref=reference_image,
320
+ weight=blend_ratio,
321
+ transform_type=self.transform_type,
322
+ )
323
+ if needs_reconvert:
324
+ adapted = adapted.astype("float32") * (1 / 255)
325
+ return adapted
326
+
327
+ def get_params(self):
328
+ return {
329
+ "reference_image": self.read_fn(random.choice(self.reference_images)),
330
+ "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]),
331
+ }
332
+
333
+ def get_transform_init_args_names(self):
334
+ return ("reference_images", "blend_ratio", "read_fn", "transform_type")
335
+
336
+ def _to_dict(self):
337
+ raise NotImplementedError("PixelDistributionAdaptation can not be serialized.")
custom_albumentations/augmentations/dropout/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .channel_dropout import *
2
+ from .coarse_dropout import *
3
+ from .cutout import *
4
+ from .grid_dropout import *
5
+ from .mask_dropout import *
custom_albumentations/augmentations/dropout/channel_dropout.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Mapping, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from custom_albumentations.core.transforms_interface import ImageOnlyTransform
7
+
8
+ from .functional import channel_dropout
9
+
10
+ __all__ = ["ChannelDropout"]
11
+
12
+
13
+ class ChannelDropout(ImageOnlyTransform):
14
+ """Randomly Drop Channels in the input Image.
15
+
16
+ Args:
17
+ channel_drop_range (int, int): range from which we choose the number of channels to drop.
18
+ fill_value (int, float): pixel value for the dropped channel.
19
+ p (float): probability of applying the transform. Default: 0.5.
20
+
21
+ Targets:
22
+ image
23
+
24
+ Image types:
25
+ uint8, uint16, unit32, float32
26
+ """
27
+
28
+ def __init__(
29
+ self,
30
+ channel_drop_range: Tuple[int, int] = (1, 1),
31
+ fill_value: Union[int, float] = 0,
32
+ always_apply: bool = False,
33
+ p: float = 0.5,
34
+ ):
35
+ super(ChannelDropout, self).__init__(always_apply, p)
36
+
37
+ self.channel_drop_range = channel_drop_range
38
+
39
+ self.min_channels = channel_drop_range[0]
40
+ self.max_channels = channel_drop_range[1]
41
+
42
+ if not 1 <= self.min_channels <= self.max_channels:
43
+ raise ValueError("Invalid channel_drop_range. Got: {}".format(channel_drop_range))
44
+
45
+ self.fill_value = fill_value
46
+
47
+ def apply(self, img: np.ndarray, channels_to_drop: Tuple[int, ...] = (0,), **params) -> np.ndarray:
48
+ return channel_dropout(img, channels_to_drop, self.fill_value)
49
+
50
+ def get_params_dependent_on_targets(self, params: Mapping[str, Any]):
51
+ img = params["image"]
52
+
53
+ num_channels = img.shape[-1]
54
+
55
+ if len(img.shape) == 2 or num_channels == 1:
56
+ raise NotImplementedError("Images has one channel. ChannelDropout is not defined.")
57
+
58
+ if self.max_channels >= num_channels:
59
+ raise ValueError("Can not drop all channels in ChannelDropout.")
60
+
61
+ num_drop_channels = random.randint(self.min_channels, self.max_channels)
62
+
63
+ channels_to_drop = random.sample(range(num_channels), k=num_drop_channels)
64
+
65
+ return {"channels_to_drop": channels_to_drop}
66
+
67
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
68
+ return "channel_drop_range", "fill_value"
69
+
70
+ @property
71
+ def targets_as_params(self):
72
+ return ["image"]
custom_albumentations/augmentations/dropout/coarse_dropout.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Iterable, List, Optional, Sequence, Tuple, Union
3
+
4
+ import numpy as np
5
+
6
+ from ...core.transforms_interface import DualTransform, KeypointType
7
+ from .functional import cutout
8
+
9
+ __all__ = ["CoarseDropout"]
10
+
11
+
12
+ class CoarseDropout(DualTransform):
13
+ """CoarseDropout of the rectangular regions in the image.
14
+
15
+ Args:
16
+ max_holes (int): Maximum number of regions to zero out.
17
+ max_height (int, float): Maximum height of the hole.
18
+ If float, it is calculated as a fraction of the image height.
19
+ max_width (int, float): Maximum width of the hole.
20
+ If float, it is calculated as a fraction of the image width.
21
+ min_holes (int): Minimum number of regions to zero out. If `None`,
22
+ `min_holes` is be set to `max_holes`. Default: `None`.
23
+ min_height (int, float): Minimum height of the hole. Default: None. If `None`,
24
+ `min_height` is set to `max_height`. Default: `None`.
25
+ If float, it is calculated as a fraction of the image height.
26
+ min_width (int, float): Minimum width of the hole. If `None`, `min_height` is
27
+ set to `max_width`. Default: `None`.
28
+ If float, it is calculated as a fraction of the image width.
29
+
30
+ fill_value (int, float, list of int, list of float): value for dropped pixels.
31
+ mask_fill_value (int, float, list of int, list of float): fill value for dropped pixels
32
+ in mask. If `None` - mask is not affected. Default: `None`.
33
+
34
+ Targets:
35
+ image, mask, keypoints
36
+
37
+ Image types:
38
+ uint8, float32
39
+
40
+ Reference:
41
+ | https://arxiv.org/abs/1708.04552
42
+ | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
43
+ | https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ max_holes: int = 8,
49
+ max_height: int = 8,
50
+ max_width: int = 8,
51
+ min_holes: Optional[int] = None,
52
+ min_height: Optional[int] = None,
53
+ min_width: Optional[int] = None,
54
+ fill_value: int = 0,
55
+ mask_fill_value: Optional[int] = None,
56
+ always_apply: bool = False,
57
+ p: float = 0.5,
58
+ ):
59
+ super(CoarseDropout, self).__init__(always_apply, p)
60
+ self.max_holes = max_holes
61
+ self.max_height = max_height
62
+ self.max_width = max_width
63
+ self.min_holes = min_holes if min_holes is not None else max_holes
64
+ self.min_height = min_height if min_height is not None else max_height
65
+ self.min_width = min_width if min_width is not None else max_width
66
+ self.fill_value = fill_value
67
+ self.mask_fill_value = mask_fill_value
68
+ if not 0 < self.min_holes <= self.max_holes:
69
+ raise ValueError("Invalid combination of min_holes and max_holes. Got: {}".format([min_holes, max_holes]))
70
+
71
+ self.check_range(self.max_height)
72
+ self.check_range(self.min_height)
73
+ self.check_range(self.max_width)
74
+ self.check_range(self.min_width)
75
+
76
+ if not 0 < self.min_height <= self.max_height:
77
+ raise ValueError(
78
+ "Invalid combination of min_height and max_height. Got: {}".format([min_height, max_height])
79
+ )
80
+ if not 0 < self.min_width <= self.max_width:
81
+ raise ValueError("Invalid combination of min_width and max_width. Got: {}".format([min_width, max_width]))
82
+
83
+ def check_range(self, dimension):
84
+ if isinstance(dimension, float) and not 0 <= dimension < 1.0:
85
+ raise ValueError(
86
+ "Invalid value {}. If using floats, the value should be in the range [0.0, 1.0)".format(dimension)
87
+ )
88
+
89
+ def apply(
90
+ self,
91
+ img: np.ndarray,
92
+ fill_value: Union[int, float] = 0,
93
+ holes: Iterable[Tuple[int, int, int, int]] = (),
94
+ **params
95
+ ) -> np.ndarray:
96
+ return cutout(img, holes, fill_value)
97
+
98
+ def apply_to_mask(
99
+ self,
100
+ img: np.ndarray,
101
+ mask_fill_value: Union[int, float] = 0,
102
+ holes: Iterable[Tuple[int, int, int, int]] = (),
103
+ **params
104
+ ) -> np.ndarray:
105
+ if mask_fill_value is None:
106
+ return img
107
+ return cutout(img, holes, mask_fill_value)
108
+
109
+ def get_params_dependent_on_targets(self, params):
110
+ img = params["image"]
111
+ height, width = img.shape[:2]
112
+
113
+ holes = []
114
+ for _n in range(random.randint(self.min_holes, self.max_holes)):
115
+ if all(
116
+ [
117
+ isinstance(self.min_height, int),
118
+ isinstance(self.min_width, int),
119
+ isinstance(self.max_height, int),
120
+ isinstance(self.max_width, int),
121
+ ]
122
+ ):
123
+ hole_height = random.randint(self.min_height, self.max_height)
124
+ hole_width = random.randint(self.min_width, self.max_width)
125
+ elif all(
126
+ [
127
+ isinstance(self.min_height, float),
128
+ isinstance(self.min_width, float),
129
+ isinstance(self.max_height, float),
130
+ isinstance(self.max_width, float),
131
+ ]
132
+ ):
133
+ hole_height = int(height * random.uniform(self.min_height, self.max_height))
134
+ hole_width = int(width * random.uniform(self.min_width, self.max_width))
135
+ else:
136
+ raise ValueError(
137
+ "Min width, max width, \
138
+ min height and max height \
139
+ should all either be ints or floats. \
140
+ Got: {} respectively".format(
141
+ [
142
+ type(self.min_width),
143
+ type(self.max_width),
144
+ type(self.min_height),
145
+ type(self.max_height),
146
+ ]
147
+ )
148
+ )
149
+
150
+ y1 = random.randint(0, height - hole_height)
151
+ x1 = random.randint(0, width - hole_width)
152
+ y2 = y1 + hole_height
153
+ x2 = x1 + hole_width
154
+ holes.append((x1, y1, x2, y2))
155
+
156
+ return {"holes": holes}
157
+
158
+ @property
159
+ def targets_as_params(self):
160
+ return ["image"]
161
+
162
+ def _keypoint_in_hole(self, keypoint: KeypointType, hole: Tuple[int, int, int, int]) -> bool:
163
+ x1, y1, x2, y2 = hole
164
+ x, y = keypoint[:2]
165
+ return x1 <= x < x2 and y1 <= y < y2
166
+
167
+ def apply_to_keypoints(
168
+ self, keypoints: Sequence[KeypointType], holes: Iterable[Tuple[int, int, int, int]] = (), **params
169
+ ) -> List[KeypointType]:
170
+ result = set(keypoints)
171
+ for hole in holes:
172
+ for kp in keypoints:
173
+ if self._keypoint_in_hole(kp, hole):
174
+ result.discard(kp)
175
+ return list(result)
176
+
177
+ def get_transform_init_args_names(self):
178
+ return (
179
+ "max_holes",
180
+ "max_height",
181
+ "max_width",
182
+ "min_holes",
183
+ "min_height",
184
+ "min_width",
185
+ "fill_value",
186
+ "mask_fill_value",
187
+ )
custom_albumentations/augmentations/dropout/cutout.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+ from typing import Any, Dict, Tuple, Union
4
+
5
+ import numpy as np
6
+
7
+ from custom_albumentations.core.transforms_interface import ImageOnlyTransform
8
+
9
+ from .functional import cutout
10
+
11
+ __all__ = ["Cutout"]
12
+
13
+
14
+ class Cutout(ImageOnlyTransform):
15
+ """CoarseDropout of the square regions in the image.
16
+
17
+ Args:
18
+ num_holes (int): number of regions to zero out
19
+ max_h_size (int): maximum height of the hole
20
+ max_w_size (int): maximum width of the hole
21
+ fill_value (int, float, list of int, list of float): value for dropped pixels.
22
+
23
+ Targets:
24
+ image
25
+
26
+ Image types:
27
+ uint8, float32
28
+
29
+ Reference:
30
+ | https://arxiv.org/abs/1708.04552
31
+ | https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
32
+ | https://github.com/aleju/imgaug/blob/master/imgaug/augmenters/arithmetic.py
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ num_holes: int = 8,
38
+ max_h_size: int = 8,
39
+ max_w_size: int = 8,
40
+ fill_value: Union[int, float] = 0,
41
+ always_apply: bool = False,
42
+ p: float = 0.5,
43
+ ):
44
+ super(Cutout, self).__init__(always_apply, p)
45
+ self.num_holes = num_holes
46
+ self.max_h_size = max_h_size
47
+ self.max_w_size = max_w_size
48
+ self.fill_value = fill_value
49
+ warnings.warn(
50
+ f"{self.__class__.__name__} has been deprecated. Please use CoarseDropout",
51
+ FutureWarning,
52
+ )
53
+
54
+ def apply(self, img: np.ndarray, fill_value: Union[int, float] = 0, holes=(), **params):
55
+ return cutout(img, holes, fill_value)
56
+
57
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
58
+ img = params["image"]
59
+ height, width = img.shape[:2]
60
+
61
+ holes = []
62
+ for _n in range(self.num_holes):
63
+ y = random.randint(0, height)
64
+ x = random.randint(0, width)
65
+
66
+ y1 = np.clip(y - self.max_h_size // 2, 0, height)
67
+ y2 = np.clip(y1 + self.max_h_size, 0, height)
68
+ x1 = np.clip(x - self.max_w_size // 2, 0, width)
69
+ x2 = np.clip(x1 + self.max_w_size, 0, width)
70
+ holes.append((x1, y1, x2, y2))
71
+
72
+ return {"holes": holes}
73
+
74
+ @property
75
+ def targets_as_params(self):
76
+ return ["image"]
77
+
78
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
79
+ return ("num_holes", "max_h_size", "max_w_size")
custom_albumentations/augmentations/dropout/functional.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterable, List, Tuple, Union
2
+
3
+ import numpy as np
4
+
5
+ from custom_albumentations.augmentations.utils import preserve_shape
6
+
7
+ __all__ = ["cutout", "channel_dropout"]
8
+
9
+
10
+ @preserve_shape
11
+ def channel_dropout(
12
+ img: np.ndarray, channels_to_drop: Union[int, Tuple[int, ...], np.ndarray], fill_value: Union[int, float] = 0
13
+ ) -> np.ndarray:
14
+ if len(img.shape) == 2 or img.shape[2] == 1:
15
+ raise NotImplementedError("Only one channel. ChannelDropout is not defined.")
16
+
17
+ img = img.copy()
18
+ img[..., channels_to_drop] = fill_value
19
+ return img
20
+
21
+
22
+ def cutout(
23
+ img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]], fill_value: Union[int, float] = 0
24
+ ) -> np.ndarray:
25
+ # Make a copy of the input image since we don't want to modify it directly
26
+ img = img.copy()
27
+ for x1, y1, x2, y2 in holes:
28
+ img[y1:y2, x1:x2] = fill_value
29
+ return img
custom_albumentations/augmentations/dropout/grid_dropout.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Iterable, Optional, Tuple
3
+
4
+ import numpy as np
5
+
6
+ from ...core.transforms_interface import DualTransform
7
+ from . import functional as F
8
+
9
+ __all__ = ["GridDropout"]
10
+
11
+
12
+ class GridDropout(DualTransform):
13
+ """GridDropout, drops out rectangular regions of an image and the corresponding mask in a grid fashion.
14
+
15
+ Args:
16
+ ratio (float): the ratio of the mask holes to the unit_size (same for horizontal and vertical directions).
17
+ Must be between 0 and 1. Default: 0.5.
18
+ unit_size_min (int): minimum size of the grid unit. Must be between 2 and the image shorter edge.
19
+ If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
20
+ unit_size_max (int): maximum size of the grid unit. Must be between 2 and the image shorter edge.
21
+ If 'None', holes_number_x and holes_number_y are used to setup the grid. Default: `None`.
22
+ holes_number_x (int): the number of grid units in x direction. Must be between 1 and image width//2.
23
+ If 'None', grid unit width is set as image_width//10. Default: `None`.
24
+ holes_number_y (int): the number of grid units in y direction. Must be between 1 and image height//2.
25
+ If `None`, grid unit height is set equal to the grid unit width or image height, whatever is smaller.
26
+ shift_x (int): offsets of the grid start in x direction from (0,0) coordinate.
27
+ Clipped between 0 and grid unit_width - hole_width. Default: 0.
28
+ shift_y (int): offsets of the grid start in y direction from (0,0) coordinate.
29
+ Clipped between 0 and grid unit height - hole_height. Default: 0.
30
+ random_offset (boolean): weather to offset the grid randomly between 0 and grid unit size - hole size
31
+ If 'True', entered shift_x, shift_y are ignored and set randomly. Default: `False`.
32
+ fill_value (int): value for the dropped pixels. Default = 0
33
+ mask_fill_value (int): value for the dropped pixels in mask.
34
+ If `None`, transformation is not applied to the mask. Default: `None`.
35
+
36
+ Targets:
37
+ image, mask
38
+
39
+ Image types:
40
+ uint8, float32
41
+
42
+ References:
43
+ https://arxiv.org/abs/2001.04086
44
+
45
+ """
46
+
47
+ def __init__(
48
+ self,
49
+ ratio: float = 0.5,
50
+ unit_size_min: Optional[int] = None,
51
+ unit_size_max: Optional[int] = None,
52
+ holes_number_x: Optional[int] = None,
53
+ holes_number_y: Optional[int] = None,
54
+ shift_x: int = 0,
55
+ shift_y: int = 0,
56
+ random_offset: bool = False,
57
+ fill_value: int = 0,
58
+ mask_fill_value: Optional[int] = None,
59
+ always_apply: bool = False,
60
+ p: float = 0.5,
61
+ ):
62
+ super(GridDropout, self).__init__(always_apply, p)
63
+ self.ratio = ratio
64
+ self.unit_size_min = unit_size_min
65
+ self.unit_size_max = unit_size_max
66
+ self.holes_number_x = holes_number_x
67
+ self.holes_number_y = holes_number_y
68
+ self.shift_x = shift_x
69
+ self.shift_y = shift_y
70
+ self.random_offset = random_offset
71
+ self.fill_value = fill_value
72
+ self.mask_fill_value = mask_fill_value
73
+ if not 0 < self.ratio <= 1:
74
+ raise ValueError("ratio must be between 0 and 1.")
75
+
76
+ def apply(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray:
77
+ return F.cutout(img, holes, self.fill_value)
78
+
79
+ def apply_to_mask(self, img: np.ndarray, holes: Iterable[Tuple[int, int, int, int]] = (), **params) -> np.ndarray:
80
+ if self.mask_fill_value is None:
81
+ return img
82
+
83
+ return F.cutout(img, holes, self.mask_fill_value)
84
+
85
+ def get_params_dependent_on_targets(self, params):
86
+ img = params["image"]
87
+ height, width = img.shape[:2]
88
+ # set grid using unit size limits
89
+ if self.unit_size_min and self.unit_size_max:
90
+ if not 2 <= self.unit_size_min <= self.unit_size_max:
91
+ raise ValueError("Max unit size should be >= min size, both at least 2 pixels.")
92
+ if self.unit_size_max > min(height, width):
93
+ raise ValueError("Grid size limits must be within the shortest image edge.")
94
+ unit_width = random.randint(self.unit_size_min, self.unit_size_max + 1)
95
+ unit_height = unit_width
96
+ else:
97
+ # set grid using holes numbers
98
+ if self.holes_number_x is None:
99
+ unit_width = max(2, width // 10)
100
+ else:
101
+ if not 1 <= self.holes_number_x <= width // 2:
102
+ raise ValueError("The hole_number_x must be between 1 and image width//2.")
103
+ unit_width = width // self.holes_number_x
104
+ if self.holes_number_y is None:
105
+ unit_height = max(min(unit_width, height), 2)
106
+ else:
107
+ if not 1 <= self.holes_number_y <= height // 2:
108
+ raise ValueError("The hole_number_y must be between 1 and image height//2.")
109
+ unit_height = height // self.holes_number_y
110
+
111
+ hole_width = int(unit_width * self.ratio)
112
+ hole_height = int(unit_height * self.ratio)
113
+ # min 1 pixel and max unit length - 1
114
+ hole_width = min(max(hole_width, 1), unit_width - 1)
115
+ hole_height = min(max(hole_height, 1), unit_height - 1)
116
+ # set offset of the grid
117
+ if self.shift_x is None:
118
+ shift_x = 0
119
+ else:
120
+ shift_x = min(max(0, self.shift_x), unit_width - hole_width)
121
+ if self.shift_y is None:
122
+ shift_y = 0
123
+ else:
124
+ shift_y = min(max(0, self.shift_y), unit_height - hole_height)
125
+ if self.random_offset:
126
+ shift_x = random.randint(0, unit_width - hole_width)
127
+ shift_y = random.randint(0, unit_height - hole_height)
128
+ holes = []
129
+ for i in range(width // unit_width + 1):
130
+ for j in range(height // unit_height + 1):
131
+ x1 = min(shift_x + unit_width * i, width)
132
+ y1 = min(shift_y + unit_height * j, height)
133
+ x2 = min(x1 + hole_width, width)
134
+ y2 = min(y1 + hole_height, height)
135
+ holes.append((x1, y1, x2, y2))
136
+
137
+ return {"holes": holes}
138
+
139
+ @property
140
+ def targets_as_params(self):
141
+ return ["image"]
142
+
143
+ def get_transform_init_args_names(self):
144
+ return (
145
+ "ratio",
146
+ "unit_size_min",
147
+ "unit_size_max",
148
+ "holes_number_x",
149
+ "holes_number_y",
150
+ "shift_x",
151
+ "shift_y",
152
+ "random_offset",
153
+ "fill_value",
154
+ "mask_fill_value",
155
+ )
custom_albumentations/augmentations/dropout/mask_dropout.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Any, Dict, Optional, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from skimage.measure import label
7
+
8
+ from ...core.transforms_interface import DualTransform, to_tuple
9
+
10
+ __all__ = ["MaskDropout"]
11
+
12
+
13
+ class MaskDropout(DualTransform):
14
+ """
15
+ Image & mask augmentation that zero out mask and image regions corresponding
16
+ to randomly chosen object instance from mask.
17
+
18
+ Mask must be single-channel image, zero values treated as background.
19
+ Image can be any number of channels.
20
+
21
+ Inspired by https://www.kaggle.com/c/severstal-steel-defect-detection/discussion/114254
22
+
23
+ Args:
24
+ max_objects: Maximum number of labels that can be zeroed out. Can be tuple, in this case it's [min, max]
25
+ image_fill_value: Fill value to use when filling image.
26
+ Can be 'inpaint' to apply inpaining (works only for 3-chahnel images)
27
+ mask_fill_value: Fill value to use when filling mask.
28
+
29
+ Targets:
30
+ image, mask
31
+
32
+ Image types:
33
+ uint8, float32
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ max_objects: int = 1,
39
+ image_fill_value: Union[int, float, str] = 0,
40
+ mask_fill_value: Union[int, float] = 0,
41
+ always_apply: bool = False,
42
+ p: float = 0.5,
43
+ ):
44
+ super(MaskDropout, self).__init__(always_apply, p)
45
+ self.max_objects = to_tuple(max_objects, 1)
46
+ self.image_fill_value = image_fill_value
47
+ self.mask_fill_value = mask_fill_value
48
+
49
+ @property
50
+ def targets_as_params(self):
51
+ return ["mask"]
52
+
53
+ def get_params_dependent_on_targets(self, params) -> Dict[str, Any]:
54
+ mask = params["mask"]
55
+
56
+ label_image, num_labels = label(mask, return_num=True)
57
+
58
+ if num_labels == 0:
59
+ dropout_mask = None
60
+ else:
61
+ objects_to_drop = random.randint(int(self.max_objects[0]), int(self.max_objects[1]))
62
+ objects_to_drop = min(num_labels, objects_to_drop)
63
+
64
+ if objects_to_drop == num_labels:
65
+ dropout_mask = mask > 0
66
+ else:
67
+ labels_index = random.sample(range(1, num_labels + 1), objects_to_drop)
68
+ dropout_mask = np.zeros((mask.shape[0], mask.shape[1]), dtype=bool)
69
+ for label_index in labels_index:
70
+ dropout_mask |= label_image == label_index
71
+
72
+ params.update({"dropout_mask": dropout_mask})
73
+ return params
74
+
75
+ def apply(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
76
+ if dropout_mask is None:
77
+ return img
78
+
79
+ if self.image_fill_value == "inpaint":
80
+ dropout_mask = dropout_mask.astype(np.uint8)
81
+ _, _, w, h = cv2.boundingRect(dropout_mask)
82
+ radius = min(3, max(w, h) // 2)
83
+ img = cv2.inpaint(img, dropout_mask, radius, cv2.INPAINT_NS)
84
+ else:
85
+ img = img.copy()
86
+ img[dropout_mask] = self.image_fill_value
87
+
88
+ return img
89
+
90
+ def apply_to_mask(self, img: np.ndarray, dropout_mask: Optional[np.ndarray] = None, **params) -> np.ndarray:
91
+ if dropout_mask is None:
92
+ return img
93
+
94
+ img = img.copy()
95
+ img[dropout_mask] = self.mask_fill_value
96
+ return img
97
+
98
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
99
+ return "max_objects", "image_fill_value", "mask_fill_value"
custom_albumentations/augmentations/functional.py ADDED
@@ -0,0 +1,1380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ from typing import Optional, Sequence, Union
4
+ from warnings import warn
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import skimage
9
+
10
+ from custom_albumentations import random_utils
11
+ from custom_albumentations.augmentations.utils import (
12
+ MAX_VALUES_BY_DTYPE,
13
+ _maybe_process_in_chunks,
14
+ clip,
15
+ clipped,
16
+ ensure_contiguous,
17
+ is_grayscale_image,
18
+ is_rgb_image,
19
+ non_rgb_warning,
20
+ preserve_channel_dim,
21
+ preserve_shape,
22
+ )
23
+
24
+ __all__ = [
25
+ "add_fog",
26
+ "add_rain",
27
+ "add_shadow",
28
+ "add_gravel",
29
+ "add_snow",
30
+ "add_sun_flare",
31
+ "add_weighted",
32
+ "adjust_brightness_torchvision",
33
+ "adjust_contrast_torchvision",
34
+ "adjust_hue_torchvision",
35
+ "adjust_saturation_torchvision",
36
+ "brightness_contrast_adjust",
37
+ "channel_shuffle",
38
+ "clahe",
39
+ "convolve",
40
+ "downscale",
41
+ "equalize",
42
+ "fancy_pca",
43
+ "from_float",
44
+ "gamma_transform",
45
+ "gauss_noise",
46
+ "image_compression",
47
+ "invert",
48
+ "iso_noise",
49
+ "linear_transformation_rgb",
50
+ "move_tone_curve",
51
+ "multiply",
52
+ "noop",
53
+ "normalize",
54
+ "posterize",
55
+ "shift_hsv",
56
+ "shift_rgb",
57
+ "solarize",
58
+ "superpixels",
59
+ "swap_tiles_on_image",
60
+ "to_float",
61
+ "to_gray",
62
+ "gray_to_rgb",
63
+ "unsharp_mask",
64
+ ]
65
+
66
+
67
+ def normalize_cv2(img, mean, denominator):
68
+ if mean.shape and len(mean) != 4 and mean.shape != img.shape:
69
+ mean = np.array(mean.tolist() + [0] * (4 - len(mean)), dtype=np.float64)
70
+ if not denominator.shape:
71
+ denominator = np.array([denominator.tolist()] * 4, dtype=np.float64)
72
+ elif len(denominator) != 4 and denominator.shape != img.shape:
73
+ denominator = np.array(denominator.tolist() + [1] * (4 - len(denominator)), dtype=np.float64)
74
+
75
+ img = np.ascontiguousarray(img.astype("float32"))
76
+ cv2.subtract(img, mean.astype(np.float64), img)
77
+ cv2.multiply(img, denominator.astype(np.float64), img)
78
+ return img
79
+
80
+
81
+ def normalize_numpy(img, mean, denominator):
82
+ img = img.astype(np.float32)
83
+ img -= mean
84
+ img *= denominator
85
+ return img
86
+
87
+
88
+ def normalize(img, mean, std, max_pixel_value=255.0):
89
+ mean = np.array(mean, dtype=np.float32)
90
+ mean *= max_pixel_value
91
+
92
+ std = np.array(std, dtype=np.float32)
93
+ std *= max_pixel_value
94
+
95
+ denominator = np.reciprocal(std, dtype=np.float32)
96
+
97
+ if img.ndim == 3 and img.shape[-1] == 3:
98
+ return normalize_cv2(img, mean, denominator)
99
+ return normalize_numpy(img, mean, denominator)
100
+
101
+
102
+ def _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift):
103
+ dtype = img.dtype
104
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
105
+ hue, sat, val = cv2.split(img)
106
+
107
+ if hue_shift != 0:
108
+ lut_hue = np.arange(0, 256, dtype=np.int16)
109
+ lut_hue = np.mod(lut_hue + hue_shift, 180).astype(dtype)
110
+ hue = cv2.LUT(hue, lut_hue)
111
+
112
+ if sat_shift != 0:
113
+ lut_sat = np.arange(0, 256, dtype=np.int16)
114
+ lut_sat = np.clip(lut_sat + sat_shift, 0, 255).astype(dtype)
115
+ sat = cv2.LUT(sat, lut_sat)
116
+
117
+ if val_shift != 0:
118
+ lut_val = np.arange(0, 256, dtype=np.int16)
119
+ lut_val = np.clip(lut_val + val_shift, 0, 255).astype(dtype)
120
+ val = cv2.LUT(val, lut_val)
121
+
122
+ img = cv2.merge((hue, sat, val)).astype(dtype)
123
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
124
+ return img
125
+
126
+
127
+ def _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift):
128
+ dtype = img.dtype
129
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
130
+ hue, sat, val = cv2.split(img)
131
+
132
+ if hue_shift != 0:
133
+ hue = cv2.add(hue, hue_shift)
134
+ hue = np.mod(hue, 360) # OpenCV fails with negative values
135
+
136
+ if sat_shift != 0:
137
+ sat = clip(cv2.add(sat, sat_shift), dtype, 1.0)
138
+
139
+ if val_shift != 0:
140
+ val = clip(cv2.add(val, val_shift), dtype, 1.0)
141
+
142
+ img = cv2.merge((hue, sat, val))
143
+ img = cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
144
+ return img
145
+
146
+
147
+ @preserve_shape
148
+ def shift_hsv(img, hue_shift, sat_shift, val_shift):
149
+ if hue_shift == 0 and sat_shift == 0 and val_shift == 0:
150
+ return img
151
+
152
+ is_gray = is_grayscale_image(img)
153
+ if is_gray:
154
+ if hue_shift != 0 or sat_shift != 0:
155
+ hue_shift = 0
156
+ sat_shift = 0
157
+ warn(
158
+ "HueSaturationValue: hue_shift and sat_shift are not applicable to grayscale image. "
159
+ "Set them to 0 or use RGB image"
160
+ )
161
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
162
+
163
+ if img.dtype == np.uint8:
164
+ img = _shift_hsv_uint8(img, hue_shift, sat_shift, val_shift)
165
+ else:
166
+ img = _shift_hsv_non_uint8(img, hue_shift, sat_shift, val_shift)
167
+
168
+ if is_gray:
169
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
170
+
171
+ return img
172
+
173
+
174
+ def solarize(img, threshold=128):
175
+ """Invert all pixel values above a threshold.
176
+
177
+ Args:
178
+ img (numpy.ndarray): The image to solarize.
179
+ threshold (int): All pixels above this greyscale level are inverted.
180
+
181
+ Returns:
182
+ numpy.ndarray: Solarized image.
183
+
184
+ """
185
+ dtype = img.dtype
186
+ max_val = MAX_VALUES_BY_DTYPE[dtype]
187
+
188
+ if dtype == np.dtype("uint8"):
189
+ lut = [(i if i < threshold else max_val - i) for i in range(max_val + 1)]
190
+
191
+ prev_shape = img.shape
192
+ img = cv2.LUT(img, np.array(lut, dtype=dtype))
193
+
194
+ if len(prev_shape) != len(img.shape):
195
+ img = np.expand_dims(img, -1)
196
+ return img
197
+
198
+ result_img = img.copy()
199
+ cond = img >= threshold
200
+ result_img[cond] = max_val - result_img[cond]
201
+ return result_img
202
+
203
+
204
+ @preserve_shape
205
+ def posterize(img, bits):
206
+ """Reduce the number of bits for each color channel.
207
+
208
+ Args:
209
+ img (numpy.ndarray): image to posterize.
210
+ bits (int): number of high bits. Must be in range [0, 8]
211
+
212
+ Returns:
213
+ numpy.ndarray: Image with reduced color channels.
214
+
215
+ """
216
+ bits = np.uint8(bits)
217
+
218
+ if img.dtype != np.uint8:
219
+ raise TypeError("Image must have uint8 channel type")
220
+ if np.any((bits < 0) | (bits > 8)):
221
+ raise ValueError("bits must be in range [0, 8]")
222
+
223
+ if not bits.shape or len(bits) == 1:
224
+ if bits == 0:
225
+ return np.zeros_like(img)
226
+ if bits == 8:
227
+ return img.copy()
228
+
229
+ lut = np.arange(0, 256, dtype=np.uint8)
230
+ mask = ~np.uint8(2 ** (8 - bits) - 1)
231
+ lut &= mask
232
+
233
+ return cv2.LUT(img, lut)
234
+
235
+ if not is_rgb_image(img):
236
+ raise TypeError("If bits is iterable image must be RGB")
237
+
238
+ result_img = np.empty_like(img)
239
+ for i, channel_bits in enumerate(bits):
240
+ if channel_bits == 0:
241
+ result_img[..., i] = np.zeros_like(img[..., i])
242
+ elif channel_bits == 8:
243
+ result_img[..., i] = img[..., i].copy()
244
+ else:
245
+ lut = np.arange(0, 256, dtype=np.uint8)
246
+ mask = ~np.uint8(2 ** (8 - channel_bits) - 1)
247
+ lut &= mask
248
+
249
+ result_img[..., i] = cv2.LUT(img[..., i], lut)
250
+
251
+ return result_img
252
+
253
+
254
+ def _equalize_pil(img, mask=None):
255
+ histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
256
+ h = [_f for _f in histogram if _f]
257
+
258
+ if len(h) <= 1:
259
+ return img.copy()
260
+
261
+ step = np.sum(h[:-1]) // 255
262
+ if not step:
263
+ return img.copy()
264
+
265
+ lut = np.empty(256, dtype=np.uint8)
266
+ n = step // 2
267
+ for i in range(256):
268
+ lut[i] = min(n // step, 255)
269
+ n += histogram[i]
270
+
271
+ return cv2.LUT(img, np.array(lut))
272
+
273
+
274
+ def _equalize_cv(img, mask=None):
275
+ if mask is None:
276
+ return cv2.equalizeHist(img)
277
+
278
+ histogram = cv2.calcHist([img], [0], mask, [256], (0, 256)).ravel()
279
+ i = 0
280
+ for val in histogram:
281
+ if val > 0:
282
+ break
283
+ i += 1
284
+ i = min(i, 255)
285
+
286
+ total = np.sum(histogram)
287
+ if histogram[i] == total:
288
+ return np.full_like(img, i)
289
+
290
+ scale = 255.0 / (total - histogram[i])
291
+ _sum = 0
292
+
293
+ lut = np.zeros(256, dtype=np.uint8)
294
+ i += 1
295
+ for i in range(i, len(histogram)):
296
+ _sum += histogram[i]
297
+ lut[i] = clip(round(_sum * scale), np.dtype("uint8"), 255)
298
+
299
+ return cv2.LUT(img, lut)
300
+
301
+
302
+ @preserve_channel_dim
303
+ def equalize(img, mask=None, mode="cv", by_channels=True):
304
+ """Equalize the image histogram.
305
+
306
+ Args:
307
+ img (numpy.ndarray): RGB or grayscale image.
308
+ mask (numpy.ndarray): An optional mask. If given, only the pixels selected by
309
+ the mask are included in the analysis. Maybe 1 channel or 3 channel array.
310
+ mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
311
+ by_channels (bool): If True, use equalization by channels separately,
312
+ else convert image to YCbCr representation and use equalization by `Y` channel.
313
+
314
+ Returns:
315
+ numpy.ndarray: Equalized image.
316
+
317
+ """
318
+ if img.dtype != np.uint8:
319
+ raise TypeError("Image must have uint8 channel type")
320
+
321
+ modes = ["cv", "pil"]
322
+
323
+ if mode not in modes:
324
+ raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode))
325
+ if mask is not None:
326
+ if is_rgb_image(mask) and is_grayscale_image(img):
327
+ raise ValueError("Wrong mask shape. Image shape: {}. " "Mask shape: {}".format(img.shape, mask.shape))
328
+ if not by_channels and not is_grayscale_image(mask):
329
+ raise ValueError(
330
+ "When by_channels=False only 1-channel mask supports. " "Mask shape: {}".format(mask.shape)
331
+ )
332
+
333
+ if mode == "pil":
334
+ function = _equalize_pil
335
+ else:
336
+ function = _equalize_cv
337
+
338
+ if mask is not None:
339
+ mask = mask.astype(np.uint8)
340
+
341
+ if is_grayscale_image(img):
342
+ return function(img, mask)
343
+
344
+ if not by_channels:
345
+ result_img = cv2.cvtColor(img, cv2.COLOR_RGB2YCrCb)
346
+ result_img[..., 0] = function(result_img[..., 0], mask)
347
+ return cv2.cvtColor(result_img, cv2.COLOR_YCrCb2RGB)
348
+
349
+ result_img = np.empty_like(img)
350
+ for i in range(3):
351
+ if mask is None:
352
+ _mask = None
353
+ elif is_grayscale_image(mask):
354
+ _mask = mask
355
+ else:
356
+ _mask = mask[..., i]
357
+
358
+ result_img[..., i] = function(img[..., i], _mask)
359
+
360
+ return result_img
361
+
362
+
363
+ @preserve_shape
364
+ def move_tone_curve(img, low_y, high_y):
365
+ """Rescales the relationship between bright and dark areas of the image by manipulating its tone curve.
366
+
367
+ Args:
368
+ img (numpy.ndarray): RGB or grayscale image.
369
+ low_y (float): y-position of a Bezier control point used
370
+ to adjust the tone curve, must be in range [0, 1]
371
+ high_y (float): y-position of a Bezier control point used
372
+ to adjust image tone curve, must be in range [0, 1]
373
+ """
374
+ input_dtype = img.dtype
375
+
376
+ if low_y < 0 or low_y > 1:
377
+ raise ValueError("low_shift must be in range [0, 1]")
378
+ if high_y < 0 or high_y > 1:
379
+ raise ValueError("high_shift must be in range [0, 1]")
380
+
381
+ if input_dtype != np.uint8:
382
+ raise ValueError("Unsupported image type {}".format(input_dtype))
383
+
384
+ t = np.linspace(0.0, 1.0, 256)
385
+
386
+ # Defines responze of a four-point bezier curve
387
+ def evaluate_bez(t):
388
+ return 3 * (1 - t) ** 2 * t * low_y + 3 * (1 - t) * t**2 * high_y + t**3
389
+
390
+ evaluate_bez = np.vectorize(evaluate_bez)
391
+ remapping = np.rint(evaluate_bez(t) * 255).astype(np.uint8)
392
+
393
+ lut_fn = _maybe_process_in_chunks(cv2.LUT, lut=remapping)
394
+ img = lut_fn(img)
395
+ return img
396
+
397
+
398
+ @clipped
399
+ def _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift):
400
+ if r_shift == g_shift == b_shift:
401
+ return img + r_shift
402
+
403
+ result_img = np.empty_like(img)
404
+ shifts = [r_shift, g_shift, b_shift]
405
+ for i, shift in enumerate(shifts):
406
+ result_img[..., i] = img[..., i] + shift
407
+
408
+ return result_img
409
+
410
+
411
+ def _shift_image_uint8(img, value):
412
+ max_value = MAX_VALUES_BY_DTYPE[img.dtype]
413
+
414
+ lut = np.arange(0, max_value + 1).astype("float32")
415
+ lut += value
416
+
417
+ lut = np.clip(lut, 0, max_value).astype(img.dtype)
418
+ return cv2.LUT(img, lut)
419
+
420
+
421
+ @preserve_shape
422
+ def _shift_rgb_uint8(img, r_shift, g_shift, b_shift):
423
+ if r_shift == g_shift == b_shift:
424
+ h, w, c = img.shape
425
+ img = img.reshape([h, w * c])
426
+
427
+ return _shift_image_uint8(img, r_shift)
428
+
429
+ result_img = np.empty_like(img)
430
+ shifts = [r_shift, g_shift, b_shift]
431
+ for i, shift in enumerate(shifts):
432
+ result_img[..., i] = _shift_image_uint8(img[..., i], shift)
433
+
434
+ return result_img
435
+
436
+
437
+ def shift_rgb(img, r_shift, g_shift, b_shift):
438
+ if img.dtype == np.uint8:
439
+ return _shift_rgb_uint8(img, r_shift, g_shift, b_shift)
440
+
441
+ return _shift_rgb_non_uint8(img, r_shift, g_shift, b_shift)
442
+
443
+
444
+ @clipped
445
+ def linear_transformation_rgb(img, transformation_matrix):
446
+ result_img = cv2.transform(img, transformation_matrix)
447
+
448
+ return result_img
449
+
450
+
451
+ @preserve_channel_dim
452
+ def clahe(img, clip_limit=2.0, tile_grid_size=(8, 8)):
453
+ if img.dtype != np.uint8:
454
+ raise TypeError("clahe supports only uint8 inputs")
455
+
456
+ clahe_mat = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
457
+
458
+ if len(img.shape) == 2 or img.shape[2] == 1:
459
+ img = clahe_mat.apply(img)
460
+ else:
461
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
462
+ img[:, :, 0] = clahe_mat.apply(img[:, :, 0])
463
+ img = cv2.cvtColor(img, cv2.COLOR_LAB2RGB)
464
+
465
+ return img
466
+
467
+
468
+ @preserve_shape
469
+ def convolve(img, kernel):
470
+ conv_fn = _maybe_process_in_chunks(cv2.filter2D, ddepth=-1, kernel=kernel)
471
+ return conv_fn(img)
472
+
473
+
474
+ @preserve_shape
475
+ def image_compression(img, quality, image_type):
476
+ if image_type in [".jpeg", ".jpg"]:
477
+ quality_flag = cv2.IMWRITE_JPEG_QUALITY
478
+ elif image_type == ".webp":
479
+ quality_flag = cv2.IMWRITE_WEBP_QUALITY
480
+ else:
481
+ NotImplementedError("Only '.jpg' and '.webp' compression transforms are implemented. ")
482
+
483
+ input_dtype = img.dtype
484
+ needs_float = False
485
+
486
+ if input_dtype == np.float32:
487
+ warn(
488
+ "Image compression augmentation "
489
+ "is most effective with uint8 inputs, "
490
+ "{} is used as input.".format(input_dtype),
491
+ UserWarning,
492
+ )
493
+ img = from_float(img, dtype=np.dtype("uint8"))
494
+ needs_float = True
495
+ elif input_dtype not in (np.uint8, np.float32):
496
+ raise ValueError("Unexpected dtype {} for image augmentation".format(input_dtype))
497
+
498
+ _, encoded_img = cv2.imencode(image_type, img, (int(quality_flag), quality))
499
+ img = cv2.imdecode(encoded_img, cv2.IMREAD_UNCHANGED)
500
+
501
+ if needs_float:
502
+ img = to_float(img, max_value=255)
503
+ return img
504
+
505
+
506
+ @preserve_shape
507
+ def add_snow(img, snow_point, brightness_coeff):
508
+ """Bleaches out pixels, imitation snow.
509
+
510
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
511
+
512
+ Args:
513
+ img (numpy.ndarray): Image.
514
+ snow_point: Number of show points.
515
+ brightness_coeff: Brightness coefficient.
516
+
517
+ Returns:
518
+ numpy.ndarray: Image.
519
+
520
+ """
521
+ non_rgb_warning(img)
522
+
523
+ input_dtype = img.dtype
524
+ needs_float = False
525
+
526
+ snow_point *= 127.5 # = 255 / 2
527
+ snow_point += 85 # = 255 / 3
528
+
529
+ if input_dtype == np.float32:
530
+ img = from_float(img, dtype=np.dtype("uint8"))
531
+ needs_float = True
532
+ elif input_dtype not in (np.uint8, np.float32):
533
+ raise ValueError("Unexpected dtype {} for RandomSnow augmentation".format(input_dtype))
534
+
535
+ image_HLS = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
536
+ image_HLS = np.array(image_HLS, dtype=np.float32)
537
+
538
+ image_HLS[:, :, 1][image_HLS[:, :, 1] < snow_point] *= brightness_coeff
539
+
540
+ image_HLS[:, :, 1] = clip(image_HLS[:, :, 1], np.uint8, 255)
541
+
542
+ image_HLS = np.array(image_HLS, dtype=np.uint8)
543
+
544
+ image_RGB = cv2.cvtColor(image_HLS, cv2.COLOR_HLS2RGB)
545
+
546
+ if needs_float:
547
+ image_RGB = to_float(image_RGB, max_value=255)
548
+
549
+ return image_RGB
550
+
551
+
552
+ @preserve_shape
553
+ def add_rain(
554
+ img,
555
+ slant,
556
+ drop_length,
557
+ drop_width,
558
+ drop_color,
559
+ blur_value,
560
+ brightness_coefficient,
561
+ rain_drops,
562
+ ):
563
+ """
564
+
565
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
566
+
567
+ Args:
568
+ img (numpy.ndarray): Image.
569
+ slant (int):
570
+ drop_length:
571
+ drop_width:
572
+ drop_color:
573
+ blur_value (int): Rainy view are blurry.
574
+ brightness_coefficient (float): Rainy days are usually shady.
575
+ rain_drops:
576
+
577
+ Returns:
578
+ numpy.ndarray: Image.
579
+
580
+ """
581
+ non_rgb_warning(img)
582
+
583
+ input_dtype = img.dtype
584
+ needs_float = False
585
+
586
+ if input_dtype == np.float32:
587
+ img = from_float(img, dtype=np.dtype("uint8"))
588
+ needs_float = True
589
+ elif input_dtype not in (np.uint8, np.float32):
590
+ raise ValueError("Unexpected dtype {} for RandomRain augmentation".format(input_dtype))
591
+
592
+ image = img.copy()
593
+
594
+ for rain_drop_x0, rain_drop_y0 in rain_drops:
595
+ rain_drop_x1 = rain_drop_x0 + slant
596
+ rain_drop_y1 = rain_drop_y0 + drop_length
597
+
598
+ cv2.line(
599
+ image,
600
+ (rain_drop_x0, rain_drop_y0),
601
+ (rain_drop_x1, rain_drop_y1),
602
+ drop_color,
603
+ drop_width,
604
+ )
605
+
606
+ image = cv2.blur(image, (blur_value, blur_value)) # rainy view are blurry
607
+ image_hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.float32)
608
+ image_hsv[:, :, 2] *= brightness_coefficient
609
+
610
+ image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
611
+
612
+ if needs_float:
613
+ image_rgb = to_float(image_rgb, max_value=255)
614
+
615
+ return image_rgb
616
+
617
+
618
+ @preserve_shape
619
+ def add_fog(img, fog_coef, alpha_coef, haze_list):
620
+ """Add fog to the image.
621
+
622
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
623
+
624
+ Args:
625
+ img (numpy.ndarray): Image.
626
+ fog_coef (float): Fog coefficient.
627
+ alpha_coef (float): Alpha coefficient.
628
+ haze_list (list):
629
+
630
+ Returns:
631
+ numpy.ndarray: Image.
632
+
633
+ """
634
+ non_rgb_warning(img)
635
+
636
+ input_dtype = img.dtype
637
+ needs_float = False
638
+
639
+ if input_dtype == np.float32:
640
+ img = from_float(img, dtype=np.dtype("uint8"))
641
+ needs_float = True
642
+ elif input_dtype not in (np.uint8, np.float32):
643
+ raise ValueError("Unexpected dtype {} for RandomFog augmentation".format(input_dtype))
644
+
645
+ width = img.shape[1]
646
+
647
+ hw = max(int(width // 3 * fog_coef), 10)
648
+
649
+ for haze_points in haze_list:
650
+ x, y = haze_points
651
+ overlay = img.copy()
652
+ output = img.copy()
653
+ alpha = alpha_coef * fog_coef
654
+ rad = hw // 2
655
+ point = (x + hw // 2, y + hw // 2)
656
+ cv2.circle(overlay, point, int(rad), (255, 255, 255), -1)
657
+ cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)
658
+
659
+ img = output.copy()
660
+
661
+ image_rgb = cv2.blur(img, (hw // 10, hw // 10))
662
+
663
+ if needs_float:
664
+ image_rgb = to_float(image_rgb, max_value=255)
665
+
666
+ return image_rgb
667
+
668
+
669
+ @preserve_shape
670
+ def add_sun_flare(img, flare_center_x, flare_center_y, src_radius, src_color, circles):
671
+ """Add sun flare.
672
+
673
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
674
+
675
+ Args:
676
+ img (numpy.ndarray):
677
+ flare_center_x (float):
678
+ flare_center_y (float):
679
+ src_radius:
680
+ src_color (int, int, int):
681
+ circles (list):
682
+
683
+ Returns:
684
+ numpy.ndarray:
685
+
686
+ """
687
+ non_rgb_warning(img)
688
+
689
+ input_dtype = img.dtype
690
+ needs_float = False
691
+
692
+ if input_dtype == np.float32:
693
+ img = from_float(img, dtype=np.dtype("uint8"))
694
+ needs_float = True
695
+ elif input_dtype not in (np.uint8, np.float32):
696
+ raise ValueError("Unexpected dtype {} for RandomSunFlareaugmentation".format(input_dtype))
697
+
698
+ overlay = img.copy()
699
+ output = img.copy()
700
+
701
+ for alpha, (x, y), rad3, (r_color, g_color, b_color) in circles:
702
+ cv2.circle(overlay, (x, y), rad3, (r_color, g_color, b_color), -1)
703
+
704
+ cv2.addWeighted(overlay, alpha, output, 1 - alpha, 0, output)
705
+
706
+ point = (int(flare_center_x), int(flare_center_y))
707
+
708
+ overlay = output.copy()
709
+ num_times = src_radius // 10
710
+ alpha = np.linspace(0.0, 1, num=num_times)
711
+ rad = np.linspace(1, src_radius, num=num_times)
712
+ for i in range(num_times):
713
+ cv2.circle(overlay, point, int(rad[i]), src_color, -1)
714
+ alp = alpha[num_times - i - 1] * alpha[num_times - i - 1] * alpha[num_times - i - 1]
715
+ cv2.addWeighted(overlay, alp, output, 1 - alp, 0, output)
716
+
717
+ image_rgb = output
718
+
719
+ if needs_float:
720
+ image_rgb = to_float(image_rgb, max_value=255)
721
+
722
+ return image_rgb
723
+
724
+
725
+ @ensure_contiguous
726
+ @preserve_shape
727
+ def add_shadow(img, vertices_list):
728
+ """Add shadows to the image.
729
+
730
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
731
+
732
+ Args:
733
+ img (numpy.ndarray):
734
+ vertices_list (list):
735
+
736
+ Returns:
737
+ numpy.ndarray:
738
+
739
+ """
740
+ non_rgb_warning(img)
741
+ input_dtype = img.dtype
742
+ needs_float = False
743
+
744
+ if input_dtype == np.float32:
745
+ img = from_float(img, dtype=np.dtype("uint8"))
746
+ needs_float = True
747
+ elif input_dtype not in (np.uint8, np.float32):
748
+ raise ValueError("Unexpected dtype {} for RandomShadow augmentation".format(input_dtype))
749
+
750
+ image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
751
+ mask = np.zeros_like(img)
752
+
753
+ # adding all shadow polygons on empty mask, single 255 denotes only red channel
754
+ for vertices in vertices_list:
755
+ cv2.fillPoly(mask, vertices, 255)
756
+
757
+ # if red channel is hot, image's "Lightness" channel's brightness is lowered
758
+ red_max_value_ind = mask[:, :, 0] == 255
759
+ image_hls[:, :, 1][red_max_value_ind] = image_hls[:, :, 1][red_max_value_ind] * 0.5
760
+
761
+ image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
762
+
763
+ if needs_float:
764
+ image_rgb = to_float(image_rgb, max_value=255)
765
+
766
+ return image_rgb
767
+
768
+
769
+ @ensure_contiguous
770
+ @preserve_shape
771
+ def add_gravel(img: np.ndarray, gravels: list):
772
+ """Add gravel to the image.
773
+
774
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
775
+
776
+ Args:
777
+ img (numpy.ndarray): image to add gravel to
778
+ gravels (list): list of gravel parameters. (float, float, float, float):
779
+ (top-left x, top-left y, bottom-right x, bottom right y)
780
+
781
+ Returns:
782
+ numpy.ndarray:
783
+ """
784
+ non_rgb_warning(img)
785
+ input_dtype = img.dtype
786
+ needs_float = False
787
+
788
+ if input_dtype == np.float32:
789
+ img = from_float(img, dtype=np.dtype("uint8"))
790
+ needs_float = True
791
+ elif input_dtype not in (np.uint8, np.float32):
792
+ raise ValueError("Unexpected dtype {} for AddGravel augmentation".format(input_dtype))
793
+
794
+ image_hls = cv2.cvtColor(img, cv2.COLOR_RGB2HLS)
795
+
796
+ for gravel in gravels:
797
+ y1, y2, x1, x2, sat = gravel
798
+ image_hls[x1:x2, y1:y2, 1] = sat
799
+
800
+ image_rgb = cv2.cvtColor(image_hls, cv2.COLOR_HLS2RGB)
801
+
802
+ if needs_float:
803
+ image_rgb = to_float(image_rgb, max_value=255)
804
+
805
+ return image_rgb
806
+
807
+
808
+ def invert(img: np.ndarray) -> np.ndarray:
809
+ # Supports all the valid dtypes
810
+ # clips the img to avoid unexpected behaviour.
811
+ return MAX_VALUES_BY_DTYPE[img.dtype] - img
812
+
813
+
814
+ def channel_shuffle(img, channels_shuffled):
815
+ img = img[..., channels_shuffled]
816
+ return img
817
+
818
+
819
+ @preserve_shape
820
+ def gamma_transform(img, gamma):
821
+ if img.dtype == np.uint8:
822
+ table = (np.arange(0, 256.0 / 255, 1.0 / 255) ** gamma) * 255
823
+ img = cv2.LUT(img, table.astype(np.uint8))
824
+ else:
825
+ img = np.power(img, gamma)
826
+
827
+ return img
828
+
829
+
830
+ @clipped
831
+ def gauss_noise(image, gauss):
832
+ image = image.astype("float32")
833
+ return image + gauss
834
+
835
+
836
+ @clipped
837
+ def _brightness_contrast_adjust_non_uint(img, alpha=1, beta=0, beta_by_max=False):
838
+ dtype = img.dtype
839
+ img = img.astype("float32")
840
+
841
+ if alpha != 1:
842
+ img *= alpha
843
+ if beta != 0:
844
+ if beta_by_max:
845
+ max_value = MAX_VALUES_BY_DTYPE[dtype]
846
+ img += beta * max_value
847
+ else:
848
+ img += beta * np.mean(img)
849
+ return img
850
+
851
+
852
+ @preserve_shape
853
+ def _brightness_contrast_adjust_uint(img, alpha=1, beta=0, beta_by_max=False):
854
+ dtype = np.dtype("uint8")
855
+
856
+ max_value = MAX_VALUES_BY_DTYPE[dtype]
857
+
858
+ lut = np.arange(0, max_value + 1).astype("float32")
859
+
860
+ if alpha != 1:
861
+ lut *= alpha
862
+ if beta != 0:
863
+ if beta_by_max:
864
+ lut += beta * max_value
865
+ else:
866
+ lut += (alpha * beta) * np.mean(img)
867
+
868
+ lut = np.clip(lut, 0, max_value).astype(dtype)
869
+ img = cv2.LUT(img, lut)
870
+ return img
871
+
872
+
873
+ def brightness_contrast_adjust(img, alpha=1, beta=0, beta_by_max=False):
874
+ if img.dtype == np.uint8:
875
+ return _brightness_contrast_adjust_uint(img, alpha, beta, beta_by_max)
876
+
877
+ return _brightness_contrast_adjust_non_uint(img, alpha, beta, beta_by_max)
878
+
879
+
880
+ @clipped
881
+ def iso_noise(image, color_shift=0.05, intensity=0.5, random_state=None, **kwargs):
882
+ """
883
+ Apply poisson noise to image to simulate camera sensor noise.
884
+
885
+ Args:
886
+ image (numpy.ndarray): Input image, currently, only RGB, uint8 images are supported.
887
+ color_shift (float):
888
+ intensity (float): Multiplication factor for noise values. Values of ~0.5 are produce noticeable,
889
+ yet acceptable level of noise.
890
+ random_state:
891
+ **kwargs:
892
+
893
+ Returns:
894
+ numpy.ndarray: Noised image
895
+
896
+ """
897
+ if image.dtype != np.uint8:
898
+ raise TypeError("Image must have uint8 channel type")
899
+ if not is_rgb_image(image):
900
+ raise TypeError("Image must be RGB")
901
+
902
+ one_over_255 = float(1.0 / 255.0)
903
+ image = np.multiply(image, one_over_255, dtype=np.float32)
904
+ hls = cv2.cvtColor(image, cv2.COLOR_RGB2HLS)
905
+ _, stddev = cv2.meanStdDev(hls)
906
+
907
+ luminance_noise = random_utils.poisson(stddev[1] * intensity * 255, size=hls.shape[:2], random_state=random_state)
908
+ color_noise = random_utils.normal(0, color_shift * 360 * intensity, size=hls.shape[:2], random_state=random_state)
909
+
910
+ hue = hls[..., 0]
911
+ hue += color_noise
912
+ hue[hue < 0] += 360
913
+ hue[hue > 360] -= 360
914
+
915
+ luminance = hls[..., 1]
916
+ luminance += (luminance_noise / 255) * (1.0 - luminance)
917
+
918
+ image = cv2.cvtColor(hls, cv2.COLOR_HLS2RGB) * 255
919
+ return image.astype(np.uint8)
920
+
921
+
922
+ def to_gray(img):
923
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
924
+ return cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
925
+
926
+
927
+ def gray_to_rgb(img):
928
+ return cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
929
+
930
+
931
+ @preserve_shape
932
+ def downscale(img, scale, down_interpolation=cv2.INTER_AREA, up_interpolation=cv2.INTER_LINEAR):
933
+ h, w = img.shape[:2]
934
+
935
+ need_cast = (
936
+ up_interpolation != cv2.INTER_NEAREST or down_interpolation != cv2.INTER_NEAREST
937
+ ) and img.dtype == np.uint8
938
+ if need_cast:
939
+ img = to_float(img)
940
+ downscaled = cv2.resize(img, None, fx=scale, fy=scale, interpolation=down_interpolation)
941
+ upscaled = cv2.resize(downscaled, (w, h), interpolation=up_interpolation)
942
+ if need_cast:
943
+ upscaled = from_float(np.clip(upscaled, 0, 1), dtype=np.dtype("uint8"))
944
+ return upscaled
945
+
946
+
947
+ def to_float(img, max_value=None):
948
+ if max_value is None:
949
+ try:
950
+ max_value = MAX_VALUES_BY_DTYPE[img.dtype]
951
+ except KeyError:
952
+ raise RuntimeError(
953
+ "Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by "
954
+ "passing the max_value argument".format(img.dtype)
955
+ )
956
+ return img.astype("float32") / max_value
957
+
958
+
959
+ def from_float(img, dtype, max_value=None):
960
+ if max_value is None:
961
+ try:
962
+ max_value = MAX_VALUES_BY_DTYPE[dtype]
963
+ except KeyError:
964
+ raise RuntimeError(
965
+ "Can't infer the maximum value for dtype {}. You need to specify the maximum value manually by "
966
+ "passing the max_value argument".format(dtype)
967
+ )
968
+ return (img * max_value).astype(dtype)
969
+
970
+
971
+ def noop(input_obj, **params): # skipcq: PYL-W0613
972
+ return input_obj
973
+
974
+
975
+ def swap_tiles_on_image(image, tiles):
976
+ """
977
+ Swap tiles on image.
978
+
979
+ Args:
980
+ image (np.ndarray): Input image.
981
+ tiles (np.ndarray): array of tuples(
982
+ current_left_up_corner_row, current_left_up_corner_col,
983
+ old_left_up_corner_row, old_left_up_corner_col,
984
+ height_tile, width_tile)
985
+
986
+ Returns:
987
+ np.ndarray: Output image.
988
+
989
+ """
990
+ new_image = image.copy()
991
+
992
+ for tile in tiles:
993
+ new_image[tile[0] : tile[0] + tile[4], tile[1] : tile[1] + tile[5]] = image[
994
+ tile[2] : tile[2] + tile[4], tile[3] : tile[3] + tile[5]
995
+ ]
996
+
997
+ return new_image
998
+
999
+
1000
+ @clipped
1001
+ def _multiply_uint8(img, multiplier):
1002
+ img = img.astype(np.float32)
1003
+ return np.multiply(img, multiplier)
1004
+
1005
+
1006
+ @preserve_shape
1007
+ def _multiply_uint8_optimized(img, multiplier):
1008
+ if is_grayscale_image(img) or len(multiplier) == 1:
1009
+ multiplier = multiplier[0]
1010
+ lut = np.arange(0, 256, dtype=np.float32)
1011
+ lut *= multiplier
1012
+ lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype])
1013
+ func = _maybe_process_in_chunks(cv2.LUT, lut=lut)
1014
+ return func(img)
1015
+
1016
+ channels = img.shape[-1]
1017
+ lut = [np.arange(0, 256, dtype=np.float32)] * channels
1018
+ lut = np.stack(lut, axis=-1)
1019
+
1020
+ lut *= multiplier
1021
+ lut = clip(lut, np.uint8, MAX_VALUES_BY_DTYPE[img.dtype])
1022
+
1023
+ images = []
1024
+ for i in range(channels):
1025
+ func = _maybe_process_in_chunks(cv2.LUT, lut=lut[:, i])
1026
+ images.append(func(img[:, :, i]))
1027
+ return np.stack(images, axis=-1)
1028
+
1029
+
1030
+ @clipped
1031
+ def _multiply_non_uint8(img, multiplier):
1032
+ return img * multiplier
1033
+
1034
+
1035
+ def multiply(img, multiplier):
1036
+ """
1037
+ Args:
1038
+ img (numpy.ndarray): Image.
1039
+ multiplier (numpy.ndarray): Multiplier coefficient.
1040
+
1041
+ Returns:
1042
+ numpy.ndarray: Image multiplied by `multiplier` coefficient.
1043
+
1044
+ """
1045
+ if img.dtype == np.uint8:
1046
+ if len(multiplier.shape) == 1:
1047
+ return _multiply_uint8_optimized(img, multiplier)
1048
+
1049
+ return _multiply_uint8(img, multiplier)
1050
+
1051
+ return _multiply_non_uint8(img, multiplier)
1052
+
1053
+
1054
+ def bbox_from_mask(mask):
1055
+ """Create bounding box from binary mask (fast version)
1056
+
1057
+ Args:
1058
+ mask (numpy.ndarray): binary mask.
1059
+
1060
+ Returns:
1061
+ tuple: A bounding box tuple `(x_min, y_min, x_max, y_max)`.
1062
+
1063
+ """
1064
+ rows = np.any(mask, axis=1)
1065
+ if not rows.any():
1066
+ return -1, -1, -1, -1
1067
+ cols = np.any(mask, axis=0)
1068
+ y_min, y_max = np.where(rows)[0][[0, -1]]
1069
+ x_min, x_max = np.where(cols)[0][[0, -1]]
1070
+ return x_min, y_min, x_max + 1, y_max + 1
1071
+
1072
+
1073
+ def mask_from_bbox(img, bbox):
1074
+ """Create binary mask from bounding box
1075
+
1076
+ Args:
1077
+ img (numpy.ndarray): input image
1078
+ bbox: A bounding box tuple `(x_min, y_min, x_max, y_max)`
1079
+
1080
+ Returns:
1081
+ mask (numpy.ndarray): binary mask
1082
+
1083
+ """
1084
+
1085
+ mask = np.zeros(img.shape[:2], dtype=np.uint8)
1086
+ x_min, y_min, x_max, y_max = bbox
1087
+ mask[y_min:y_max, x_min:x_max] = 1
1088
+ return mask
1089
+
1090
+
1091
+ def fancy_pca(img, alpha=0.1):
1092
+ """Perform 'Fancy PCA' augmentation from:
1093
+ http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
1094
+
1095
+ Args:
1096
+ img (numpy.ndarray): numpy array with (h, w, rgb) shape, as ints between 0-255
1097
+ alpha (float): how much to perturb/scale the eigen vecs and vals
1098
+ the paper used std=0.1
1099
+
1100
+ Returns:
1101
+ numpy.ndarray: numpy image-like array as uint8 range(0, 255)
1102
+
1103
+ """
1104
+ if not is_rgb_image(img) or img.dtype != np.uint8:
1105
+ raise TypeError("Image must be RGB image in uint8 format.")
1106
+
1107
+ orig_img = img.astype(float).copy()
1108
+
1109
+ img = img / 255.0 # rescale to 0 to 1 range
1110
+
1111
+ # flatten image to columns of RGB
1112
+ img_rs = img.reshape(-1, 3)
1113
+ # img_rs shape (640000, 3)
1114
+
1115
+ # center mean
1116
+ img_centered = img_rs - np.mean(img_rs, axis=0)
1117
+
1118
+ # paper says 3x3 covariance matrix
1119
+ img_cov = np.cov(img_centered, rowvar=False)
1120
+
1121
+ # eigen values and eigen vectors
1122
+ eig_vals, eig_vecs = np.linalg.eigh(img_cov)
1123
+
1124
+ # sort values and vector
1125
+ sort_perm = eig_vals[::-1].argsort()
1126
+ eig_vals[::-1].sort()
1127
+ eig_vecs = eig_vecs[:, sort_perm]
1128
+
1129
+ # get [p1, p2, p3]
1130
+ m1 = np.column_stack((eig_vecs))
1131
+
1132
+ # get 3x1 matrix of eigen values multiplied by random variable draw from normal
1133
+ # distribution with mean of 0 and standard deviation of 0.1
1134
+ m2 = np.zeros((3, 1))
1135
+ # according to the paper alpha should only be draw once per augmentation (not once per channel)
1136
+ # alpha = np.random.normal(0, alpha_std)
1137
+
1138
+ # broad cast to speed things up
1139
+ m2[:, 0] = alpha * eig_vals[:]
1140
+
1141
+ # this is the vector that we're going to add to each pixel in a moment
1142
+ add_vect = np.matrix(m1) * np.matrix(m2)
1143
+
1144
+ for idx in range(3): # RGB
1145
+ orig_img[..., idx] += add_vect[idx] * 255
1146
+
1147
+ # for image processing it was found that working with float 0.0 to 1.0
1148
+ # was easier than integers between 0-255
1149
+ # orig_img /= 255.0
1150
+ orig_img = np.clip(orig_img, 0.0, 255.0)
1151
+
1152
+ # orig_img *= 255
1153
+ orig_img = orig_img.astype(np.uint8)
1154
+
1155
+ return orig_img
1156
+
1157
+
1158
+ def _adjust_brightness_torchvision_uint8(img, factor):
1159
+ lut = np.arange(0, 256) * factor
1160
+ lut = np.clip(lut, 0, 255).astype(np.uint8)
1161
+ return cv2.LUT(img, lut)
1162
+
1163
+
1164
+ @preserve_shape
1165
+ def adjust_brightness_torchvision(img, factor):
1166
+ if factor == 0:
1167
+ return np.zeros_like(img)
1168
+ elif factor == 1:
1169
+ return img
1170
+
1171
+ if img.dtype == np.uint8:
1172
+ return _adjust_brightness_torchvision_uint8(img, factor)
1173
+
1174
+ return clip(img * factor, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype])
1175
+
1176
+
1177
+ def _adjust_contrast_torchvision_uint8(img, factor, mean):
1178
+ lut = np.arange(0, 256) * factor
1179
+ lut = lut + mean * (1 - factor)
1180
+ lut = clip(lut, img.dtype, 255)
1181
+
1182
+ return cv2.LUT(img, lut)
1183
+
1184
+
1185
+ @preserve_shape
1186
+ def adjust_contrast_torchvision(img, factor):
1187
+ if factor == 1:
1188
+ return img
1189
+
1190
+ if is_grayscale_image(img):
1191
+ mean = img.mean()
1192
+ else:
1193
+ mean = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY).mean()
1194
+
1195
+ if factor == 0:
1196
+ if img.dtype != np.float32:
1197
+ mean = int(mean + 0.5)
1198
+ return np.full_like(img, mean, dtype=img.dtype)
1199
+
1200
+ if img.dtype == np.uint8:
1201
+ return _adjust_contrast_torchvision_uint8(img, factor, mean)
1202
+
1203
+ return clip(
1204
+ img.astype(np.float32) * factor + mean * (1 - factor),
1205
+ img.dtype,
1206
+ MAX_VALUES_BY_DTYPE[img.dtype],
1207
+ )
1208
+
1209
+
1210
+ @preserve_shape
1211
+ def adjust_saturation_torchvision(img, factor, gamma=0):
1212
+ if factor == 1:
1213
+ return img
1214
+
1215
+ if is_grayscale_image(img):
1216
+ gray = img
1217
+ return gray
1218
+ else:
1219
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
1220
+ gray = cv2.cvtColor(gray, cv2.COLOR_GRAY2RGB)
1221
+
1222
+ if factor == 0:
1223
+ return gray
1224
+
1225
+ result = cv2.addWeighted(img, factor, gray, 1 - factor, gamma=gamma)
1226
+ if img.dtype == np.uint8:
1227
+ return result
1228
+
1229
+ # OpenCV does not clip values for float dtype
1230
+ return clip(result, img.dtype, MAX_VALUES_BY_DTYPE[img.dtype])
1231
+
1232
+
1233
+ def _adjust_hue_torchvision_uint8(img, factor):
1234
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
1235
+
1236
+ lut = np.arange(0, 256, dtype=np.int16)
1237
+ lut = np.mod(lut + 180 * factor, 180).astype(np.uint8)
1238
+ img[..., 0] = cv2.LUT(img[..., 0], lut)
1239
+
1240
+ return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
1241
+
1242
+
1243
+ def adjust_hue_torchvision(img, factor):
1244
+ if is_grayscale_image(img):
1245
+ return img
1246
+
1247
+ if factor == 0:
1248
+ return img
1249
+
1250
+ if img.dtype == np.uint8:
1251
+ return _adjust_hue_torchvision_uint8(img, factor)
1252
+
1253
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2HSV)
1254
+ img[..., 0] = np.mod(img[..., 0] + factor * 360, 360)
1255
+ return cv2.cvtColor(img, cv2.COLOR_HSV2RGB)
1256
+
1257
+
1258
+ @preserve_shape
1259
+ def superpixels(
1260
+ image: np.ndarray, n_segments: int, replace_samples: Sequence[bool], max_size: Optional[int], interpolation: int
1261
+ ) -> np.ndarray:
1262
+ if not np.any(replace_samples):
1263
+ return image
1264
+
1265
+ orig_shape = image.shape
1266
+ if max_size is not None:
1267
+ size = max(image.shape[:2])
1268
+ if size > max_size:
1269
+ scale = max_size / size
1270
+ height, width = image.shape[:2]
1271
+ new_height, new_width = int(height * scale), int(width * scale)
1272
+ resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(new_width, new_height), interpolation=interpolation)
1273
+ image = resize_fn(image)
1274
+
1275
+ segments = skimage.segmentation.slic(
1276
+ image, n_segments=n_segments, compactness=10, channel_axis=-1 if image.ndim > 2 else None
1277
+ )
1278
+
1279
+ min_value = 0
1280
+ max_value = MAX_VALUES_BY_DTYPE[image.dtype]
1281
+ image = np.copy(image)
1282
+ if image.ndim == 2:
1283
+ image = image.reshape(*image.shape, 1)
1284
+ nb_channels = image.shape[2]
1285
+ for c in range(nb_channels):
1286
+ # segments+1 here because otherwise regionprops always misses the last label
1287
+ regions = skimage.measure.regionprops(segments + 1, intensity_image=image[..., c])
1288
+ for ridx, region in enumerate(regions):
1289
+ # with mod here, because slic can sometimes create more superpixel than requested.
1290
+ # replace_samples then does not have enough values, so we just start over with the first one again.
1291
+ if replace_samples[ridx % len(replace_samples)]:
1292
+ mean_intensity = region.mean_intensity
1293
+ image_sp_c = image[..., c]
1294
+
1295
+ if image_sp_c.dtype.kind in ["i", "u", "b"]:
1296
+ # After rounding the value can end up slightly outside of the value_range. Hence, we need to clip.
1297
+ # We do clip via min(max(...)) instead of np.clip because
1298
+ # the latter one does not seem to keep dtypes for dtypes with large itemsizes (e.g. uint64).
1299
+ value: Union[int, float]
1300
+ value = int(np.round(mean_intensity))
1301
+ value = min(max(value, min_value), max_value)
1302
+ else:
1303
+ value = mean_intensity
1304
+
1305
+ image_sp_c[segments == ridx] = value
1306
+
1307
+ if orig_shape != image.shape:
1308
+ resize_fn = _maybe_process_in_chunks(
1309
+ cv2.resize, dsize=(orig_shape[1], orig_shape[0]), interpolation=interpolation
1310
+ )
1311
+ image = resize_fn(image)
1312
+
1313
+ return image
1314
+
1315
+
1316
+ @clipped
1317
+ def add_weighted(img1, alpha, img2, beta):
1318
+ return img1.astype(float) * alpha + img2.astype(float) * beta
1319
+
1320
+
1321
+ @clipped
1322
+ @preserve_shape
1323
+ def unsharp_mask(image: np.ndarray, ksize: int, sigma: float = 0.0, alpha: float = 0.2, threshold: int = 10):
1324
+ blur_fn = _maybe_process_in_chunks(cv2.GaussianBlur, ksize=(ksize, ksize), sigmaX=sigma)
1325
+
1326
+ input_dtype = image.dtype
1327
+ if input_dtype == np.uint8:
1328
+ image = to_float(image)
1329
+ elif input_dtype not in (np.uint8, np.float32):
1330
+ raise ValueError("Unexpected dtype {} for UnsharpMask augmentation".format(input_dtype))
1331
+
1332
+ blur = blur_fn(image)
1333
+ residual = image - blur
1334
+
1335
+ # Do not sharpen noise
1336
+ mask = np.abs(residual) * 255 > threshold
1337
+ mask = mask.astype("float32")
1338
+
1339
+ sharp = image + alpha * residual
1340
+ # Avoid color noise artefacts.
1341
+ sharp = np.clip(sharp, 0, 1)
1342
+
1343
+ soft_mask = blur_fn(mask)
1344
+ output = soft_mask * sharp + (1 - soft_mask) * image
1345
+ return from_float(output, dtype=input_dtype)
1346
+
1347
+
1348
+ @preserve_shape
1349
+ def pixel_dropout(image: np.ndarray, drop_mask: np.ndarray, drop_value: Union[float, Sequence[float]]) -> np.ndarray:
1350
+ if isinstance(drop_value, (int, float)) and drop_value == 0:
1351
+ drop_values = np.zeros_like(image)
1352
+ else:
1353
+ drop_values = np.full_like(image, drop_value) # type: ignore
1354
+ return np.where(drop_mask, drop_values, image)
1355
+
1356
+
1357
+ @clipped
1358
+ @preserve_shape
1359
+ def spatter(
1360
+ img: np.ndarray,
1361
+ non_mud: Optional[np.ndarray],
1362
+ mud: Optional[np.ndarray],
1363
+ rain: Optional[np.ndarray],
1364
+ mode: str,
1365
+ ) -> np.ndarray:
1366
+ non_rgb_warning(img)
1367
+
1368
+ coef = MAX_VALUES_BY_DTYPE[img.dtype]
1369
+ img = img.astype(np.float32) * (1 / coef)
1370
+
1371
+ if mode == "rain":
1372
+ assert rain is not None
1373
+ img = img + rain
1374
+ elif mode == "mud":
1375
+ assert non_mud is not None and mud is not None
1376
+ img = img * non_mud + mud
1377
+ else:
1378
+ raise ValueError("Unsupported spatter mode: " + str(mode))
1379
+
1380
+ return img * 255
custom_albumentations/augmentations/geometric/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .functional import *
2
+ from .resize import *
3
+ from .rotate import *
4
+ from .transforms import *
custom_albumentations/augmentations/geometric/functional.py ADDED
@@ -0,0 +1,1300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Sequence, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ import skimage.transform
7
+ from scipy.ndimage import gaussian_filter
8
+
9
+ from custom_albumentations.augmentations.utils import (
10
+ _maybe_process_in_chunks,
11
+ angle_2pi_range,
12
+ clipped,
13
+ preserve_channel_dim,
14
+ preserve_shape,
15
+ )
16
+
17
+ from ... import random_utils
18
+ from ...core.bbox_utils import denormalize_bbox, normalize_bbox
19
+ from ...core.transforms_interface import (
20
+ BoxInternalType,
21
+ FillValueType,
22
+ ImageColorType,
23
+ KeypointInternalType,
24
+ )
25
+
26
+ __all__ = [
27
+ "optical_distortion",
28
+ "elastic_transform_approx",
29
+ "grid_distortion",
30
+ "pad",
31
+ "pad_with_params",
32
+ "bbox_rot90",
33
+ "keypoint_rot90",
34
+ "rotate",
35
+ "bbox_rotate",
36
+ "keypoint_rotate",
37
+ "shift_scale_rotate",
38
+ "keypoint_shift_scale_rotate",
39
+ "bbox_shift_scale_rotate",
40
+ "elastic_transform",
41
+ "resize",
42
+ "scale",
43
+ "keypoint_scale",
44
+ "py3round",
45
+ "_func_max_size",
46
+ "longest_max_size",
47
+ "smallest_max_size",
48
+ "perspective",
49
+ "perspective_bbox",
50
+ "rotation2DMatrixToEulerAngles",
51
+ "perspective_keypoint",
52
+ "_is_identity_matrix",
53
+ "warp_affine",
54
+ "keypoint_affine",
55
+ "bbox_affine",
56
+ "safe_rotate",
57
+ "bbox_safe_rotate",
58
+ "keypoint_safe_rotate",
59
+ "piecewise_affine",
60
+ "to_distance_maps",
61
+ "from_distance_maps",
62
+ "keypoint_piecewise_affine",
63
+ "bbox_piecewise_affine",
64
+ "bbox_flip",
65
+ "bbox_hflip",
66
+ "bbox_transpose",
67
+ "bbox_vflip",
68
+ "hflip",
69
+ "hflip_cv2",
70
+ "transpose",
71
+ "keypoint_flip",
72
+ "keypoint_hflip",
73
+ "keypoint_transpose",
74
+ "keypoint_vflip",
75
+ ]
76
+
77
+
78
+ def bbox_rot90(bbox: BoxInternalType, factor: int, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
79
+ """Rotates a bounding box by 90 degrees CCW (see np.rot90)
80
+
81
+ Args:
82
+ bbox: A bounding box tuple (x_min, y_min, x_max, y_max).
83
+ factor: Number of CCW rotations. Must be in set {0, 1, 2, 3} See np.rot90.
84
+ rows: Image rows.
85
+ cols: Image cols.
86
+
87
+ Returns:
88
+ tuple: A bounding box tuple (x_min, y_min, x_max, y_max).
89
+
90
+ """
91
+ if factor not in {0, 1, 2, 3}:
92
+ raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
93
+ x_min, y_min, x_max, y_max = bbox[:4]
94
+ if factor == 1:
95
+ bbox = y_min, 1 - x_max, y_max, 1 - x_min
96
+ elif factor == 2:
97
+ bbox = 1 - x_max, 1 - y_max, 1 - x_min, 1 - y_min
98
+ elif factor == 3:
99
+ bbox = 1 - y_max, x_min, 1 - y_min, x_max
100
+ return bbox
101
+
102
+
103
+ @angle_2pi_range
104
+ def keypoint_rot90(keypoint: KeypointInternalType, factor: int, rows: int, cols: int, **params) -> KeypointInternalType:
105
+ """Rotates a keypoint by 90 degrees CCW (see np.rot90)
106
+
107
+ Args:
108
+ keypoint: A keypoint `(x, y, angle, scale)`.
109
+ factor: Number of CCW rotations. Must be in range [0;3] See np.rot90.
110
+ rows: Image height.
111
+ cols: Image width.
112
+
113
+ Returns:
114
+ tuple: A keypoint `(x, y, angle, scale)`.
115
+
116
+ Raises:
117
+ ValueError: if factor not in set {0, 1, 2, 3}
118
+
119
+ """
120
+ x, y, angle, scale = keypoint[:4]
121
+
122
+ if factor not in {0, 1, 2, 3}:
123
+ raise ValueError("Parameter n must be in set {0, 1, 2, 3}")
124
+
125
+ if factor == 1:
126
+ x, y, angle = y, (cols - 1) - x, angle - math.pi / 2
127
+ elif factor == 2:
128
+ x, y, angle = (cols - 1) - x, (rows - 1) - y, angle - math.pi
129
+ elif factor == 3:
130
+ x, y, angle = (rows - 1) - y, x, angle + math.pi / 2
131
+
132
+ return x, y, angle, scale
133
+
134
+
135
+ @preserve_channel_dim
136
+ def rotate(
137
+ img: np.ndarray,
138
+ angle: float,
139
+ interpolation: int = cv2.INTER_LINEAR,
140
+ border_mode: int = cv2.BORDER_REFLECT_101,
141
+ value: Optional[ImageColorType] = None,
142
+ ):
143
+ height, width = img.shape[:2]
144
+ # for images we use additional shifts of (0.5, 0.5) as otherwise
145
+ # we get an ugly black border for 90deg rotations
146
+ matrix = cv2.getRotationMatrix2D((width / 2 - 0.5, height / 2 - 0.5), angle, 1.0)
147
+
148
+ warp_fn = _maybe_process_in_chunks(
149
+ cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
150
+ )
151
+ return warp_fn(img)
152
+
153
+
154
+ def bbox_rotate(bbox: BoxInternalType, angle: float, method: str, rows: int, cols: int) -> BoxInternalType:
155
+ """Rotates a bounding box by angle degrees.
156
+
157
+ Args:
158
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
159
+ angle: Angle of rotation in degrees.
160
+ method: Rotation method used. Should be one of: "largest_box", "ellipse". Default: "largest_box".
161
+ rows: Image rows.
162
+ cols: Image cols.
163
+
164
+ Returns:
165
+ A bounding box `(x_min, y_min, x_max, y_max)`.
166
+
167
+ References:
168
+ https://arxiv.org/abs/2109.13488
169
+
170
+ """
171
+ x_min, y_min, x_max, y_max = bbox[:4]
172
+ scale = cols / float(rows)
173
+ if method == "largest_box":
174
+ x = np.array([x_min, x_max, x_max, x_min]) - 0.5
175
+ y = np.array([y_min, y_min, y_max, y_max]) - 0.5
176
+ elif method == "ellipse":
177
+ w = (x_max - x_min) / 2
178
+ h = (y_max - y_min) / 2
179
+ data = np.arange(0, 360, dtype=np.float32)
180
+ x = w * np.sin(np.radians(data)) + (w + x_min - 0.5)
181
+ y = h * np.cos(np.radians(data)) + (h + y_min - 0.5)
182
+ else:
183
+ raise ValueError(f"Method {method} is not a valid rotation method.")
184
+ angle = np.deg2rad(angle)
185
+ x_t = (np.cos(angle) * x * scale + np.sin(angle) * y) / scale
186
+ y_t = -np.sin(angle) * x * scale + np.cos(angle) * y
187
+ x_t = x_t + 0.5
188
+ y_t = y_t + 0.5
189
+
190
+ x_min, x_max = min(x_t), max(x_t)
191
+ y_min, y_max = min(y_t), max(y_t)
192
+
193
+ return x_min, y_min, x_max, y_max
194
+
195
+
196
+ @angle_2pi_range
197
+ def keypoint_rotate(keypoint, angle, rows, cols, **params):
198
+ """Rotate a keypoint by angle.
199
+
200
+ Args:
201
+ keypoint (tuple): A keypoint `(x, y, angle, scale)`.
202
+ angle (float): Rotation angle.
203
+ rows (int): Image height.
204
+ cols (int): Image width.
205
+
206
+ Returns:
207
+ tuple: A keypoint `(x, y, angle, scale)`.
208
+
209
+ """
210
+ center = (cols - 1) * 0.5, (rows - 1) * 0.5
211
+ matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
212
+ x, y, a, s = keypoint[:4]
213
+ x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
214
+ return x, y, a + math.radians(angle), s
215
+
216
+
217
+ @preserve_channel_dim
218
+ def shift_scale_rotate(
219
+ img, angle, scale, dx, dy, interpolation=cv2.INTER_LINEAR, border_mode=cv2.BORDER_REFLECT_101, value=None
220
+ ):
221
+ height, width = img.shape[:2]
222
+ # for images we use additional shifts of (0.5, 0.5) as otherwise
223
+ # we get an ugly black border for 90deg rotations
224
+ center = (width / 2 - 0.5, height / 2 - 0.5)
225
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
226
+ matrix[0, 2] += dx * width
227
+ matrix[1, 2] += dy * height
228
+
229
+ warp_affine_fn = _maybe_process_in_chunks(
230
+ cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
231
+ )
232
+ return warp_affine_fn(img)
233
+
234
+
235
+ @angle_2pi_range
236
+ def keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols, **params):
237
+ (
238
+ x,
239
+ y,
240
+ a,
241
+ s,
242
+ ) = keypoint[:4]
243
+ height, width = rows, cols
244
+ center = (cols - 1) * 0.5, (rows - 1) * 0.5
245
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
246
+ matrix[0, 2] += dx * width
247
+ matrix[1, 2] += dy * height
248
+
249
+ x, y = cv2.transform(np.array([[[x, y]]]), matrix).squeeze()
250
+ angle = a + math.radians(angle)
251
+ scale = s * scale
252
+
253
+ return x, y, angle, scale
254
+
255
+
256
+ def bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, rotate_method, rows, cols, **kwargs): # skipcq: PYL-W0613
257
+ """Rotates, shifts and scales a bounding box. Rotation is made by angle degrees,
258
+ scaling is made by scale factor and shifting is made by dx and dy.
259
+
260
+
261
+ Args:
262
+ bbox (tuple): A bounding box `(x_min, y_min, x_max, y_max)`.
263
+ angle (int): Angle of rotation in degrees.
264
+ scale (int): Scale factor.
265
+ dx (int): Shift along x-axis in pixel units.
266
+ dy (int): Shift along y-axis in pixel units.
267
+ rotate_method(str): Rotation method used. Should be one of: "largest_box", "ellipse".
268
+ Default: "largest_box".
269
+ rows (int): Image rows.
270
+ cols (int): Image cols.
271
+
272
+ Returns:
273
+ A bounding box `(x_min, y_min, x_max, y_max)`.
274
+
275
+ """
276
+ height, width = rows, cols
277
+ center = (width / 2, height / 2)
278
+ if rotate_method == "ellipse":
279
+ x_min, y_min, x_max, y_max = bbox_rotate(bbox, angle, rotate_method, rows, cols)
280
+ matrix = cv2.getRotationMatrix2D(center, 0, scale)
281
+ else:
282
+ x_min, y_min, x_max, y_max = bbox[:4]
283
+ matrix = cv2.getRotationMatrix2D(center, angle, scale)
284
+ matrix[0, 2] += dx * width
285
+ matrix[1, 2] += dy * height
286
+ x = np.array([x_min, x_max, x_max, x_min])
287
+ y = np.array([y_min, y_min, y_max, y_max])
288
+ ones = np.ones(shape=(len(x)))
289
+ points_ones = np.vstack([x, y, ones]).transpose()
290
+ points_ones[:, 0] *= width
291
+ points_ones[:, 1] *= height
292
+ tr_points = matrix.dot(points_ones.T).T
293
+ tr_points[:, 0] /= width
294
+ tr_points[:, 1] /= height
295
+
296
+ x_min, x_max = min(tr_points[:, 0]), max(tr_points[:, 0])
297
+ y_min, y_max = min(tr_points[:, 1]), max(tr_points[:, 1])
298
+
299
+ return x_min, y_min, x_max, y_max
300
+
301
+
302
+ @preserve_shape
303
+ def elastic_transform(
304
+ img: np.ndarray,
305
+ alpha: float,
306
+ sigma: float,
307
+ alpha_affine: float,
308
+ interpolation: int = cv2.INTER_LINEAR,
309
+ border_mode: int = cv2.BORDER_REFLECT_101,
310
+ value: Optional[ImageColorType] = None,
311
+ random_state: Optional[np.random.RandomState] = None,
312
+ approximate: bool = False,
313
+ same_dxdy: bool = False,
314
+ ):
315
+ """Elastic deformation of images as described in [Simard2003]_ (with modifications).
316
+ Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
317
+
318
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
319
+ Convolutional Neural Networks applied to Visual Document Analysis", in
320
+ Proc. of the International Conference on Document Analysis and
321
+ Recognition, 2003.
322
+ """
323
+ height, width = img.shape[:2]
324
+
325
+ # Random affine
326
+ center_square = np.array((height, width), dtype=np.float32) // 2
327
+ square_size = min((height, width)) // 3
328
+ alpha = float(alpha)
329
+ sigma = float(sigma)
330
+ alpha_affine = float(alpha_affine)
331
+
332
+ pts1 = np.array(
333
+ [
334
+ center_square + square_size,
335
+ [center_square[0] + square_size, center_square[1] - square_size],
336
+ center_square - square_size,
337
+ ],
338
+ dtype=np.float32,
339
+ )
340
+ pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype(
341
+ np.float32
342
+ )
343
+ matrix = cv2.getAffineTransform(pts1, pts2)
344
+
345
+ warp_fn = _maybe_process_in_chunks(
346
+ cv2.warpAffine, M=matrix, dsize=(width, height), flags=interpolation, borderMode=border_mode, borderValue=value
347
+ )
348
+ img = warp_fn(img)
349
+
350
+ if approximate:
351
+ # Approximate computation smooth displacement map with a large enough kernel.
352
+ # On large images (512+) this is approximately 2X times faster
353
+ dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
354
+ cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
355
+ dx *= alpha
356
+ if same_dxdy:
357
+ # Speed up even more
358
+ dy = dx
359
+ else:
360
+ dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
361
+ cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
362
+ dy *= alpha
363
+ else:
364
+ dx = np.float32(
365
+ gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha
366
+ )
367
+ if same_dxdy:
368
+ # Speed up
369
+ dy = dx
370
+ else:
371
+ dy = np.float32(
372
+ gaussian_filter((random_utils.rand(height, width, random_state=random_state) * 2 - 1), sigma) * alpha
373
+ )
374
+
375
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
376
+
377
+ map_x = np.float32(x + dx)
378
+ map_y = np.float32(y + dy)
379
+
380
+ remap_fn = _maybe_process_in_chunks(
381
+ cv2.remap, map1=map_x, map2=map_y, interpolation=interpolation, borderMode=border_mode, borderValue=value
382
+ )
383
+ return remap_fn(img)
384
+
385
+
386
+ @preserve_channel_dim
387
+ def resize(img, height, width, interpolation=cv2.INTER_LINEAR):
388
+ img_height, img_width = img.shape[:2]
389
+ if height == img_height and width == img_width:
390
+ return img
391
+ resize_fn = _maybe_process_in_chunks(cv2.resize, dsize=(width, height), interpolation=interpolation)
392
+ return resize_fn(img)
393
+
394
+
395
+ @preserve_channel_dim
396
+ def scale(img: np.ndarray, scale: float, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray:
397
+ height, width = img.shape[:2]
398
+ new_height, new_width = int(height * scale), int(width * scale)
399
+ return resize(img, new_height, new_width, interpolation)
400
+
401
+
402
+ def keypoint_scale(keypoint: KeypointInternalType, scale_x: float, scale_y: float) -> KeypointInternalType:
403
+ """Scales a keypoint by scale_x and scale_y.
404
+
405
+ Args:
406
+ keypoint: A keypoint `(x, y, angle, scale)`.
407
+ scale_x: Scale coefficient x-axis.
408
+ scale_y: Scale coefficient y-axis.
409
+
410
+ Returns:
411
+ A keypoint `(x, y, angle, scale)`.
412
+
413
+ """
414
+ x, y, angle, scale = keypoint[:4]
415
+ return x * scale_x, y * scale_y, angle, scale * max(scale_x, scale_y)
416
+
417
+
418
+ def py3round(number):
419
+ """Unified rounding in all python versions."""
420
+ if abs(round(number) - number) == 0.5:
421
+ return int(2.0 * round(number / 2.0))
422
+
423
+ return int(round(number))
424
+
425
+
426
+ def _func_max_size(img, max_size, interpolation, func):
427
+ height, width = img.shape[:2]
428
+
429
+ scale = max_size / float(func(width, height))
430
+
431
+ if scale != 1.0:
432
+ new_height, new_width = tuple(py3round(dim * scale) for dim in (height, width))
433
+ img = resize(img, height=new_height, width=new_width, interpolation=interpolation)
434
+ return img
435
+
436
+
437
+ @preserve_channel_dim
438
+ def longest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray:
439
+ return _func_max_size(img, max_size, interpolation, max)
440
+
441
+
442
+ @preserve_channel_dim
443
+ def smallest_max_size(img: np.ndarray, max_size: int, interpolation: int) -> np.ndarray:
444
+ return _func_max_size(img, max_size, interpolation, min)
445
+
446
+
447
+ @preserve_channel_dim
448
+ def perspective(
449
+ img: np.ndarray,
450
+ matrix: np.ndarray,
451
+ max_width: int,
452
+ max_height: int,
453
+ border_val: Union[int, float, List[int], List[float], np.ndarray],
454
+ border_mode: int,
455
+ keep_size: bool,
456
+ interpolation: int,
457
+ ):
458
+ h, w = img.shape[:2]
459
+ perspective_func = _maybe_process_in_chunks(
460
+ cv2.warpPerspective,
461
+ M=matrix,
462
+ dsize=(max_width, max_height),
463
+ borderMode=border_mode,
464
+ borderValue=border_val,
465
+ flags=interpolation,
466
+ )
467
+ warped = perspective_func(img)
468
+
469
+ if keep_size:
470
+ return resize(warped, h, w, interpolation=interpolation)
471
+
472
+ return warped
473
+
474
+
475
+ def perspective_bbox(
476
+ bbox: BoxInternalType,
477
+ height: int,
478
+ width: int,
479
+ matrix: np.ndarray,
480
+ max_width: int,
481
+ max_height: int,
482
+ keep_size: bool,
483
+ ) -> BoxInternalType:
484
+ x1, y1, x2, y2 = denormalize_bbox(bbox, height, width)[:4]
485
+
486
+ points = np.array([[x1, y1], [x2, y1], [x2, y2], [x1, y2]], dtype=np.float32)
487
+
488
+ x1, y1, x2, y2 = float("inf"), float("inf"), 0, 0
489
+ for pt in points:
490
+ pt = perspective_keypoint(pt.tolist() + [0, 0], height, width, matrix, max_width, max_height, keep_size)
491
+ x, y = pt[:2]
492
+ x1 = min(x1, x)
493
+ x2 = max(x2, x)
494
+ y1 = min(y1, y)
495
+ y2 = max(y2, y)
496
+
497
+ return normalize_bbox((x1, y1, x2, y2), height if keep_size else max_height, width if keep_size else max_width)
498
+
499
+
500
+ def rotation2DMatrixToEulerAngles(matrix: np.ndarray, y_up: bool = False) -> float:
501
+ """
502
+ Args:
503
+ matrix (np.ndarray): Rotation matrix
504
+ y_up (bool): is Y axis looks up or down
505
+ """
506
+ if y_up:
507
+ return np.arctan2(matrix[1, 0], matrix[0, 0])
508
+ return np.arctan2(-matrix[1, 0], matrix[0, 0])
509
+
510
+
511
+ @angle_2pi_range
512
+ def perspective_keypoint(
513
+ keypoint: KeypointInternalType,
514
+ height: int,
515
+ width: int,
516
+ matrix: np.ndarray,
517
+ max_width: int,
518
+ max_height: int,
519
+ keep_size: bool,
520
+ ) -> KeypointInternalType:
521
+ x, y, angle, scale = keypoint
522
+
523
+ keypoint_vector = np.array([x, y], dtype=np.float32).reshape([1, 1, 2])
524
+
525
+ x, y = cv2.perspectiveTransform(keypoint_vector, matrix)[0, 0]
526
+ angle += rotation2DMatrixToEulerAngles(matrix[:2, :2], y_up=True)
527
+
528
+ scale_x = np.sign(matrix[0, 0]) * np.sqrt(matrix[0, 0] ** 2 + matrix[0, 1] ** 2)
529
+ scale_y = np.sign(matrix[1, 1]) * np.sqrt(matrix[1, 0] ** 2 + matrix[1, 1] ** 2)
530
+ scale *= max(scale_x, scale_y)
531
+
532
+ if keep_size:
533
+ scale_x = width / max_width
534
+ scale_y = height / max_height
535
+ return keypoint_scale((x, y, angle, scale), scale_x, scale_y)
536
+
537
+ return x, y, angle, scale
538
+
539
+
540
+ def _is_identity_matrix(matrix: skimage.transform.ProjectiveTransform) -> bool:
541
+ return np.allclose(matrix.params, np.eye(3, dtype=np.float32))
542
+
543
+
544
+ @preserve_channel_dim
545
+ def warp_affine(
546
+ image: np.ndarray,
547
+ matrix: skimage.transform.ProjectiveTransform,
548
+ interpolation: int,
549
+ cval: Union[int, float, Sequence[int], Sequence[float]],
550
+ mode: int,
551
+ output_shape: Sequence[int],
552
+ ) -> np.ndarray:
553
+ if _is_identity_matrix(matrix):
554
+ return image
555
+
556
+ dsize = int(np.round(output_shape[1])), int(np.round(output_shape[0]))
557
+ warp_fn = _maybe_process_in_chunks(
558
+ cv2.warpAffine, M=matrix.params[:2], dsize=dsize, flags=interpolation, borderMode=mode, borderValue=cval
559
+ )
560
+ tmp = warp_fn(image)
561
+ return tmp
562
+
563
+
564
+ @angle_2pi_range
565
+ def keypoint_affine(
566
+ keypoint: KeypointInternalType,
567
+ matrix: skimage.transform.ProjectiveTransform,
568
+ scale: dict,
569
+ ) -> KeypointInternalType:
570
+ if _is_identity_matrix(matrix):
571
+ return keypoint
572
+
573
+ x, y, a, s = keypoint[:4]
574
+ x, y = cv2.transform(np.array([[[x, y]]]), matrix.params[:2]).squeeze()
575
+ a += rotation2DMatrixToEulerAngles(matrix.params[:2])
576
+ s *= np.max([scale["x"], scale["y"]])
577
+ return x, y, a, s
578
+
579
+
580
+ def bbox_affine(
581
+ bbox: BoxInternalType,
582
+ matrix: skimage.transform.ProjectiveTransform,
583
+ rotate_method: str,
584
+ rows: int,
585
+ cols: int,
586
+ output_shape: Sequence[int],
587
+ ) -> BoxInternalType:
588
+ if _is_identity_matrix(matrix):
589
+ return bbox
590
+ x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
591
+ if rotate_method == "largest_box":
592
+ points = np.array(
593
+ [
594
+ [x_min, y_min],
595
+ [x_max, y_min],
596
+ [x_max, y_max],
597
+ [x_min, y_max],
598
+ ]
599
+ )
600
+ elif rotate_method == "ellipse":
601
+ w = (x_max - x_min) / 2
602
+ h = (y_max - y_min) / 2
603
+ data = np.arange(0, 360, dtype=np.float32)
604
+ x = w * np.sin(np.radians(data)) + (w + x_min - 0.5)
605
+ y = h * np.cos(np.radians(data)) + (h + y_min - 0.5)
606
+ points = np.hstack([x.reshape(-1, 1), y.reshape(-1, 1)])
607
+ else:
608
+ raise ValueError(f"Method {rotate_method} is not a valid rotation method.")
609
+ points = skimage.transform.matrix_transform(points, matrix.params)
610
+ x_min = np.min(points[:, 0])
611
+ x_max = np.max(points[:, 0])
612
+ y_min = np.min(points[:, 1])
613
+ y_max = np.max(points[:, 1])
614
+
615
+ return normalize_bbox((x_min, y_min, x_max, y_max), output_shape[0], output_shape[1])
616
+
617
+
618
+ @preserve_channel_dim
619
+ def safe_rotate(
620
+ img: np.ndarray,
621
+ matrix: np.ndarray,
622
+ interpolation: int,
623
+ value: FillValueType = None,
624
+ border_mode: int = cv2.BORDER_REFLECT_101,
625
+ ) -> np.ndarray:
626
+ h, w = img.shape[:2]
627
+ warp_fn = _maybe_process_in_chunks(
628
+ cv2.warpAffine,
629
+ M=matrix,
630
+ dsize=(w, h),
631
+ flags=interpolation,
632
+ borderMode=border_mode,
633
+ borderValue=value,
634
+ )
635
+ return warp_fn(img)
636
+
637
+
638
+ def bbox_safe_rotate(bbox: BoxInternalType, matrix: np.ndarray, cols: int, rows: int) -> BoxInternalType:
639
+ x1, y1, x2, y2 = denormalize_bbox(bbox, rows, cols)[:4]
640
+ points = np.array(
641
+ [
642
+ [x1, y1, 1],
643
+ [x2, y1, 1],
644
+ [x2, y2, 1],
645
+ [x1, y2, 1],
646
+ ]
647
+ )
648
+ points = points @ matrix.T
649
+ x1 = points[:, 0].min()
650
+ x2 = points[:, 0].max()
651
+ y1 = points[:, 1].min()
652
+ y2 = points[:, 1].max()
653
+
654
+ def fix_point(pt1: float, pt2: float, max_val: float) -> Tuple[float, float]:
655
+ # In my opinion, these errors should be very low, around 1-2 pixels.
656
+ if pt1 < 0:
657
+ return 0, pt2 + pt1
658
+ if pt2 > max_val:
659
+ return pt1 - (pt2 - max_val), max_val
660
+ return pt1, pt2
661
+
662
+ x1, x2 = fix_point(x1, x2, cols)
663
+ y1, y2 = fix_point(y1, y2, rows)
664
+
665
+ return normalize_bbox((x1, y1, x2, y2), rows, cols)
666
+
667
+
668
+ def keypoint_safe_rotate(
669
+ keypoint: KeypointInternalType,
670
+ matrix: np.ndarray,
671
+ angle: float,
672
+ scale_x: float,
673
+ scale_y: float,
674
+ cols: int,
675
+ rows: int,
676
+ ) -> KeypointInternalType:
677
+ x, y, a, s = keypoint[:4]
678
+ point = np.array([[x, y, 1]])
679
+ x, y = (point @ matrix.T)[0]
680
+
681
+ # To avoid problems with float errors
682
+ x = np.clip(x, 0, cols - 1)
683
+ y = np.clip(y, 0, rows - 1)
684
+
685
+ a += angle
686
+ s *= max(scale_x, scale_y)
687
+ return x, y, a, s
688
+
689
+
690
+ @clipped
691
+ def piecewise_affine(
692
+ img: np.ndarray,
693
+ matrix: Optional[skimage.transform.PiecewiseAffineTransform],
694
+ interpolation: int,
695
+ mode: str,
696
+ cval: float,
697
+ ) -> np.ndarray:
698
+ if matrix is None:
699
+ return img
700
+ return skimage.transform.warp(
701
+ img, matrix, order=interpolation, mode=mode, cval=cval, preserve_range=True, output_shape=img.shape
702
+ )
703
+
704
+
705
+ def to_distance_maps(
706
+ keypoints: Sequence[Tuple[float, float]], height: int, width: int, inverted: bool = False
707
+ ) -> np.ndarray:
708
+ """Generate a ``(H,W,N)`` array of distance maps for ``N`` keypoints.
709
+
710
+ The ``n``-th distance map contains at every location ``(y, x)`` the
711
+ euclidean distance to the ``n``-th keypoint.
712
+
713
+ This function can be used as a helper when augmenting keypoints with a
714
+ method that only supports the augmentation of images.
715
+
716
+ Args:
717
+ keypoint: keypoint coordinates
718
+ height: image height
719
+ width: image width
720
+ inverted (bool): If ``True``, inverted distance maps are returned where each
721
+ distance value d is replaced by ``d/(d+1)``, i.e. the distance
722
+ maps have values in the range ``(0.0, 1.0]`` with ``1.0`` denoting
723
+ exactly the position of the respective keypoint.
724
+
725
+ Returns:
726
+ (H, W, N) ndarray
727
+ A ``float32`` array containing ``N`` distance maps for ``N``
728
+ keypoints. Each location ``(y, x, n)`` in the array denotes the
729
+ euclidean distance at ``(y, x)`` to the ``n``-th keypoint.
730
+ If `inverted` is ``True``, the distance ``d`` is replaced
731
+ by ``d/(d+1)``. The height and width of the array match the
732
+ height and width in ``KeypointsOnImage.shape``.
733
+ """
734
+ distance_maps = np.zeros((height, width, len(keypoints)), dtype=np.float32)
735
+
736
+ yy = np.arange(0, height)
737
+ xx = np.arange(0, width)
738
+ grid_xx, grid_yy = np.meshgrid(xx, yy)
739
+
740
+ for i, (x, y) in enumerate(keypoints):
741
+ distance_maps[:, :, i] = (grid_xx - x) ** 2 + (grid_yy - y) ** 2
742
+
743
+ distance_maps = np.sqrt(distance_maps)
744
+ if inverted:
745
+ return 1 / (distance_maps + 1)
746
+ return distance_maps
747
+
748
+
749
+ def from_distance_maps(
750
+ distance_maps: np.ndarray,
751
+ inverted: bool,
752
+ if_not_found_coords: Optional[Union[Sequence[int], dict]],
753
+ threshold: Optional[float] = None,
754
+ ) -> List[Tuple[float, float]]:
755
+ """Convert outputs of ``to_distance_maps()`` to ``KeypointsOnImage``.
756
+ This is the inverse of `to_distance_maps`.
757
+
758
+ Args:
759
+ distance_maps (np.ndarray): The distance maps. ``N`` is the number of keypoints.
760
+ inverted (bool): Whether the given distance maps were generated in inverted mode
761
+ (i.e. :func:`KeypointsOnImage.to_distance_maps` was called with ``inverted=True``) or in non-inverted mode.
762
+ if_not_found_coords (tuple, list, dict or None, optional):
763
+ Coordinates to use for keypoints that cannot be found in `distance_maps`.
764
+
765
+ * If this is a ``list``/``tuple``, it must contain two ``int`` values.
766
+ * If it is a ``dict``, it must contain the keys ``x`` and ``y`` with each containing one ``int`` value.
767
+ * If this is ``None``, then the keypoint will not be added.
768
+ threshold (float): The search for keypoints works by searching for the
769
+ argmin (non-inverted) or argmax (inverted) in each channel. This
770
+ parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit
771
+ as a keypoint. Use ``None`` to use no min/max.
772
+ nb_channels (None, int): Number of channels of the image on which the keypoints are placed.
773
+ Some keypoint augmenters require that information. If set to ``None``, the keypoint's shape will be set
774
+ to ``(height, width)``, otherwise ``(height, width, nb_channels)``.
775
+ """
776
+ if distance_maps.ndim != 3:
777
+ raise ValueError(
778
+ f"Expected three-dimensional input, "
779
+ f"got {distance_maps.ndim} dimensions and shape {distance_maps.shape}."
780
+ )
781
+ height, width, nb_keypoints = distance_maps.shape
782
+
783
+ drop_if_not_found = False
784
+ if if_not_found_coords is None:
785
+ drop_if_not_found = True
786
+ if_not_found_x = -1
787
+ if_not_found_y = -1
788
+ elif isinstance(if_not_found_coords, (tuple, list)):
789
+ if len(if_not_found_coords) != 2:
790
+ raise ValueError(
791
+ f"Expected tuple/list 'if_not_found_coords' to contain exactly two entries, "
792
+ f"got {len(if_not_found_coords)}."
793
+ )
794
+ if_not_found_x = if_not_found_coords[0]
795
+ if_not_found_y = if_not_found_coords[1]
796
+ elif isinstance(if_not_found_coords, dict):
797
+ if_not_found_x = if_not_found_coords["x"]
798
+ if_not_found_y = if_not_found_coords["y"]
799
+ else:
800
+ raise ValueError(
801
+ f"Expected if_not_found_coords to be None or tuple or list or dict, got {type(if_not_found_coords)}."
802
+ )
803
+
804
+ keypoints = []
805
+ for i in range(nb_keypoints):
806
+ if inverted:
807
+ hitidx_flat = np.argmax(distance_maps[..., i])
808
+ else:
809
+ hitidx_flat = np.argmin(distance_maps[..., i])
810
+ hitidx_ndim = np.unravel_index(hitidx_flat, (height, width))
811
+ if not inverted and threshold is not None:
812
+ found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] < threshold
813
+ elif inverted and threshold is not None:
814
+ found = distance_maps[hitidx_ndim[0], hitidx_ndim[1], i] >= threshold
815
+ else:
816
+ found = True
817
+ if found:
818
+ keypoints.append((float(hitidx_ndim[1]), float(hitidx_ndim[0])))
819
+ else:
820
+ if not drop_if_not_found:
821
+ keypoints.append((if_not_found_x, if_not_found_y))
822
+
823
+ return keypoints
824
+
825
+
826
+ def keypoint_piecewise_affine(
827
+ keypoint: KeypointInternalType,
828
+ matrix: Optional[skimage.transform.PiecewiseAffineTransform],
829
+ h: int,
830
+ w: int,
831
+ keypoints_threshold: float,
832
+ ) -> KeypointInternalType:
833
+ if matrix is None:
834
+ return keypoint
835
+ x, y, a, s = keypoint[:4]
836
+ dist_maps = to_distance_maps([(x, y)], h, w, True)
837
+ dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0)
838
+ x, y = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)[0]
839
+ return x, y, a, s
840
+
841
+
842
+ def bbox_piecewise_affine(
843
+ bbox: BoxInternalType,
844
+ matrix: Optional[skimage.transform.PiecewiseAffineTransform],
845
+ h: int,
846
+ w: int,
847
+ keypoints_threshold: float,
848
+ ) -> BoxInternalType:
849
+ if matrix is None:
850
+ return bbox
851
+ x1, y1, x2, y2 = denormalize_bbox(bbox, h, w)[:4]
852
+ keypoints = [
853
+ (x1, y1),
854
+ (x2, y1),
855
+ (x2, y2),
856
+ (x1, y2),
857
+ ]
858
+ dist_maps = to_distance_maps(keypoints, h, w, True)
859
+ dist_maps = piecewise_affine(dist_maps, matrix, 0, "constant", 0)
860
+ keypoints = from_distance_maps(dist_maps, True, {"x": -1, "y": -1}, keypoints_threshold)
861
+ keypoints = [i for i in keypoints if 0 <= i[0] < w and 0 <= i[1] < h]
862
+ keypoints_arr = np.array(keypoints)
863
+ x1 = keypoints_arr[:, 0].min()
864
+ y1 = keypoints_arr[:, 1].min()
865
+ x2 = keypoints_arr[:, 0].max()
866
+ y2 = keypoints_arr[:, 1].max()
867
+ return normalize_bbox((x1, y1, x2, y2), h, w)
868
+
869
+
870
+ def vflip(img: np.ndarray) -> np.ndarray:
871
+ return np.ascontiguousarray(img[::-1, ...])
872
+
873
+
874
+ def hflip(img: np.ndarray) -> np.ndarray:
875
+ return np.ascontiguousarray(img[:, ::-1, ...])
876
+
877
+
878
+ def hflip_cv2(img: np.ndarray) -> np.ndarray:
879
+ return cv2.flip(img, 1)
880
+
881
+
882
+ @preserve_shape
883
+ def random_flip(img: np.ndarray, code: int) -> np.ndarray:
884
+ return cv2.flip(img, code)
885
+
886
+
887
+ def transpose(img: np.ndarray) -> np.ndarray:
888
+ return img.transpose(1, 0, 2) if len(img.shape) > 2 else img.transpose(1, 0)
889
+
890
+
891
+ def rot90(img: np.ndarray, factor: int) -> np.ndarray:
892
+ img = np.rot90(img, factor)
893
+ return np.ascontiguousarray(img)
894
+
895
+
896
+ def bbox_vflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
897
+ """Flip a bounding box vertically around the x-axis.
898
+
899
+ Args:
900
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
901
+ rows: Image rows.
902
+ cols: Image cols.
903
+
904
+ Returns:
905
+ tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
906
+
907
+ """
908
+ x_min, y_min, x_max, y_max = bbox[:4]
909
+ return x_min, 1 - y_max, x_max, 1 - y_min
910
+
911
+
912
+ def bbox_hflip(bbox: BoxInternalType, rows: int, cols: int) -> BoxInternalType: # skipcq: PYL-W0613
913
+ """Flip a bounding box horizontally around the y-axis.
914
+
915
+ Args:
916
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
917
+ rows: Image rows.
918
+ cols: Image cols.
919
+
920
+ Returns:
921
+ A bounding box `(x_min, y_min, x_max, y_max)`.
922
+
923
+ """
924
+ x_min, y_min, x_max, y_max = bbox[:4]
925
+ return 1 - x_max, y_min, 1 - x_min, y_max
926
+
927
+
928
+ def bbox_flip(bbox: BoxInternalType, d: int, rows: int, cols: int) -> BoxInternalType:
929
+ """Flip a bounding box either vertically, horizontally or both depending on the value of `d`.
930
+
931
+ Args:
932
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
933
+ d: dimension. 0 for vertical flip, 1 for horizontal, -1 for transpose
934
+ rows: Image rows.
935
+ cols: Image cols.
936
+
937
+ Returns:
938
+ A bounding box `(x_min, y_min, x_max, y_max)`.
939
+
940
+ Raises:
941
+ ValueError: if value of `d` is not -1, 0 or 1.
942
+
943
+ """
944
+ if d == 0:
945
+ bbox = bbox_vflip(bbox, rows, cols)
946
+ elif d == 1:
947
+ bbox = bbox_hflip(bbox, rows, cols)
948
+ elif d == -1:
949
+ bbox = bbox_hflip(bbox, rows, cols)
950
+ bbox = bbox_vflip(bbox, rows, cols)
951
+ else:
952
+ raise ValueError("Invalid d value {}. Valid values are -1, 0 and 1".format(d))
953
+ return bbox
954
+
955
+
956
+ def bbox_transpose(
957
+ bbox: KeypointInternalType, axis: int, rows: int, cols: int
958
+ ) -> KeypointInternalType: # skipcq: PYL-W0613
959
+ """Transposes a bounding box along given axis.
960
+
961
+ Args:
962
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
963
+ axis: 0 - main axis, 1 - secondary axis.
964
+ rows: Image rows.
965
+ cols: Image cols.
966
+
967
+ Returns:
968
+ A bounding box tuple `(x_min, y_min, x_max, y_max)`.
969
+
970
+ Raises:
971
+ ValueError: If axis not equal to 0 or 1.
972
+
973
+ """
974
+ x_min, y_min, x_max, y_max = bbox[:4]
975
+ if axis not in {0, 1}:
976
+ raise ValueError("Axis must be either 0 or 1.")
977
+ if axis == 0:
978
+ bbox = (y_min, x_min, y_max, x_max)
979
+ if axis == 1:
980
+ bbox = (1 - y_max, 1 - x_max, 1 - y_min, 1 - x_min)
981
+ return bbox
982
+
983
+
984
+ @angle_2pi_range
985
+ def keypoint_vflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType:
986
+ """Flip a keypoint vertically around the x-axis.
987
+
988
+ Args:
989
+ keypoint: A keypoint `(x, y, angle, scale)`.
990
+ rows: Image height.
991
+ cols: Image width.
992
+
993
+ Returns:
994
+ tuple: A keypoint `(x, y, angle, scale)`.
995
+
996
+ """
997
+ x, y, angle, scale = keypoint[:4]
998
+ angle = -angle
999
+ return x, (rows - 1) - y, angle, scale
1000
+
1001
+
1002
+ @angle_2pi_range
1003
+ def keypoint_hflip(keypoint: KeypointInternalType, rows: int, cols: int) -> KeypointInternalType:
1004
+ """Flip a keypoint horizontally around the y-axis.
1005
+
1006
+ Args:
1007
+ keypoint: A keypoint `(x, y, angle, scale)`.
1008
+ rows: Image height.
1009
+ cols: Image width.
1010
+
1011
+ Returns:
1012
+ A keypoint `(x, y, angle, scale)`.
1013
+
1014
+ """
1015
+ x, y, angle, scale = keypoint[:4]
1016
+ angle = math.pi - angle
1017
+ return (cols - 1) - x, y, angle, scale
1018
+
1019
+
1020
+ def keypoint_flip(keypoint: KeypointInternalType, d: int, rows: int, cols: int) -> KeypointInternalType:
1021
+ """Flip a keypoint either vertically, horizontally or both depending on the value of `d`.
1022
+
1023
+ Args:
1024
+ keypoint: A keypoint `(x, y, angle, scale)`.
1025
+ d: Number of flip. Must be -1, 0 or 1:
1026
+ * 0 - vertical flip,
1027
+ * 1 - horizontal flip,
1028
+ * -1 - vertical and horizontal flip.
1029
+ rows: Image height.
1030
+ cols: Image width.
1031
+
1032
+ Returns:
1033
+ A keypoint `(x, y, angle, scale)`.
1034
+
1035
+ Raises:
1036
+ ValueError: if value of `d` is not -1, 0 or 1.
1037
+
1038
+ """
1039
+ if d == 0:
1040
+ keypoint = keypoint_vflip(keypoint, rows, cols)
1041
+ elif d == 1:
1042
+ keypoint = keypoint_hflip(keypoint, rows, cols)
1043
+ elif d == -1:
1044
+ keypoint = keypoint_hflip(keypoint, rows, cols)
1045
+ keypoint = keypoint_vflip(keypoint, rows, cols)
1046
+ else:
1047
+ raise ValueError(f"Invalid d value {d}. Valid values are -1, 0 and 1")
1048
+ return keypoint
1049
+
1050
+
1051
+ def keypoint_transpose(keypoint: KeypointInternalType) -> KeypointInternalType:
1052
+ """Rotate a keypoint by angle.
1053
+
1054
+ Args:
1055
+ keypoint: A keypoint `(x, y, angle, scale)`.
1056
+
1057
+ Returns:
1058
+ A keypoint `(x, y, angle, scale)`.
1059
+
1060
+ """
1061
+ x, y, angle, scale = keypoint[:4]
1062
+
1063
+ if angle <= np.pi:
1064
+ angle = np.pi - angle
1065
+ else:
1066
+ angle = 3 * np.pi - angle
1067
+
1068
+ return y, x, angle, scale
1069
+
1070
+
1071
+ @preserve_channel_dim
1072
+ def pad(
1073
+ img: np.ndarray,
1074
+ min_height: int,
1075
+ min_width: int,
1076
+ border_mode: int = cv2.BORDER_REFLECT_101,
1077
+ value: Optional[ImageColorType] = None,
1078
+ ) -> np.ndarray:
1079
+ height, width = img.shape[:2]
1080
+
1081
+ if height < min_height:
1082
+ h_pad_top = int((min_height - height) / 2.0)
1083
+ h_pad_bottom = min_height - height - h_pad_top
1084
+ else:
1085
+ h_pad_top = 0
1086
+ h_pad_bottom = 0
1087
+
1088
+ if width < min_width:
1089
+ w_pad_left = int((min_width - width) / 2.0)
1090
+ w_pad_right = min_width - width - w_pad_left
1091
+ else:
1092
+ w_pad_left = 0
1093
+ w_pad_right = 0
1094
+
1095
+ img = pad_with_params(img, h_pad_top, h_pad_bottom, w_pad_left, w_pad_right, border_mode, value)
1096
+
1097
+ if img.shape[:2] != (max(min_height, height), max(min_width, width)):
1098
+ raise RuntimeError(
1099
+ "Invalid result shape. Got: {}. Expected: {}".format(
1100
+ img.shape[:2], (max(min_height, height), max(min_width, width))
1101
+ )
1102
+ )
1103
+
1104
+ return img
1105
+
1106
+
1107
+ @preserve_channel_dim
1108
+ def pad_with_params(
1109
+ img: np.ndarray,
1110
+ h_pad_top: int,
1111
+ h_pad_bottom: int,
1112
+ w_pad_left: int,
1113
+ w_pad_right: int,
1114
+ border_mode: int = cv2.BORDER_REFLECT_101,
1115
+ value: Optional[ImageColorType] = None,
1116
+ ) -> np.ndarray:
1117
+ pad_fn = _maybe_process_in_chunks(
1118
+ cv2.copyMakeBorder,
1119
+ top=h_pad_top,
1120
+ bottom=h_pad_bottom,
1121
+ left=w_pad_left,
1122
+ right=w_pad_right,
1123
+ borderType=border_mode,
1124
+ value=value,
1125
+ )
1126
+ return pad_fn(img)
1127
+
1128
+
1129
+ @preserve_shape
1130
+ def optical_distortion(
1131
+ img: np.ndarray,
1132
+ k: int = 0,
1133
+ dx: int = 0,
1134
+ dy: int = 0,
1135
+ interpolation: int = cv2.INTER_LINEAR,
1136
+ border_mode: int = cv2.BORDER_REFLECT_101,
1137
+ value: Optional[ImageColorType] = None,
1138
+ ) -> np.ndarray:
1139
+ """Barrel / pincushion distortion. Unconventional augment.
1140
+
1141
+ Reference:
1142
+ | https://stackoverflow.com/questions/6199636/formulas-for-barrel-pincushion-distortion
1143
+ | https://stackoverflow.com/questions/10364201/image-transformation-in-opencv
1144
+ | https://stackoverflow.com/questions/2477774/correcting-fisheye-distortion-programmatically
1145
+ | http://www.coldvision.io/2017/03/02/advanced-lane-finding-using-opencv/
1146
+ """
1147
+ height, width = img.shape[:2]
1148
+
1149
+ fx = width
1150
+ fy = height
1151
+
1152
+ cx = width * 0.5 + dx
1153
+ cy = height * 0.5 + dy
1154
+
1155
+ camera_matrix = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
1156
+
1157
+ distortion = np.array([k, k, 0, 0, 0], dtype=np.float32)
1158
+ map1, map2 = cv2.initUndistortRectifyMap(
1159
+ camera_matrix, distortion, None, None, (width, height), cv2.CV_32FC1 # type: ignore[attr-defined]
1160
+ )
1161
+ return cv2.remap(img, map1, map2, interpolation=interpolation, borderMode=border_mode, borderValue=value)
1162
+
1163
+
1164
+ @preserve_shape
1165
+ def grid_distortion(
1166
+ img: np.ndarray,
1167
+ num_steps: int = 10,
1168
+ xsteps: Tuple = (),
1169
+ ysteps: Tuple = (),
1170
+ interpolation: int = cv2.INTER_LINEAR,
1171
+ border_mode: int = cv2.BORDER_REFLECT_101,
1172
+ value: Optional[ImageColorType] = None,
1173
+ ) -> np.ndarray:
1174
+ """Perform a grid distortion of an input image.
1175
+
1176
+ Reference:
1177
+ http://pythology.blogspot.sg/2014/03/interpolation-on-regular-distorted-grid.html
1178
+ """
1179
+ height, width = img.shape[:2]
1180
+
1181
+ x_step = width // num_steps
1182
+ xx = np.zeros(width, np.float32)
1183
+ prev = 0
1184
+ for idx in range(num_steps + 1):
1185
+ x = idx * x_step
1186
+ start = int(x)
1187
+ end = int(x) + x_step
1188
+ if end > width:
1189
+ end = width
1190
+ cur = width
1191
+ else:
1192
+ cur = prev + x_step * xsteps[idx]
1193
+
1194
+ xx[start:end] = np.linspace(prev, cur, end - start)
1195
+ prev = cur
1196
+
1197
+ y_step = height // num_steps
1198
+ yy = np.zeros(height, np.float32)
1199
+ prev = 0
1200
+ for idx in range(num_steps + 1):
1201
+ y = idx * y_step
1202
+ start = int(y)
1203
+ end = int(y) + y_step
1204
+ if end > height:
1205
+ end = height
1206
+ cur = height
1207
+ else:
1208
+ cur = prev + y_step * ysteps[idx]
1209
+
1210
+ yy[start:end] = np.linspace(prev, cur, end - start)
1211
+ prev = cur
1212
+
1213
+ map_x, map_y = np.meshgrid(xx, yy)
1214
+ map_x = map_x.astype(np.float32)
1215
+ map_y = map_y.astype(np.float32)
1216
+
1217
+ remap_fn = _maybe_process_in_chunks(
1218
+ cv2.remap,
1219
+ map1=map_x,
1220
+ map2=map_y,
1221
+ interpolation=interpolation,
1222
+ borderMode=border_mode,
1223
+ borderValue=value,
1224
+ )
1225
+ return remap_fn(img)
1226
+
1227
+
1228
+ @preserve_shape
1229
+ def elastic_transform_approx(
1230
+ img: np.ndarray,
1231
+ alpha: float,
1232
+ sigma: float,
1233
+ alpha_affine: float,
1234
+ interpolation: int = cv2.INTER_LINEAR,
1235
+ border_mode: int = cv2.BORDER_REFLECT_101,
1236
+ value: Optional[ImageColorType] = None,
1237
+ random_state: Optional[np.random.RandomState] = None,
1238
+ ) -> np.ndarray:
1239
+ """Elastic deformation of images as described in [Simard2003]_ (with modifications for speed).
1240
+ Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
1241
+
1242
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
1243
+ Convolutional Neural Networks applied to Visual Document Analysis", in
1244
+ Proc. of the International Conference on Document Analysis and
1245
+ Recognition, 2003.
1246
+ """
1247
+ height, width = img.shape[:2]
1248
+
1249
+ # Random affine
1250
+ center_square = np.array((height, width), dtype=np.float32) // 2
1251
+ square_size = min((height, width)) // 3
1252
+ alpha = float(alpha)
1253
+ sigma = float(sigma)
1254
+ alpha_affine = float(alpha_affine)
1255
+
1256
+ pts1 = np.array(
1257
+ [
1258
+ center_square + square_size,
1259
+ [center_square[0] + square_size, center_square[1] - square_size],
1260
+ center_square - square_size,
1261
+ ],
1262
+ dtype=np.float32,
1263
+ )
1264
+ pts2 = pts1 + random_utils.uniform(-alpha_affine, alpha_affine, size=pts1.shape, random_state=random_state).astype(
1265
+ np.float32
1266
+ )
1267
+ matrix = cv2.getAffineTransform(pts1, pts2)
1268
+
1269
+ warp_fn = _maybe_process_in_chunks(
1270
+ cv2.warpAffine,
1271
+ M=matrix,
1272
+ dsize=(width, height),
1273
+ flags=interpolation,
1274
+ borderMode=border_mode,
1275
+ borderValue=value,
1276
+ )
1277
+ img = warp_fn(img)
1278
+
1279
+ dx = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
1280
+ cv2.GaussianBlur(dx, (17, 17), sigma, dst=dx)
1281
+ dx *= alpha
1282
+
1283
+ dy = random_utils.rand(height, width, random_state=random_state).astype(np.float32) * 2 - 1
1284
+ cv2.GaussianBlur(dy, (17, 17), sigma, dst=dy)
1285
+ dy *= alpha
1286
+
1287
+ x, y = np.meshgrid(np.arange(width), np.arange(height))
1288
+
1289
+ map_x = np.float32(x + dx)
1290
+ map_y = np.float32(y + dy)
1291
+
1292
+ remap_fn = _maybe_process_in_chunks(
1293
+ cv2.remap,
1294
+ map1=map_x,
1295
+ map2=map_y,
1296
+ interpolation=interpolation,
1297
+ borderMode=border_mode,
1298
+ borderValue=value,
1299
+ )
1300
+ return remap_fn(img)
custom_albumentations/augmentations/geometric/resize.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from typing import Dict, Sequence, Tuple, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ from ...core.transforms_interface import (
8
+ BoxInternalType,
9
+ DualTransform,
10
+ KeypointInternalType,
11
+ to_tuple,
12
+ )
13
+ from . import functional as F
14
+
15
+ __all__ = ["RandomScale", "LongestMaxSize", "SmallestMaxSize", "Resize"]
16
+
17
+
18
+ class RandomScale(DualTransform):
19
+ """Randomly resize the input. Output image size is different from the input image size.
20
+
21
+ Args:
22
+ scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
23
+ range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
24
+ If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
25
+ Default: (-0.1, 0.1).
26
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
27
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
28
+ Default: cv2.INTER_LINEAR.
29
+ p (float): probability of applying the transform. Default: 0.5.
30
+
31
+ Targets:
32
+ image, mask, bboxes, keypoints
33
+
34
+ Image types:
35
+ uint8, float32
36
+ """
37
+
38
+ def __init__(self, scale_limit=0.1, interpolation=cv2.INTER_LINEAR, always_apply=False, p=0.5):
39
+ super(RandomScale, self).__init__(always_apply, p)
40
+ self.scale_limit = to_tuple(scale_limit, bias=1.0)
41
+ self.interpolation = interpolation
42
+
43
+ def get_params(self):
44
+ return {"scale": random.uniform(self.scale_limit[0], self.scale_limit[1])}
45
+
46
+ def apply(self, img, scale=0, interpolation=cv2.INTER_LINEAR, **params):
47
+ return F.scale(img, scale, interpolation)
48
+
49
+ def apply_to_bbox(self, bbox, **params):
50
+ # Bounding box coordinates are scale invariant
51
+ return bbox
52
+
53
+ def apply_to_keypoint(self, keypoint, scale=0, **params):
54
+ return F.keypoint_scale(keypoint, scale, scale)
55
+
56
+ def get_transform_init_args(self):
57
+ return {"interpolation": self.interpolation, "scale_limit": to_tuple(self.scale_limit, bias=-1.0)}
58
+
59
+
60
+ class LongestMaxSize(DualTransform):
61
+ """Rescale an image so that maximum side is equal to max_size, keeping the aspect ratio of the initial image.
62
+
63
+ Args:
64
+ max_size (int, list of int): maximum size of the image after the transformation. When using a list, max size
65
+ will be randomly selected from the values in the list.
66
+ interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR.
67
+ p (float): probability of applying the transform. Default: 1.
68
+
69
+ Targets:
70
+ image, mask, bboxes, keypoints
71
+
72
+ Image types:
73
+ uint8, float32
74
+ """
75
+
76
+ def __init__(
77
+ self,
78
+ max_size: Union[int, Sequence[int]] = 1024,
79
+ interpolation: int = cv2.INTER_LINEAR,
80
+ always_apply: bool = False,
81
+ p: float = 1,
82
+ ):
83
+ super(LongestMaxSize, self).__init__(always_apply, p)
84
+ self.interpolation = interpolation
85
+ self.max_size = max_size
86
+
87
+ def apply(
88
+ self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params
89
+ ) -> np.ndarray:
90
+ return F.longest_max_size(img, max_size=max_size, interpolation=interpolation)
91
+
92
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
93
+ # Bounding box coordinates are scale invariant
94
+ return bbox
95
+
96
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType:
97
+ height = params["rows"]
98
+ width = params["cols"]
99
+
100
+ scale = max_size / max([height, width])
101
+ return F.keypoint_scale(keypoint, scale, scale)
102
+
103
+ def get_params(self) -> Dict[str, int]:
104
+ return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)}
105
+
106
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
107
+ return ("max_size", "interpolation")
108
+
109
+
110
+ class SmallestMaxSize(DualTransform):
111
+ """Rescale an image so that minimum side is equal to max_size, keeping the aspect ratio of the initial image.
112
+
113
+ Args:
114
+ max_size (int, list of int): maximum size of smallest side of the image after the transformation. When using a
115
+ list, max size will be randomly selected from the values in the list.
116
+ interpolation (OpenCV flag): interpolation method. Default: cv2.INTER_LINEAR.
117
+ p (float): probability of applying the transform. Default: 1.
118
+
119
+ Targets:
120
+ image, mask, bboxes, keypoints
121
+
122
+ Image types:
123
+ uint8, float32
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ max_size: Union[int, Sequence[int]] = 1024,
129
+ interpolation: int = cv2.INTER_LINEAR,
130
+ always_apply: bool = False,
131
+ p: float = 1,
132
+ ):
133
+ super(SmallestMaxSize, self).__init__(always_apply, p)
134
+ self.interpolation = interpolation
135
+ self.max_size = max_size
136
+
137
+ def apply(
138
+ self, img: np.ndarray, max_size: int = 1024, interpolation: int = cv2.INTER_LINEAR, **params
139
+ ) -> np.ndarray:
140
+ return F.smallest_max_size(img, max_size=max_size, interpolation=interpolation)
141
+
142
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
143
+ return bbox
144
+
145
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, max_size: int = 1024, **params) -> KeypointInternalType:
146
+ height = params["rows"]
147
+ width = params["cols"]
148
+
149
+ scale = max_size / min([height, width])
150
+ return F.keypoint_scale(keypoint, scale, scale)
151
+
152
+ def get_params(self) -> Dict[str, int]:
153
+ return {"max_size": self.max_size if isinstance(self.max_size, int) else random.choice(self.max_size)}
154
+
155
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
156
+ return ("max_size", "interpolation")
157
+
158
+
159
+ class Resize(DualTransform):
160
+ """Resize the input to the given height and width.
161
+
162
+ Args:
163
+ height (int): desired height of the output.
164
+ width (int): desired width of the output.
165
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
166
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
167
+ Default: cv2.INTER_LINEAR.
168
+ p (float): probability of applying the transform. Default: 1.
169
+
170
+ Targets:
171
+ image, mask, bboxes, keypoints
172
+
173
+ Image types:
174
+ uint8, float32
175
+ """
176
+
177
+ def __init__(self, height, width, interpolation=cv2.INTER_LINEAR, always_apply=False, p=1):
178
+ super(Resize, self).__init__(always_apply, p)
179
+ self.height = height
180
+ self.width = width
181
+ self.interpolation = interpolation
182
+
183
+ def apply(self, img, interpolation=cv2.INTER_LINEAR, **params):
184
+ return F.resize(img, height=self.height, width=self.width, interpolation=interpolation)
185
+
186
+ def apply_to_bbox(self, bbox, **params):
187
+ # Bounding box coordinates are scale invariant
188
+ return bbox
189
+
190
+ def apply_to_keypoint(self, keypoint, **params):
191
+ height = params["rows"]
192
+ width = params["cols"]
193
+ scale_x = self.width / width
194
+ scale_y = self.height / height
195
+ return F.keypoint_scale(keypoint, scale_x, scale_y)
196
+
197
+ def get_transform_init_args_names(self):
198
+ return ("height", "width", "interpolation")
custom_albumentations/augmentations/geometric/rotate.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
4
+
5
+ import cv2
6
+ import numpy as np
7
+
8
+ from ...core.transforms_interface import (
9
+ BoxInternalType,
10
+ DualTransform,
11
+ FillValueType,
12
+ KeypointInternalType,
13
+ to_tuple,
14
+ )
15
+ from ..crops import functional as FCrops
16
+ from . import functional as F
17
+
18
+ __all__ = ["Rotate", "RandomRotate90", "SafeRotate"]
19
+
20
+
21
+ class RandomRotate90(DualTransform):
22
+ """Randomly rotate the input by 90 degrees zero or more times.
23
+
24
+ Args:
25
+ p (float): probability of applying the transform. Default: 0.5.
26
+
27
+ Targets:
28
+ image, mask, bboxes, keypoints
29
+
30
+ Image types:
31
+ uint8, float32
32
+ """
33
+
34
+ def apply(self, img, factor=0, **params):
35
+ """
36
+ Args:
37
+ factor (int): number of times the input will be rotated by 90 degrees.
38
+ """
39
+ return np.ascontiguousarray(np.rot90(img, factor))
40
+
41
+ def get_params(self):
42
+ # Random int in the range [0, 3]
43
+ return {"factor": random.randint(0, 3)}
44
+
45
+ def apply_to_bbox(self, bbox, factor=0, **params):
46
+ return F.bbox_rot90(bbox, factor, **params)
47
+
48
+ def apply_to_keypoint(self, keypoint, factor=0, **params):
49
+ return F.keypoint_rot90(keypoint, factor, **params)
50
+
51
+ def get_transform_init_args_names(self):
52
+ return ()
53
+
54
+
55
+ class Rotate(DualTransform):
56
+ """Rotate the input by an angle selected randomly from the uniform distribution.
57
+
58
+ Args:
59
+ limit ((int, int) or int): range from which a random angle is picked. If limit is a single int
60
+ an angle is picked from (-limit, limit). Default: (-90, 90)
61
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
62
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
63
+ Default: cv2.INTER_LINEAR.
64
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
65
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
66
+ Default: cv2.BORDER_REFLECT_101
67
+ value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
68
+ mask_value (int, float,
69
+ list of ints,
70
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
71
+ rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
72
+ Default: "largest_box"
73
+ crop_border (bool): If True would make a largest possible crop within rotated image
74
+ p (float): probability of applying the transform. Default: 0.5.
75
+
76
+ Targets:
77
+ image, mask, bboxes, keypoints
78
+
79
+ Image types:
80
+ uint8, float32
81
+ """
82
+
83
+ def __init__(
84
+ self,
85
+ limit=90,
86
+ interpolation=cv2.INTER_LINEAR,
87
+ border_mode=cv2.BORDER_REFLECT_101,
88
+ value=None,
89
+ mask_value=None,
90
+ rotate_method="largest_box",
91
+ crop_border=False,
92
+ always_apply=False,
93
+ p=0.5,
94
+ ):
95
+ super(Rotate, self).__init__(always_apply, p)
96
+ self.limit = to_tuple(limit)
97
+ self.interpolation = interpolation
98
+ self.border_mode = border_mode
99
+ self.value = value
100
+ self.mask_value = mask_value
101
+ self.rotate_method = rotate_method
102
+ self.crop_border = crop_border
103
+
104
+ if rotate_method not in ["largest_box", "ellipse"]:
105
+ raise ValueError(f"Rotation method {self.rotate_method} is not valid.")
106
+
107
+ def apply(
108
+ self, img, angle=0, interpolation=cv2.INTER_LINEAR, x_min=None, x_max=None, y_min=None, y_max=None, **params
109
+ ):
110
+ img_out = F.rotate(img, angle, interpolation, self.border_mode, self.value)
111
+ if self.crop_border:
112
+ img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max)
113
+ return img_out
114
+
115
+ def apply_to_mask(self, img, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, **params):
116
+ img_out = F.rotate(img, angle, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
117
+ if self.crop_border:
118
+ img_out = FCrops.crop(img_out, x_min, y_min, x_max, y_max)
119
+ return img_out
120
+
121
+ def apply_to_bbox(self, bbox, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params):
122
+ bbox_out = F.bbox_rotate(bbox, angle, self.rotate_method, rows, cols)
123
+ if self.crop_border:
124
+ bbox_out = FCrops.bbox_crop(bbox_out, x_min, y_min, x_max, y_max, rows, cols)
125
+ return bbox_out
126
+
127
+ def apply_to_keypoint(
128
+ self, keypoint, angle=0, x_min=None, x_max=None, y_min=None, y_max=None, cols=0, rows=0, **params
129
+ ):
130
+ keypoint_out = F.keypoint_rotate(keypoint, angle, rows, cols, **params)
131
+ if self.crop_border:
132
+ keypoint_out = FCrops.crop_keypoint_by_coords(keypoint_out, (x_min, y_min, x_max, y_max))
133
+ return keypoint_out
134
+
135
+ @staticmethod
136
+ def _rotated_rect_with_max_area(h, w, angle):
137
+ """
138
+ Given a rectangle of size wxh that has been rotated by 'angle' (in
139
+ degrees), computes the width and height of the largest possible
140
+ axis-aligned rectangle (maximal area) within the rotated rectangle.
141
+
142
+ Code from: https://stackoverflow.com/questions/16702966/rotate-image-and-crop-out-black-borders
143
+ """
144
+
145
+ angle = math.radians(angle)
146
+ width_is_longer = w >= h
147
+ side_long, side_short = (w, h) if width_is_longer else (h, w)
148
+
149
+ # since the solutions for angle, -angle and 180-angle are all the same,
150
+ # it is sufficient to look at the first quadrant and the absolute values of sin,cos:
151
+ sin_a, cos_a = abs(math.sin(angle)), abs(math.cos(angle))
152
+ if side_short <= 2.0 * sin_a * cos_a * side_long or abs(sin_a - cos_a) < 1e-10:
153
+ # half constrained case: two crop corners touch the longer side,
154
+ # the other two corners are on the mid-line parallel to the longer line
155
+ x = 0.5 * side_short
156
+ wr, hr = (x / sin_a, x / cos_a) if width_is_longer else (x / cos_a, x / sin_a)
157
+ else:
158
+ # fully constrained case: crop touches all 4 sides
159
+ cos_2a = cos_a * cos_a - sin_a * sin_a
160
+ wr, hr = (w * cos_a - h * sin_a) / cos_2a, (h * cos_a - w * sin_a) / cos_2a
161
+
162
+ return dict(
163
+ x_min=max(0, int(w / 2 - wr / 2)),
164
+ x_max=min(w, int(w / 2 + wr / 2)),
165
+ y_min=max(0, int(h / 2 - hr / 2)),
166
+ y_max=min(h, int(h / 2 + hr / 2)),
167
+ )
168
+
169
+ @property
170
+ def targets_as_params(self) -> List[str]:
171
+ return ["image"]
172
+
173
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
174
+ out_params = {"angle": random.uniform(self.limit[0], self.limit[1])}
175
+ if self.crop_border:
176
+ h, w = params["image"].shape[:2]
177
+ out_params.update(self._rotated_rect_with_max_area(h, w, out_params["angle"]))
178
+ return out_params
179
+
180
+ def get_transform_init_args_names(self):
181
+ return ("limit", "interpolation", "border_mode", "value", "mask_value", "rotate_method", "crop_border")
182
+
183
+
184
+ class SafeRotate(DualTransform):
185
+ """Rotate the input inside the input's frame by an angle selected randomly from the uniform distribution.
186
+
187
+ The resulting image may have artifacts in it. After rotation, the image may have a different aspect ratio, and
188
+ after resizing, it returns to its original shape with the original aspect ratio of the image. For these reason we
189
+ may see some artifacts.
190
+
191
+ Args:
192
+ limit ((int, int) or int): range from which a random angle is picked. If limit is a single int
193
+ an angle is picked from (-limit, limit). Default: (-90, 90)
194
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
195
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
196
+ Default: cv2.INTER_LINEAR.
197
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
198
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
199
+ Default: cv2.BORDER_REFLECT_101
200
+ value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
201
+ mask_value (int, float,
202
+ list of ints,
203
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
204
+ p (float): probability of applying the transform. Default: 0.5.
205
+
206
+ Targets:
207
+ image, mask, bboxes, keypoints
208
+
209
+ Image types:
210
+ uint8, float32
211
+ """
212
+
213
+ def __init__(
214
+ self,
215
+ limit: Union[float, Tuple[float, float]] = 90,
216
+ interpolation: int = cv2.INTER_LINEAR,
217
+ border_mode: int = cv2.BORDER_REFLECT_101,
218
+ value: FillValueType = None,
219
+ mask_value: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
220
+ always_apply: bool = False,
221
+ p: float = 0.5,
222
+ ):
223
+ super(SafeRotate, self).__init__(always_apply, p)
224
+ self.limit = to_tuple(limit)
225
+ self.interpolation = interpolation
226
+ self.border_mode = border_mode
227
+ self.value = value
228
+ self.mask_value = mask_value
229
+
230
+ def apply(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray:
231
+ return F.safe_rotate(img, matrix, self.interpolation, self.value, self.border_mode)
232
+
233
+ def apply_to_mask(self, img: np.ndarray, matrix: np.ndarray = np.array(None), **params) -> np.ndarray:
234
+ return F.safe_rotate(img, matrix, cv2.INTER_NEAREST, self.mask_value, self.border_mode)
235
+
236
+ def apply_to_bbox(self, bbox: BoxInternalType, cols: int = 0, rows: int = 0, **params) -> BoxInternalType:
237
+ return F.bbox_safe_rotate(bbox, params["matrix"], cols, rows)
238
+
239
+ def apply_to_keypoint(
240
+ self,
241
+ keypoint: KeypointInternalType,
242
+ angle: float = 0,
243
+ scale_x: float = 0,
244
+ scale_y: float = 0,
245
+ cols: int = 0,
246
+ rows: int = 0,
247
+ **params
248
+ ) -> KeypointInternalType:
249
+ return F.keypoint_safe_rotate(keypoint, params["matrix"], angle, scale_x, scale_y, cols, rows)
250
+
251
+ @property
252
+ def targets_as_params(self) -> List[str]:
253
+ return ["image"]
254
+
255
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
256
+ angle = random.uniform(self.limit[0], self.limit[1])
257
+
258
+ image = params["image"]
259
+ h, w = image.shape[:2]
260
+
261
+ # https://stackoverflow.com/questions/43892506/opencv-python-rotate-image-without-cropping-sides
262
+ image_center = (w / 2, h / 2)
263
+
264
+ # Rotation Matrix
265
+ rotation_mat = cv2.getRotationMatrix2D(image_center, angle, 1.0)
266
+
267
+ # rotation calculates the cos and sin, taking absolutes of those.
268
+ abs_cos = abs(rotation_mat[0, 0])
269
+ abs_sin = abs(rotation_mat[0, 1])
270
+
271
+ # find the new width and height bounds
272
+ new_w = math.ceil(h * abs_sin + w * abs_cos)
273
+ new_h = math.ceil(h * abs_cos + w * abs_sin)
274
+
275
+ scale_x = w / new_w
276
+ scale_y = h / new_h
277
+
278
+ # Shift the image to create padding
279
+ rotation_mat[0, 2] += new_w / 2 - image_center[0]
280
+ rotation_mat[1, 2] += new_h / 2 - image_center[1]
281
+
282
+ # Rescale to original size
283
+ scale_mat = np.diag(np.ones(3))
284
+ scale_mat[0, 0] *= scale_x
285
+ scale_mat[1, 1] *= scale_y
286
+ _tmp = np.diag(np.ones(3))
287
+ _tmp[:2] = rotation_mat
288
+ _tmp = scale_mat @ _tmp
289
+ rotation_mat = _tmp[:2]
290
+
291
+ return {"matrix": rotation_mat, "angle": angle, "scale_x": scale_x, "scale_y": scale_y}
292
+
293
+ def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str]:
294
+ return ("limit", "interpolation", "border_mode", "value", "mask_value")
custom_albumentations/augmentations/geometric/transforms.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ from enum import Enum
4
+ from typing import Dict, Optional, Sequence, Tuple, Union
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import skimage.transform
9
+
10
+ from custom_albumentations.core.bbox_utils import denormalize_bbox, normalize_bbox
11
+
12
+ from ... import random_utils
13
+ from ...core.transforms_interface import (
14
+ BoxInternalType,
15
+ DualTransform,
16
+ ImageColorType,
17
+ KeypointInternalType,
18
+ ScaleFloatType,
19
+ to_tuple,
20
+ )
21
+ from ..functional import bbox_from_mask
22
+ from . import functional as F
23
+
24
+ __all__ = [
25
+ "ShiftScaleRotate",
26
+ "ElasticTransform",
27
+ "Perspective",
28
+ "Affine",
29
+ "PiecewiseAffine",
30
+ "VerticalFlip",
31
+ "HorizontalFlip",
32
+ "Flip",
33
+ "Transpose",
34
+ "OpticalDistortion",
35
+ "GridDistortion",
36
+ "PadIfNeeded",
37
+ ]
38
+
39
+
40
+ class ShiftScaleRotate(DualTransform):
41
+ """Randomly apply affine transforms: translate, scale and rotate the input.
42
+
43
+ Args:
44
+ shift_limit ((float, float) or float): shift factor range for both height and width. If shift_limit
45
+ is a single float value, the range will be (-shift_limit, shift_limit). Absolute values for lower and
46
+ upper bounds should lie in range [0, 1]. Default: (-0.0625, 0.0625).
47
+ scale_limit ((float, float) or float): scaling factor range. If scale_limit is a single float value, the
48
+ range will be (-scale_limit, scale_limit). Note that the scale_limit will be biased by 1.
49
+ If scale_limit is a tuple, like (low, high), sampling will be done from the range (1 + low, 1 + high).
50
+ Default: (-0.1, 0.1).
51
+ rotate_limit ((int, int) or int): rotation range. If rotate_limit is a single int value, the
52
+ range will be (-rotate_limit, rotate_limit). Default: (-45, 45).
53
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
54
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
55
+ Default: cv2.INTER_LINEAR.
56
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
57
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
58
+ Default: cv2.BORDER_REFLECT_101
59
+ value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
60
+ mask_value (int, float,
61
+ list of int,
62
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
63
+ shift_limit_x ((float, float) or float): shift factor range for width. If it is set then this value
64
+ instead of shift_limit will be used for shifting width. If shift_limit_x is a single float value,
65
+ the range will be (-shift_limit_x, shift_limit_x). Absolute values for lower and upper bounds should lie in
66
+ the range [0, 1]. Default: None.
67
+ shift_limit_y ((float, float) or float): shift factor range for height. If it is set then this value
68
+ instead of shift_limit will be used for shifting height. If shift_limit_y is a single float value,
69
+ the range will be (-shift_limit_y, shift_limit_y). Absolute values for lower and upper bounds should lie
70
+ in the range [0, 1]. Default: None.
71
+ rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or "ellipse".
72
+ Default: "largest_box"
73
+ p (float): probability of applying the transform. Default: 0.5.
74
+
75
+ Targets:
76
+ image, mask, keypoints
77
+
78
+ Image types:
79
+ uint8, float32
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ shift_limit=0.0625,
85
+ scale_limit=0.1,
86
+ rotate_limit=45,
87
+ interpolation=cv2.INTER_LINEAR,
88
+ border_mode=cv2.BORDER_REFLECT_101,
89
+ value=None,
90
+ mask_value=None,
91
+ shift_limit_x=None,
92
+ shift_limit_y=None,
93
+ rotate_method="largest_box",
94
+ always_apply=False,
95
+ p=0.5,
96
+ ):
97
+ super(ShiftScaleRotate, self).__init__(always_apply, p)
98
+ self.shift_limit_x = to_tuple(shift_limit_x if shift_limit_x is not None else shift_limit)
99
+ self.shift_limit_y = to_tuple(shift_limit_y if shift_limit_y is not None else shift_limit)
100
+ self.scale_limit = to_tuple(scale_limit, bias=1.0)
101
+ self.rotate_limit = to_tuple(rotate_limit)
102
+ self.interpolation = interpolation
103
+ self.border_mode = border_mode
104
+ self.value = value
105
+ self.mask_value = mask_value
106
+ self.rotate_method = rotate_method
107
+
108
+ if self.rotate_method not in ["largest_box", "ellipse"]:
109
+ raise ValueError(f"Rotation method {self.rotate_method} is not valid.")
110
+
111
+ def apply(self, img, angle=0, scale=0, dx=0, dy=0, interpolation=cv2.INTER_LINEAR, **params):
112
+ return F.shift_scale_rotate(img, angle, scale, dx, dy, interpolation, self.border_mode, self.value)
113
+
114
+ def apply_to_mask(self, img, angle=0, scale=0, dx=0, dy=0, **params):
115
+ return F.shift_scale_rotate(img, angle, scale, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
116
+
117
+ def apply_to_keypoint(self, keypoint, angle=0, scale=0, dx=0, dy=0, rows=0, cols=0, **params):
118
+ return F.keypoint_shift_scale_rotate(keypoint, angle, scale, dx, dy, rows, cols)
119
+
120
+ def get_params(self):
121
+ return {
122
+ "angle": random.uniform(self.rotate_limit[0], self.rotate_limit[1]),
123
+ "scale": random.uniform(self.scale_limit[0], self.scale_limit[1]),
124
+ "dx": random.uniform(self.shift_limit_x[0], self.shift_limit_x[1]),
125
+ "dy": random.uniform(self.shift_limit_y[0], self.shift_limit_y[1]),
126
+ }
127
+
128
+ def apply_to_bbox(self, bbox, angle, scale, dx, dy, **params):
129
+ return F.bbox_shift_scale_rotate(bbox, angle, scale, dx, dy, self.rotate_method, **params)
130
+
131
+ def get_transform_init_args(self):
132
+ return {
133
+ "shift_limit_x": self.shift_limit_x,
134
+ "shift_limit_y": self.shift_limit_y,
135
+ "scale_limit": to_tuple(self.scale_limit, bias=-1.0),
136
+ "rotate_limit": self.rotate_limit,
137
+ "interpolation": self.interpolation,
138
+ "border_mode": self.border_mode,
139
+ "value": self.value,
140
+ "mask_value": self.mask_value,
141
+ "rotate_method": self.rotate_method,
142
+ }
143
+
144
+
145
+ class ElasticTransform(DualTransform):
146
+ """Elastic deformation of images as described in [Simard2003]_ (with modifications).
147
+ Based on https://gist.github.com/ernestum/601cdf56d2b424757de5
148
+
149
+ .. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
150
+ Convolutional Neural Networks applied to Visual Document Analysis", in
151
+ Proc. of the International Conference on Document Analysis and
152
+ Recognition, 2003.
153
+
154
+ Args:
155
+ alpha (float):
156
+ sigma (float): Gaussian filter parameter.
157
+ alpha_affine (float): The range will be (-alpha_affine, alpha_affine)
158
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
159
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
160
+ Default: cv2.INTER_LINEAR.
161
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
162
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
163
+ Default: cv2.BORDER_REFLECT_101
164
+ value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
165
+ mask_value (int, float,
166
+ list of ints,
167
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
168
+ approximate (boolean): Whether to smooth displacement map with fixed kernel size.
169
+ Enabling this option gives ~2X speedup on large images.
170
+ same_dxdy (boolean): Whether to use same random generated shift for x and y.
171
+ Enabling this option gives ~2X speedup.
172
+
173
+ Targets:
174
+ image, mask, bbox
175
+
176
+ Image types:
177
+ uint8, float32
178
+ """
179
+
180
+ def __init__(
181
+ self,
182
+ alpha=1,
183
+ sigma=50,
184
+ alpha_affine=50,
185
+ interpolation=cv2.INTER_LINEAR,
186
+ border_mode=cv2.BORDER_REFLECT_101,
187
+ value=None,
188
+ mask_value=None,
189
+ always_apply=False,
190
+ approximate=False,
191
+ same_dxdy=False,
192
+ p=0.5,
193
+ ):
194
+ super(ElasticTransform, self).__init__(always_apply, p)
195
+ self.alpha = alpha
196
+ self.alpha_affine = alpha_affine
197
+ self.sigma = sigma
198
+ self.interpolation = interpolation
199
+ self.border_mode = border_mode
200
+ self.value = value
201
+ self.mask_value = mask_value
202
+ self.approximate = approximate
203
+ self.same_dxdy = same_dxdy
204
+
205
+ def apply(self, img, random_state=None, interpolation=cv2.INTER_LINEAR, **params):
206
+ return F.elastic_transform(
207
+ img,
208
+ self.alpha,
209
+ self.sigma,
210
+ self.alpha_affine,
211
+ interpolation,
212
+ self.border_mode,
213
+ self.value,
214
+ np.random.RandomState(random_state),
215
+ self.approximate,
216
+ self.same_dxdy,
217
+ )
218
+
219
+ def apply_to_mask(self, img, random_state=None, **params):
220
+ return F.elastic_transform(
221
+ img,
222
+ self.alpha,
223
+ self.sigma,
224
+ self.alpha_affine,
225
+ cv2.INTER_NEAREST,
226
+ self.border_mode,
227
+ self.mask_value,
228
+ np.random.RandomState(random_state),
229
+ self.approximate,
230
+ self.same_dxdy,
231
+ )
232
+
233
+ def apply_to_bbox(self, bbox, random_state=None, **params):
234
+ rows, cols = params["rows"], params["cols"]
235
+ mask = np.zeros((rows, cols), dtype=np.uint8)
236
+ bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
237
+ x_min, y_min, x_max, y_max = bbox_denorm[:4]
238
+ x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
239
+ mask[y_min:y_max, x_min:x_max] = 1
240
+ mask = F.elastic_transform(
241
+ mask,
242
+ self.alpha,
243
+ self.sigma,
244
+ self.alpha_affine,
245
+ cv2.INTER_NEAREST,
246
+ self.border_mode,
247
+ self.mask_value,
248
+ np.random.RandomState(random_state),
249
+ self.approximate,
250
+ )
251
+ bbox_returned = bbox_from_mask(mask)
252
+ bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
253
+ return bbox_returned
254
+
255
+ def get_params(self):
256
+ return {"random_state": random.randint(0, 10000)}
257
+
258
+ def get_transform_init_args_names(self):
259
+ return (
260
+ "alpha",
261
+ "sigma",
262
+ "alpha_affine",
263
+ "interpolation",
264
+ "border_mode",
265
+ "value",
266
+ "mask_value",
267
+ "approximate",
268
+ "same_dxdy",
269
+ )
270
+
271
+
272
+ class Perspective(DualTransform):
273
+ """Perform a random four point perspective transform of the input.
274
+
275
+ Args:
276
+ scale (float or (float, float)): standard deviation of the normal distributions. These are used to sample
277
+ the random distances of the subimage's corners from the full image's corners.
278
+ If scale is a single float value, the range will be (0, scale). Default: (0.05, 0.1).
279
+ keep_size (bool): Whether to resize image’s back to their original size after applying the perspective
280
+ transform. If set to False, the resulting images may end up having different shapes
281
+ and will always be a list, never an array. Default: True
282
+ pad_mode (OpenCV flag): OpenCV border mode.
283
+ pad_val (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
284
+ Default: 0
285
+ mask_pad_val (int, float, list of int, list of float): padding value for mask
286
+ if border_mode is cv2.BORDER_CONSTANT. Default: 0
287
+ fit_output (bool): If True, the image plane size and position will be adjusted to still capture
288
+ the whole image after perspective transformation. (Followed by image resizing if keep_size is set to True.)
289
+ Otherwise, parts of the transformed image may be outside of the image plane.
290
+ This setting should not be set to True when using large scale values as it could lead to very large images.
291
+ Default: False
292
+ p (float): probability of applying the transform. Default: 0.5.
293
+
294
+ Targets:
295
+ image, mask, keypoints, bboxes
296
+
297
+ Image types:
298
+ uint8, float32
299
+ """
300
+
301
+ def __init__(
302
+ self,
303
+ scale=(0.05, 0.1),
304
+ keep_size=True,
305
+ pad_mode=cv2.BORDER_CONSTANT,
306
+ pad_val=0,
307
+ mask_pad_val=0,
308
+ fit_output=False,
309
+ interpolation=cv2.INTER_LINEAR,
310
+ always_apply=False,
311
+ p=0.5,
312
+ ):
313
+ super().__init__(always_apply, p)
314
+ self.scale = to_tuple(scale, 0)
315
+ self.keep_size = keep_size
316
+ self.pad_mode = pad_mode
317
+ self.pad_val = pad_val
318
+ self.mask_pad_val = mask_pad_val
319
+ self.fit_output = fit_output
320
+ self.interpolation = interpolation
321
+
322
+ def apply(self, img, matrix=None, max_height=None, max_width=None, **params):
323
+ return F.perspective(
324
+ img, matrix, max_width, max_height, self.pad_val, self.pad_mode, self.keep_size, params["interpolation"]
325
+ )
326
+
327
+ def apply_to_bbox(self, bbox, matrix=None, max_height=None, max_width=None, **params):
328
+ return F.perspective_bbox(bbox, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size)
329
+
330
+ def apply_to_keypoint(self, keypoint, matrix=None, max_height=None, max_width=None, **params):
331
+ return F.perspective_keypoint(
332
+ keypoint, params["rows"], params["cols"], matrix, max_width, max_height, self.keep_size
333
+ )
334
+
335
+ @property
336
+ def targets_as_params(self):
337
+ return ["image"]
338
+
339
+ def get_params_dependent_on_targets(self, params):
340
+ h, w = params["image"].shape[:2]
341
+
342
+ scale = random_utils.uniform(*self.scale)
343
+ points = random_utils.normal(0, scale, [4, 2])
344
+ points = np.mod(np.abs(points), 0.32)
345
+
346
+ # top left -- no changes needed, just use jitter
347
+ # top right
348
+ points[1, 0] = 1.0 - points[1, 0] # w = 1.0 - jitter
349
+ # bottom right
350
+ points[2] = 1.0 - points[2] # w = 1.0 - jitt
351
+ # bottom left
352
+ points[3, 1] = 1.0 - points[3, 1] # h = 1.0 - jitter
353
+
354
+ points[:, 0] *= w
355
+ points[:, 1] *= h
356
+
357
+ # Obtain a consistent order of the points and unpack them individually.
358
+ # Warning: don't just do (tl, tr, br, bl) = _order_points(...)
359
+ # here, because the reordered points is used further below.
360
+ points = self._order_points(points)
361
+ tl, tr, br, bl = points
362
+
363
+ # compute the width of the new image, which will be the
364
+ # maximum distance between bottom-right and bottom-left
365
+ # x-coordiates or the top-right and top-left x-coordinates
366
+ min_width = None
367
+ max_width = None
368
+ while min_width is None or min_width < 2:
369
+ width_top = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
370
+ width_bottom = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
371
+ max_width = int(max(width_top, width_bottom))
372
+ min_width = int(min(width_top, width_bottom))
373
+ if min_width < 2:
374
+ step_size = (2 - min_width) / 2
375
+ tl[0] -= step_size
376
+ tr[0] += step_size
377
+ bl[0] -= step_size
378
+ br[0] += step_size
379
+
380
+ # compute the height of the new image, which will be the maximum distance between the top-right
381
+ # and bottom-right y-coordinates or the top-left and bottom-left y-coordinates
382
+ min_height = None
383
+ max_height = None
384
+ while min_height is None or min_height < 2:
385
+ height_right = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
386
+ height_left = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
387
+ max_height = int(max(height_right, height_left))
388
+ min_height = int(min(height_right, height_left))
389
+ if min_height < 2:
390
+ step_size = (2 - min_height) / 2
391
+ tl[1] -= step_size
392
+ tr[1] -= step_size
393
+ bl[1] += step_size
394
+ br[1] += step_size
395
+
396
+ # now that we have the dimensions of the new image, construct
397
+ # the set of destination points to obtain a "birds eye view",
398
+ # (i.e. top-down view) of the image, again specifying points
399
+ # in the top-left, top-right, bottom-right, and bottom-left order
400
+ # do not use width-1 or height-1 here, as for e.g. width=3, height=2
401
+ # the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
402
+ dst = np.array([[0, 0], [max_width, 0], [max_width, max_height], [0, max_height]], dtype=np.float32)
403
+
404
+ # compute the perspective transform matrix and then apply it
405
+ m = cv2.getPerspectiveTransform(points, dst)
406
+
407
+ if self.fit_output:
408
+ m, max_width, max_height = self._expand_transform(m, (h, w))
409
+
410
+ return {"matrix": m, "max_height": max_height, "max_width": max_width, "interpolation": self.interpolation}
411
+
412
+ @classmethod
413
+ def _expand_transform(cls, matrix, shape):
414
+ height, width = shape
415
+ # do not use width-1 or height-1 here, as for e.g. width=3, height=2, max_height
416
+ # the bottom right coordinate is at (3.0, 2.0) and not (2.0, 1.0)
417
+ rect = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)
418
+ dst = cv2.perspectiveTransform(np.array([rect]), matrix)[0]
419
+
420
+ # get min x, y over transformed 4 points
421
+ # then modify target points by subtracting these minima => shift to (0, 0)
422
+ dst -= dst.min(axis=0, keepdims=True)
423
+ dst = np.around(dst, decimals=0)
424
+
425
+ matrix_expanded = cv2.getPerspectiveTransform(rect, dst)
426
+ max_width, max_height = dst.max(axis=0)
427
+ return matrix_expanded, int(max_width), int(max_height)
428
+
429
+ @staticmethod
430
+ def _order_points(pts: np.ndarray) -> np.ndarray:
431
+ pts = np.array(sorted(pts, key=lambda x: x[0]))
432
+ left = pts[:2] # points with smallest x coordinate - left points
433
+ right = pts[2:] # points with greatest x coordinate - right points
434
+
435
+ if left[0][1] < left[1][1]:
436
+ tl, bl = left
437
+ else:
438
+ bl, tl = left
439
+
440
+ if right[0][1] < right[1][1]:
441
+ tr, br = right
442
+ else:
443
+ br, tr = right
444
+
445
+ return np.array([tl, tr, br, bl], dtype=np.float32)
446
+
447
+ def get_transform_init_args_names(self):
448
+ return "scale", "keep_size", "pad_mode", "pad_val", "mask_pad_val", "fit_output", "interpolation"
449
+
450
+
451
+ class Affine(DualTransform):
452
+ """Augmentation to apply affine transformations to images.
453
+ This is mostly a wrapper around the corresponding classes and functions in OpenCV.
454
+
455
+ Affine transformations involve:
456
+
457
+ - Translation ("move" image on the x-/y-axis)
458
+ - Rotation
459
+ - Scaling ("zoom" in/out)
460
+ - Shear (move one side of the image, turning a square into a trapezoid)
461
+
462
+ All such transformations can create "new" pixels in the image without a defined content, e.g.
463
+ if the image is translated to the left, pixels are created on the right.
464
+ A method has to be defined to deal with these pixel values.
465
+ The parameters `cval` and `mode` of this class deal with this.
466
+
467
+ Some transformations involve interpolations between several pixels
468
+ of the input image to generate output pixel values. The parameters `interpolation` and
469
+ `mask_interpolation` deals with the method of interpolation used for this.
470
+
471
+ Args:
472
+ scale (number, tuple of number or dict): Scaling factor to use, where ``1.0`` denotes "no change" and
473
+ ``0.5`` is zoomed out to ``50`` percent of the original size.
474
+ * If a single number, then that value will be used for all images.
475
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
476
+ That the same range will be used for both x- and y-axis. To keep the aspect ratio, set
477
+ ``keep_ratio=True``, then the same value will be used for both x- and y-axis.
478
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
479
+ Each of these keys can have the same values as described above.
480
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
481
+ *independently* per axis, resulting in samples that differ between the axes. Note that when
482
+ the ``keep_ratio=True``, the x- and y-axis ranges should be the same.
483
+ translate_percent (None, number, tuple of number or dict): Translation as a fraction of the image height/width
484
+ (x-translation, y-translation), where ``0`` denotes "no change"
485
+ and ``0.5`` denotes "half of the axis size".
486
+ * If ``None`` then equivalent to ``0.0`` unless `translate_px` has a value other than ``None``.
487
+ * If a single number, then that value will be used for all images.
488
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``.
489
+ That sampled fraction value will be used identically for both x- and y-axis.
490
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
491
+ Each of these keys can have the same values as described above.
492
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
493
+ *independently* per axis, resulting in samples that differ between the axes.
494
+ translate_px (None, int, tuple of int or dict): Translation in pixels.
495
+ * If ``None`` then equivalent to ``0`` unless `translate_percent` has a value other than ``None``.
496
+ * If a single int, then that value will be used for all images.
497
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from
498
+ the discrete interval ``[a..b]``. That number will be used identically for both x- and y-axis.
499
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
500
+ Each of these keys can have the same values as described above.
501
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
502
+ *independently* per axis, resulting in samples that differ between the axes.
503
+ rotate (number or tuple of number): Rotation in degrees (**NOT** radians), i.e. expected value range is
504
+ around ``[-360, 360]``. Rotation happens around the *center* of the image,
505
+ not the top left corner as in some other frameworks.
506
+ * If a number, then that value will be used for all images.
507
+ * If a tuple ``(a, b)``, then a value will be uniformly sampled per image from the interval ``[a, b]``
508
+ and used as the rotation value.
509
+ shear (number, tuple of number or dict): Shear in degrees (**NOT** radians), i.e. expected value range is
510
+ around ``[-360, 360]``, with reasonable values being in the range of ``[-45, 45]``.
511
+ * If a number, then that value will be used for all images as
512
+ the shear on the x-axis (no shear on the y-axis will be done).
513
+ * If a tuple ``(a, b)``, then two value will be uniformly sampled per image
514
+ from the interval ``[a, b]`` and be used as the x- and y-shear value.
515
+ * If a dictionary, then it is expected to have the keys ``x`` and/or ``y``.
516
+ Each of these keys can have the same values as described above.
517
+ Using a dictionary allows to set different values for the two axis and sampling will then happen
518
+ *independently* per axis, resulting in samples that differ between the axes.
519
+ interpolation (int): OpenCV interpolation flag.
520
+ mask_interpolation (int): OpenCV interpolation flag.
521
+ cval (number or sequence of number): The constant value to use when filling in newly created pixels.
522
+ (E.g. translating by 1px to the right will create a new 1px-wide column of pixels
523
+ on the left of the image).
524
+ The value is only used when `mode=constant`. The expected value range is ``[0, 255]`` for ``uint8`` images.
525
+ cval_mask (number or tuple of number): Same as cval but only for masks.
526
+ mode (int): OpenCV border flag.
527
+ fit_output (bool): If True, the image plane size and position will be adjusted to tightly capture
528
+ the whole image after affine transformation (`translate_percent` and `translate_px` are ignored).
529
+ Otherwise (``False``), parts of the transformed image may end up outside the image plane.
530
+ Fitting the output shape can be useful to avoid corners of the image being outside the image plane
531
+ after applying rotations. Default: False
532
+ keep_ratio (bool): When True, the original aspect ratio will be kept when the random scale is applied.
533
+ Default: False.
534
+ rotate_method (str): rotation method used for the bounding boxes. Should be one of "largest_box" or
535
+ "ellipse"[1].
536
+ Default: "largest_box"
537
+ p (float): probability of applying the transform. Default: 0.5.
538
+
539
+ Targets:
540
+ image, mask, keypoints, bboxes
541
+
542
+ Image types:
543
+ uint8, float32
544
+
545
+ Reference:
546
+ [1] https://arxiv.org/abs/2109.13488
547
+ """
548
+
549
+ def __init__(
550
+ self,
551
+ scale: Optional[Union[float, Sequence[float], dict]] = None,
552
+ translate_percent: Optional[Union[float, Sequence[float], dict]] = None,
553
+ translate_px: Optional[Union[int, Sequence[int], dict]] = None,
554
+ rotate: Optional[Union[float, Sequence[float]]] = None,
555
+ shear: Optional[Union[float, Sequence[float], dict]] = None,
556
+ interpolation: int = cv2.INTER_LINEAR,
557
+ mask_interpolation: int = cv2.INTER_NEAREST,
558
+ cval: Union[int, float, Sequence[int], Sequence[float]] = 0,
559
+ cval_mask: Union[int, float, Sequence[int], Sequence[float]] = 0,
560
+ mode: int = cv2.BORDER_CONSTANT,
561
+ fit_output: bool = False,
562
+ keep_ratio: bool = False,
563
+ rotate_method: str = "largest_box",
564
+ always_apply: bool = False,
565
+ p: float = 0.5,
566
+ ):
567
+ super().__init__(always_apply=always_apply, p=p)
568
+
569
+ params = [scale, translate_percent, translate_px, rotate, shear]
570
+ if all([p is None for p in params]):
571
+ scale = {"x": (0.9, 1.1), "y": (0.9, 1.1)}
572
+ translate_percent = {"x": (-0.1, 0.1), "y": (-0.1, 0.1)}
573
+ rotate = (-15, 15)
574
+ shear = {"x": (-10, 10), "y": (-10, 10)}
575
+ else:
576
+ scale = scale if scale is not None else 1.0
577
+ rotate = rotate if rotate is not None else 0.0
578
+ shear = shear if shear is not None else 0.0
579
+
580
+ self.interpolation = interpolation
581
+ self.mask_interpolation = mask_interpolation
582
+ self.cval = cval
583
+ self.cval_mask = cval_mask
584
+ self.mode = mode
585
+ self.scale = self._handle_dict_arg(scale, "scale")
586
+ self.translate_percent, self.translate_px = self._handle_translate_arg(translate_px, translate_percent)
587
+ self.rotate = to_tuple(rotate, rotate)
588
+ self.fit_output = fit_output
589
+ self.shear = self._handle_dict_arg(shear, "shear")
590
+ self.keep_ratio = keep_ratio
591
+ self.rotate_method = rotate_method
592
+
593
+ if self.keep_ratio and self.scale["x"] != self.scale["y"]:
594
+ raise ValueError(
595
+ "When keep_ratio is True, the x and y scale range should be identical. got {}".format(self.scale)
596
+ )
597
+
598
+ def get_transform_init_args_names(self):
599
+ return (
600
+ "interpolation",
601
+ "mask_interpolation",
602
+ "cval",
603
+ "mode",
604
+ "scale",
605
+ "translate_percent",
606
+ "translate_px",
607
+ "rotate",
608
+ "fit_output",
609
+ "shear",
610
+ "cval_mask",
611
+ "keep_ratio",
612
+ "rotate_method",
613
+ )
614
+
615
+ @staticmethod
616
+ def _handle_dict_arg(val: Union[float, Sequence[float], dict], name: str, default: float = 1.0):
617
+ if isinstance(val, dict):
618
+ if "x" not in val and "y" not in val:
619
+ raise ValueError(
620
+ f'Expected {name} dictionary to contain at least key "x" or ' 'key "y". Found neither of them.'
621
+ )
622
+ x = val.get("x", default)
623
+ y = val.get("y", default)
624
+ return {"x": to_tuple(x, x), "y": to_tuple(y, y)}
625
+ return {"x": to_tuple(val, val), "y": to_tuple(val, val)}
626
+
627
+ @classmethod
628
+ def _handle_translate_arg(
629
+ cls,
630
+ translate_px: Optional[Union[float, Sequence[float], dict]],
631
+ translate_percent: Optional[Union[float, Sequence[float], dict]],
632
+ ):
633
+ if translate_percent is None and translate_px is None:
634
+ translate_px = 0
635
+
636
+ if translate_percent is not None and translate_px is not None:
637
+ raise ValueError(
638
+ "Expected either translate_percent or translate_px to be " "provided, " "but neither of them was."
639
+ )
640
+
641
+ if translate_percent is not None:
642
+ # translate by percent
643
+ return cls._handle_dict_arg(translate_percent, "translate_percent", default=0.0), translate_px
644
+
645
+ if translate_px is None:
646
+ raise ValueError("translate_px is None.")
647
+ # translate by pixels
648
+ return translate_percent, cls._handle_dict_arg(translate_px, "translate_px")
649
+
650
+ def apply(
651
+ self,
652
+ img: np.ndarray,
653
+ matrix: skimage.transform.ProjectiveTransform = None,
654
+ output_shape: Sequence[int] = (),
655
+ **params
656
+ ) -> np.ndarray:
657
+ return F.warp_affine(
658
+ img,
659
+ matrix,
660
+ interpolation=self.interpolation,
661
+ cval=self.cval,
662
+ mode=self.mode,
663
+ output_shape=output_shape,
664
+ )
665
+
666
+ def apply_to_mask(
667
+ self,
668
+ img: np.ndarray,
669
+ matrix: skimage.transform.ProjectiveTransform = None,
670
+ output_shape: Sequence[int] = (),
671
+ **params
672
+ ) -> np.ndarray:
673
+ return F.warp_affine(
674
+ img,
675
+ matrix,
676
+ interpolation=self.mask_interpolation,
677
+ cval=self.cval_mask,
678
+ mode=self.mode,
679
+ output_shape=output_shape,
680
+ )
681
+
682
+ def apply_to_bbox(
683
+ self,
684
+ bbox: BoxInternalType,
685
+ matrix: skimage.transform.ProjectiveTransform = None,
686
+ rows: int = 0,
687
+ cols: int = 0,
688
+ output_shape: Sequence[int] = (),
689
+ **params
690
+ ) -> BoxInternalType:
691
+ return F.bbox_affine(bbox, matrix, self.rotate_method, rows, cols, output_shape)
692
+
693
+ def apply_to_keypoint(
694
+ self,
695
+ keypoint: KeypointInternalType,
696
+ matrix: Optional[skimage.transform.ProjectiveTransform] = None,
697
+ scale: Optional[dict] = None,
698
+ **params
699
+ ) -> KeypointInternalType:
700
+ assert scale is not None and matrix is not None
701
+ return F.keypoint_affine(keypoint, matrix=matrix, scale=scale)
702
+
703
+ @property
704
+ def targets_as_params(self):
705
+ return ["image"]
706
+
707
+ def get_params_dependent_on_targets(self, params: dict) -> dict:
708
+ h, w = params["image"].shape[:2]
709
+
710
+ translate: Dict[str, Union[int, float]]
711
+ if self.translate_px is not None:
712
+ translate = {key: random.randint(*value) for key, value in self.translate_px.items()}
713
+ elif self.translate_percent is not None:
714
+ translate = {key: random.uniform(*value) for key, value in self.translate_percent.items()}
715
+ translate["x"] = translate["x"] * w
716
+ translate["y"] = translate["y"] * h
717
+ else:
718
+ translate = {"x": 0, "y": 0}
719
+
720
+ # Look to issue https://github.com/albumentations-team/albumentations/issues/1079
721
+ shear = {key: -random.uniform(*value) for key, value in self.shear.items()}
722
+ scale = {key: random.uniform(*value) for key, value in self.scale.items()}
723
+ if self.keep_ratio:
724
+ scale["y"] = scale["x"]
725
+
726
+ # Look to issue https://github.com/albumentations-team/albumentations/issues/1079
727
+ rotate = -random.uniform(*self.rotate)
728
+
729
+ # for images we use additional shifts of (0.5, 0.5) as otherwise
730
+ # we get an ugly black border for 90deg rotations
731
+ shift_x = w / 2 - 0.5
732
+ shift_y = h / 2 - 0.5
733
+
734
+ matrix_to_topleft = skimage.transform.SimilarityTransform(translation=[-shift_x, -shift_y])
735
+ matrix_shear_y_rot = skimage.transform.AffineTransform(rotation=-np.pi / 2)
736
+ matrix_shear_y = skimage.transform.AffineTransform(shear=np.deg2rad(shear["y"]))
737
+ matrix_shear_y_rot_inv = skimage.transform.AffineTransform(rotation=np.pi / 2)
738
+ matrix_transforms = skimage.transform.AffineTransform(
739
+ scale=(scale["x"], scale["y"]),
740
+ translation=(translate["x"], translate["y"]),
741
+ rotation=np.deg2rad(rotate),
742
+ shear=np.deg2rad(shear["x"]),
743
+ )
744
+ matrix_to_center = skimage.transform.SimilarityTransform(translation=[shift_x, shift_y])
745
+ matrix = (
746
+ matrix_to_topleft
747
+ + matrix_shear_y_rot
748
+ + matrix_shear_y
749
+ + matrix_shear_y_rot_inv
750
+ + matrix_transforms
751
+ + matrix_to_center
752
+ )
753
+ if self.fit_output:
754
+ matrix, output_shape = self._compute_affine_warp_output_shape(matrix, params["image"].shape)
755
+ else:
756
+ output_shape = params["image"].shape
757
+
758
+ return {
759
+ "rotate": rotate,
760
+ "scale": scale,
761
+ "matrix": matrix,
762
+ "output_shape": output_shape,
763
+ }
764
+
765
+ @staticmethod
766
+ def _compute_affine_warp_output_shape(
767
+ matrix: skimage.transform.ProjectiveTransform, input_shape: Sequence[int]
768
+ ) -> Tuple[skimage.transform.ProjectiveTransform, Sequence[int]]:
769
+ height, width = input_shape[:2]
770
+
771
+ if height == 0 or width == 0:
772
+ return matrix, input_shape
773
+
774
+ # determine shape of output image
775
+ corners = np.array([[0, 0], [0, height - 1], [width - 1, height - 1], [width - 1, 0]])
776
+ corners = matrix(corners)
777
+ minc = corners[:, 0].min()
778
+ minr = corners[:, 1].min()
779
+ maxc = corners[:, 0].max()
780
+ maxr = corners[:, 1].max()
781
+ out_height = maxr - minr + 1
782
+ out_width = maxc - minc + 1
783
+ if len(input_shape) == 3:
784
+ output_shape = np.ceil((out_height, out_width, input_shape[2]))
785
+ else:
786
+ output_shape = np.ceil((out_height, out_width))
787
+ output_shape_tuple = tuple([int(v) for v in output_shape.tolist()])
788
+ # fit output image in new shape
789
+ translation = (-minc, -minr)
790
+ matrix_to_fit = skimage.transform.SimilarityTransform(translation=translation)
791
+ matrix = matrix + matrix_to_fit
792
+ return matrix, output_shape_tuple
793
+
794
+
795
+ class PiecewiseAffine(DualTransform):
796
+ """Apply affine transformations that differ between local neighbourhoods.
797
+ This augmentation places a regular grid of points on an image and randomly moves the neighbourhood of these point
798
+ around via affine transformations. This leads to local distortions.
799
+
800
+ This is mostly a wrapper around scikit-image's ``PiecewiseAffine``.
801
+ See also ``Affine`` for a similar technique.
802
+
803
+ Note:
804
+ This augmenter is very slow. Try to use ``ElasticTransformation`` instead, which is at least 10x faster.
805
+
806
+ Note:
807
+ For coordinate-based inputs (keypoints, bounding boxes, polygons, ...),
808
+ this augmenter still has to perform an image-based augmentation,
809
+ which will make it significantly slower and not fully correct for such inputs than other transforms.
810
+
811
+ Args:
812
+ scale (float, tuple of float): Each point on the regular grid is moved around via a normal distribution.
813
+ This scale factor is equivalent to the normal distribution's sigma.
814
+ Note that the jitter (how far each point is moved in which direction) is multiplied by the height/width of
815
+ the image if ``absolute_scale=False`` (default), so this scale can be the same for different sized images.
816
+ Recommended values are in the range ``0.01`` to ``0.05`` (weak to strong augmentations).
817
+ * If a single ``float``, then that value will always be used as the scale.
818
+ * If a tuple ``(a, b)`` of ``float`` s, then a random value will
819
+ be uniformly sampled per image from the interval ``[a, b]``.
820
+ nb_rows (int, tuple of int): Number of rows of points that the regular grid should have.
821
+ Must be at least ``2``. For large images, you might want to pick a higher value than ``4``.
822
+ You might have to then adjust scale to lower values.
823
+ * If a single ``int``, then that value will always be used as the number of rows.
824
+ * If a tuple ``(a, b)``, then a value from the discrete interval
825
+ ``[a..b]`` will be uniformly sampled per image.
826
+ nb_cols (int, tuple of int): Number of columns. Analogous to `nb_rows`.
827
+ interpolation (int): The order of interpolation. The order has to be in the range 0-5:
828
+ - 0: Nearest-neighbor
829
+ - 1: Bi-linear (default)
830
+ - 2: Bi-quadratic
831
+ - 3: Bi-cubic
832
+ - 4: Bi-quartic
833
+ - 5: Bi-quintic
834
+ mask_interpolation (int): same as interpolation but for mask.
835
+ cval (number): The constant value to use when filling in newly created pixels.
836
+ cval_mask (number): Same as cval but only for masks.
837
+ mode (str): {'constant', 'edge', 'symmetric', 'reflect', 'wrap'}, optional
838
+ Points outside the boundaries of the input are filled according
839
+ to the given mode. Modes match the behaviour of `numpy.pad`.
840
+ absolute_scale (bool): Take `scale` as an absolute value rather than a relative value.
841
+ keypoints_threshold (float): Used as threshold in conversion from distance maps to keypoints.
842
+ The search for keypoints works by searching for the
843
+ argmin (non-inverted) or argmax (inverted) in each channel. This
844
+ parameters contains the maximum (non-inverted) or minimum (inverted) value to accept in order to view a hit
845
+ as a keypoint. Use ``None`` to use no min/max. Default: 0.01
846
+
847
+ Targets:
848
+ image, mask, keypoints, bboxes
849
+
850
+ Image types:
851
+ uint8, float32
852
+
853
+ """
854
+
855
+ def __init__(
856
+ self,
857
+ scale: ScaleFloatType = (0.03, 0.05),
858
+ nb_rows: Union[int, Sequence[int]] = 4,
859
+ nb_cols: Union[int, Sequence[int]] = 4,
860
+ interpolation: int = 1,
861
+ mask_interpolation: int = 0,
862
+ cval: int = 0,
863
+ cval_mask: int = 0,
864
+ mode: str = "constant",
865
+ absolute_scale: bool = False,
866
+ always_apply: bool = False,
867
+ keypoints_threshold: float = 0.01,
868
+ p: float = 0.5,
869
+ ):
870
+ super(PiecewiseAffine, self).__init__(always_apply, p)
871
+
872
+ self.scale = to_tuple(scale, scale)
873
+ self.nb_rows = to_tuple(nb_rows, nb_rows)
874
+ self.nb_cols = to_tuple(nb_cols, nb_cols)
875
+ self.interpolation = interpolation
876
+ self.mask_interpolation = mask_interpolation
877
+ self.cval = cval
878
+ self.cval_mask = cval_mask
879
+ self.mode = mode
880
+ self.absolute_scale = absolute_scale
881
+ self.keypoints_threshold = keypoints_threshold
882
+
883
+ def get_transform_init_args_names(self):
884
+ return (
885
+ "scale",
886
+ "nb_rows",
887
+ "nb_cols",
888
+ "interpolation",
889
+ "mask_interpolation",
890
+ "cval",
891
+ "cval_mask",
892
+ "mode",
893
+ "absolute_scale",
894
+ "keypoints_threshold",
895
+ )
896
+
897
+ @property
898
+ def targets_as_params(self):
899
+ return ["image"]
900
+
901
+ def get_params_dependent_on_targets(self, params) -> dict:
902
+ h, w = params["image"].shape[:2]
903
+
904
+ nb_rows = np.clip(random.randint(*self.nb_rows), 2, None)
905
+ nb_cols = np.clip(random.randint(*self.nb_cols), 2, None)
906
+ nb_cells = nb_cols * nb_rows
907
+ scale = random.uniform(*self.scale)
908
+
909
+ jitter: np.ndarray = random_utils.normal(0, scale, (nb_cells, 2))
910
+ if not np.any(jitter > 0):
911
+ for i in range(10): # See: https://github.com/albumentations-team/albumentations/issues/1442
912
+ jitter = random_utils.normal(0, scale, (nb_cells, 2))
913
+ if np.any(jitter > 0):
914
+ break
915
+ if not np.any(jitter > 0):
916
+ return {"matrix": None}
917
+
918
+ y = np.linspace(0, h, nb_rows)
919
+ x = np.linspace(0, w, nb_cols)
920
+
921
+ # (H, W) and (H, W) for H=rows, W=cols
922
+ xx_src, yy_src = np.meshgrid(x, y)
923
+
924
+ # (1, HW, 2) => (HW, 2) for H=rows, W=cols
925
+ points_src = np.dstack([yy_src.flat, xx_src.flat])[0]
926
+
927
+ if self.absolute_scale:
928
+ jitter[:, 0] = jitter[:, 0] / h if h > 0 else 0.0
929
+ jitter[:, 1] = jitter[:, 1] / w if w > 0 else 0.0
930
+
931
+ jitter[:, 0] = jitter[:, 0] * h
932
+ jitter[:, 1] = jitter[:, 1] * w
933
+
934
+ points_dest = np.copy(points_src)
935
+ points_dest[:, 0] = points_dest[:, 0] + jitter[:, 0]
936
+ points_dest[:, 1] = points_dest[:, 1] + jitter[:, 1]
937
+
938
+ # Restrict all destination points to be inside the image plane.
939
+ # This is necessary, as otherwise keypoints could be augmented
940
+ # outside of the image plane and these would be replaced by
941
+ # (-1, -1), which would not conform with the behaviour of the other augmenters.
942
+ points_dest[:, 0] = np.clip(points_dest[:, 0], 0, h - 1)
943
+ points_dest[:, 1] = np.clip(points_dest[:, 1], 0, w - 1)
944
+
945
+ matrix = skimage.transform.PiecewiseAffineTransform()
946
+ matrix.estimate(points_src[:, ::-1], points_dest[:, ::-1])
947
+
948
+ return {
949
+ "matrix": matrix,
950
+ }
951
+
952
+ def apply(
953
+ self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params
954
+ ) -> np.ndarray:
955
+ return F.piecewise_affine(img, matrix, self.interpolation, self.mode, self.cval)
956
+
957
+ def apply_to_mask(
958
+ self, img: np.ndarray, matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None, **params
959
+ ) -> np.ndarray:
960
+ return F.piecewise_affine(img, matrix, self.mask_interpolation, self.mode, self.cval_mask)
961
+
962
+ def apply_to_bbox(
963
+ self,
964
+ bbox: BoxInternalType,
965
+ rows: int = 0,
966
+ cols: int = 0,
967
+ matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None,
968
+ **params
969
+ ) -> BoxInternalType:
970
+ return F.bbox_piecewise_affine(bbox, matrix, rows, cols, self.keypoints_threshold)
971
+
972
+ def apply_to_keypoint(
973
+ self,
974
+ keypoint: KeypointInternalType,
975
+ rows: int = 0,
976
+ cols: int = 0,
977
+ matrix: Optional[skimage.transform.PiecewiseAffineTransform] = None,
978
+ **params
979
+ ):
980
+ return F.keypoint_piecewise_affine(keypoint, matrix, rows, cols, self.keypoints_threshold)
981
+
982
+
983
+ class PadIfNeeded(DualTransform):
984
+ """Pad side of the image / max if side is less than desired number.
985
+
986
+ Args:
987
+ min_height (int): minimal result image height.
988
+ min_width (int): minimal result image width.
989
+ pad_height_divisor (int): if not None, ensures image height is dividable by value of this argument.
990
+ pad_width_divisor (int): if not None, ensures image width is dividable by value of this argument.
991
+ position (Union[str, PositionType]): Position of the image. should be PositionType.CENTER or
992
+ PositionType.TOP_LEFT or PositionType.TOP_RIGHT or PositionType.BOTTOM_LEFT or PositionType.BOTTOM_RIGHT.
993
+ or PositionType.RANDOM. Default: PositionType.CENTER.
994
+ border_mode (OpenCV flag): OpenCV border mode.
995
+ value (int, float, list of int, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
996
+ mask_value (int, float,
997
+ list of int,
998
+ list of float): padding value for mask if border_mode is cv2.BORDER_CONSTANT.
999
+ p (float): probability of applying the transform. Default: 1.0.
1000
+
1001
+ Targets:
1002
+ image, mask, bbox, keypoints
1003
+
1004
+ Image types:
1005
+ uint8, float32
1006
+ """
1007
+
1008
+ class PositionType(Enum):
1009
+ CENTER = "center"
1010
+ TOP_LEFT = "top_left"
1011
+ TOP_RIGHT = "top_right"
1012
+ BOTTOM_LEFT = "bottom_left"
1013
+ BOTTOM_RIGHT = "bottom_right"
1014
+ RANDOM = "random"
1015
+
1016
+ def __init__(
1017
+ self,
1018
+ min_height: Optional[int] = 1024,
1019
+ min_width: Optional[int] = 1024,
1020
+ pad_height_divisor: Optional[int] = None,
1021
+ pad_width_divisor: Optional[int] = None,
1022
+ position: Union[PositionType, str] = PositionType.CENTER,
1023
+ border_mode: int = cv2.BORDER_REFLECT_101,
1024
+ value: Optional[ImageColorType] = None,
1025
+ mask_value: Optional[ImageColorType] = None,
1026
+ always_apply: bool = False,
1027
+ p: float = 1.0,
1028
+ ):
1029
+ if (min_height is None) == (pad_height_divisor is None):
1030
+ raise ValueError("Only one of 'min_height' and 'pad_height_divisor' parameters must be set")
1031
+
1032
+ if (min_width is None) == (pad_width_divisor is None):
1033
+ raise ValueError("Only one of 'min_width' and 'pad_width_divisor' parameters must be set")
1034
+
1035
+ super(PadIfNeeded, self).__init__(always_apply, p)
1036
+ self.min_height = min_height
1037
+ self.min_width = min_width
1038
+ self.pad_width_divisor = pad_width_divisor
1039
+ self.pad_height_divisor = pad_height_divisor
1040
+ self.position = PadIfNeeded.PositionType(position)
1041
+ self.border_mode = border_mode
1042
+ self.value = value
1043
+ self.mask_value = mask_value
1044
+
1045
+ def update_params(self, params, **kwargs):
1046
+ params = super(PadIfNeeded, self).update_params(params, **kwargs)
1047
+ rows = params["rows"]
1048
+ cols = params["cols"]
1049
+
1050
+ if self.min_height is not None:
1051
+ if rows < self.min_height:
1052
+ h_pad_top = int((self.min_height - rows) / 2.0)
1053
+ h_pad_bottom = self.min_height - rows - h_pad_top
1054
+ else:
1055
+ h_pad_top = 0
1056
+ h_pad_bottom = 0
1057
+ else:
1058
+ pad_remained = rows % self.pad_height_divisor
1059
+ pad_rows = self.pad_height_divisor - pad_remained if pad_remained > 0 else 0
1060
+
1061
+ h_pad_top = pad_rows // 2
1062
+ h_pad_bottom = pad_rows - h_pad_top
1063
+
1064
+ if self.min_width is not None:
1065
+ if cols < self.min_width:
1066
+ w_pad_left = int((self.min_width - cols) / 2.0)
1067
+ w_pad_right = self.min_width - cols - w_pad_left
1068
+ else:
1069
+ w_pad_left = 0
1070
+ w_pad_right = 0
1071
+ else:
1072
+ pad_remainder = cols % self.pad_width_divisor
1073
+ pad_cols = self.pad_width_divisor - pad_remainder if pad_remainder > 0 else 0
1074
+
1075
+ w_pad_left = pad_cols // 2
1076
+ w_pad_right = pad_cols - w_pad_left
1077
+
1078
+ h_pad_top, h_pad_bottom, w_pad_left, w_pad_right = self.__update_position_params(
1079
+ h_top=h_pad_top, h_bottom=h_pad_bottom, w_left=w_pad_left, w_right=w_pad_right
1080
+ )
1081
+
1082
+ params.update(
1083
+ {
1084
+ "pad_top": h_pad_top,
1085
+ "pad_bottom": h_pad_bottom,
1086
+ "pad_left": w_pad_left,
1087
+ "pad_right": w_pad_right,
1088
+ }
1089
+ )
1090
+ return params
1091
+
1092
+ def apply(
1093
+ self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params
1094
+ ) -> np.ndarray:
1095
+ return F.pad_with_params(
1096
+ img,
1097
+ pad_top,
1098
+ pad_bottom,
1099
+ pad_left,
1100
+ pad_right,
1101
+ border_mode=self.border_mode,
1102
+ value=self.value,
1103
+ )
1104
+
1105
+ def apply_to_mask(
1106
+ self, img: np.ndarray, pad_top: int = 0, pad_bottom: int = 0, pad_left: int = 0, pad_right: int = 0, **params
1107
+ ) -> np.ndarray:
1108
+ return F.pad_with_params(
1109
+ img,
1110
+ pad_top,
1111
+ pad_bottom,
1112
+ pad_left,
1113
+ pad_right,
1114
+ border_mode=self.border_mode,
1115
+ value=self.mask_value,
1116
+ )
1117
+
1118
+ def apply_to_bbox(
1119
+ self,
1120
+ bbox: BoxInternalType,
1121
+ pad_top: int = 0,
1122
+ pad_bottom: int = 0,
1123
+ pad_left: int = 0,
1124
+ pad_right: int = 0,
1125
+ rows: int = 0,
1126
+ cols: int = 0,
1127
+ **params
1128
+ ) -> BoxInternalType:
1129
+ x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
1130
+ bbox = x_min + pad_left, y_min + pad_top, x_max + pad_left, y_max + pad_top
1131
+ return normalize_bbox(bbox, rows + pad_top + pad_bottom, cols + pad_left + pad_right)
1132
+
1133
+ def apply_to_keypoint(
1134
+ self,
1135
+ keypoint: KeypointInternalType,
1136
+ pad_top: int = 0,
1137
+ pad_bottom: int = 0,
1138
+ pad_left: int = 0,
1139
+ pad_right: int = 0,
1140
+ **params
1141
+ ) -> KeypointInternalType:
1142
+ x, y, angle, scale = keypoint[:4]
1143
+ return x + pad_left, y + pad_top, angle, scale
1144
+
1145
+ def get_transform_init_args_names(self):
1146
+ return (
1147
+ "min_height",
1148
+ "min_width",
1149
+ "pad_height_divisor",
1150
+ "pad_width_divisor",
1151
+ "border_mode",
1152
+ "value",
1153
+ "mask_value",
1154
+ )
1155
+
1156
+ def __update_position_params(
1157
+ self, h_top: int, h_bottom: int, w_left: int, w_right: int
1158
+ ) -> Tuple[int, int, int, int]:
1159
+ if self.position == PadIfNeeded.PositionType.TOP_LEFT:
1160
+ h_bottom += h_top
1161
+ w_right += w_left
1162
+ h_top = 0
1163
+ w_left = 0
1164
+
1165
+ elif self.position == PadIfNeeded.PositionType.TOP_RIGHT:
1166
+ h_bottom += h_top
1167
+ w_left += w_right
1168
+ h_top = 0
1169
+ w_right = 0
1170
+
1171
+ elif self.position == PadIfNeeded.PositionType.BOTTOM_LEFT:
1172
+ h_top += h_bottom
1173
+ w_right += w_left
1174
+ h_bottom = 0
1175
+ w_left = 0
1176
+
1177
+ elif self.position == PadIfNeeded.PositionType.BOTTOM_RIGHT:
1178
+ h_top += h_bottom
1179
+ w_left += w_right
1180
+ h_bottom = 0
1181
+ w_right = 0
1182
+
1183
+ elif self.position == PadIfNeeded.PositionType.RANDOM:
1184
+ h_pad = h_top + h_bottom
1185
+ w_pad = w_left + w_right
1186
+ h_top = random.randint(0, h_pad)
1187
+ h_bottom = h_pad - h_top
1188
+ w_left = random.randint(0, w_pad)
1189
+ w_right = w_pad - w_left
1190
+
1191
+ return h_top, h_bottom, w_left, w_right
1192
+
1193
+
1194
+ class VerticalFlip(DualTransform):
1195
+ """Flip the input vertically around the x-axis.
1196
+
1197
+ Args:
1198
+ p (float): probability of applying the transform. Default: 0.5.
1199
+
1200
+ Targets:
1201
+ image, mask, bboxes, keypoints
1202
+
1203
+ Image types:
1204
+ uint8, float32
1205
+ """
1206
+
1207
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
1208
+ return F.vflip(img)
1209
+
1210
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
1211
+ return F.bbox_vflip(bbox, **params)
1212
+
1213
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
1214
+ return F.keypoint_vflip(keypoint, **params)
1215
+
1216
+ def get_transform_init_args_names(self):
1217
+ return ()
1218
+
1219
+
1220
+ class HorizontalFlip(DualTransform):
1221
+ """Flip the input horizontally around the y-axis.
1222
+
1223
+ Args:
1224
+ p (float): probability of applying the transform. Default: 0.5.
1225
+
1226
+ Targets:
1227
+ image, mask, bboxes, keypoints
1228
+
1229
+ Image types:
1230
+ uint8, float32
1231
+ """
1232
+
1233
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
1234
+ if img.ndim == 3 and img.shape[2] > 1 and img.dtype == np.uint8:
1235
+ # Opencv is faster than numpy only in case of
1236
+ # non-gray scale 8bits images
1237
+ return F.hflip_cv2(img)
1238
+
1239
+ return F.hflip(img)
1240
+
1241
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
1242
+ return F.bbox_hflip(bbox, **params)
1243
+
1244
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
1245
+ return F.keypoint_hflip(keypoint, **params)
1246
+
1247
+ def get_transform_init_args_names(self):
1248
+ return ()
1249
+
1250
+
1251
+ class Flip(DualTransform):
1252
+ """Flip the input either horizontally, vertically or both horizontally and vertically.
1253
+
1254
+ Args:
1255
+ p (float): probability of applying the transform. Default: 0.5.
1256
+
1257
+ Targets:
1258
+ image, mask, bboxes, keypoints
1259
+
1260
+ Image types:
1261
+ uint8, float32
1262
+ """
1263
+
1264
+ def apply(self, img: np.ndarray, d: int = 0, **params) -> np.ndarray:
1265
+ """Args:
1266
+ d (int): code that specifies how to flip the input. 0 for vertical flipping, 1 for horizontal flipping,
1267
+ -1 for both vertical and horizontal flipping (which is also could be seen as rotating the input by
1268
+ 180 degrees).
1269
+ """
1270
+ return F.random_flip(img, d)
1271
+
1272
+ def get_params(self):
1273
+ # Random int in the range [-1, 1]
1274
+ return {"d": random.randint(-1, 1)}
1275
+
1276
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
1277
+ return F.bbox_flip(bbox, **params)
1278
+
1279
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
1280
+ return F.keypoint_flip(keypoint, **params)
1281
+
1282
+ def get_transform_init_args_names(self):
1283
+ return ()
1284
+
1285
+
1286
+ class Transpose(DualTransform):
1287
+ """Transpose the input by swapping rows and columns.
1288
+
1289
+ Args:
1290
+ p (float): probability of applying the transform. Default: 0.5.
1291
+
1292
+ Targets:
1293
+ image, mask, bboxes, keypoints
1294
+
1295
+ Image types:
1296
+ uint8, float32
1297
+ """
1298
+
1299
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
1300
+ return F.transpose(img)
1301
+
1302
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
1303
+ return F.bbox_transpose(bbox, 0, **params)
1304
+
1305
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
1306
+ return F.keypoint_transpose(keypoint)
1307
+
1308
+ def get_transform_init_args_names(self):
1309
+ return ()
1310
+
1311
+
1312
+ class OpticalDistortion(DualTransform):
1313
+ """
1314
+ Args:
1315
+ distort_limit (float, (float, float)): If distort_limit is a single float, the range
1316
+ will be (-distort_limit, distort_limit). Default: (-0.05, 0.05).
1317
+ shift_limit (float, (float, float))): If shift_limit is a single float, the range
1318
+ will be (-shift_limit, shift_limit). Default: (-0.05, 0.05).
1319
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
1320
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
1321
+ Default: cv2.INTER_LINEAR.
1322
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
1323
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
1324
+ Default: cv2.BORDER_REFLECT_101
1325
+ value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
1326
+ mask_value (int, float,
1327
+ list of ints,
1328
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
1329
+
1330
+ Targets:
1331
+ image, mask, bbox
1332
+
1333
+ Image types:
1334
+ uint8, float32
1335
+ """
1336
+
1337
+ def __init__(
1338
+ self,
1339
+ distort_limit: ScaleFloatType = 0.05,
1340
+ shift_limit: ScaleFloatType = 0.05,
1341
+ interpolation: int = cv2.INTER_LINEAR,
1342
+ border_mode: int = cv2.BORDER_REFLECT_101,
1343
+ value: Optional[ImageColorType] = None,
1344
+ mask_value: Optional[ImageColorType] = None,
1345
+ always_apply: bool = False,
1346
+ p: float = 0.5,
1347
+ ):
1348
+ super(OpticalDistortion, self).__init__(always_apply, p)
1349
+ self.shift_limit = to_tuple(shift_limit)
1350
+ self.distort_limit = to_tuple(distort_limit)
1351
+ self.interpolation = interpolation
1352
+ self.border_mode = border_mode
1353
+ self.value = value
1354
+ self.mask_value = mask_value
1355
+
1356
+ def apply(
1357
+ self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, interpolation: int = cv2.INTER_LINEAR, **params
1358
+ ) -> np.ndarray:
1359
+ return F.optical_distortion(img, k, dx, dy, interpolation, self.border_mode, self.value)
1360
+
1361
+ def apply_to_mask(self, img: np.ndarray, k: int = 0, dx: int = 0, dy: int = 0, **params) -> np.ndarray:
1362
+ return F.optical_distortion(img, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
1363
+
1364
+ def apply_to_bbox(self, bbox: BoxInternalType, k: int = 0, dx: int = 0, dy: int = 0, **params) -> BoxInternalType:
1365
+ rows, cols = params["rows"], params["cols"]
1366
+ mask = np.zeros((rows, cols), dtype=np.uint8)
1367
+ bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
1368
+ x_min, y_min, x_max, y_max = bbox_denorm[:4]
1369
+ x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
1370
+ mask[y_min:y_max, x_min:x_max] = 1
1371
+ mask = F.optical_distortion(mask, k, dx, dy, cv2.INTER_NEAREST, self.border_mode, self.mask_value)
1372
+ bbox_returned = bbox_from_mask(mask)
1373
+ bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
1374
+ return bbox_returned
1375
+
1376
+ def get_params(self):
1377
+ return {
1378
+ "k": random.uniform(self.distort_limit[0], self.distort_limit[1]),
1379
+ "dx": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
1380
+ "dy": round(random.uniform(self.shift_limit[0], self.shift_limit[1])),
1381
+ }
1382
+
1383
+ def get_transform_init_args_names(self):
1384
+ return (
1385
+ "distort_limit",
1386
+ "shift_limit",
1387
+ "interpolation",
1388
+ "border_mode",
1389
+ "value",
1390
+ "mask_value",
1391
+ )
1392
+
1393
+
1394
+ class GridDistortion(DualTransform):
1395
+ """
1396
+ Args:
1397
+ num_steps (int): count of grid cells on each side.
1398
+ distort_limit (float, (float, float)): If distort_limit is a single float, the range
1399
+ will be (-distort_limit, distort_limit). Default: (-0.03, 0.03).
1400
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
1401
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
1402
+ Default: cv2.INTER_LINEAR.
1403
+ border_mode (OpenCV flag): flag that is used to specify the pixel extrapolation method. Should be one of:
1404
+ cv2.BORDER_CONSTANT, cv2.BORDER_REPLICATE, cv2.BORDER_REFLECT, cv2.BORDER_WRAP, cv2.BORDER_REFLECT_101.
1405
+ Default: cv2.BORDER_REFLECT_101
1406
+ value (int, float, list of ints, list of float): padding value if border_mode is cv2.BORDER_CONSTANT.
1407
+ mask_value (int, float,
1408
+ list of ints,
1409
+ list of float): padding value if border_mode is cv2.BORDER_CONSTANT applied for masks.
1410
+ normalized (bool): if true, distortion will be normalized to do not go outside the image. Default: False
1411
+ See for more information: https://github.com/albumentations-team/albumentations/pull/722
1412
+
1413
+ Targets:
1414
+ image, mask
1415
+
1416
+ Image types:
1417
+ uint8, float32
1418
+ """
1419
+
1420
+ def __init__(
1421
+ self,
1422
+ num_steps: int = 5,
1423
+ distort_limit: ScaleFloatType = 0.3,
1424
+ interpolation: int = cv2.INTER_LINEAR,
1425
+ border_mode: int = cv2.BORDER_REFLECT_101,
1426
+ value: Optional[ImageColorType] = None,
1427
+ mask_value: Optional[ImageColorType] = None,
1428
+ normalized: bool = False,
1429
+ always_apply: bool = False,
1430
+ p: float = 0.5,
1431
+ ):
1432
+ super(GridDistortion, self).__init__(always_apply, p)
1433
+ self.num_steps = num_steps
1434
+ self.distort_limit = to_tuple(distort_limit)
1435
+ self.interpolation = interpolation
1436
+ self.border_mode = border_mode
1437
+ self.value = value
1438
+ self.mask_value = mask_value
1439
+ self.normalized = normalized
1440
+
1441
+ def apply(
1442
+ self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), interpolation: int = cv2.INTER_LINEAR, **params
1443
+ ) -> np.ndarray:
1444
+ return F.grid_distortion(img, self.num_steps, stepsx, stepsy, interpolation, self.border_mode, self.value)
1445
+
1446
+ def apply_to_mask(self, img: np.ndarray, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> np.ndarray:
1447
+ return F.grid_distortion(
1448
+ img, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value
1449
+ )
1450
+
1451
+ def apply_to_bbox(self, bbox: BoxInternalType, stepsx: Tuple = (), stepsy: Tuple = (), **params) -> BoxInternalType:
1452
+ rows, cols = params["rows"], params["cols"]
1453
+ mask = np.zeros((rows, cols), dtype=np.uint8)
1454
+ bbox_denorm = F.denormalize_bbox(bbox, rows, cols)
1455
+ x_min, y_min, x_max, y_max = bbox_denorm[:4]
1456
+ x_min, y_min, x_max, y_max = int(x_min), int(y_min), int(x_max), int(y_max)
1457
+ mask[y_min:y_max, x_min:x_max] = 1
1458
+ mask = F.grid_distortion(
1459
+ mask, self.num_steps, stepsx, stepsy, cv2.INTER_NEAREST, self.border_mode, self.mask_value
1460
+ )
1461
+ bbox_returned = bbox_from_mask(mask)
1462
+ bbox_returned = F.normalize_bbox(bbox_returned, rows, cols)
1463
+ return bbox_returned
1464
+
1465
+ def _normalize(self, h, w, xsteps, ysteps):
1466
+ # compensate for smaller last steps in source image.
1467
+ x_step = w // self.num_steps
1468
+ last_x_step = min(w, ((self.num_steps + 1) * x_step)) - (self.num_steps * x_step)
1469
+ xsteps[-1] *= last_x_step / x_step
1470
+
1471
+ y_step = h // self.num_steps
1472
+ last_y_step = min(h, ((self.num_steps + 1) * y_step)) - (self.num_steps * y_step)
1473
+ ysteps[-1] *= last_y_step / y_step
1474
+
1475
+ # now normalize such that distortion never leaves image bounds.
1476
+ tx = w / math.floor(w / self.num_steps)
1477
+ ty = h / math.floor(h / self.num_steps)
1478
+ xsteps = np.array(xsteps) * (tx / np.sum(xsteps))
1479
+ ysteps = np.array(ysteps) * (ty / np.sum(ysteps))
1480
+
1481
+ return {"stepsx": xsteps, "stepsy": ysteps}
1482
+
1483
+ @property
1484
+ def targets_as_params(self):
1485
+ return ["image"]
1486
+
1487
+ def get_params_dependent_on_targets(self, params):
1488
+ h, w = params["image"].shape[:2]
1489
+
1490
+ stepsx = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)]
1491
+ stepsy = [1 + random.uniform(self.distort_limit[0], self.distort_limit[1]) for _ in range(self.num_steps + 1)]
1492
+
1493
+ if self.normalized:
1494
+ return self._normalize(h, w, stepsx, stepsy)
1495
+
1496
+ return {"stepsx": stepsx, "stepsy": stepsy}
1497
+
1498
+ def get_transform_init_args_names(self):
1499
+ return "num_steps", "distort_limit", "interpolation", "border_mode", "value", "mask_value", "normalized"
custom_albumentations/augmentations/transforms.py ADDED
@@ -0,0 +1,2667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import, division
2
+
3
+ import math
4
+ import numbers
5
+ import random
6
+ import warnings
7
+ from enum import IntEnum
8
+ from types import LambdaType
9
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
10
+
11
+ import cv2
12
+ import numpy as np
13
+ from scipy import special
14
+ from scipy.ndimage import gaussian_filter
15
+
16
+ from custom_albumentations import random_utils
17
+ from custom_albumentations.augmentations.blur.functional import blur
18
+ from custom_albumentations.augmentations.utils import (
19
+ get_num_channels,
20
+ is_grayscale_image,
21
+ is_rgb_image,
22
+ )
23
+
24
+ from ..core.transforms_interface import (
25
+ DualTransform,
26
+ ImageOnlyTransform,
27
+ NoOp,
28
+ ScaleFloatType,
29
+ to_tuple,
30
+ )
31
+ from ..core.utils import format_args
32
+ from . import functional as F
33
+
34
+ __all__ = [
35
+ "Normalize",
36
+ "RandomGamma",
37
+ "RandomGridShuffle",
38
+ "HueSaturationValue",
39
+ "RGBShift",
40
+ "RandomBrightness",
41
+ "RandomContrast",
42
+ "GaussNoise",
43
+ "CLAHE",
44
+ "ChannelShuffle",
45
+ "InvertImg",
46
+ "ToGray",
47
+ "ToRGB",
48
+ "ToSepia",
49
+ "JpegCompression",
50
+ "ImageCompression",
51
+ "ToFloat",
52
+ "FromFloat",
53
+ "RandomBrightnessContrast",
54
+ "RandomSnow",
55
+ "RandomGravel",
56
+ "RandomRain",
57
+ "RandomFog",
58
+ "RandomSunFlare",
59
+ "RandomShadow",
60
+ "RandomToneCurve",
61
+ "Lambda",
62
+ "ISONoise",
63
+ "Solarize",
64
+ "Equalize",
65
+ "Posterize",
66
+ "Downscale",
67
+ "MultiplicativeNoise",
68
+ "FancyPCA",
69
+ "ColorJitter",
70
+ "Sharpen",
71
+ "Emboss",
72
+ "Superpixels",
73
+ "TemplateTransform",
74
+ "RingingOvershoot",
75
+ "UnsharpMask",
76
+ "PixelDropout",
77
+ "Spatter",
78
+ ]
79
+
80
+
81
+ class RandomGridShuffle(DualTransform):
82
+ """
83
+ Random shuffle grid's cells on image.
84
+
85
+ Args:
86
+ grid ((int, int)): size of grid for splitting image.
87
+
88
+ Targets:
89
+ image, mask, keypoints
90
+
91
+ Image types:
92
+ uint8, float32
93
+ """
94
+
95
+ def __init__(self, grid: Tuple[int, int] = (3, 3), always_apply: bool = False, p: float = 0.5):
96
+ super(RandomGridShuffle, self).__init__(always_apply, p)
97
+ self.grid = grid
98
+
99
+ def apply(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params):
100
+ return F.swap_tiles_on_image(img, tiles)
101
+
102
+ def apply_to_mask(self, img: np.ndarray, tiles: np.ndarray = np.array(None), **params):
103
+ return F.swap_tiles_on_image(img, tiles)
104
+
105
+ def apply_to_keypoint(
106
+ self, keypoint: Tuple[float, ...], tiles: np.ndarray = np.array(None), rows: int = 0, cols: int = 0, **params
107
+ ):
108
+ for (
109
+ current_left_up_corner_row,
110
+ current_left_up_corner_col,
111
+ old_left_up_corner_row,
112
+ old_left_up_corner_col,
113
+ height_tile,
114
+ width_tile,
115
+ ) in tiles:
116
+ x, y = keypoint[:2]
117
+
118
+ if (old_left_up_corner_row <= y < (old_left_up_corner_row + height_tile)) and (
119
+ old_left_up_corner_col <= x < (old_left_up_corner_col + width_tile)
120
+ ):
121
+ x = x - old_left_up_corner_col + current_left_up_corner_col
122
+ y = y - old_left_up_corner_row + current_left_up_corner_row
123
+ keypoint = (x, y) + tuple(keypoint[2:])
124
+ break
125
+
126
+ return keypoint
127
+
128
+ def get_params_dependent_on_targets(self, params):
129
+ height, width = params["image"].shape[:2]
130
+ n, m = self.grid
131
+
132
+ if n <= 0 or m <= 0:
133
+ raise ValueError("Grid's values must be positive. Current grid [%s, %s]" % (n, m))
134
+
135
+ if n > height // 2 or m > width // 2:
136
+ raise ValueError("Incorrect size cell of grid. Just shuffle pixels of image")
137
+
138
+ height_split = np.linspace(0, height, n + 1, dtype=np.int32)
139
+ width_split = np.linspace(0, width, m + 1, dtype=np.int32)
140
+
141
+ height_matrix, width_matrix = np.meshgrid(height_split, width_split, indexing="ij")
142
+
143
+ index_height_matrix = height_matrix[:-1, :-1]
144
+ index_width_matrix = width_matrix[:-1, :-1]
145
+
146
+ shifted_index_height_matrix = height_matrix[1:, 1:]
147
+ shifted_index_width_matrix = width_matrix[1:, 1:]
148
+
149
+ height_tile_sizes = shifted_index_height_matrix - index_height_matrix
150
+ width_tile_sizes = shifted_index_width_matrix - index_width_matrix
151
+
152
+ tiles_sizes = np.stack((height_tile_sizes, width_tile_sizes), axis=2)
153
+
154
+ index_matrix = np.indices((n, m))
155
+ new_index_matrix = np.stack(index_matrix, axis=2)
156
+
157
+ for bbox_size in np.unique(tiles_sizes.reshape(-1, 2), axis=0):
158
+ eq_mat = np.all(tiles_sizes == bbox_size, axis=2)
159
+ new_index_matrix[eq_mat] = random_utils.permutation(new_index_matrix[eq_mat])
160
+
161
+ new_index_matrix = np.split(new_index_matrix, 2, axis=2)
162
+
163
+ old_x = index_height_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)
164
+ old_y = index_width_matrix[new_index_matrix[0], new_index_matrix[1]].reshape(-1)
165
+
166
+ shift_x = height_tile_sizes.reshape(-1)
167
+ shift_y = width_tile_sizes.reshape(-1)
168
+
169
+ curr_x = index_height_matrix.reshape(-1)
170
+ curr_y = index_width_matrix.reshape(-1)
171
+
172
+ tiles = np.stack([curr_x, curr_y, old_x, old_y, shift_x, shift_y], axis=1)
173
+
174
+ return {"tiles": tiles}
175
+
176
+ @property
177
+ def targets_as_params(self):
178
+ return ["image"]
179
+
180
+ def get_transform_init_args_names(self):
181
+ return ("grid",)
182
+
183
+
184
+ class Normalize(ImageOnlyTransform):
185
+ """Normalization is applied by the formula: `img = (img - mean * max_pixel_value) / (std * max_pixel_value)`
186
+
187
+ Args:
188
+ mean (float, list of float): mean values
189
+ std (float, list of float): std values
190
+ max_pixel_value (float): maximum possible pixel value
191
+
192
+ Targets:
193
+ image
194
+
195
+ Image types:
196
+ uint8, float32
197
+ """
198
+
199
+ def __init__(
200
+ self,
201
+ mean=(0.485, 0.456, 0.406),
202
+ std=(0.229, 0.224, 0.225),
203
+ max_pixel_value=255.0,
204
+ always_apply=False,
205
+ p=1.0,
206
+ ):
207
+ super(Normalize, self).__init__(always_apply, p)
208
+ self.mean = mean
209
+ self.std = std
210
+ self.max_pixel_value = max_pixel_value
211
+
212
+ def apply(self, image, **params):
213
+ return F.normalize(image, self.mean, self.std, self.max_pixel_value)
214
+
215
+ def get_transform_init_args_names(self):
216
+ return ("mean", "std", "max_pixel_value")
217
+
218
+
219
+ class ImageCompression(ImageOnlyTransform):
220
+ """Decreases image quality by Jpeg, WebP compression of an image.
221
+
222
+ Args:
223
+ quality_lower (float): lower bound on the image quality.
224
+ Should be in [0, 100] range for jpeg and [1, 100] for webp.
225
+ quality_upper (float): upper bound on the image quality.
226
+ Should be in [0, 100] range for jpeg and [1, 100] for webp.
227
+ compression_type (ImageCompressionType): should be ImageCompressionType.JPEG or ImageCompressionType.WEBP.
228
+ Default: ImageCompressionType.JPEG
229
+
230
+ Targets:
231
+ image
232
+
233
+ Image types:
234
+ uint8, float32
235
+ """
236
+
237
+ class ImageCompressionType(IntEnum):
238
+ JPEG = 0
239
+ WEBP = 1
240
+
241
+ def __init__(
242
+ self,
243
+ quality_lower=99,
244
+ quality_upper=100,
245
+ compression_type=ImageCompressionType.JPEG,
246
+ always_apply=False,
247
+ p=0.5,
248
+ ):
249
+ super(ImageCompression, self).__init__(always_apply, p)
250
+
251
+ self.compression_type = ImageCompression.ImageCompressionType(compression_type)
252
+ low_thresh_quality_assert = 0
253
+
254
+ if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
255
+ low_thresh_quality_assert = 1
256
+
257
+ if not low_thresh_quality_assert <= quality_lower <= 100:
258
+ raise ValueError("Invalid quality_lower. Got: {}".format(quality_lower))
259
+ if not low_thresh_quality_assert <= quality_upper <= 100:
260
+ raise ValueError("Invalid quality_upper. Got: {}".format(quality_upper))
261
+
262
+ self.quality_lower = quality_lower
263
+ self.quality_upper = quality_upper
264
+
265
+ def apply(self, image, quality=100, image_type=".jpg", **params):
266
+ if not image.ndim == 2 and image.shape[-1] not in (1, 3, 4):
267
+ raise TypeError("ImageCompression transformation expects 1, 3 or 4 channel images.")
268
+ return F.image_compression(image, quality, image_type)
269
+
270
+ def get_params(self):
271
+ image_type = ".jpg"
272
+
273
+ if self.compression_type == ImageCompression.ImageCompressionType.WEBP:
274
+ image_type = ".webp"
275
+
276
+ return {
277
+ "quality": random.randint(self.quality_lower, self.quality_upper),
278
+ "image_type": image_type,
279
+ }
280
+
281
+ def get_transform_init_args(self):
282
+ return {
283
+ "quality_lower": self.quality_lower,
284
+ "quality_upper": self.quality_upper,
285
+ "compression_type": self.compression_type.value,
286
+ }
287
+
288
+
289
+ class JpegCompression(ImageCompression):
290
+ """Decreases image quality by Jpeg compression of an image.
291
+
292
+ Args:
293
+ quality_lower (float): lower bound on the jpeg quality. Should be in [0, 100] range
294
+ quality_upper (float): upper bound on the jpeg quality. Should be in [0, 100] range
295
+
296
+ Targets:
297
+ image
298
+
299
+ Image types:
300
+ uint8, float32
301
+ """
302
+
303
+ def __init__(self, quality_lower=99, quality_upper=100, always_apply=False, p=0.5):
304
+ super(JpegCompression, self).__init__(
305
+ quality_lower=quality_lower,
306
+ quality_upper=quality_upper,
307
+ compression_type=ImageCompression.ImageCompressionType.JPEG,
308
+ always_apply=always_apply,
309
+ p=p,
310
+ )
311
+ warnings.warn(
312
+ f"{self.__class__.__name__} has been deprecated. Please use ImageCompression",
313
+ FutureWarning,
314
+ )
315
+
316
+ def get_transform_init_args(self):
317
+ return {
318
+ "quality_lower": self.quality_lower,
319
+ "quality_upper": self.quality_upper,
320
+ }
321
+
322
+
323
+ class RandomSnow(ImageOnlyTransform):
324
+ """Bleach out some pixel values simulating snow.
325
+
326
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
327
+
328
+ Args:
329
+ snow_point_lower (float): lower_bond of the amount of snow. Should be in [0, 1] range
330
+ snow_point_upper (float): upper_bond of the amount of snow. Should be in [0, 1] range
331
+ brightness_coeff (float): larger number will lead to a more snow on the image. Should be >= 0
332
+
333
+ Targets:
334
+ image
335
+
336
+ Image types:
337
+ uint8, float32
338
+ """
339
+
340
+ def __init__(
341
+ self,
342
+ snow_point_lower=0.1,
343
+ snow_point_upper=0.3,
344
+ brightness_coeff=2.5,
345
+ always_apply=False,
346
+ p=0.5,
347
+ ):
348
+ super(RandomSnow, self).__init__(always_apply, p)
349
+
350
+ if not 0 <= snow_point_lower <= snow_point_upper <= 1:
351
+ raise ValueError(
352
+ "Invalid combination of snow_point_lower and snow_point_upper. Got: {}".format(
353
+ (snow_point_lower, snow_point_upper)
354
+ )
355
+ )
356
+ if brightness_coeff < 0:
357
+ raise ValueError("brightness_coeff must be greater than 0. Got: {}".format(brightness_coeff))
358
+
359
+ self.snow_point_lower = snow_point_lower
360
+ self.snow_point_upper = snow_point_upper
361
+ self.brightness_coeff = brightness_coeff
362
+
363
+ def apply(self, image, snow_point=0.1, **params):
364
+ return F.add_snow(image, snow_point, self.brightness_coeff)
365
+
366
+ def get_params(self):
367
+ return {"snow_point": random.uniform(self.snow_point_lower, self.snow_point_upper)}
368
+
369
+ def get_transform_init_args_names(self):
370
+ return ("snow_point_lower", "snow_point_upper", "brightness_coeff")
371
+
372
+
373
+ class RandomGravel(ImageOnlyTransform):
374
+ """Add gravels.
375
+
376
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
377
+
378
+ Args:
379
+ gravel_roi (float, float, float, float): (top-left x, top-left y,
380
+ bottom-right x, bottom right y). Should be in [0, 1] range
381
+ number_of_patches (int): no. of gravel patches required
382
+
383
+ Targets:
384
+ image
385
+
386
+ Image types:
387
+ uint8, float32
388
+ """
389
+
390
+ def __init__(
391
+ self,
392
+ gravel_roi: tuple = (0.1, 0.4, 0.9, 0.9),
393
+ number_of_patches: int = 2,
394
+ always_apply: bool = False,
395
+ p: float = 0.5,
396
+ ):
397
+ super(RandomGravel, self).__init__(always_apply, p)
398
+
399
+ (gravel_lower_x, gravel_lower_y, gravel_upper_x, gravel_upper_y) = gravel_roi
400
+
401
+ if not 0 <= gravel_lower_x < gravel_upper_x <= 1 or not 0 <= gravel_lower_y < gravel_upper_y <= 1:
402
+ raise ValueError("Invalid gravel_roi. Got: %s." % gravel_roi)
403
+ if number_of_patches < 1:
404
+ raise ValueError("Invalid gravel number_of_patches. Got: %s." % number_of_patches)
405
+
406
+ self.gravel_roi = gravel_roi
407
+ self.number_of_patches = number_of_patches
408
+
409
+ def generate_gravel_patch(self, rectangular_roi):
410
+ x1, y1, x2, y2 = rectangular_roi
411
+ gravels = []
412
+ area = abs((x2 - x1) * (y2 - y1))
413
+ count = area // 10
414
+ gravels = np.empty([count, 2], dtype=np.int64)
415
+ gravels[:, 0] = random_utils.randint(x1, x2, count)
416
+ gravels[:, 1] = random_utils.randint(y1, y2, count)
417
+ return gravels
418
+
419
+ def apply(self, image, gravels_infos=(), **params):
420
+ return F.add_gravel(image, gravels_infos)
421
+
422
+ @property
423
+ def targets_as_params(self):
424
+ return ["image"]
425
+
426
+ def get_params_dependent_on_targets(self, params):
427
+ img = params["image"]
428
+ height, width = img.shape[:2]
429
+
430
+ x_min, y_min, x_max, y_max = self.gravel_roi
431
+ x_min = int(x_min * width)
432
+ x_max = int(x_max * width)
433
+ y_min = int(y_min * height)
434
+ y_max = int(y_max * height)
435
+
436
+ max_height = 200
437
+ max_width = 30
438
+
439
+ rectangular_rois = np.zeros([self.number_of_patches, 4], dtype=np.int64)
440
+ xx1 = random_utils.randint(x_min + 1, x_max, self.number_of_patches) # xmax
441
+ xx2 = random_utils.randint(x_min, xx1) # xmin
442
+ yy1 = random_utils.randint(y_min + 1, y_max, self.number_of_patches) # ymax
443
+ yy2 = random_utils.randint(y_min, yy1) # ymin
444
+
445
+ rectangular_rois[:, 0] = xx2
446
+ rectangular_rois[:, 1] = yy2
447
+ rectangular_rois[:, 2] = [min(tup) for tup in zip(xx1, xx2 + max_height)]
448
+ rectangular_rois[:, 3] = [min(tup) for tup in zip(yy1, yy2 + max_width)]
449
+
450
+ minx = []
451
+ maxx = []
452
+ miny = []
453
+ maxy = []
454
+ val = []
455
+ for roi in rectangular_rois:
456
+ gravels = self.generate_gravel_patch(roi)
457
+ x = gravels[:, 0]
458
+ y = gravels[:, 1]
459
+ r = random_utils.randint(1, 4, len(gravels))
460
+ sat = random_utils.randint(0, 255, len(gravels))
461
+ miny.append(np.maximum(y - r, 0))
462
+ maxy.append(np.minimum(y + r, y))
463
+ minx.append(np.maximum(x - r, 0))
464
+ maxx.append(np.minimum(x + r, x))
465
+ val.append(sat)
466
+
467
+ return {
468
+ "gravels_infos": np.stack(
469
+ [
470
+ np.concatenate(miny),
471
+ np.concatenate(maxy),
472
+ np.concatenate(minx),
473
+ np.concatenate(maxx),
474
+ np.concatenate(val),
475
+ ],
476
+ 1,
477
+ )
478
+ }
479
+
480
+ def get_transform_init_args_names(self):
481
+ return {"gravel_roi": self.gravel_roi, "number_of_patches": self.number_of_patches}
482
+
483
+
484
+ class RandomRain(ImageOnlyTransform):
485
+ """Adds rain effects.
486
+
487
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
488
+
489
+ Args:
490
+ slant_lower: should be in range [-20, 20].
491
+ slant_upper: should be in range [-20, 20].
492
+ drop_length: should be in range [0, 100].
493
+ drop_width: should be in range [1, 5].
494
+ drop_color (list of (r, g, b)): rain lines color.
495
+ blur_value (int): rainy view are blurry
496
+ brightness_coefficient (float): rainy days are usually shady. Should be in range [0, 1].
497
+ rain_type: One of [None, "drizzle", "heavy", "torrential"]
498
+
499
+ Targets:
500
+ image
501
+
502
+ Image types:
503
+ uint8, float32
504
+ """
505
+
506
+ def __init__(
507
+ self,
508
+ slant_lower=-10,
509
+ slant_upper=10,
510
+ drop_length=20,
511
+ drop_width=1,
512
+ drop_color=(200, 200, 200),
513
+ blur_value=7,
514
+ brightness_coefficient=0.7,
515
+ rain_type=None,
516
+ always_apply=False,
517
+ p=0.5,
518
+ ):
519
+ super(RandomRain, self).__init__(always_apply, p)
520
+
521
+ if rain_type not in ["drizzle", "heavy", "torrential", None]:
522
+ raise ValueError(
523
+ "raint_type must be one of ({}). Got: {}".format(["drizzle", "heavy", "torrential", None], rain_type)
524
+ )
525
+ if not -20 <= slant_lower <= slant_upper <= 20:
526
+ raise ValueError(
527
+ "Invalid combination of slant_lower and slant_upper. Got: {}".format((slant_lower, slant_upper))
528
+ )
529
+ if not 1 <= drop_width <= 5:
530
+ raise ValueError("drop_width must be in range [1, 5]. Got: {}".format(drop_width))
531
+ if not 0 <= drop_length <= 100:
532
+ raise ValueError("drop_length must be in range [0, 100]. Got: {}".format(drop_length))
533
+ if not 0 <= brightness_coefficient <= 1:
534
+ raise ValueError("brightness_coefficient must be in range [0, 1]. Got: {}".format(brightness_coefficient))
535
+
536
+ self.slant_lower = slant_lower
537
+ self.slant_upper = slant_upper
538
+
539
+ self.drop_length = drop_length
540
+ self.drop_width = drop_width
541
+ self.drop_color = drop_color
542
+ self.blur_value = blur_value
543
+ self.brightness_coefficient = brightness_coefficient
544
+ self.rain_type = rain_type
545
+
546
+ def apply(self, image, slant=10, drop_length=20, rain_drops=(), **params):
547
+ return F.add_rain(
548
+ image,
549
+ slant,
550
+ drop_length,
551
+ self.drop_width,
552
+ self.drop_color,
553
+ self.blur_value,
554
+ self.brightness_coefficient,
555
+ rain_drops,
556
+ )
557
+
558
+ @property
559
+ def targets_as_params(self):
560
+ return ["image"]
561
+
562
+ def get_params_dependent_on_targets(self, params):
563
+ img = params["image"]
564
+ slant = int(random.uniform(self.slant_lower, self.slant_upper))
565
+
566
+ height, width = img.shape[:2]
567
+ area = height * width
568
+
569
+ if self.rain_type == "drizzle":
570
+ num_drops = area // 770
571
+ drop_length = 10
572
+ elif self.rain_type == "heavy":
573
+ num_drops = width * height // 600
574
+ drop_length = 30
575
+ elif self.rain_type == "torrential":
576
+ num_drops = area // 500
577
+ drop_length = 60
578
+ else:
579
+ drop_length = self.drop_length
580
+ num_drops = area // 600
581
+
582
+ rain_drops = []
583
+
584
+ for _i in range(num_drops): # If You want heavy rain, try increasing this
585
+ if slant < 0:
586
+ x = random.randint(slant, width)
587
+ else:
588
+ x = random.randint(0, width - slant)
589
+
590
+ y = random.randint(0, height - drop_length)
591
+
592
+ rain_drops.append((x, y))
593
+
594
+ return {"drop_length": drop_length, "slant": slant, "rain_drops": rain_drops}
595
+
596
+ def get_transform_init_args_names(self):
597
+ return (
598
+ "slant_lower",
599
+ "slant_upper",
600
+ "drop_length",
601
+ "drop_width",
602
+ "drop_color",
603
+ "blur_value",
604
+ "brightness_coefficient",
605
+ "rain_type",
606
+ )
607
+
608
+
609
+ class RandomFog(ImageOnlyTransform):
610
+ """Simulates fog for the image
611
+
612
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
613
+
614
+ Args:
615
+ fog_coef_lower (float): lower limit for fog intensity coefficient. Should be in [0, 1] range.
616
+ fog_coef_upper (float): upper limit for fog intensity coefficient. Should be in [0, 1] range.
617
+ alpha_coef (float): transparency of the fog circles. Should be in [0, 1] range.
618
+
619
+ Targets:
620
+ image
621
+
622
+ Image types:
623
+ uint8, float32
624
+ """
625
+
626
+ def __init__(
627
+ self,
628
+ fog_coef_lower=0.3,
629
+ fog_coef_upper=1,
630
+ alpha_coef=0.08,
631
+ always_apply=False,
632
+ p=0.5,
633
+ ):
634
+ super(RandomFog, self).__init__(always_apply, p)
635
+
636
+ if not 0 <= fog_coef_lower <= fog_coef_upper <= 1:
637
+ raise ValueError(
638
+ "Invalid combination if fog_coef_lower and fog_coef_upper. Got: {}".format(
639
+ (fog_coef_lower, fog_coef_upper)
640
+ )
641
+ )
642
+ if not 0 <= alpha_coef <= 1:
643
+ raise ValueError("alpha_coef must be in range [0, 1]. Got: {}".format(alpha_coef))
644
+
645
+ self.fog_coef_lower = fog_coef_lower
646
+ self.fog_coef_upper = fog_coef_upper
647
+ self.alpha_coef = alpha_coef
648
+
649
+ def apply(self, image, fog_coef=0.1, haze_list=(), **params):
650
+ return F.add_fog(image, fog_coef, self.alpha_coef, haze_list)
651
+
652
+ @property
653
+ def targets_as_params(self):
654
+ return ["image"]
655
+
656
+ def get_params_dependent_on_targets(self, params):
657
+ img = params["image"]
658
+ fog_coef = random.uniform(self.fog_coef_lower, self.fog_coef_upper)
659
+
660
+ height, width = imshape = img.shape[:2]
661
+
662
+ hw = max(1, int(width // 3 * fog_coef))
663
+
664
+ haze_list = []
665
+ midx = width // 2 - 2 * hw
666
+ midy = height // 2 - hw
667
+ index = 1
668
+
669
+ while midx > -hw or midy > -hw:
670
+ for _i in range(hw // 10 * index):
671
+ x = random.randint(midx, width - midx - hw)
672
+ y = random.randint(midy, height - midy - hw)
673
+ haze_list.append((x, y))
674
+
675
+ midx -= 3 * hw * width // sum(imshape)
676
+ midy -= 3 * hw * height // sum(imshape)
677
+ index += 1
678
+
679
+ return {"haze_list": haze_list, "fog_coef": fog_coef}
680
+
681
+ def get_transform_init_args_names(self):
682
+ return ("fog_coef_lower", "fog_coef_upper", "alpha_coef")
683
+
684
+
685
+ class RandomSunFlare(ImageOnlyTransform):
686
+ """Simulates Sun Flare for the image
687
+
688
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
689
+
690
+ Args:
691
+ flare_roi (float, float, float, float): region of the image where flare will
692
+ appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1].
693
+ angle_lower (float): should be in range [0, `angle_upper`].
694
+ angle_upper (float): should be in range [`angle_lower`, 1].
695
+ num_flare_circles_lower (int): lower limit for the number of flare circles.
696
+ Should be in range [0, `num_flare_circles_upper`].
697
+ num_flare_circles_upper (int): upper limit for the number of flare circles.
698
+ Should be in range [`num_flare_circles_lower`, inf].
699
+ src_radius (int):
700
+ src_color ((int, int, int)): color of the flare
701
+
702
+ Targets:
703
+ image
704
+
705
+ Image types:
706
+ uint8, float32
707
+ """
708
+
709
+ def __init__(
710
+ self,
711
+ flare_roi=(0, 0, 1, 0.5),
712
+ angle_lower=0,
713
+ angle_upper=1,
714
+ num_flare_circles_lower=6,
715
+ num_flare_circles_upper=10,
716
+ src_radius=400,
717
+ src_color=(255, 255, 255),
718
+ always_apply=False,
719
+ p=0.5,
720
+ ):
721
+ super(RandomSunFlare, self).__init__(always_apply, p)
722
+
723
+ (
724
+ flare_center_lower_x,
725
+ flare_center_lower_y,
726
+ flare_center_upper_x,
727
+ flare_center_upper_y,
728
+ ) = flare_roi
729
+
730
+ if (
731
+ not 0 <= flare_center_lower_x < flare_center_upper_x <= 1
732
+ or not 0 <= flare_center_lower_y < flare_center_upper_y <= 1
733
+ ):
734
+ raise ValueError("Invalid flare_roi. Got: {}".format(flare_roi))
735
+ if not 0 <= angle_lower < angle_upper <= 1:
736
+ raise ValueError(
737
+ "Invalid combination of angle_lower nad angle_upper. Got: {}".format((angle_lower, angle_upper))
738
+ )
739
+ if not 0 <= num_flare_circles_lower < num_flare_circles_upper:
740
+ raise ValueError(
741
+ "Invalid combination of num_flare_circles_lower nad num_flare_circles_upper. Got: {}".format(
742
+ (num_flare_circles_lower, num_flare_circles_upper)
743
+ )
744
+ )
745
+
746
+ self.flare_center_lower_x = flare_center_lower_x
747
+ self.flare_center_upper_x = flare_center_upper_x
748
+
749
+ self.flare_center_lower_y = flare_center_lower_y
750
+ self.flare_center_upper_y = flare_center_upper_y
751
+
752
+ self.angle_lower = angle_lower
753
+ self.angle_upper = angle_upper
754
+ self.num_flare_circles_lower = num_flare_circles_lower
755
+ self.num_flare_circles_upper = num_flare_circles_upper
756
+
757
+ self.src_radius = src_radius
758
+ self.src_color = src_color
759
+
760
+ def apply(self, image, flare_center_x=0.5, flare_center_y=0.5, circles=(), **params):
761
+ return F.add_sun_flare(
762
+ image,
763
+ flare_center_x,
764
+ flare_center_y,
765
+ self.src_radius,
766
+ self.src_color,
767
+ circles,
768
+ )
769
+
770
+ @property
771
+ def targets_as_params(self):
772
+ return ["image"]
773
+
774
+ def get_params_dependent_on_targets(self, params):
775
+ img = params["image"]
776
+ height, width = img.shape[:2]
777
+
778
+ angle = 2 * math.pi * random.uniform(self.angle_lower, self.angle_upper)
779
+
780
+ flare_center_x = random.uniform(self.flare_center_lower_x, self.flare_center_upper_x)
781
+ flare_center_y = random.uniform(self.flare_center_lower_y, self.flare_center_upper_y)
782
+
783
+ flare_center_x = int(width * flare_center_x)
784
+ flare_center_y = int(height * flare_center_y)
785
+
786
+ num_circles = random.randint(self.num_flare_circles_lower, self.num_flare_circles_upper)
787
+
788
+ circles = []
789
+
790
+ x = []
791
+ y = []
792
+
793
+ def line(t):
794
+ return (flare_center_x + t * math.cos(angle), flare_center_y + t * math.sin(angle))
795
+
796
+ for t_val in range(-flare_center_x, width - flare_center_x, 10):
797
+ rand_x, rand_y = line(t_val)
798
+ x.append(rand_x)
799
+ y.append(rand_y)
800
+
801
+ for _i in range(num_circles):
802
+ alpha = random.uniform(0.05, 0.2)
803
+ r = random.randint(0, len(x) - 1)
804
+ rad = random.randint(1, max(height // 100 - 2, 2))
805
+
806
+ r_color = random.randint(max(self.src_color[0] - 50, 0), self.src_color[0])
807
+ g_color = random.randint(max(self.src_color[1] - 50, 0), self.src_color[1])
808
+ b_color = random.randint(max(self.src_color[2] - 50, 0), self.src_color[2])
809
+
810
+ circles += [
811
+ (
812
+ alpha,
813
+ (int(x[r]), int(y[r])),
814
+ pow(rad, 3),
815
+ (r_color, g_color, b_color),
816
+ )
817
+ ]
818
+
819
+ return {
820
+ "circles": circles,
821
+ "flare_center_x": flare_center_x,
822
+ "flare_center_y": flare_center_y,
823
+ }
824
+
825
+ def get_transform_init_args(self):
826
+ return {
827
+ "flare_roi": (
828
+ self.flare_center_lower_x,
829
+ self.flare_center_lower_y,
830
+ self.flare_center_upper_x,
831
+ self.flare_center_upper_y,
832
+ ),
833
+ "angle_lower": self.angle_lower,
834
+ "angle_upper": self.angle_upper,
835
+ "num_flare_circles_lower": self.num_flare_circles_lower,
836
+ "num_flare_circles_upper": self.num_flare_circles_upper,
837
+ "src_radius": self.src_radius,
838
+ "src_color": self.src_color,
839
+ }
840
+
841
+
842
+ class RandomShadow(ImageOnlyTransform):
843
+ """Simulates shadows for the image
844
+
845
+ From https://github.com/UjjwalSaxena/Automold--Road-Augmentation-Library
846
+
847
+ Args:
848
+ shadow_roi (float, float, float, float): region of the image where shadows
849
+ will appear (x_min, y_min, x_max, y_max). All values should be in range [0, 1].
850
+ num_shadows_lower (int): Lower limit for the possible number of shadows.
851
+ Should be in range [0, `num_shadows_upper`].
852
+ num_shadows_upper (int): Lower limit for the possible number of shadows.
853
+ Should be in range [`num_shadows_lower`, inf].
854
+ shadow_dimension (int): number of edges in the shadow polygons
855
+
856
+ Targets:
857
+ image
858
+
859
+ Image types:
860
+ uint8, float32
861
+ """
862
+
863
+ def __init__(
864
+ self,
865
+ shadow_roi=(0, 0.5, 1, 1),
866
+ num_shadows_lower=1,
867
+ num_shadows_upper=2,
868
+ shadow_dimension=5,
869
+ always_apply=False,
870
+ p=0.5,
871
+ ):
872
+ super(RandomShadow, self).__init__(always_apply, p)
873
+
874
+ (shadow_lower_x, shadow_lower_y, shadow_upper_x, shadow_upper_y) = shadow_roi
875
+
876
+ if not 0 <= shadow_lower_x <= shadow_upper_x <= 1 or not 0 <= shadow_lower_y <= shadow_upper_y <= 1:
877
+ raise ValueError("Invalid shadow_roi. Got: {}".format(shadow_roi))
878
+ if not 0 <= num_shadows_lower <= num_shadows_upper:
879
+ raise ValueError(
880
+ "Invalid combination of num_shadows_lower nad num_shadows_upper. Got: {}".format(
881
+ (num_shadows_lower, num_shadows_upper)
882
+ )
883
+ )
884
+
885
+ self.shadow_roi = shadow_roi
886
+
887
+ self.num_shadows_lower = num_shadows_lower
888
+ self.num_shadows_upper = num_shadows_upper
889
+
890
+ self.shadow_dimension = shadow_dimension
891
+
892
+ def apply(self, image, vertices_list=(), **params):
893
+ return F.add_shadow(image, vertices_list)
894
+
895
+ @property
896
+ def targets_as_params(self):
897
+ return ["image"]
898
+
899
+ def get_params_dependent_on_targets(self, params):
900
+ img = params["image"]
901
+ height, width = img.shape[:2]
902
+
903
+ num_shadows = random.randint(self.num_shadows_lower, self.num_shadows_upper)
904
+
905
+ x_min, y_min, x_max, y_max = self.shadow_roi
906
+
907
+ x_min = int(x_min * width)
908
+ x_max = int(x_max * width)
909
+ y_min = int(y_min * height)
910
+ y_max = int(y_max * height)
911
+
912
+ vertices_list = []
913
+
914
+ for _index in range(num_shadows):
915
+ vertex = []
916
+ for _dimension in range(self.shadow_dimension):
917
+ vertex.append((random.randint(x_min, x_max), random.randint(y_min, y_max)))
918
+
919
+ vertices = np.array([vertex], dtype=np.int32)
920
+ vertices_list.append(vertices)
921
+
922
+ return {"vertices_list": vertices_list}
923
+
924
+ def get_transform_init_args_names(self):
925
+ return (
926
+ "shadow_roi",
927
+ "num_shadows_lower",
928
+ "num_shadows_upper",
929
+ "shadow_dimension",
930
+ )
931
+
932
+
933
+ class RandomToneCurve(ImageOnlyTransform):
934
+ """Randomly change the relationship between bright and dark areas of the image by manipulating its tone curve.
935
+
936
+ Args:
937
+ scale (float): standard deviation of the normal distribution.
938
+ Used to sample random distances to move two control points that modify the image's curve.
939
+ Values should be in range [0, 1]. Default: 0.1
940
+
941
+
942
+ Targets:
943
+ image
944
+
945
+ Image types:
946
+ uint8
947
+ """
948
+
949
+ def __init__(
950
+ self,
951
+ scale=0.1,
952
+ always_apply=False,
953
+ p=0.5,
954
+ ):
955
+ super(RandomToneCurve, self).__init__(always_apply, p)
956
+ self.scale = scale
957
+
958
+ def apply(self, image, low_y, high_y, **params):
959
+ return F.move_tone_curve(image, low_y, high_y)
960
+
961
+ def get_params(self):
962
+ return {
963
+ "low_y": np.clip(random_utils.normal(loc=0.25, scale=self.scale), 0, 1),
964
+ "high_y": np.clip(random_utils.normal(loc=0.75, scale=self.scale), 0, 1),
965
+ }
966
+
967
+ def get_transform_init_args_names(self):
968
+ return ("scale",)
969
+
970
+
971
+ class HueSaturationValue(ImageOnlyTransform):
972
+ """Randomly change hue, saturation and value of the input image.
973
+
974
+ Args:
975
+ hue_shift_limit ((int, int) or int): range for changing hue. If hue_shift_limit is a single int, the range
976
+ will be (-hue_shift_limit, hue_shift_limit). Default: (-20, 20).
977
+ sat_shift_limit ((int, int) or int): range for changing saturation. If sat_shift_limit is a single int,
978
+ the range will be (-sat_shift_limit, sat_shift_limit). Default: (-30, 30).
979
+ val_shift_limit ((int, int) or int): range for changing value. If val_shift_limit is a single int, the range
980
+ will be (-val_shift_limit, val_shift_limit). Default: (-20, 20).
981
+ p (float): probability of applying the transform. Default: 0.5.
982
+
983
+ Targets:
984
+ image
985
+
986
+ Image types:
987
+ uint8, float32
988
+ """
989
+
990
+ def __init__(
991
+ self,
992
+ hue_shift_limit=20,
993
+ sat_shift_limit=30,
994
+ val_shift_limit=20,
995
+ always_apply=False,
996
+ p=0.5,
997
+ ):
998
+ super(HueSaturationValue, self).__init__(always_apply, p)
999
+ self.hue_shift_limit = to_tuple(hue_shift_limit)
1000
+ self.sat_shift_limit = to_tuple(sat_shift_limit)
1001
+ self.val_shift_limit = to_tuple(val_shift_limit)
1002
+
1003
+ def apply(self, image, hue_shift=0, sat_shift=0, val_shift=0, **params):
1004
+ if not is_rgb_image(image) and not is_grayscale_image(image):
1005
+ raise TypeError("HueSaturationValue transformation expects 1-channel or 3-channel images.")
1006
+ return F.shift_hsv(image, hue_shift, sat_shift, val_shift)
1007
+
1008
+ def get_params(self):
1009
+ return {
1010
+ "hue_shift": random.uniform(self.hue_shift_limit[0], self.hue_shift_limit[1]),
1011
+ "sat_shift": random.uniform(self.sat_shift_limit[0], self.sat_shift_limit[1]),
1012
+ "val_shift": random.uniform(self.val_shift_limit[0], self.val_shift_limit[1]),
1013
+ }
1014
+
1015
+ def get_transform_init_args_names(self):
1016
+ return ("hue_shift_limit", "sat_shift_limit", "val_shift_limit")
1017
+
1018
+
1019
+ class Solarize(ImageOnlyTransform):
1020
+ """Invert all pixel values above a threshold.
1021
+
1022
+ Args:
1023
+ threshold ((int, int) or int, or (float, float) or float): range for solarizing threshold.
1024
+ If threshold is a single value, the range will be [threshold, threshold]. Default: 128.
1025
+ p (float): probability of applying the transform. Default: 0.5.
1026
+
1027
+ Targets:
1028
+ image
1029
+
1030
+ Image types:
1031
+ any
1032
+ """
1033
+
1034
+ def __init__(self, threshold=128, always_apply=False, p=0.5):
1035
+ super(Solarize, self).__init__(always_apply, p)
1036
+
1037
+ if isinstance(threshold, (int, float)):
1038
+ self.threshold = to_tuple(threshold, low=threshold)
1039
+ else:
1040
+ self.threshold = to_tuple(threshold, low=0)
1041
+
1042
+ def apply(self, image, threshold=0, **params):
1043
+ return F.solarize(image, threshold)
1044
+
1045
+ def get_params(self):
1046
+ return {"threshold": random.uniform(self.threshold[0], self.threshold[1])}
1047
+
1048
+ def get_transform_init_args_names(self):
1049
+ return ("threshold",)
1050
+
1051
+
1052
+ class Posterize(ImageOnlyTransform):
1053
+ """Reduce the number of bits for each color channel.
1054
+
1055
+ Args:
1056
+ num_bits ((int, int) or int,
1057
+ or list of ints [r, g, b],
1058
+ or list of ints [[r1, r1], [g1, g2], [b1, b2]]): number of high bits.
1059
+ If num_bits is a single value, the range will be [num_bits, num_bits].
1060
+ Must be in range [0, 8]. Default: 4.
1061
+ p (float): probability of applying the transform. Default: 0.5.
1062
+
1063
+ Targets:
1064
+ image
1065
+
1066
+ Image types:
1067
+ uint8
1068
+ """
1069
+
1070
+ def __init__(self, num_bits=4, always_apply=False, p=0.5):
1071
+ super(Posterize, self).__init__(always_apply, p)
1072
+
1073
+ if isinstance(num_bits, (list, tuple)):
1074
+ if len(num_bits) == 3:
1075
+ self.num_bits = [to_tuple(i, 0) for i in num_bits]
1076
+ else:
1077
+ self.num_bits = to_tuple(num_bits, 0)
1078
+ else:
1079
+ self.num_bits = to_tuple(num_bits, num_bits)
1080
+
1081
+ def apply(self, image, num_bits=1, **params):
1082
+ return F.posterize(image, num_bits)
1083
+
1084
+ def get_params(self):
1085
+ if len(self.num_bits) == 3:
1086
+ return {"num_bits": [random.randint(i[0], i[1]) for i in self.num_bits]}
1087
+ return {"num_bits": random.randint(self.num_bits[0], self.num_bits[1])}
1088
+
1089
+ def get_transform_init_args_names(self):
1090
+ return ("num_bits",)
1091
+
1092
+
1093
+ class Equalize(ImageOnlyTransform):
1094
+ """Equalize the image histogram.
1095
+
1096
+ Args:
1097
+ mode (str): {'cv', 'pil'}. Use OpenCV or Pillow equalization method.
1098
+ by_channels (bool): If True, use equalization by channels separately,
1099
+ else convert image to YCbCr representation and use equalization by `Y` channel.
1100
+ mask (np.ndarray, callable): If given, only the pixels selected by
1101
+ the mask are included in the analysis. Maybe 1 channel or 3 channel array or callable.
1102
+ Function signature must include `image` argument.
1103
+ mask_params (list of str): Params for mask function.
1104
+
1105
+ Targets:
1106
+ image
1107
+
1108
+ Image types:
1109
+ uint8
1110
+ """
1111
+
1112
+ def __init__(
1113
+ self,
1114
+ mode="cv",
1115
+ by_channels=True,
1116
+ mask=None,
1117
+ mask_params=(),
1118
+ always_apply=False,
1119
+ p=0.5,
1120
+ ):
1121
+ modes = ["cv", "pil"]
1122
+ if mode not in modes:
1123
+ raise ValueError("Unsupported equalization mode. Supports: {}. " "Got: {}".format(modes, mode))
1124
+
1125
+ super(Equalize, self).__init__(always_apply, p)
1126
+ self.mode = mode
1127
+ self.by_channels = by_channels
1128
+ self.mask = mask
1129
+ self.mask_params = mask_params
1130
+
1131
+ def apply(self, image, mask=None, **params):
1132
+ return F.equalize(image, mode=self.mode, by_channels=self.by_channels, mask=mask)
1133
+
1134
+ def get_params_dependent_on_targets(self, params):
1135
+ if not callable(self.mask):
1136
+ return {"mask": self.mask}
1137
+
1138
+ return {"mask": self.mask(**params)}
1139
+
1140
+ @property
1141
+ def targets_as_params(self):
1142
+ return ["image"] + list(self.mask_params)
1143
+
1144
+ def get_transform_init_args_names(self):
1145
+ return ("mode", "by_channels")
1146
+
1147
+
1148
+ class RGBShift(ImageOnlyTransform):
1149
+ """Randomly shift values for each channel of the input RGB image.
1150
+
1151
+ Args:
1152
+ r_shift_limit ((int, int) or int): range for changing values for the red channel. If r_shift_limit is a single
1153
+ int, the range will be (-r_shift_limit, r_shift_limit). Default: (-20, 20).
1154
+ g_shift_limit ((int, int) or int): range for changing values for the green channel. If g_shift_limit is a
1155
+ single int, the range will be (-g_shift_limit, g_shift_limit). Default: (-20, 20).
1156
+ b_shift_limit ((int, int) or int): range for changing values for the blue channel. If b_shift_limit is a single
1157
+ int, the range will be (-b_shift_limit, b_shift_limit). Default: (-20, 20).
1158
+ p (float): probability of applying the transform. Default: 0.5.
1159
+
1160
+ Targets:
1161
+ image
1162
+
1163
+ Image types:
1164
+ uint8, float32
1165
+ """
1166
+
1167
+ def __init__(
1168
+ self,
1169
+ r_shift_limit=20,
1170
+ g_shift_limit=20,
1171
+ b_shift_limit=20,
1172
+ always_apply=False,
1173
+ p=0.5,
1174
+ ):
1175
+ super(RGBShift, self).__init__(always_apply, p)
1176
+ self.r_shift_limit = to_tuple(r_shift_limit)
1177
+ self.g_shift_limit = to_tuple(g_shift_limit)
1178
+ self.b_shift_limit = to_tuple(b_shift_limit)
1179
+
1180
+ def apply(self, image, r_shift=0, g_shift=0, b_shift=0, **params):
1181
+ if not is_rgb_image(image):
1182
+ raise TypeError("RGBShift transformation expects 3-channel images.")
1183
+ return F.shift_rgb(image, r_shift, g_shift, b_shift)
1184
+
1185
+ def get_params(self):
1186
+ return {
1187
+ "r_shift": random.uniform(self.r_shift_limit[0], self.r_shift_limit[1]),
1188
+ "g_shift": random.uniform(self.g_shift_limit[0], self.g_shift_limit[1]),
1189
+ "b_shift": random.uniform(self.b_shift_limit[0], self.b_shift_limit[1]),
1190
+ }
1191
+
1192
+ def get_transform_init_args_names(self):
1193
+ return ("r_shift_limit", "g_shift_limit", "b_shift_limit")
1194
+
1195
+
1196
+ class RandomBrightnessContrast(ImageOnlyTransform):
1197
+ """Randomly change brightness and contrast of the input image.
1198
+
1199
+ Args:
1200
+ brightness_limit ((float, float) or float): factor range for changing brightness.
1201
+ If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
1202
+ contrast_limit ((float, float) or float): factor range for changing contrast.
1203
+ If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
1204
+ brightness_by_max (Boolean): If True adjust contrast by image dtype maximum,
1205
+ else adjust contrast by image mean.
1206
+ p (float): probability of applying the transform. Default: 0.5.
1207
+
1208
+ Targets:
1209
+ image
1210
+
1211
+ Image types:
1212
+ uint8, float32
1213
+ """
1214
+
1215
+ def __init__(
1216
+ self,
1217
+ brightness_limit=0.2,
1218
+ contrast_limit=0.2,
1219
+ brightness_by_max=True,
1220
+ always_apply=False,
1221
+ p=0.5,
1222
+ ):
1223
+ super(RandomBrightnessContrast, self).__init__(always_apply, p)
1224
+ self.brightness_limit = to_tuple(brightness_limit)
1225
+ self.contrast_limit = to_tuple(contrast_limit)
1226
+ self.brightness_by_max = brightness_by_max
1227
+
1228
+ def apply(self, img, alpha=1.0, beta=0.0, **params):
1229
+ return F.brightness_contrast_adjust(img, alpha, beta, self.brightness_by_max)
1230
+
1231
+ def get_params(self):
1232
+ return {
1233
+ "alpha": 1.0 + random.uniform(self.contrast_limit[0], self.contrast_limit[1]),
1234
+ "beta": 0.0 + random.uniform(self.brightness_limit[0], self.brightness_limit[1]),
1235
+ }
1236
+
1237
+ def get_transform_init_args_names(self):
1238
+ return ("brightness_limit", "contrast_limit", "brightness_by_max")
1239
+
1240
+
1241
+ class RandomBrightness(RandomBrightnessContrast):
1242
+ """Randomly change brightness of the input image.
1243
+
1244
+ Args:
1245
+ limit ((float, float) or float): factor range for changing brightness.
1246
+ If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
1247
+ p (float): probability of applying the transform. Default: 0.5.
1248
+
1249
+ Targets:
1250
+ image
1251
+
1252
+ Image types:
1253
+ uint8, float32
1254
+ """
1255
+
1256
+ def __init__(self, limit=0.2, always_apply=False, p=0.5):
1257
+ super(RandomBrightness, self).__init__(brightness_limit=limit, contrast_limit=0, always_apply=always_apply, p=p)
1258
+ warnings.warn(
1259
+ "This class has been deprecated. Please use RandomBrightnessContrast",
1260
+ FutureWarning,
1261
+ )
1262
+
1263
+ def get_transform_init_args(self):
1264
+ return {"limit": self.brightness_limit}
1265
+
1266
+
1267
+ class RandomContrast(RandomBrightnessContrast):
1268
+ """Randomly change contrast of the input image.
1269
+
1270
+ Args:
1271
+ limit ((float, float) or float): factor range for changing contrast.
1272
+ If limit is a single float, the range will be (-limit, limit). Default: (-0.2, 0.2).
1273
+ p (float): probability of applying the transform. Default: 0.5.
1274
+
1275
+ Targets:
1276
+ image
1277
+
1278
+ Image types:
1279
+ uint8, float32
1280
+ """
1281
+
1282
+ def __init__(self, limit=0.2, always_apply=False, p=0.5):
1283
+ super(RandomContrast, self).__init__(brightness_limit=0, contrast_limit=limit, always_apply=always_apply, p=p)
1284
+ warnings.warn(
1285
+ f"{self.__class__.__name__} has been deprecated. Please use RandomBrightnessContrast",
1286
+ FutureWarning,
1287
+ )
1288
+
1289
+ def get_transform_init_args(self):
1290
+ return {"limit": self.contrast_limit}
1291
+
1292
+
1293
+ class GaussNoise(ImageOnlyTransform):
1294
+ """Apply gaussian noise to the input image.
1295
+
1296
+ Args:
1297
+ var_limit ((float, float) or float): variance range for noise. If var_limit is a single float, the range
1298
+ will be (0, var_limit). Default: (10.0, 50.0).
1299
+ mean (float): mean of the noise. Default: 0
1300
+ per_channel (bool): if set to True, noise will be sampled for each channel independently.
1301
+ Otherwise, the noise will be sampled once for all channels. Default: True
1302
+ p (float): probability of applying the transform. Default: 0.5.
1303
+
1304
+ Targets:
1305
+ image
1306
+
1307
+ Image types:
1308
+ uint8, float32
1309
+ """
1310
+
1311
+ def __init__(self, var_limit=(10.0, 50.0), mean=0, per_channel=True, always_apply=False, p=0.5):
1312
+ super(GaussNoise, self).__init__(always_apply, p)
1313
+ if isinstance(var_limit, (tuple, list)):
1314
+ if var_limit[0] < 0:
1315
+ raise ValueError("Lower var_limit should be non negative.")
1316
+ if var_limit[1] < 0:
1317
+ raise ValueError("Upper var_limit should be non negative.")
1318
+ self.var_limit = var_limit
1319
+ elif isinstance(var_limit, (int, float)):
1320
+ if var_limit < 0:
1321
+ raise ValueError("var_limit should be non negative.")
1322
+
1323
+ self.var_limit = (0, var_limit)
1324
+ else:
1325
+ raise TypeError(
1326
+ "Expected var_limit type to be one of (int, float, tuple, list), got {}".format(type(var_limit))
1327
+ )
1328
+
1329
+ self.mean = mean
1330
+ self.per_channel = per_channel
1331
+
1332
+ def apply(self, img, gauss=None, **params):
1333
+ return F.gauss_noise(img, gauss=gauss)
1334
+
1335
+ def get_params_dependent_on_targets(self, params):
1336
+ image = params["image"]
1337
+ var = random.uniform(self.var_limit[0], self.var_limit[1])
1338
+ sigma = var**0.5
1339
+
1340
+ if self.per_channel:
1341
+ gauss = random_utils.normal(self.mean, sigma, image.shape)
1342
+ else:
1343
+ gauss = random_utils.normal(self.mean, sigma, image.shape[:2])
1344
+ if len(image.shape) == 3:
1345
+ gauss = np.expand_dims(gauss, -1)
1346
+
1347
+ return {"gauss": gauss}
1348
+
1349
+ @property
1350
+ def targets_as_params(self):
1351
+ return ["image"]
1352
+
1353
+ def get_transform_init_args_names(self):
1354
+ return ("var_limit", "per_channel", "mean")
1355
+
1356
+
1357
+ class ISONoise(ImageOnlyTransform):
1358
+ """
1359
+ Apply camera sensor noise.
1360
+
1361
+ Args:
1362
+ color_shift (float, float): variance range for color hue change.
1363
+ Measured as a fraction of 360 degree Hue angle in HLS colorspace.
1364
+ intensity ((float, float): Multiplicative factor that control strength
1365
+ of color and luminace noise.
1366
+ p (float): probability of applying the transform. Default: 0.5.
1367
+
1368
+ Targets:
1369
+ image
1370
+
1371
+ Image types:
1372
+ uint8
1373
+ """
1374
+
1375
+ def __init__(self, color_shift=(0.01, 0.05), intensity=(0.1, 0.5), always_apply=False, p=0.5):
1376
+ super(ISONoise, self).__init__(always_apply, p)
1377
+ self.intensity = intensity
1378
+ self.color_shift = color_shift
1379
+
1380
+ def apply(self, img, color_shift=0.05, intensity=1.0, random_state=None, **params):
1381
+ return F.iso_noise(img, color_shift, intensity, np.random.RandomState(random_state))
1382
+
1383
+ def get_params(self):
1384
+ return {
1385
+ "color_shift": random.uniform(self.color_shift[0], self.color_shift[1]),
1386
+ "intensity": random.uniform(self.intensity[0], self.intensity[1]),
1387
+ "random_state": random.randint(0, 65536),
1388
+ }
1389
+
1390
+ def get_transform_init_args_names(self):
1391
+ return ("intensity", "color_shift")
1392
+
1393
+
1394
+ class CLAHE(ImageOnlyTransform):
1395
+ """Apply Contrast Limited Adaptive Histogram Equalization to the input image.
1396
+
1397
+ Args:
1398
+ clip_limit (float or (float, float)): upper threshold value for contrast limiting.
1399
+ If clip_limit is a single float value, the range will be (1, clip_limit). Default: (1, 4).
1400
+ tile_grid_size ((int, int)): size of grid for histogram equalization. Default: (8, 8).
1401
+ p (float): probability of applying the transform. Default: 0.5.
1402
+
1403
+ Targets:
1404
+ image
1405
+
1406
+ Image types:
1407
+ uint8
1408
+ """
1409
+
1410
+ def __init__(self, clip_limit=4.0, tile_grid_size=(8, 8), always_apply=False, p=0.5):
1411
+ super(CLAHE, self).__init__(always_apply, p)
1412
+ self.clip_limit = to_tuple(clip_limit, 1)
1413
+ self.tile_grid_size = tuple(tile_grid_size)
1414
+
1415
+ def apply(self, img, clip_limit=2, **params):
1416
+ if not is_rgb_image(img) and not is_grayscale_image(img):
1417
+ raise TypeError("CLAHE transformation expects 1-channel or 3-channel images.")
1418
+
1419
+ return F.clahe(img, clip_limit, self.tile_grid_size)
1420
+
1421
+ def get_params(self):
1422
+ return {"clip_limit": random.uniform(self.clip_limit[0], self.clip_limit[1])}
1423
+
1424
+ def get_transform_init_args_names(self):
1425
+ return ("clip_limit", "tile_grid_size")
1426
+
1427
+
1428
+ class ChannelShuffle(ImageOnlyTransform):
1429
+ """Randomly rearrange channels of the input RGB image.
1430
+
1431
+ Args:
1432
+ p (float): probability of applying the transform. Default: 0.5.
1433
+
1434
+ Targets:
1435
+ image
1436
+
1437
+ Image types:
1438
+ uint8, float32
1439
+ """
1440
+
1441
+ @property
1442
+ def targets_as_params(self):
1443
+ return ["image"]
1444
+
1445
+ def apply(self, img, channels_shuffled=(0, 1, 2), **params):
1446
+ return F.channel_shuffle(img, channels_shuffled)
1447
+
1448
+ def get_params_dependent_on_targets(self, params):
1449
+ img = params["image"]
1450
+ ch_arr = list(range(img.shape[2]))
1451
+ random.shuffle(ch_arr)
1452
+ return {"channels_shuffled": ch_arr}
1453
+
1454
+ def get_transform_init_args_names(self):
1455
+ return ()
1456
+
1457
+
1458
+ class InvertImg(ImageOnlyTransform):
1459
+ """Invert the input image by subtracting pixel values from max values of the image types,
1460
+ i.e., 255 for uint8 and 1.0 for float32.
1461
+
1462
+ Args:
1463
+ p (float): probability of applying the transform. Default: 0.5.
1464
+
1465
+ Targets:
1466
+ image
1467
+
1468
+ Image types:
1469
+ uint8, float32
1470
+ """
1471
+
1472
+ def apply(self, img, **params):
1473
+ return F.invert(img)
1474
+
1475
+ def get_transform_init_args_names(self):
1476
+ return ()
1477
+
1478
+
1479
+ class RandomGamma(ImageOnlyTransform):
1480
+ """
1481
+ Args:
1482
+ gamma_limit (float or (float, float)): If gamma_limit is a single float value,
1483
+ the range will be (-gamma_limit, gamma_limit). Default: (80, 120).
1484
+ eps: Deprecated.
1485
+
1486
+ Targets:
1487
+ image
1488
+
1489
+ Image types:
1490
+ uint8, float32
1491
+ """
1492
+
1493
+ def __init__(self, gamma_limit=(80, 120), eps=None, always_apply=False, p=0.5):
1494
+ super(RandomGamma, self).__init__(always_apply, p)
1495
+ self.gamma_limit = to_tuple(gamma_limit)
1496
+ self.eps = eps
1497
+
1498
+ def apply(self, img, gamma=1, **params):
1499
+ return F.gamma_transform(img, gamma=gamma)
1500
+
1501
+ def get_params(self):
1502
+ return {"gamma": random.uniform(self.gamma_limit[0], self.gamma_limit[1]) / 100.0}
1503
+
1504
+ def get_transform_init_args_names(self):
1505
+ return ("gamma_limit", "eps")
1506
+
1507
+
1508
+ class ToGray(ImageOnlyTransform):
1509
+ """Convert the input RGB image to grayscale. If the mean pixel value for the resulting image is greater
1510
+ than 127, invert the resulting grayscale image.
1511
+
1512
+ Args:
1513
+ p (float): probability of applying the transform. Default: 0.5.
1514
+
1515
+ Targets:
1516
+ image
1517
+
1518
+ Image types:
1519
+ uint8, float32
1520
+ """
1521
+
1522
+ def apply(self, img, **params):
1523
+ if is_grayscale_image(img):
1524
+ warnings.warn("The image is already gray.")
1525
+ return img
1526
+ if not is_rgb_image(img):
1527
+ raise TypeError("ToGray transformation expects 3-channel images.")
1528
+
1529
+ return F.to_gray(img)
1530
+
1531
+ def get_transform_init_args_names(self):
1532
+ return ()
1533
+
1534
+
1535
+ class ToRGB(ImageOnlyTransform):
1536
+ """Convert the input grayscale image to RGB.
1537
+
1538
+ Args:
1539
+ p (float): probability of applying the transform. Default: 1.
1540
+
1541
+ Targets:
1542
+ image
1543
+
1544
+ Image types:
1545
+ uint8, float32
1546
+ """
1547
+
1548
+ def __init__(self, always_apply=True, p=1.0):
1549
+ super(ToRGB, self).__init__(always_apply=always_apply, p=p)
1550
+
1551
+ def apply(self, img, **params):
1552
+ if is_rgb_image(img):
1553
+ warnings.warn("The image is already an RGB.")
1554
+ return img
1555
+ if not is_grayscale_image(img):
1556
+ raise TypeError("ToRGB transformation expects 2-dim images or 3-dim with the last dimension equal to 1.")
1557
+
1558
+ return F.gray_to_rgb(img)
1559
+
1560
+ def get_transform_init_args_names(self):
1561
+ return ()
1562
+
1563
+
1564
+ class ToSepia(ImageOnlyTransform):
1565
+ """Applies sepia filter to the input RGB image
1566
+
1567
+ Args:
1568
+ p (float): probability of applying the transform. Default: 0.5.
1569
+
1570
+ Targets:
1571
+ image
1572
+
1573
+ Image types:
1574
+ uint8, float32
1575
+ """
1576
+
1577
+ def __init__(self, always_apply=False, p=0.5):
1578
+ super(ToSepia, self).__init__(always_apply, p)
1579
+ self.sepia_transformation_matrix = np.array(
1580
+ [[0.393, 0.769, 0.189], [0.349, 0.686, 0.168], [0.272, 0.534, 0.131]]
1581
+ )
1582
+
1583
+ def apply(self, image, **params):
1584
+ if not is_rgb_image(image):
1585
+ raise TypeError("ToSepia transformation expects 3-channel images.")
1586
+ return F.linear_transformation_rgb(image, self.sepia_transformation_matrix)
1587
+
1588
+ def get_transform_init_args_names(self):
1589
+ return ()
1590
+
1591
+
1592
+ class ToFloat(ImageOnlyTransform):
1593
+ """Divide pixel values by `max_value` to get a float32 output array where all values lie in the range [0, 1.0].
1594
+ If `max_value` is None the transform will try to infer the maximum value by inspecting the data type of the input
1595
+ image.
1596
+
1597
+ See Also:
1598
+ :class:`~albumentations.augmentations.transforms.FromFloat`
1599
+
1600
+ Args:
1601
+ max_value (float): maximum possible input value. Default: None.
1602
+ p (float): probability of applying the transform. Default: 1.0.
1603
+
1604
+ Targets:
1605
+ image
1606
+
1607
+ Image types:
1608
+ any type
1609
+
1610
+ """
1611
+
1612
+ def __init__(self, max_value=None, always_apply=False, p=1.0):
1613
+ super(ToFloat, self).__init__(always_apply, p)
1614
+ self.max_value = max_value
1615
+
1616
+ def apply(self, img, **params):
1617
+ return F.to_float(img, self.max_value)
1618
+
1619
+ def get_transform_init_args_names(self):
1620
+ return ("max_value",)
1621
+
1622
+
1623
+ class FromFloat(ImageOnlyTransform):
1624
+ """Take an input array where all values should lie in the range [0, 1.0], multiply them by `max_value` and then
1625
+ cast the resulted value to a type specified by `dtype`. If `max_value` is None the transform will try to infer
1626
+ the maximum value for the data type from the `dtype` argument.
1627
+
1628
+ This is the inverse transform for :class:`~albumentations.augmentations.transforms.ToFloat`.
1629
+
1630
+ Args:
1631
+ max_value (float): maximum possible input value. Default: None.
1632
+ dtype (string or numpy data type): data type of the output. See the `'Data types' page from the NumPy docs`_.
1633
+ Default: 'uint16'.
1634
+ p (float): probability of applying the transform. Default: 1.0.
1635
+
1636
+ Targets:
1637
+ image
1638
+
1639
+ Image types:
1640
+ float32
1641
+
1642
+ .. _'Data types' page from the NumPy docs:
1643
+ https://docs.scipy.org/doc/numpy/user/basics.types.html
1644
+ """
1645
+
1646
+ def __init__(self, dtype="uint16", max_value=None, always_apply=False, p=1.0):
1647
+ super(FromFloat, self).__init__(always_apply, p)
1648
+ self.dtype = np.dtype(dtype)
1649
+ self.max_value = max_value
1650
+
1651
+ def apply(self, img, **params):
1652
+ return F.from_float(img, self.dtype, self.max_value)
1653
+
1654
+ def get_transform_init_args(self):
1655
+ return {"dtype": self.dtype.name, "max_value": self.max_value}
1656
+
1657
+
1658
+ class Downscale(ImageOnlyTransform):
1659
+ """Decreases image quality by downscaling and upscaling back.
1660
+
1661
+ Args:
1662
+ scale_min (float): lower bound on the image scale. Should be < 1.
1663
+ scale_max (float): lower bound on the image scale. Should be .
1664
+ interpolation: cv2 interpolation method. Could be:
1665
+ - single cv2 interpolation flag - selected method will be used for downscale and upscale.
1666
+ - dict(downscale=flag, upscale=flag)
1667
+ - Downscale.Interpolation(downscale=flag, upscale=flag) -
1668
+ Default: Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
1669
+
1670
+ Targets:
1671
+ image
1672
+
1673
+ Image types:
1674
+ uint8, float32
1675
+ """
1676
+
1677
+ class Interpolation:
1678
+ def __init__(self, *, downscale: int = cv2.INTER_NEAREST, upscale: int = cv2.INTER_NEAREST):
1679
+ self.downscale = downscale
1680
+ self.upscale = upscale
1681
+
1682
+ def __init__(
1683
+ self,
1684
+ scale_min: float = 0.25,
1685
+ scale_max: float = 0.25,
1686
+ interpolation: Optional[Union[int, Interpolation, Dict[str, int]]] = None,
1687
+ always_apply: bool = False,
1688
+ p: float = 0.5,
1689
+ ):
1690
+ super(Downscale, self).__init__(always_apply, p)
1691
+ if interpolation is None:
1692
+ self.interpolation = self.Interpolation(downscale=cv2.INTER_NEAREST, upscale=cv2.INTER_NEAREST)
1693
+ warnings.warn(
1694
+ "Using default interpolation INTER_NEAREST, which is sub-optimal."
1695
+ "Please specify interpolation mode for downscale and upscale explicitly."
1696
+ "For additional information see this PR https://github.com/albumentations-team/albumentations/pull/584"
1697
+ )
1698
+ elif isinstance(interpolation, int):
1699
+ self.interpolation = self.Interpolation(downscale=interpolation, upscale=interpolation)
1700
+ elif isinstance(interpolation, self.Interpolation):
1701
+ self.interpolation = interpolation
1702
+ elif isinstance(interpolation, dict):
1703
+ self.interpolation = self.Interpolation(**interpolation)
1704
+ else:
1705
+ raise ValueError(
1706
+ "Wrong interpolation data type. Supported types: `Optional[Union[int, Interpolation, Dict[str, int]]]`."
1707
+ f" Got: {type(interpolation)}"
1708
+ )
1709
+
1710
+ if scale_min > scale_max:
1711
+ raise ValueError("Expected scale_min be less or equal scale_max, got {} {}".format(scale_min, scale_max))
1712
+ if scale_max >= 1:
1713
+ raise ValueError("Expected scale_max to be less than 1, got {}".format(scale_max))
1714
+ self.scale_min = scale_min
1715
+ self.scale_max = scale_max
1716
+
1717
+ def apply(self, img: np.ndarray, scale: Optional[float] = None, **params) -> np.ndarray:
1718
+ return F.downscale(
1719
+ img,
1720
+ scale=scale,
1721
+ down_interpolation=self.interpolation.downscale,
1722
+ up_interpolation=self.interpolation.upscale,
1723
+ )
1724
+
1725
+ def get_params(self) -> Dict[str, Any]:
1726
+ return {"scale": random.uniform(self.scale_min, self.scale_max)}
1727
+
1728
+ def get_transform_init_args_names(self) -> Tuple[str, str]:
1729
+ return "scale_min", "scale_max"
1730
+
1731
+ def _to_dict(self) -> Dict[str, Any]:
1732
+ result = super()._to_dict()
1733
+ result["interpolation"] = {"upscale": self.interpolation.upscale, "downscale": self.interpolation.downscale}
1734
+ return result
1735
+
1736
+
1737
+ class Lambda(NoOp):
1738
+ """A flexible transformation class for using user-defined transformation functions per targets.
1739
+ Function signature must include **kwargs to accept optinal arguments like interpolation method, image size, etc:
1740
+
1741
+ Args:
1742
+ image (callable): Image transformation function.
1743
+ mask (callable): Mask transformation function.
1744
+ keypoint (callable): Keypoint transformation function.
1745
+ bbox (callable): BBox transformation function.
1746
+ always_apply (bool): Indicates whether this transformation should be always applied.
1747
+ p (float): probability of applying the transform. Default: 1.0.
1748
+
1749
+ Targets:
1750
+ image, mask, bboxes, keypoints
1751
+
1752
+ Image types:
1753
+ Any
1754
+ """
1755
+
1756
+ def __init__(
1757
+ self,
1758
+ image=None,
1759
+ mask=None,
1760
+ keypoint=None,
1761
+ bbox=None,
1762
+ name=None,
1763
+ always_apply=False,
1764
+ p=1.0,
1765
+ ):
1766
+ super(Lambda, self).__init__(always_apply, p)
1767
+
1768
+ self.name = name
1769
+ self.custom_apply_fns = {target_name: F.noop for target_name in ("image", "mask", "keypoint", "bbox")}
1770
+ for target_name, custom_apply_fn in {
1771
+ "image": image,
1772
+ "mask": mask,
1773
+ "keypoint": keypoint,
1774
+ "bbox": bbox,
1775
+ }.items():
1776
+ if custom_apply_fn is not None:
1777
+ if isinstance(custom_apply_fn, LambdaType) and custom_apply_fn.__name__ == "<lambda>":
1778
+ warnings.warn(
1779
+ "Using lambda is incompatible with multiprocessing. "
1780
+ "Consider using regular functions or partial()."
1781
+ )
1782
+
1783
+ self.custom_apply_fns[target_name] = custom_apply_fn
1784
+
1785
+ def apply(self, img, **params):
1786
+ fn = self.custom_apply_fns["image"]
1787
+ return fn(img, **params)
1788
+
1789
+ def apply_to_mask(self, mask, **params):
1790
+ fn = self.custom_apply_fns["mask"]
1791
+ return fn(mask, **params)
1792
+
1793
+ def apply_to_bbox(self, bbox, **params):
1794
+ fn = self.custom_apply_fns["bbox"]
1795
+ return fn(bbox, **params)
1796
+
1797
+ def apply_to_keypoint(self, keypoint, **params):
1798
+ fn = self.custom_apply_fns["keypoint"]
1799
+ return fn(keypoint, **params)
1800
+
1801
+ @classmethod
1802
+ def is_serializable(cls):
1803
+ return False
1804
+
1805
+ def _to_dict(self):
1806
+ if self.name is None:
1807
+ raise ValueError(
1808
+ "To make a Lambda transform serializable you should provide the `name` argument, "
1809
+ "e.g. `Lambda(name='my_transform', image=<some func>, ...)`."
1810
+ )
1811
+ return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
1812
+
1813
+ def __repr__(self):
1814
+ state = {"name": self.name}
1815
+ state.update(self.custom_apply_fns.items())
1816
+ state.update(self.get_base_init_args())
1817
+ return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state))
1818
+
1819
+
1820
+ class MultiplicativeNoise(ImageOnlyTransform):
1821
+ """Multiply image to random number or array of numbers.
1822
+
1823
+ Args:
1824
+ multiplier (float or tuple of floats): If single float image will be multiplied to this number.
1825
+ If tuple of float multiplier will be in range `[multiplier[0], multiplier[1])`. Default: (0.9, 1.1).
1826
+ per_channel (bool): If `False`, same values for all channels will be used.
1827
+ If `True` use sample values for each channels. Default False.
1828
+ elementwise (bool): If `False` multiply multiply all pixels in an image with a random value sampled once.
1829
+ If `True` Multiply image pixels with values that are pixelwise randomly sampled. Defaule: False.
1830
+
1831
+ Targets:
1832
+ image
1833
+
1834
+ Image types:
1835
+ Any
1836
+ """
1837
+
1838
+ def __init__(
1839
+ self,
1840
+ multiplier=(0.9, 1.1),
1841
+ per_channel=False,
1842
+ elementwise=False,
1843
+ always_apply=False,
1844
+ p=0.5,
1845
+ ):
1846
+ super(MultiplicativeNoise, self).__init__(always_apply, p)
1847
+ self.multiplier = to_tuple(multiplier, multiplier)
1848
+ self.per_channel = per_channel
1849
+ self.elementwise = elementwise
1850
+
1851
+ def apply(self, img, multiplier=np.array([1]), **kwargs):
1852
+ return F.multiply(img, multiplier)
1853
+
1854
+ def get_params_dependent_on_targets(self, params):
1855
+ if self.multiplier[0] == self.multiplier[1]:
1856
+ return {"multiplier": np.array([self.multiplier[0]])}
1857
+
1858
+ img = params["image"]
1859
+
1860
+ h, w = img.shape[:2]
1861
+
1862
+ if self.per_channel:
1863
+ c = 1 if is_grayscale_image(img) else img.shape[-1]
1864
+ else:
1865
+ c = 1
1866
+
1867
+ if self.elementwise:
1868
+ shape = [h, w, c]
1869
+ else:
1870
+ shape = [c]
1871
+
1872
+ multiplier = random_utils.uniform(self.multiplier[0], self.multiplier[1], shape)
1873
+ if is_grayscale_image(img) and img.ndim == 2:
1874
+ multiplier = np.squeeze(multiplier)
1875
+
1876
+ return {"multiplier": multiplier}
1877
+
1878
+ @property
1879
+ def targets_as_params(self):
1880
+ return ["image"]
1881
+
1882
+ def get_transform_init_args_names(self):
1883
+ return "multiplier", "per_channel", "elementwise"
1884
+
1885
+
1886
+ class FancyPCA(ImageOnlyTransform):
1887
+ """Augment RGB image using FancyPCA from Krizhevsky's paper
1888
+ "ImageNet Classification with Deep Convolutional Neural Networks"
1889
+
1890
+ Args:
1891
+ alpha (float): how much to perturb/scale the eigen vecs and vals.
1892
+ scale is samples from gaussian distribution (mu=0, sigma=alpha)
1893
+
1894
+ Targets:
1895
+ image
1896
+
1897
+ Image types:
1898
+ 3-channel uint8 images only
1899
+
1900
+ Credit:
1901
+ http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks.pdf
1902
+ https://deshanadesai.github.io/notes/Fancy-PCA-with-Scikit-Image
1903
+ https://pixelatedbrian.github.io/2018-04-29-fancy_pca/
1904
+ """
1905
+
1906
+ def __init__(self, alpha=0.1, always_apply=False, p=0.5):
1907
+ super(FancyPCA, self).__init__(always_apply=always_apply, p=p)
1908
+ self.alpha = alpha
1909
+
1910
+ def apply(self, img, alpha=0.1, **params):
1911
+ img = F.fancy_pca(img, alpha)
1912
+ return img
1913
+
1914
+ def get_params(self):
1915
+ return {"alpha": random.gauss(0, self.alpha)}
1916
+
1917
+ def get_transform_init_args_names(self):
1918
+ return ("alpha",)
1919
+
1920
+
1921
+ class ColorJitter(ImageOnlyTransform):
1922
+ """Randomly changes the brightness, contrast, and saturation of an image. Compared to ColorJitter from torchvision,
1923
+ this transform gives a little bit different results because Pillow (used in torchvision) and OpenCV (used in
1924
+ Albumentations) transform an image to HSV format by different formulas. Another difference - Pillow uses uint8
1925
+ overflow, but we use value saturation.
1926
+
1927
+ Args:
1928
+ brightness (float or tuple of float (min, max)): How much to jitter brightness.
1929
+ brightness_factor is chosen uniformly from [max(0, 1 - brightness), 1 + brightness]
1930
+ or the given [min, max]. Should be non negative numbers.
1931
+ contrast (float or tuple of float (min, max)): How much to jitter contrast.
1932
+ contrast_factor is chosen uniformly from [max(0, 1 - contrast), 1 + contrast]
1933
+ or the given [min, max]. Should be non negative numbers.
1934
+ saturation (float or tuple of float (min, max)): How much to jitter saturation.
1935
+ saturation_factor is chosen uniformly from [max(0, 1 - saturation), 1 + saturation]
1936
+ or the given [min, max]. Should be non negative numbers.
1937
+ hue (float or tuple of float (min, max)): How much to jitter hue.
1938
+ hue_factor is chosen uniformly from [-hue, hue] or the given [min, max].
1939
+ Should have 0 <= hue <= 0.5 or -0.5 <= min <= max <= 0.5.
1940
+ """
1941
+
1942
+ def __init__(
1943
+ self,
1944
+ brightness=0.2,
1945
+ contrast=0.2,
1946
+ saturation=0.2,
1947
+ hue=0.2,
1948
+ always_apply=False,
1949
+ p=0.5,
1950
+ ):
1951
+ super(ColorJitter, self).__init__(always_apply=always_apply, p=p)
1952
+
1953
+ self.brightness = self.__check_values(brightness, "brightness")
1954
+ self.contrast = self.__check_values(contrast, "contrast")
1955
+ self.saturation = self.__check_values(saturation, "saturation")
1956
+ self.hue = self.__check_values(hue, "hue", offset=0, bounds=[-0.5, 0.5], clip=False)
1957
+
1958
+ self.transforms = [
1959
+ F.adjust_brightness_torchvision,
1960
+ F.adjust_contrast_torchvision,
1961
+ F.adjust_saturation_torchvision,
1962
+ F.adjust_hue_torchvision,
1963
+ ]
1964
+
1965
+ @staticmethod
1966
+ def __check_values(value, name, offset=1, bounds=(0, float("inf")), clip=True):
1967
+ if isinstance(value, numbers.Number):
1968
+ if value < 0:
1969
+ raise ValueError("If {} is a single number, it must be non negative.".format(name))
1970
+ value = [offset - value, offset + value]
1971
+ if clip:
1972
+ value[0] = max(value[0], 0)
1973
+ elif isinstance(value, (tuple, list)) and len(value) == 2:
1974
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
1975
+ raise ValueError("{} values should be between {}".format(name, bounds))
1976
+ else:
1977
+ raise TypeError("{} should be a single number or a list/tuple with length 2.".format(name))
1978
+
1979
+ return value
1980
+
1981
+ def get_params(self):
1982
+ brightness = random.uniform(self.brightness[0], self.brightness[1])
1983
+ contrast = random.uniform(self.contrast[0], self.contrast[1])
1984
+ saturation = random.uniform(self.saturation[0], self.saturation[1])
1985
+ hue = random.uniform(self.hue[0], self.hue[1])
1986
+
1987
+ order = [0, 1, 2, 3]
1988
+ random.shuffle(order)
1989
+
1990
+ return {
1991
+ "brightness": brightness,
1992
+ "contrast": contrast,
1993
+ "saturation": saturation,
1994
+ "hue": hue,
1995
+ "order": order,
1996
+ }
1997
+
1998
+ def apply(self, img, brightness=1.0, contrast=1.0, saturation=1.0, hue=0, order=[0, 1, 2, 3], **params):
1999
+ if not is_rgb_image(img) and not is_grayscale_image(img):
2000
+ raise TypeError("ColorJitter transformation expects 1-channel or 3-channel images.")
2001
+ params = [brightness, contrast, saturation, hue]
2002
+ for i in order:
2003
+ img = self.transforms[i](img, params[i])
2004
+ return img
2005
+
2006
+ def get_transform_init_args_names(self):
2007
+ return ("brightness", "contrast", "saturation", "hue")
2008
+
2009
+
2010
+ class Sharpen(ImageOnlyTransform):
2011
+ """Sharpen the input image and overlays the result with the original image.
2012
+
2013
+ Args:
2014
+ alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is
2015
+ visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5).
2016
+ lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0).
2017
+ p (float): probability of applying the transform. Default: 0.5.
2018
+
2019
+ Targets:
2020
+ image
2021
+ """
2022
+
2023
+ def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5):
2024
+ super(Sharpen, self).__init__(always_apply, p)
2025
+ self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
2026
+ self.lightness = self.__check_values(to_tuple(lightness, 0.0), name="lightness")
2027
+
2028
+ @staticmethod
2029
+ def __check_values(value, name, bounds=(0, float("inf"))):
2030
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
2031
+ raise ValueError("{} values should be between {}".format(name, bounds))
2032
+ return value
2033
+
2034
+ @staticmethod
2035
+ def __generate_sharpening_matrix(alpha_sample, lightness_sample):
2036
+ matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
2037
+ matrix_effect = np.array(
2038
+ [[-1, -1, -1], [-1, 8 + lightness_sample, -1], [-1, -1, -1]],
2039
+ dtype=np.float32,
2040
+ )
2041
+
2042
+ matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect
2043
+ return matrix
2044
+
2045
+ def get_params(self):
2046
+ alpha = random.uniform(*self.alpha)
2047
+ lightness = random.uniform(*self.lightness)
2048
+ sharpening_matrix = self.__generate_sharpening_matrix(alpha_sample=alpha, lightness_sample=lightness)
2049
+ return {"sharpening_matrix": sharpening_matrix}
2050
+
2051
+ def apply(self, img, sharpening_matrix=None, **params):
2052
+ return F.convolve(img, sharpening_matrix)
2053
+
2054
+ def get_transform_init_args_names(self):
2055
+ return ("alpha", "lightness")
2056
+
2057
+
2058
+ class Emboss(ImageOnlyTransform):
2059
+ """Emboss the input image and overlays the result with the original image.
2060
+
2061
+ Args:
2062
+ alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is
2063
+ visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5).
2064
+ strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7).
2065
+ p (float): probability of applying the transform. Default: 0.5.
2066
+
2067
+ Targets:
2068
+ image
2069
+ """
2070
+
2071
+ def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5):
2072
+ super(Emboss, self).__init__(always_apply, p)
2073
+ self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
2074
+ self.strength = self.__check_values(to_tuple(strength, 0.0), name="strength")
2075
+
2076
+ @staticmethod
2077
+ def __check_values(value, name, bounds=(0, float("inf"))):
2078
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
2079
+ raise ValueError("{} values should be between {}".format(name, bounds))
2080
+ return value
2081
+
2082
+ @staticmethod
2083
+ def __generate_emboss_matrix(alpha_sample, strength_sample):
2084
+ matrix_nochange = np.array([[0, 0, 0], [0, 1, 0], [0, 0, 0]], dtype=np.float32)
2085
+ matrix_effect = np.array(
2086
+ [
2087
+ [-1 - strength_sample, 0 - strength_sample, 0],
2088
+ [0 - strength_sample, 1, 0 + strength_sample],
2089
+ [0, 0 + strength_sample, 1 + strength_sample],
2090
+ ],
2091
+ dtype=np.float32,
2092
+ )
2093
+ matrix = (1 - alpha_sample) * matrix_nochange + alpha_sample * matrix_effect
2094
+ return matrix
2095
+
2096
+ def get_params(self):
2097
+ alpha = random.uniform(*self.alpha)
2098
+ strength = random.uniform(*self.strength)
2099
+ emboss_matrix = self.__generate_emboss_matrix(alpha_sample=alpha, strength_sample=strength)
2100
+ return {"emboss_matrix": emboss_matrix}
2101
+
2102
+ def apply(self, img, emboss_matrix=None, **params):
2103
+ return F.convolve(img, emboss_matrix)
2104
+
2105
+ def get_transform_init_args_names(self):
2106
+ return ("alpha", "strength")
2107
+
2108
+
2109
+ class Superpixels(ImageOnlyTransform):
2110
+ """Transform images partially/completely to their superpixel representation.
2111
+ This implementation uses skimage's version of the SLIC algorithm.
2112
+
2113
+ Args:
2114
+ p_replace (float or tuple of float): Defines for any segment the probability that the pixels within that
2115
+ segment are replaced by their average color (otherwise, the pixels are not changed).
2116
+ Examples:
2117
+ * A probability of ``0.0`` would mean, that the pixels in no
2118
+ segment are replaced by their average color (image is not
2119
+ changed at all).
2120
+ * A probability of ``0.5`` would mean, that around half of all
2121
+ segments are replaced by their average color.
2122
+ * A probability of ``1.0`` would mean, that all segments are
2123
+ replaced by their average color (resulting in a voronoi
2124
+ image).
2125
+ Behaviour based on chosen data types for this parameter:
2126
+ * If a ``float``, then that ``flat`` will always be used.
2127
+ * If ``tuple`` ``(a, b)``, then a random probability will be
2128
+ sampled from the interval ``[a, b]`` per image.
2129
+ n_segments (int, or tuple of int): Rough target number of how many superpixels to generate (the algorithm
2130
+ may deviate from this number). Lower value will lead to coarser superpixels.
2131
+ Higher values are computationally more intensive and will hence lead to a slowdown
2132
+ * If a single ``int``, then that value will always be used as the
2133
+ number of segments.
2134
+ * If a ``tuple`` ``(a, b)``, then a value from the discrete
2135
+ interval ``[a..b]`` will be sampled per image.
2136
+ max_size (int or None): Maximum image size at which the augmentation is performed.
2137
+ If the width or height of an image exceeds this value, it will be
2138
+ downscaled before the augmentation so that the longest side matches `max_size`.
2139
+ This is done to speed up the process. The final output image has the same size as the input image.
2140
+ Note that in case `p_replace` is below ``1.0``,
2141
+ the down-/upscaling will affect the not-replaced pixels too.
2142
+ Use ``None`` to apply no down-/upscaling.
2143
+ interpolation (OpenCV flag): flag that is used to specify the interpolation algorithm. Should be one of:
2144
+ cv2.INTER_NEAREST, cv2.INTER_LINEAR, cv2.INTER_CUBIC, cv2.INTER_AREA, cv2.INTER_LANCZOS4.
2145
+ Default: cv2.INTER_LINEAR.
2146
+ p (float): probability of applying the transform. Default: 0.5.
2147
+
2148
+ Targets:
2149
+ image
2150
+ """
2151
+
2152
+ def __init__(
2153
+ self,
2154
+ p_replace: Union[float, Sequence[float]] = 0.1,
2155
+ n_segments: Union[int, Sequence[int]] = 100,
2156
+ max_size: Optional[int] = 128,
2157
+ interpolation: int = cv2.INTER_LINEAR,
2158
+ always_apply: bool = False,
2159
+ p: float = 0.5,
2160
+ ):
2161
+ super().__init__(always_apply=always_apply, p=p)
2162
+ self.p_replace = to_tuple(p_replace, p_replace)
2163
+ self.n_segments = to_tuple(n_segments, n_segments)
2164
+ self.max_size = max_size
2165
+ self.interpolation = interpolation
2166
+
2167
+ if min(self.n_segments) < 1:
2168
+ raise ValueError(f"n_segments must be >= 1. Got: {n_segments}")
2169
+
2170
+ def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
2171
+ return ("p_replace", "n_segments", "max_size", "interpolation")
2172
+
2173
+ def get_params(self) -> dict:
2174
+ n_segments = random.randint(*self.n_segments)
2175
+ p = random.uniform(*self.p_replace)
2176
+ return {"replace_samples": random_utils.random(n_segments) < p, "n_segments": n_segments}
2177
+
2178
+ def apply(self, img: np.ndarray, replace_samples: Sequence[bool] = (False,), n_segments: int = 1, **kwargs):
2179
+ return F.superpixels(img, n_segments, replace_samples, self.max_size, self.interpolation)
2180
+
2181
+
2182
+ class TemplateTransform(ImageOnlyTransform):
2183
+ """
2184
+ Apply blending of input image with specified templates
2185
+ Args:
2186
+ templates (numpy array or list of numpy arrays): Images as template for transform.
2187
+ img_weight ((float, float) or float): If single float will be used as weight for input image.
2188
+ If tuple of float img_weight will be in range `[img_weight[0], img_weight[1])`. Default: 0.5.
2189
+ template_weight ((float, float) or float): If single float will be used as weight for template.
2190
+ If tuple of float template_weight will be in range `[template_weight[0], template_weight[1])`.
2191
+ Default: 0.5.
2192
+ template_transform: transformation object which could be applied to template,
2193
+ must produce template the same size as input image.
2194
+ name (string): (Optional) Name of transform, used only for deserialization.
2195
+ p (float): probability of applying the transform. Default: 0.5.
2196
+ Targets:
2197
+ image
2198
+ Image types:
2199
+ uint8, float32
2200
+ """
2201
+
2202
+ def __init__(
2203
+ self,
2204
+ templates,
2205
+ img_weight=0.5,
2206
+ template_weight=0.5,
2207
+ template_transform=None,
2208
+ name=None,
2209
+ always_apply=False,
2210
+ p=0.5,
2211
+ ):
2212
+ super().__init__(always_apply, p)
2213
+
2214
+ self.templates = templates if isinstance(templates, (list, tuple)) else [templates]
2215
+ self.img_weight = to_tuple(img_weight, img_weight)
2216
+ self.template_weight = to_tuple(template_weight, template_weight)
2217
+ self.template_transform = template_transform
2218
+ self.name = name
2219
+
2220
+ def apply(self, img, template=None, img_weight=0.5, template_weight=0.5, **params):
2221
+ return F.add_weighted(img, img_weight, template, template_weight)
2222
+
2223
+ def get_params(self):
2224
+ return {
2225
+ "img_weight": random.uniform(self.img_weight[0], self.img_weight[1]),
2226
+ "template_weight": random.uniform(self.template_weight[0], self.template_weight[1]),
2227
+ }
2228
+
2229
+ def get_params_dependent_on_targets(self, params):
2230
+ img = params["image"]
2231
+ template = random.choice(self.templates)
2232
+
2233
+ if self.template_transform is not None:
2234
+ template = self.template_transform(image=template)["image"]
2235
+
2236
+ if get_num_channels(template) not in [1, get_num_channels(img)]:
2237
+ raise ValueError(
2238
+ "Template must be a single channel or "
2239
+ "has the same number of channels as input image ({}), got {}".format(
2240
+ get_num_channels(img), get_num_channels(template)
2241
+ )
2242
+ )
2243
+
2244
+ if template.dtype != img.dtype:
2245
+ raise ValueError("Image and template must be the same image type")
2246
+
2247
+ if img.shape[:2] != template.shape[:2]:
2248
+ raise ValueError(
2249
+ "Image and template must be the same size, got {} and {}".format(img.shape[:2], template.shape[:2])
2250
+ )
2251
+
2252
+ if get_num_channels(template) == 1 and get_num_channels(img) > 1:
2253
+ template = np.stack((template,) * get_num_channels(img), axis=-1)
2254
+
2255
+ # in order to support grayscale image with dummy dim
2256
+ template = template.reshape(img.shape)
2257
+
2258
+ return {"template": template}
2259
+
2260
+ @classmethod
2261
+ def is_serializable(cls):
2262
+ return False
2263
+
2264
+ @property
2265
+ def targets_as_params(self):
2266
+ return ["image"]
2267
+
2268
+ def _to_dict(self):
2269
+ if self.name is None:
2270
+ raise ValueError(
2271
+ "To make a TemplateTransform serializable you should provide the `name` argument, "
2272
+ "e.g. `TemplateTransform(name='my_transform', ...)`."
2273
+ )
2274
+ return {"__class_fullname__": self.get_class_fullname(), "__name__": self.name}
2275
+
2276
+
2277
+ class RingingOvershoot(ImageOnlyTransform):
2278
+ """Create ringing or overshoot artefacts by conlvolving image with 2D sinc filter.
2279
+
2280
+ Args:
2281
+ blur_limit (int, (int, int)): maximum kernel size for sinc filter.
2282
+ Should be in range [3, inf). Default: (7, 15).
2283
+ cutoff (float, (float, float)): range to choose the cutoff frequency in radians.
2284
+ Should be in range (0, np.pi)
2285
+ Default: (np.pi / 4, np.pi / 2).
2286
+ p (float): probability of applying the transform. Default: 0.5.
2287
+
2288
+ Reference:
2289
+ dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
2290
+ https://arxiv.org/abs/2107.10833
2291
+
2292
+ Targets:
2293
+ image
2294
+ """
2295
+
2296
+ def __init__(
2297
+ self,
2298
+ blur_limit: Union[int, Sequence[int]] = (7, 15),
2299
+ cutoff: Union[float, Sequence[float]] = (np.pi / 4, np.pi / 2),
2300
+ always_apply=False,
2301
+ p=0.5,
2302
+ ):
2303
+ super(RingingOvershoot, self).__init__(always_apply, p)
2304
+ self.blur_limit = to_tuple(blur_limit, 3)
2305
+ self.cutoff = self.__check_values(to_tuple(cutoff, np.pi / 2), name="cutoff", bounds=(0, np.pi))
2306
+
2307
+ @staticmethod
2308
+ def __check_values(value, name, bounds=(0, float("inf"))):
2309
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
2310
+ raise ValueError(f"{name} values should be between {bounds}")
2311
+ return value
2312
+
2313
+ def get_params(self):
2314
+ ksize = random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2)
2315
+ if ksize % 2 == 0:
2316
+ raise ValueError(f"Kernel size must be odd. Got: {ksize}")
2317
+
2318
+ cutoff = random.uniform(*self.cutoff)
2319
+
2320
+ # From dsp.stackexchange.com/questions/58301/2-d-circularly-symmetric-low-pass-filter
2321
+ with np.errstate(divide="ignore", invalid="ignore"):
2322
+ kernel = np.fromfunction(
2323
+ lambda x, y: cutoff
2324
+ * special.j1(cutoff * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2))
2325
+ / (2 * np.pi * np.sqrt((x - (ksize - 1) / 2) ** 2 + (y - (ksize - 1) / 2) ** 2)),
2326
+ [ksize, ksize],
2327
+ )
2328
+ kernel[(ksize - 1) // 2, (ksize - 1) // 2] = cutoff**2 / (4 * np.pi)
2329
+
2330
+ # Normalize kernel
2331
+ kernel = kernel.astype(np.float32) / np.sum(kernel)
2332
+
2333
+ return {"kernel": kernel}
2334
+
2335
+ def apply(self, img, kernel=None, **params):
2336
+ return F.convolve(img, kernel)
2337
+
2338
+ def get_transform_init_args_names(self):
2339
+ return ("blur_limit", "cutoff")
2340
+
2341
+
2342
+ class UnsharpMask(ImageOnlyTransform):
2343
+ """
2344
+ Sharpen the input image using Unsharp Masking processing and overlays the result with the original image.
2345
+
2346
+ Args:
2347
+ blur_limit (int, (int, int)): maximum Gaussian kernel size for blurring the input image.
2348
+ Must be zero or odd and in range [0, inf). If set to 0 it will be computed from sigma
2349
+ as `round(sigma * (3 if img.dtype == np.uint8 else 4) * 2 + 1) + 1`.
2350
+ If set single value `blur_limit` will be in range (0, blur_limit).
2351
+ Default: (3, 7).
2352
+ sigma_limit (float, (float, float)): Gaussian kernel standard deviation. Must be in range [0, inf).
2353
+ If set single value `sigma_limit` will be in range (0, sigma_limit).
2354
+ If set to 0 sigma will be computed as `sigma = 0.3*((ksize-1)*0.5 - 1) + 0.8`. Default: 0.
2355
+ alpha (float, (float, float)): range to choose the visibility of the sharpened image.
2356
+ At 0, only the original image is visible, at 1.0 only its sharpened version is visible.
2357
+ Default: (0.2, 0.5).
2358
+ threshold (int): Value to limit sharpening only for areas with high pixel difference between original image
2359
+ and it's smoothed version. Higher threshold means less sharpening on flat areas.
2360
+ Must be in range [0, 255]. Default: 10.
2361
+ p (float): probability of applying the transform. Default: 0.5.
2362
+
2363
+ Reference:
2364
+ arxiv.org/pdf/2107.10833.pdf
2365
+
2366
+ Targets:
2367
+ image
2368
+ """
2369
+
2370
+ def __init__(
2371
+ self,
2372
+ blur_limit: Union[int, Sequence[int]] = (3, 7),
2373
+ sigma_limit: Union[float, Sequence[float]] = 0.0,
2374
+ alpha: Union[float, Sequence[float]] = (0.2, 0.5),
2375
+ threshold: int = 10,
2376
+ always_apply=False,
2377
+ p=0.5,
2378
+ ):
2379
+ super(UnsharpMask, self).__init__(always_apply, p)
2380
+ self.blur_limit = to_tuple(blur_limit, 3)
2381
+ self.sigma_limit = self.__check_values(to_tuple(sigma_limit, 0.0), name="sigma_limit")
2382
+ self.alpha = self.__check_values(to_tuple(alpha, 0.0), name="alpha", bounds=(0.0, 1.0))
2383
+ self.threshold = threshold
2384
+
2385
+ if self.blur_limit[0] == 0 and self.sigma_limit[0] == 0:
2386
+ self.blur_limit = 3, max(3, self.blur_limit[1])
2387
+ raise ValueError("blur_limit and sigma_limit minimum value can not be both equal to 0.")
2388
+
2389
+ if (self.blur_limit[0] != 0 and self.blur_limit[0] % 2 != 1) or (
2390
+ self.blur_limit[1] != 0 and self.blur_limit[1] % 2 != 1
2391
+ ):
2392
+ raise ValueError("UnsharpMask supports only odd blur limits.")
2393
+
2394
+ @staticmethod
2395
+ def __check_values(value, name, bounds=(0, float("inf"))):
2396
+ if not bounds[0] <= value[0] <= value[1] <= bounds[1]:
2397
+ raise ValueError(f"{name} values should be between {bounds}")
2398
+ return value
2399
+
2400
+ def get_params(self):
2401
+ return {
2402
+ "ksize": random.randrange(self.blur_limit[0], self.blur_limit[1] + 1, 2),
2403
+ "sigma": random.uniform(*self.sigma_limit),
2404
+ "alpha": random.uniform(*self.alpha),
2405
+ }
2406
+
2407
+ def apply(self, img, ksize=3, sigma=0, alpha=0.2, **params):
2408
+ return F.unsharp_mask(img, ksize, sigma=sigma, alpha=alpha, threshold=self.threshold)
2409
+
2410
+ def get_transform_init_args_names(self):
2411
+ return ("blur_limit", "sigma_limit", "alpha", "threshold")
2412
+
2413
+
2414
+ class PixelDropout(DualTransform):
2415
+ """Set pixels to 0 with some probability.
2416
+
2417
+ Args:
2418
+ dropout_prob (float): pixel drop probability. Default: 0.01
2419
+ per_channel (bool): if set to `True` drop mask will be sampled fo each channel,
2420
+ otherwise the same mask will be sampled for all channels. Default: False
2421
+ drop_value (number or sequence of numbers or None): Value that will be set in dropped place.
2422
+ If set to None value will be sampled randomly, default ranges will be used:
2423
+ - uint8 - [0, 255]
2424
+ - uint16 - [0, 65535]
2425
+ - uint32 - [0, 4294967295]
2426
+ - float, double - [0, 1]
2427
+ Default: 0
2428
+ mask_drop_value (number or sequence of numbers or None): Value that will be set in dropped place in masks.
2429
+ If set to None masks will be unchanged. Default: 0
2430
+ p (float): probability of applying the transform. Default: 0.5.
2431
+
2432
+ Targets:
2433
+ image, mask
2434
+ Image types:
2435
+ any
2436
+ """
2437
+
2438
+ def __init__(
2439
+ self,
2440
+ dropout_prob: float = 0.01,
2441
+ per_channel: bool = False,
2442
+ drop_value: Optional[Union[float, Sequence[float]]] = 0,
2443
+ mask_drop_value: Optional[Union[float, Sequence[float]]] = None,
2444
+ always_apply: bool = False,
2445
+ p: float = 0.5,
2446
+ ):
2447
+ super().__init__(always_apply, p)
2448
+ self.dropout_prob = dropout_prob
2449
+ self.per_channel = per_channel
2450
+ self.drop_value = drop_value
2451
+ self.mask_drop_value = mask_drop_value
2452
+
2453
+ if self.mask_drop_value is not None and self.per_channel:
2454
+ raise ValueError("PixelDropout supports mask only with per_channel=False")
2455
+
2456
+ def apply(
2457
+ self,
2458
+ img: np.ndarray,
2459
+ drop_mask: np.ndarray = np.array(None),
2460
+ drop_value: Union[float, Sequence[float]] = (),
2461
+ **params
2462
+ ) -> np.ndarray:
2463
+ return F.pixel_dropout(img, drop_mask, drop_value)
2464
+
2465
+ def apply_to_mask(self, img: np.ndarray, drop_mask: np.ndarray = np.array(None), **params) -> np.ndarray:
2466
+ if self.mask_drop_value is None:
2467
+ return img
2468
+
2469
+ if img.ndim == 2:
2470
+ drop_mask = np.squeeze(drop_mask)
2471
+
2472
+ return F.pixel_dropout(img, drop_mask, self.mask_drop_value)
2473
+
2474
+ def apply_to_bbox(self, bbox, **params):
2475
+ return bbox
2476
+
2477
+ def apply_to_keypoint(self, keypoint, **params):
2478
+ return keypoint
2479
+
2480
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
2481
+ img = params["image"]
2482
+ shape = img.shape if self.per_channel else img.shape[:2]
2483
+
2484
+ rnd = np.random.RandomState(random.randint(0, 1 << 31))
2485
+ # Use choice to create boolean matrix, if we will use binomial after that we will need type conversion
2486
+ drop_mask = rnd.choice([True, False], shape, p=[self.dropout_prob, 1 - self.dropout_prob])
2487
+
2488
+ drop_value: Union[float, Sequence[float], np.ndarray]
2489
+ if drop_mask.ndim != img.ndim:
2490
+ drop_mask = np.expand_dims(drop_mask, -1)
2491
+ if self.drop_value is None:
2492
+ drop_shape = 1 if is_grayscale_image(img) else int(img.shape[-1])
2493
+
2494
+ if img.dtype in (np.uint8, np.uint16, np.uint32):
2495
+ drop_value = rnd.randint(0, int(F.MAX_VALUES_BY_DTYPE[img.dtype]), drop_shape, img.dtype)
2496
+ elif img.dtype in [np.float32, np.double]:
2497
+ drop_value = rnd.uniform(0, 1, drop_shape).astype(img.dtype)
2498
+ else:
2499
+ raise ValueError(f"Unsupported dtype: {img.dtype}")
2500
+ else:
2501
+ drop_value = self.drop_value
2502
+
2503
+ return {"drop_mask": drop_mask, "drop_value": drop_value}
2504
+
2505
+ @property
2506
+ def targets_as_params(self) -> List[str]:
2507
+ return ["image"]
2508
+
2509
+ def get_transform_init_args_names(self) -> Tuple[str, str, str, str]:
2510
+ return ("dropout_prob", "per_channel", "drop_value", "mask_drop_value")
2511
+
2512
+
2513
+ class Spatter(ImageOnlyTransform):
2514
+ """
2515
+ Apply spatter transform. It simulates corruption which can occlude a lens in the form of rain or mud.
2516
+
2517
+ Args:
2518
+ mean (float, or tuple of floats): Mean value of normal distribution for generating liquid layer.
2519
+ If single float it will be used as mean.
2520
+ If tuple of float mean will be sampled from range `[mean[0], mean[1])`. Default: (0.65).
2521
+ std (float, or tuple of floats): Standard deviation value of normal distribution for generating liquid layer.
2522
+ If single float it will be used as std.
2523
+ If tuple of float std will be sampled from range `[std[0], std[1])`. Default: (0.3).
2524
+ gauss_sigma (float, or tuple of floats): Sigma value for gaussian filtering of liquid layer.
2525
+ If single float it will be used as gauss_sigma.
2526
+ If tuple of float gauss_sigma will be sampled from range `[sigma[0], sigma[1])`. Default: (2).
2527
+ cutout_threshold (float, or tuple of floats): Threshold for filtering liqued layer
2528
+ (determines number of drops). If single float it will used as cutout_threshold.
2529
+ If tuple of float cutout_threshold will be sampled from range `[cutout_threshold[0], cutout_threshold[1])`.
2530
+ Default: (0.68).
2531
+ intensity (float, or tuple of floats): Intensity of corruption.
2532
+ If single float it will be used as intensity.
2533
+ If tuple of float intensity will be sampled from range `[intensity[0], intensity[1])`. Default: (0.6).
2534
+ mode (string, or list of strings): Type of corruption. Currently, supported options are 'rain' and 'mud'.
2535
+ If list is provided type of corruption will be sampled list. Default: ("rain").
2536
+ color (list of (r, g, b) or dict or None): Corruption elements color.
2537
+ If list uses provided list as color for specified mode.
2538
+ If dict uses provided color for specified mode. Color for each specified mode should be provided in dict.
2539
+ If None uses default colors (rain: (238, 238, 175), mud: (20, 42, 63)).
2540
+ p (float): probability of applying the transform. Default: 0.5.
2541
+
2542
+ Targets:
2543
+ image
2544
+
2545
+ Image types:
2546
+ uint8, float32
2547
+
2548
+ Reference:
2549
+ | https://arxiv.org/pdf/1903.12261.pdf
2550
+ | https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py
2551
+ """
2552
+
2553
+ def __init__(
2554
+ self,
2555
+ mean: ScaleFloatType = 0.65,
2556
+ std: ScaleFloatType = 0.3,
2557
+ gauss_sigma: ScaleFloatType = 2,
2558
+ cutout_threshold: ScaleFloatType = 0.68,
2559
+ intensity: ScaleFloatType = 0.6,
2560
+ mode: Union[str, Sequence[str]] = "rain",
2561
+ color: Optional[Union[Sequence[int], Dict[str, Sequence[int]]]] = None,
2562
+ always_apply: bool = False,
2563
+ p: float = 0.5,
2564
+ ):
2565
+ super().__init__(always_apply=always_apply, p=p)
2566
+
2567
+ self.mean = to_tuple(mean, mean)
2568
+ self.std = to_tuple(std, std)
2569
+ self.gauss_sigma = to_tuple(gauss_sigma, gauss_sigma)
2570
+ self.intensity = to_tuple(intensity, intensity)
2571
+ self.cutout_threshold = to_tuple(cutout_threshold, cutout_threshold)
2572
+ self.color = (
2573
+ color
2574
+ if color is not None
2575
+ else {
2576
+ "rain": [238, 238, 175],
2577
+ "mud": [20, 42, 63],
2578
+ }
2579
+ )
2580
+ self.mode = mode if isinstance(mode, (list, tuple)) else [mode]
2581
+
2582
+ if len(set(self.mode)) > 1 and not isinstance(self.color, dict):
2583
+ raise ValueError(f"Unsupported color: {self.color}. Please specify color for each mode (use dict for it).")
2584
+
2585
+ for i in self.mode:
2586
+ if i not in ["rain", "mud"]:
2587
+ raise ValueError(f"Unsupported color mode: {mode}. Transform supports only `rain` and `mud` mods.")
2588
+ if isinstance(self.color, dict):
2589
+ if i not in self.color:
2590
+ raise ValueError(f"Wrong color definition: {self.color}. Color for mode: {i} not specified.")
2591
+ if len(self.color[i]) != 3:
2592
+ raise ValueError(
2593
+ f"Unsupported color: {self.color[i]} for mode {i}. Color should be presented in RGB format."
2594
+ )
2595
+
2596
+ if isinstance(self.color, (list, tuple)):
2597
+ if len(self.color) != 3:
2598
+ raise ValueError(f"Unsupported color: {self.color}. Color should be presented in RGB format.")
2599
+ self.color = {self.mode[0]: self.color}
2600
+
2601
+ def apply(
2602
+ self,
2603
+ img: np.ndarray,
2604
+ non_mud: Optional[np.ndarray] = None,
2605
+ mud: Optional[np.ndarray] = None,
2606
+ drops: Optional[np.ndarray] = None,
2607
+ mode: str = "",
2608
+ **params
2609
+ ) -> np.ndarray:
2610
+ return F.spatter(img, non_mud, mud, drops, mode)
2611
+
2612
+ @property
2613
+ def targets_as_params(self) -> List[str]:
2614
+ return ["image"]
2615
+
2616
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
2617
+ h, w = params["image"].shape[:2]
2618
+
2619
+ mean = random.uniform(self.mean[0], self.mean[1])
2620
+ std = random.uniform(self.std[0], self.std[1])
2621
+ cutout_threshold = random.uniform(self.cutout_threshold[0], self.cutout_threshold[1])
2622
+ sigma = random.uniform(self.gauss_sigma[0], self.gauss_sigma[1])
2623
+ mode = random.choice(self.mode)
2624
+ intensity = random.uniform(self.intensity[0], self.intensity[1])
2625
+ color = np.array(self.color[mode]) / 255.0
2626
+
2627
+ liquid_layer = random_utils.normal(size=(h, w), loc=mean, scale=std)
2628
+ liquid_layer = gaussian_filter(liquid_layer, sigma=sigma, mode="nearest")
2629
+ liquid_layer[liquid_layer < cutout_threshold] = 0
2630
+
2631
+ if mode == "rain":
2632
+ liquid_layer = (liquid_layer * 255).astype(np.uint8)
2633
+ dist = 255 - cv2.Canny(liquid_layer, 50, 150)
2634
+ dist = cv2.distanceTransform(dist, cv2.DIST_L2, 5)
2635
+ _, dist = cv2.threshold(dist, 20, 20, cv2.THRESH_TRUNC)
2636
+ dist = blur(dist, 3).astype(np.uint8)
2637
+ dist = F.equalize(dist)
2638
+
2639
+ ker = np.array([[-2, -1, 0], [-1, 1, 1], [0, 1, 2]])
2640
+ dist = F.convolve(dist, ker)
2641
+ dist = blur(dist, 3).astype(np.float32)
2642
+
2643
+ m = liquid_layer * dist
2644
+ m *= 1 / np.max(m, axis=(0, 1))
2645
+
2646
+ drops = m[:, :, None] * color * intensity
2647
+ mud = None
2648
+ non_mud = None
2649
+ else:
2650
+ m = np.where(liquid_layer > cutout_threshold, 1, 0)
2651
+ m = gaussian_filter(m.astype(np.float32), sigma=sigma, mode="nearest")
2652
+ m[m < 1.2 * cutout_threshold] = 0
2653
+ m = m[..., np.newaxis]
2654
+
2655
+ mud = m * color
2656
+ non_mud = 1 - m
2657
+ drops = None
2658
+
2659
+ return {
2660
+ "non_mud": non_mud,
2661
+ "mud": mud,
2662
+ "drops": drops,
2663
+ "mode": mode,
2664
+ }
2665
+
2666
+ def get_transform_init_args_names(self) -> Tuple[str, str, str, str, str, str, str]:
2667
+ return "mean", "std", "gauss_sigma", "intensity", "cutout_threshold", "mode", "color"
custom_albumentations/augmentations/utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import wraps
2
+ from typing import Callable, Union
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from typing_extensions import Concatenate, ParamSpec
7
+
8
+ from custom_albumentations.core.keypoints_utils import angle_to_2pi_range
9
+ from custom_albumentations.core.transforms_interface import KeypointInternalType
10
+
11
+ __all__ = [
12
+ "read_bgr_image",
13
+ "read_rgb_image",
14
+ "MAX_VALUES_BY_DTYPE",
15
+ "NPDTYPE_TO_OPENCV_DTYPE",
16
+ "clipped",
17
+ "get_opencv_dtype_from_numpy",
18
+ "angle_2pi_range",
19
+ "clip",
20
+ "preserve_shape",
21
+ "preserve_channel_dim",
22
+ "ensure_contiguous",
23
+ "is_rgb_image",
24
+ "is_grayscale_image",
25
+ "is_multispectral_image",
26
+ "get_num_channels",
27
+ "non_rgb_warning",
28
+ "_maybe_process_in_chunks",
29
+ ]
30
+
31
+ P = ParamSpec("P")
32
+
33
+ MAX_VALUES_BY_DTYPE = {
34
+ np.dtype("uint8"): 255,
35
+ np.dtype("uint16"): 65535,
36
+ np.dtype("uint32"): 4294967295,
37
+ np.dtype("float32"): 1.0,
38
+ }
39
+
40
+ NPDTYPE_TO_OPENCV_DTYPE = {
41
+ np.uint8: cv2.CV_8U, # type: ignore[attr-defined]
42
+ np.uint16: cv2.CV_16U, # type: ignore[attr-defined]
43
+ np.int32: cv2.CV_32S, # type: ignore[attr-defined]
44
+ np.float32: cv2.CV_32F, # type: ignore[attr-defined]
45
+ np.float64: cv2.CV_64F, # type: ignore[attr-defined]
46
+ np.dtype("uint8"): cv2.CV_8U, # type: ignore[attr-defined]
47
+ np.dtype("uint16"): cv2.CV_16U, # type: ignore[attr-defined]
48
+ np.dtype("int32"): cv2.CV_32S, # type: ignore[attr-defined]
49
+ np.dtype("float32"): cv2.CV_32F, # type: ignore[attr-defined]
50
+ np.dtype("float64"): cv2.CV_64F, # type: ignore[attr-defined]
51
+ }
52
+
53
+
54
+ def read_bgr_image(path):
55
+ return cv2.imread(path, cv2.IMREAD_COLOR)
56
+
57
+
58
+ def read_rgb_image(path):
59
+ image = cv2.imread(path, cv2.IMREAD_COLOR)
60
+ return cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
61
+
62
+
63
+ def clipped(func: Callable[Concatenate[np.ndarray, P], np.ndarray]) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
64
+ @wraps(func)
65
+ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
66
+ dtype = img.dtype
67
+ maxval = MAX_VALUES_BY_DTYPE.get(dtype, 1.0)
68
+ return clip(func(img, *args, **kwargs), dtype, maxval)
69
+
70
+ return wrapped_function
71
+
72
+
73
+ def clip(img: np.ndarray, dtype: np.dtype, maxval: float) -> np.ndarray:
74
+ return np.clip(img, 0, maxval).astype(dtype)
75
+
76
+
77
+ def get_opencv_dtype_from_numpy(value: Union[np.ndarray, int, np.dtype, object]) -> int:
78
+ """
79
+ Return a corresponding OpenCV dtype for a numpy's dtype
80
+ :param value: Input dtype of numpy array
81
+ :return: Corresponding dtype for OpenCV
82
+ """
83
+ if isinstance(value, np.ndarray):
84
+ value = value.dtype
85
+ return NPDTYPE_TO_OPENCV_DTYPE[value]
86
+
87
+
88
+ def angle_2pi_range(
89
+ func: Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]
90
+ ) -> Callable[Concatenate[KeypointInternalType, P], KeypointInternalType]:
91
+ @wraps(func)
92
+ def wrapped_function(keypoint: KeypointInternalType, *args: P.args, **kwargs: P.kwargs) -> KeypointInternalType:
93
+ (x, y, a, s) = func(keypoint, *args, **kwargs)[:4]
94
+ return (x, y, angle_to_2pi_range(a), s)
95
+
96
+ return wrapped_function
97
+
98
+
99
+ def preserve_shape(
100
+ func: Callable[Concatenate[np.ndarray, P], np.ndarray]
101
+ ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
102
+ """Preserve shape of the image"""
103
+
104
+ @wraps(func)
105
+ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
106
+ shape = img.shape
107
+ result = func(img, *args, **kwargs)
108
+ result = result.reshape(shape)
109
+ return result
110
+
111
+ return wrapped_function
112
+
113
+
114
+ def preserve_channel_dim(
115
+ func: Callable[Concatenate[np.ndarray, P], np.ndarray]
116
+ ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
117
+ """Preserve dummy channel dim."""
118
+
119
+ @wraps(func)
120
+ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
121
+ shape = img.shape
122
+ result = func(img, *args, **kwargs)
123
+ if len(shape) == 3 and shape[-1] == 1 and len(result.shape) == 2:
124
+ result = np.expand_dims(result, axis=-1)
125
+ return result
126
+
127
+ return wrapped_function
128
+
129
+
130
+ def ensure_contiguous(
131
+ func: Callable[Concatenate[np.ndarray, P], np.ndarray]
132
+ ) -> Callable[Concatenate[np.ndarray, P], np.ndarray]:
133
+ """Ensure that input img is contiguous."""
134
+
135
+ @wraps(func)
136
+ def wrapped_function(img: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> np.ndarray:
137
+ img = np.require(img, requirements=["C_CONTIGUOUS"])
138
+ result = func(img, *args, **kwargs)
139
+ return result
140
+
141
+ return wrapped_function
142
+
143
+
144
+ def is_rgb_image(image: np.ndarray) -> bool:
145
+ return len(image.shape) == 3 and image.shape[-1] == 3
146
+
147
+
148
+ def is_grayscale_image(image: np.ndarray) -> bool:
149
+ return (len(image.shape) == 2) or (len(image.shape) == 3 and image.shape[-1] == 1)
150
+
151
+
152
+ def is_multispectral_image(image: np.ndarray) -> bool:
153
+ return len(image.shape) == 3 and image.shape[-1] not in [1, 3]
154
+
155
+
156
+ def get_num_channels(image: np.ndarray) -> int:
157
+ return image.shape[2] if len(image.shape) == 3 else 1
158
+
159
+
160
+ def non_rgb_warning(image: np.ndarray) -> None:
161
+ if not is_rgb_image(image):
162
+ message = "This transformation expects 3-channel images"
163
+ if is_grayscale_image(image):
164
+ message += "\nYou can convert your grayscale image to RGB using cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))"
165
+ if is_multispectral_image(image): # Any image with a number of channels other than 1 and 3
166
+ message += "\nThis transformation cannot be applied to multi-spectral images"
167
+
168
+ raise ValueError(message)
169
+
170
+
171
+ def _maybe_process_in_chunks(
172
+ process_fn: Callable[Concatenate[np.ndarray, P], np.ndarray], **kwargs
173
+ ) -> Callable[[np.ndarray], np.ndarray]:
174
+ """
175
+ Wrap OpenCV function to enable processing images with more than 4 channels.
176
+
177
+ Limitations:
178
+ This wrapper requires image to be the first argument and rest must be sent via named arguments.
179
+
180
+ Args:
181
+ process_fn: Transform function (e.g cv2.resize).
182
+ kwargs: Additional parameters.
183
+
184
+ Returns:
185
+ numpy.ndarray: Transformed image.
186
+
187
+ """
188
+
189
+ @wraps(process_fn)
190
+ def __process_fn(img: np.ndarray) -> np.ndarray:
191
+ num_channels = get_num_channels(img)
192
+ if num_channels > 4:
193
+ chunks = []
194
+ for index in range(0, num_channels, 4):
195
+ if num_channels - index == 2:
196
+ # Many OpenCV functions cannot work with 2-channel images
197
+ for i in range(2):
198
+ chunk = img[:, :, index + i : index + i + 1]
199
+ chunk = process_fn(chunk, **kwargs)
200
+ chunk = np.expand_dims(chunk, -1)
201
+ chunks.append(chunk)
202
+ else:
203
+ chunk = img[:, :, index : index + 4]
204
+ chunk = process_fn(chunk, **kwargs)
205
+ chunks.append(chunk)
206
+ img = np.dstack(chunks)
207
+ else:
208
+ img = process_fn(img, **kwargs)
209
+ return img
210
+
211
+ return __process_fn
custom_albumentations/core/__init__.py ADDED
File without changes
custom_albumentations/core/bbox_utils.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, TypeVar, cast
4
+
5
+ import numpy as np
6
+
7
+ from .transforms_interface import BoxInternalType, BoxType
8
+ from .utils import DataProcessor, Params
9
+
10
+ __all__ = [
11
+ "normalize_bbox",
12
+ "denormalize_bbox",
13
+ "normalize_bboxes",
14
+ "denormalize_bboxes",
15
+ "calculate_bbox_area",
16
+ "filter_bboxes_by_visibility",
17
+ "convert_bbox_to_albumentations",
18
+ "convert_bbox_from_albumentations",
19
+ "convert_bboxes_to_albumentations",
20
+ "convert_bboxes_from_albumentations",
21
+ "check_bbox",
22
+ "check_bboxes",
23
+ "filter_bboxes",
24
+ "union_of_bboxes",
25
+ "BboxProcessor",
26
+ "BboxParams",
27
+ ]
28
+
29
+ TBox = TypeVar("TBox", BoxType, BoxInternalType)
30
+
31
+
32
+ class BboxParams(Params):
33
+ """
34
+ Parameters of bounding boxes
35
+
36
+ Args:
37
+ format (str): format of bounding boxes. Should be 'coco', 'pascal_voc', 'albumentations' or 'yolo'.
38
+
39
+ The `coco` format
40
+ `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
41
+ The `pascal_voc` format
42
+ `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212].
43
+ The `albumentations` format
44
+ is like `pascal_voc`, but normalized,
45
+ in other words: `[x_min, y_min, x_max, y_max]`, e.g. [0.2, 0.3, 0.4, 0.5].
46
+ The `yolo` format
47
+ `[x, y, width, height]`, e.g. [0.1, 0.2, 0.3, 0.4];
48
+ `x`, `y` - normalized bbox center; `width`, `height` - normalized bbox width and height.
49
+ label_fields (list): list of fields that are joined with boxes, e.g labels.
50
+ Should be same type as boxes.
51
+ min_area (float): minimum area of a bounding box. All bounding boxes whose
52
+ visible area in pixels is less than this value will be removed. Default: 0.0.
53
+ min_visibility (float): minimum fraction of area for a bounding box
54
+ to remain this box in list. Default: 0.0.
55
+ min_width (float): Minimum width of a bounding box. All bounding boxes whose width is
56
+ less than this value will be removed. Default: 0.0.
57
+ min_height (float): Minimum height of a bounding box. All bounding boxes whose height is
58
+ less than this value will be removed. Default: 0.0.
59
+ check_each_transform (bool): if `True`, then bboxes will be checked after each dual transform.
60
+ Default: `True`
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ format: str,
66
+ label_fields: Optional[Sequence[str]] = None,
67
+ min_area: float = 0.0,
68
+ min_visibility: float = 0.0,
69
+ min_width: float = 0.0,
70
+ min_height: float = 0.0,
71
+ check_each_transform: bool = True,
72
+ ):
73
+ super(BboxParams, self).__init__(format, label_fields)
74
+ self.min_area = min_area
75
+ self.min_visibility = min_visibility
76
+ self.min_width = min_width
77
+ self.min_height = min_height
78
+ self.check_each_transform = check_each_transform
79
+
80
+ def _to_dict(self) -> Dict[str, Any]:
81
+ data = super(BboxParams, self)._to_dict()
82
+ data.update(
83
+ {
84
+ "min_area": self.min_area,
85
+ "min_visibility": self.min_visibility,
86
+ "min_width": self.min_width,
87
+ "min_height": self.min_height,
88
+ "check_each_transform": self.check_each_transform,
89
+ }
90
+ )
91
+ return data
92
+
93
+ @classmethod
94
+ def is_serializable(cls) -> bool:
95
+ return True
96
+
97
+ @classmethod
98
+ def get_class_fullname(cls) -> str:
99
+ return "BboxParams"
100
+
101
+
102
+ class BboxProcessor(DataProcessor):
103
+ def __init__(self, params: BboxParams, additional_targets: Optional[Dict[str, str]] = None):
104
+ super().__init__(params, additional_targets)
105
+
106
+ @property
107
+ def default_data_name(self) -> str:
108
+ return "bboxes"
109
+
110
+ def ensure_data_valid(self, data: Dict[str, Any]) -> None:
111
+ for data_name in self.data_fields:
112
+ data_exists = data_name in data and len(data[data_name])
113
+ if data_exists and len(data[data_name][0]) < 5:
114
+ if self.params.label_fields is None:
115
+ raise ValueError(
116
+ "Please specify 'label_fields' in 'bbox_params' or add labels to the end of bbox "
117
+ "because bboxes must have labels"
118
+ )
119
+ if self.params.label_fields:
120
+ if not all(i in data.keys() for i in self.params.label_fields):
121
+ raise ValueError("Your 'label_fields' are not valid - them must have same names as params in dict")
122
+
123
+ def filter(self, data: Sequence, rows: int, cols: int) -> List:
124
+ self.params: BboxParams
125
+ return filter_bboxes(
126
+ data,
127
+ rows,
128
+ cols,
129
+ min_area=self.params.min_area,
130
+ min_visibility=self.params.min_visibility,
131
+ min_width=self.params.min_width,
132
+ min_height=self.params.min_height,
133
+ )
134
+
135
+ def check(self, data: Sequence, rows: int, cols: int) -> None:
136
+ check_bboxes(data)
137
+
138
+ def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> List[BoxType]:
139
+ return convert_bboxes_from_albumentations(data, self.params.format, rows, cols, check_validity=True)
140
+
141
+ def convert_to_albumentations(self, data: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
142
+ return convert_bboxes_to_albumentations(data, self.params.format, rows, cols, check_validity=True)
143
+
144
+
145
+ def normalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox:
146
+ """Normalize coordinates of a bounding box. Divide x-coordinates by image width and y-coordinates
147
+ by image height.
148
+
149
+ Args:
150
+ bbox: Denormalized bounding box `(x_min, y_min, x_max, y_max)`.
151
+ rows: Image height.
152
+ cols: Image width.
153
+
154
+ Returns:
155
+ Normalized bounding box `(x_min, y_min, x_max, y_max)`.
156
+
157
+ Raises:
158
+ ValueError: If rows or cols is less or equal zero
159
+
160
+ """
161
+
162
+ if rows <= 0:
163
+ raise ValueError("Argument rows must be positive integer")
164
+ if cols <= 0:
165
+ raise ValueError("Argument cols must be positive integer")
166
+
167
+ tail: Tuple[Any, ...]
168
+ (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
169
+
170
+ x_min, x_max = x_min / cols, x_max / cols
171
+ y_min, y_max = y_min / rows, y_max / rows
172
+
173
+ return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore
174
+
175
+
176
+ def denormalize_bbox(bbox: TBox, rows: int, cols: int) -> TBox:
177
+ """Denormalize coordinates of a bounding box. Multiply x-coordinates by image width and y-coordinates
178
+ by image height. This is an inverse operation for :func:`~albumentations.augmentations.bbox.normalize_bbox`.
179
+
180
+ Args:
181
+ bbox: Normalized bounding box `(x_min, y_min, x_max, y_max)`.
182
+ rows: Image height.
183
+ cols: Image width.
184
+
185
+ Returns:
186
+ Denormalized bounding box `(x_min, y_min, x_max, y_max)`.
187
+
188
+ Raises:
189
+ ValueError: If rows or cols is less or equal zero
190
+
191
+ """
192
+ tail: Tuple[Any, ...]
193
+ (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
194
+
195
+ if rows <= 0:
196
+ raise ValueError("Argument rows must be positive integer")
197
+ if cols <= 0:
198
+ raise ValueError("Argument cols must be positive integer")
199
+
200
+ x_min, x_max = x_min * cols, x_max * cols
201
+ y_min, y_max = y_min * rows, y_max * rows
202
+
203
+ return cast(BoxType, (x_min, y_min, x_max, y_max) + tail) # type: ignore
204
+
205
+
206
+ def normalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
207
+ """Normalize a list of bounding boxes.
208
+
209
+ Args:
210
+ bboxes: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
211
+ rows: Image height.
212
+ cols: Image width.
213
+
214
+ Returns:
215
+ Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
216
+
217
+ """
218
+ return [normalize_bbox(bbox, rows, cols) for bbox in bboxes]
219
+
220
+
221
+ def denormalize_bboxes(bboxes: Sequence[BoxType], rows: int, cols: int) -> List[BoxType]:
222
+ """Denormalize a list of bounding boxes.
223
+
224
+ Args:
225
+ bboxes: Normalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
226
+ rows: Image height.
227
+ cols: Image width.
228
+
229
+ Returns:
230
+ List: Denormalized bounding boxes `[(x_min, y_min, x_max, y_max)]`.
231
+
232
+ """
233
+ return [denormalize_bbox(bbox, rows, cols) for bbox in bboxes]
234
+
235
+
236
+ def calculate_bbox_area(bbox: BoxType, rows: int, cols: int) -> float:
237
+ """Calculate the area of a bounding box in (fractional) pixels.
238
+
239
+ Args:
240
+ bbox: A bounding box `(x_min, y_min, x_max, y_max)`.
241
+ rows: Image height.
242
+ cols: Image width.
243
+
244
+ Return:
245
+ Area in (fractional) pixels of the (denormalized) bounding box.
246
+
247
+ """
248
+ bbox = denormalize_bbox(bbox, rows, cols)
249
+ x_min, y_min, x_max, y_max = bbox[:4]
250
+ area = (x_max - x_min) * (y_max - y_min)
251
+ return area
252
+
253
+
254
+ def filter_bboxes_by_visibility(
255
+ original_shape: Sequence[int],
256
+ bboxes: Sequence[BoxType],
257
+ transformed_shape: Sequence[int],
258
+ transformed_bboxes: Sequence[BoxType],
259
+ threshold: float = 0.0,
260
+ min_area: float = 0.0,
261
+ ) -> List[BoxType]:
262
+ """Filter bounding boxes and return only those boxes whose visibility after transformation is above
263
+ the threshold and minimal area of bounding box in pixels is more then min_area.
264
+
265
+ Args:
266
+ original_shape: Original image shape `(height, width, ...)`.
267
+ bboxes: Original bounding boxes `[(x_min, y_min, x_max, y_max)]`.
268
+ transformed_shape: Transformed image shape `(height, width)`.
269
+ transformed_bboxes: Transformed bounding boxes `[(x_min, y_min, x_max, y_max)]`.
270
+ threshold: visibility threshold. Should be a value in the range [0.0, 1.0].
271
+ min_area: Minimal area threshold.
272
+
273
+ Returns:
274
+ Filtered bounding boxes `[(x_min, y_min, x_max, y_max)]`.
275
+
276
+ """
277
+ img_height, img_width = original_shape[:2]
278
+ transformed_img_height, transformed_img_width = transformed_shape[:2]
279
+
280
+ visible_bboxes = []
281
+ for bbox, transformed_bbox in zip(bboxes, transformed_bboxes):
282
+ if not all(0.0 <= value <= 1.0 for value in transformed_bbox[:4]):
283
+ continue
284
+ bbox_area = calculate_bbox_area(bbox, img_height, img_width)
285
+ transformed_bbox_area = calculate_bbox_area(transformed_bbox, transformed_img_height, transformed_img_width)
286
+ if transformed_bbox_area < min_area:
287
+ continue
288
+ visibility = transformed_bbox_area / bbox_area
289
+ if visibility >= threshold:
290
+ visible_bboxes.append(transformed_bbox)
291
+ return visible_bboxes
292
+
293
+
294
+ def convert_bbox_to_albumentations(
295
+ bbox: BoxType, source_format: str, rows: int, cols: int, check_validity: bool = False
296
+ ) -> BoxType:
297
+ """Convert a bounding box from a format specified in `source_format` to the format used by albumentations:
298
+ normalized coordinates of top-left and bottom-right corners of the bounding box in a form of
299
+ `(x_min, y_min, x_max, y_max)` e.g. `(0.15, 0.27, 0.67, 0.5)`.
300
+
301
+ Args:
302
+ bbox: A bounding box tuple.
303
+ source_format: format of the bounding box. Should be 'coco', 'pascal_voc', or 'yolo'.
304
+ check_validity: Check if all boxes are valid boxes.
305
+ rows: Image height.
306
+ cols: Image width.
307
+
308
+ Returns:
309
+ tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
310
+
311
+ Note:
312
+ The `coco` format of a bounding box looks like `(x_min, y_min, width, height)`, e.g. (97, 12, 150, 200).
313
+ The `pascal_voc` format of a bounding box looks like `(x_min, y_min, x_max, y_max)`, e.g. (97, 12, 247, 212).
314
+ The `yolo` format of a bounding box looks like `(x, y, width, height)`, e.g. (0.3, 0.1, 0.05, 0.07);
315
+ where `x`, `y` coordinates of the center of the box, all values normalized to 1 by image height and width.
316
+
317
+ Raises:
318
+ ValueError: if `target_format` is not equal to `coco` or `pascal_voc`, or `yolo`.
319
+ ValueError: If in YOLO format all labels not in range (0, 1).
320
+
321
+ """
322
+ if source_format not in {"coco", "pascal_voc", "yolo"}:
323
+ raise ValueError(
324
+ f"Unknown source_format {source_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'"
325
+ )
326
+
327
+ if source_format == "coco":
328
+ (x_min, y_min, width, height), tail = bbox[:4], bbox[4:]
329
+ x_max = x_min + width
330
+ y_max = y_min + height
331
+ elif source_format == "yolo":
332
+ # https://github.com/pjreddie/darknet/blob/f6d861736038da22c9eb0739dca84003c5a5e275/scripts/voc_label.py#L12
333
+ _bbox = np.array(bbox[:4])
334
+ if check_validity and np.any((_bbox <= 0) | (_bbox > 1)):
335
+ raise ValueError("In YOLO format all coordinates must be float and in range (0, 1]")
336
+
337
+ (x, y, w, h), tail = bbox[:4], bbox[4:]
338
+
339
+ w_half, h_half = w / 2, h / 2
340
+ x_min = x - w_half
341
+ y_min = y - h_half
342
+ x_max = x_min + w
343
+ y_max = y_min + h
344
+ else:
345
+ (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:]
346
+
347
+ bbox = (x_min, y_min, x_max, y_max) + tuple(tail) # type: ignore
348
+
349
+ if source_format != "yolo":
350
+ bbox = normalize_bbox(bbox, rows, cols)
351
+ if check_validity:
352
+ check_bbox(bbox)
353
+ return bbox
354
+
355
+
356
+ def convert_bbox_from_albumentations(
357
+ bbox: BoxType, target_format: str, rows: int, cols: int, check_validity: bool = False
358
+ ) -> BoxType:
359
+ """Convert a bounding box from the format used by albumentations to a format, specified in `target_format`.
360
+
361
+ Args:
362
+ bbox: An albumentations bounding box `(x_min, y_min, x_max, y_max)`.
363
+ target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'.
364
+ rows: Image height.
365
+ cols: Image width.
366
+ check_validity: Check if all boxes are valid boxes.
367
+
368
+ Returns:
369
+ tuple: A bounding box.
370
+
371
+ Note:
372
+ The `coco` format of a bounding box looks like `[x_min, y_min, width, height]`, e.g. [97, 12, 150, 200].
373
+ The `pascal_voc` format of a bounding box looks like `[x_min, y_min, x_max, y_max]`, e.g. [97, 12, 247, 212].
374
+ The `yolo` format of a bounding box looks like `[x, y, width, height]`, e.g. [0.3, 0.1, 0.05, 0.07].
375
+
376
+ Raises:
377
+ ValueError: if `target_format` is not equal to `coco`, `pascal_voc` or `yolo`.
378
+
379
+ """
380
+ if target_format not in {"coco", "pascal_voc", "yolo"}:
381
+ raise ValueError(
382
+ f"Unknown target_format {target_format}. Supported formats are: 'coco', 'pascal_voc' and 'yolo'"
383
+ )
384
+ if check_validity:
385
+ check_bbox(bbox)
386
+
387
+ if target_format != "yolo":
388
+ bbox = denormalize_bbox(bbox, rows, cols)
389
+ if target_format == "coco":
390
+ (x_min, y_min, x_max, y_max), tail = bbox[:4], tuple(bbox[4:])
391
+ width = x_max - x_min
392
+ height = y_max - y_min
393
+ bbox = cast(BoxType, (x_min, y_min, width, height) + tail)
394
+ elif target_format == "yolo":
395
+ (x_min, y_min, x_max, y_max), tail = bbox[:4], bbox[4:]
396
+ x = (x_min + x_max) / 2.0
397
+ y = (y_min + y_max) / 2.0
398
+ w = x_max - x_min
399
+ h = y_max - y_min
400
+ bbox = cast(BoxType, (x, y, w, h) + tail)
401
+ return bbox
402
+
403
+
404
+ def convert_bboxes_to_albumentations(
405
+ bboxes: Sequence[BoxType], source_format, rows, cols, check_validity=False
406
+ ) -> List[BoxType]:
407
+ """Convert a list bounding boxes from a format specified in `source_format` to the format used by albumentations"""
408
+ return [convert_bbox_to_albumentations(bbox, source_format, rows, cols, check_validity) for bbox in bboxes]
409
+
410
+
411
+ def convert_bboxes_from_albumentations(
412
+ bboxes: Sequence[BoxType], target_format: str, rows: int, cols: int, check_validity: bool = False
413
+ ) -> List[BoxType]:
414
+ """Convert a list of bounding boxes from the format used by albumentations to a format, specified
415
+ in `target_format`.
416
+
417
+ Args:
418
+ bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
419
+ target_format: required format of the output bounding box. Should be 'coco', 'pascal_voc' or 'yolo'.
420
+ rows: Image height.
421
+ cols: Image width.
422
+ check_validity: Check if all boxes are valid boxes.
423
+
424
+ Returns:
425
+ List of bounding boxes.
426
+
427
+ """
428
+ return [convert_bbox_from_albumentations(bbox, target_format, rows, cols, check_validity) for bbox in bboxes]
429
+
430
+
431
+ def check_bbox(bbox: BoxType) -> None:
432
+ """Check if bbox boundaries are in range 0, 1 and minimums are lesser then maximums"""
433
+ for name, value in zip(["x_min", "y_min", "x_max", "y_max"], bbox[:4]):
434
+ if not 0 <= value <= 1 and not np.isclose(value, 0) and not np.isclose(value, 1):
435
+ raise ValueError(f"Expected {name} for bbox {bbox} to be in the range [0.0, 1.0], got {value}.")
436
+ x_min, y_min, x_max, y_max = bbox[:4]
437
+ if x_max <= x_min:
438
+ raise ValueError(f"x_max is less than or equal to x_min for bbox {bbox}.")
439
+ if y_max <= y_min:
440
+ raise ValueError(f"y_max is less than or equal to y_min for bbox {bbox}.")
441
+
442
+
443
+ def check_bboxes(bboxes: Sequence[BoxType]) -> None:
444
+ """Check if bboxes boundaries are in range 0, 1 and minimums are lesser then maximums"""
445
+ for bbox in bboxes:
446
+ check_bbox(bbox)
447
+
448
+
449
+ def filter_bboxes(
450
+ bboxes: Sequence[BoxType],
451
+ rows: int,
452
+ cols: int,
453
+ min_area: float = 0.0,
454
+ min_visibility: float = 0.0,
455
+ min_width: float = 0.0,
456
+ min_height: float = 0.0,
457
+ ) -> List[BoxType]:
458
+ """Remove bounding boxes that either lie outside of the visible area by more then min_visibility
459
+ or whose area in pixels is under the threshold set by `min_area`. Also it crops boxes to final image size.
460
+
461
+ Args:
462
+ bboxes: List of albumentation bounding box `(x_min, y_min, x_max, y_max)`.
463
+ rows: Image height.
464
+ cols: Image width.
465
+ min_area: Minimum area of a bounding box. All bounding boxes whose visible area in pixels.
466
+ is less than this value will be removed. Default: 0.0.
467
+ min_visibility: Minimum fraction of area for a bounding box to remain this box in list. Default: 0.0.
468
+ min_width: Minimum width of a bounding box. All bounding boxes whose width is
469
+ less than this value will be removed. Default: 0.0.
470
+ min_height: Minimum height of a bounding box. All bounding boxes whose height is
471
+ less than this value will be removed. Default: 0.0.
472
+
473
+ Returns:
474
+ List of bounding boxes.
475
+
476
+ """
477
+ resulting_boxes: List[BoxType] = []
478
+ for bbox in bboxes:
479
+ # Calculate areas of bounding box before and after clipping.
480
+ transformed_box_area = calculate_bbox_area(bbox, rows, cols)
481
+ bbox, tail = cast(BoxType, tuple(np.clip(bbox[:4], 0, 1.0))), tuple(bbox[4:])
482
+ clipped_box_area = calculate_bbox_area(bbox, rows, cols)
483
+
484
+ # Calculate width and height of the clipped bounding box.
485
+ x_min, y_min, x_max, y_max = denormalize_bbox(bbox, rows, cols)[:4]
486
+ clipped_width, clipped_height = x_max - x_min, y_max - y_min
487
+
488
+ if (
489
+ clipped_box_area != 0 # to ensure transformed_box_area!=0 and to handle min_area=0 or min_visibility=0
490
+ and clipped_box_area >= min_area
491
+ and clipped_box_area / transformed_box_area >= min_visibility
492
+ and clipped_width >= min_width
493
+ and clipped_height >= min_height
494
+ ):
495
+ resulting_boxes.append(cast(BoxType, bbox + tail))
496
+ return resulting_boxes
497
+
498
+
499
+ def union_of_bboxes(height: int, width: int, bboxes: Sequence[BoxType], erosion_rate: float = 0.0) -> BoxType:
500
+ """Calculate union of bounding boxes.
501
+
502
+ Args:
503
+ height (float): Height of image or space.
504
+ width (float): Width of image or space.
505
+ bboxes (List[tuple]): List like bounding boxes. Format is `[(x_min, y_min, x_max, y_max)]`.
506
+ erosion_rate (float): How much each bounding box can be shrinked, useful for erosive cropping.
507
+ Set this in range [0, 1]. 0 will not be erosive at all, 1.0 can make any bbox to lose its volume.
508
+
509
+ Returns:
510
+ tuple: A bounding box `(x_min, y_min, x_max, y_max)`.
511
+
512
+ """
513
+ x1, y1 = width, height
514
+ x2, y2 = 0, 0
515
+ for bbox in bboxes:
516
+ x_min, y_min, x_max, y_max = bbox[:4]
517
+ w, h = x_max - x_min, y_max - y_min
518
+ lim_x1, lim_y1 = x_min + erosion_rate * w, y_min + erosion_rate * h
519
+ lim_x2, lim_y2 = x_max - erosion_rate * w, y_max - erosion_rate * h
520
+ x1, y1 = np.min([x1, lim_x1]), np.min([y1, lim_y1])
521
+ x2, y2 = np.max([x2, lim_x2]), np.max([y2, lim_y2])
522
+ return x1, y1, x2, y2
custom_albumentations/core/composition.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import random
4
+ import typing
5
+ import warnings
6
+ from collections import defaultdict
7
+
8
+ import numpy as np
9
+
10
+ from .. import random_utils
11
+ from .bbox_utils import BboxParams, BboxProcessor
12
+ from .keypoints_utils import KeypointParams, KeypointsProcessor
13
+ from .serialization import (
14
+ SERIALIZABLE_REGISTRY,
15
+ Serializable,
16
+ get_shortest_class_fullname,
17
+ instantiate_nonserializable,
18
+ )
19
+ from .transforms_interface import BasicTransform
20
+ from .utils import format_args, get_shape
21
+
22
+ __all__ = [
23
+ "BaseCompose",
24
+ "Compose",
25
+ "SomeOf",
26
+ "OneOf",
27
+ "OneOrOther",
28
+ "BboxParams",
29
+ "KeypointParams",
30
+ "ReplayCompose",
31
+ "Sequential",
32
+ ]
33
+
34
+
35
+ REPR_INDENT_STEP = 2
36
+ TransformType = typing.Union[BasicTransform, "BaseCompose"]
37
+ TransformsSeqType = typing.Sequence[TransformType]
38
+
39
+
40
+ def get_always_apply(transforms: typing.Union["BaseCompose", TransformsSeqType]) -> TransformsSeqType:
41
+ new_transforms: typing.List[TransformType] = []
42
+ for transform in transforms: # type: ignore
43
+ if isinstance(transform, BaseCompose):
44
+ new_transforms.extend(get_always_apply(transform))
45
+ elif transform.always_apply:
46
+ new_transforms.append(transform)
47
+ return new_transforms
48
+
49
+
50
+ class BaseCompose(Serializable):
51
+ def __init__(self, transforms: TransformsSeqType, p: float):
52
+ if isinstance(transforms, (BaseCompose, BasicTransform)):
53
+ warnings.warn(
54
+ "transforms is single transform, but a sequence is expected! Transform will be wrapped into list."
55
+ )
56
+ transforms = [transforms]
57
+
58
+ self.transforms = transforms
59
+ self.p = p
60
+
61
+ self.replay_mode = False
62
+ self.applied_in_replay = False
63
+
64
+ def __len__(self) -> int:
65
+ return len(self.transforms)
66
+
67
+ def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]:
68
+ raise NotImplementedError
69
+
70
+ def __getitem__(self, item: int) -> TransformType: # type: ignore
71
+ return self.transforms[item]
72
+
73
+ def __repr__(self) -> str:
74
+ return self.indented_repr()
75
+
76
+ def indented_repr(self, indent: int = REPR_INDENT_STEP) -> str:
77
+ args = {k: v for k, v in self._to_dict().items() if not (k.startswith("__") or k == "transforms")}
78
+ repr_string = self.__class__.__name__ + "(["
79
+ for t in self.transforms:
80
+ repr_string += "\n"
81
+ if hasattr(t, "indented_repr"):
82
+ t_repr = t.indented_repr(indent + REPR_INDENT_STEP) # type: ignore
83
+ else:
84
+ t_repr = repr(t)
85
+ repr_string += " " * indent + t_repr + ","
86
+ repr_string += "\n" + " " * (indent - REPR_INDENT_STEP) + "], {args})".format(args=format_args(args))
87
+ return repr_string
88
+
89
+ @classmethod
90
+ def get_class_fullname(cls) -> str:
91
+ return get_shortest_class_fullname(cls)
92
+
93
+ @classmethod
94
+ def is_serializable(cls) -> bool:
95
+ return True
96
+
97
+ def _to_dict(self) -> typing.Dict[str, typing.Any]:
98
+ return {
99
+ "__class_fullname__": self.get_class_fullname(),
100
+ "p": self.p,
101
+ "transforms": [t._to_dict() for t in self.transforms], # skipcq: PYL-W0212
102
+ }
103
+
104
+ def get_dict_with_id(self) -> typing.Dict[str, typing.Any]:
105
+ return {
106
+ "__class_fullname__": self.get_class_fullname(),
107
+ "id": id(self),
108
+ "params": None,
109
+ "transforms": [t.get_dict_with_id() for t in self.transforms],
110
+ }
111
+
112
+ def add_targets(self, additional_targets: typing.Optional[typing.Dict[str, str]]) -> None:
113
+ if additional_targets:
114
+ for t in self.transforms:
115
+ t.add_targets(additional_targets)
116
+
117
+ def set_deterministic(self, flag: bool, save_key: str = "replay") -> None:
118
+ for t in self.transforms:
119
+ t.set_deterministic(flag, save_key)
120
+
121
+
122
+ class Compose(BaseCompose):
123
+ """Compose transforms and handle all transformations regarding bounding boxes
124
+
125
+ Args:
126
+ transforms (list): list of transformations to compose.
127
+ bbox_params (BboxParams): Parameters for bounding boxes transforms
128
+ keypoint_params (KeypointParams): Parameters for keypoints transforms
129
+ additional_targets (dict): Dict with keys - new target name, values - old target name. ex: {'image2': 'image'}
130
+ p (float): probability of applying all list of transforms. Default: 1.0.
131
+ is_check_shapes (bool): If True shapes consistency of images/mask/masks would be checked on each call. If you
132
+ would like to disable this check - pass False (do it only if you are sure in your data consistency).
133
+ """
134
+
135
+ def __init__(
136
+ self,
137
+ transforms: TransformsSeqType,
138
+ bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None,
139
+ keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
140
+ additional_targets: typing.Optional[typing.Dict[str, str]] = None,
141
+ p: float = 1.0,
142
+ is_check_shapes: bool = True,
143
+ ):
144
+ super(Compose, self).__init__(transforms, p)
145
+
146
+ self.processors: typing.Dict[str, typing.Union[BboxProcessor, KeypointsProcessor]] = {}
147
+ if bbox_params:
148
+ if isinstance(bbox_params, dict):
149
+ b_params = BboxParams(**bbox_params)
150
+ elif isinstance(bbox_params, BboxParams):
151
+ b_params = bbox_params
152
+ else:
153
+ raise ValueError("unknown format of bbox_params, please use `dict` or `BboxParams`")
154
+ self.processors["bboxes"] = BboxProcessor(b_params, additional_targets)
155
+
156
+ if keypoint_params:
157
+ if isinstance(keypoint_params, dict):
158
+ k_params = KeypointParams(**keypoint_params)
159
+ elif isinstance(keypoint_params, KeypointParams):
160
+ k_params = keypoint_params
161
+ else:
162
+ raise ValueError("unknown format of keypoint_params, please use `dict` or `KeypointParams`")
163
+ self.processors["keypoints"] = KeypointsProcessor(k_params, additional_targets)
164
+
165
+ if additional_targets is None:
166
+ additional_targets = {}
167
+
168
+ self.additional_targets = additional_targets
169
+
170
+ for proc in self.processors.values():
171
+ proc.ensure_transforms_valid(self.transforms)
172
+
173
+ self.add_targets(additional_targets)
174
+
175
+ self.is_check_args = True
176
+ self._disable_check_args_for_transforms(self.transforms)
177
+
178
+ self.is_check_shapes = is_check_shapes
179
+
180
+ @staticmethod
181
+ def _disable_check_args_for_transforms(transforms: TransformsSeqType) -> None:
182
+ for transform in transforms:
183
+ if isinstance(transform, BaseCompose):
184
+ Compose._disable_check_args_for_transforms(transform.transforms)
185
+ if isinstance(transform, Compose):
186
+ transform._disable_check_args()
187
+
188
+ def _disable_check_args(self) -> None:
189
+ self.is_check_args = False
190
+
191
+ def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
192
+ if args:
193
+ raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
194
+ if self.is_check_args:
195
+ self._check_args(**data)
196
+ assert isinstance(force_apply, (bool, int)), "force_apply must have bool or int type"
197
+ need_to_run = force_apply or random.random() < self.p
198
+ for p in self.processors.values():
199
+ p.ensure_data_valid(data)
200
+ transforms = self.transforms if need_to_run else get_always_apply(self.transforms)
201
+
202
+ check_each_transform = any(
203
+ getattr(item.params, "check_each_transform", False) for item in self.processors.values()
204
+ )
205
+
206
+ for p in self.processors.values():
207
+ p.preprocess(data)
208
+
209
+ for idx, t in enumerate(transforms):
210
+ data = t(**data)
211
+
212
+ if check_each_transform:
213
+ data = self._check_data_post_transform(data)
214
+ data = Compose._make_targets_contiguous(data) # ensure output targets are contiguous
215
+
216
+ for p in self.processors.values():
217
+ p.postprocess(data)
218
+
219
+ return data
220
+
221
+ def _check_data_post_transform(self, data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
222
+ rows, cols = get_shape(data["image"])
223
+
224
+ for p in self.processors.values():
225
+ if not getattr(p.params, "check_each_transform", False):
226
+ continue
227
+
228
+ for data_name in p.data_fields:
229
+ data[data_name] = p.filter(data[data_name], rows, cols)
230
+ return data
231
+
232
+ def _to_dict(self) -> typing.Dict[str, typing.Any]:
233
+ dictionary = super(Compose, self)._to_dict()
234
+ bbox_processor = self.processors.get("bboxes")
235
+ keypoints_processor = self.processors.get("keypoints")
236
+ dictionary.update(
237
+ {
238
+ "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212
239
+ "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212
240
+ if keypoints_processor
241
+ else None,
242
+ "additional_targets": self.additional_targets,
243
+ "is_check_shapes": self.is_check_shapes,
244
+ }
245
+ )
246
+ return dictionary
247
+
248
+ def get_dict_with_id(self) -> typing.Dict[str, typing.Any]:
249
+ dictionary = super().get_dict_with_id()
250
+ bbox_processor = self.processors.get("bboxes")
251
+ keypoints_processor = self.processors.get("keypoints")
252
+ dictionary.update(
253
+ {
254
+ "bbox_params": bbox_processor.params._to_dict() if bbox_processor else None, # skipcq: PYL-W0212
255
+ "keypoint_params": keypoints_processor.params._to_dict() # skipcq: PYL-W0212
256
+ if keypoints_processor
257
+ else None,
258
+ "additional_targets": self.additional_targets,
259
+ "params": None,
260
+ "is_check_shapes": self.is_check_shapes,
261
+ }
262
+ )
263
+ return dictionary
264
+
265
+ def _check_args(self, **kwargs) -> None:
266
+ checked_single = ["image", "mask"]
267
+ checked_multi = ["masks"]
268
+ check_bbox_param = ["bboxes"]
269
+ # ["bboxes", "keypoints"] could be almost any type, no need to check them
270
+ shapes = []
271
+ for data_name, data in kwargs.items():
272
+ internal_data_name = self.additional_targets.get(data_name, data_name)
273
+ if internal_data_name in checked_single:
274
+ if not isinstance(data, np.ndarray):
275
+ raise TypeError("{} must be numpy array type".format(data_name))
276
+ shapes.append(data.shape[:2])
277
+ if internal_data_name in checked_multi:
278
+ if data is not None and len(data):
279
+ if not isinstance(data[0], np.ndarray):
280
+ raise TypeError("{} must be list of numpy arrays".format(data_name))
281
+ shapes.append(data[0].shape[:2])
282
+ if internal_data_name in check_bbox_param and self.processors.get("bboxes") is None:
283
+ raise ValueError("bbox_params must be specified for bbox transformations")
284
+
285
+ if self.is_check_shapes and shapes and shapes.count(shapes[0]) != len(shapes):
286
+ raise ValueError(
287
+ "Height and Width of image, mask or masks should be equal. You can disable shapes check "
288
+ "by setting a parameter is_check_shapes=False of Compose class (do it only if you are sure "
289
+ "about your data consistency)."
290
+ )
291
+
292
+ @staticmethod
293
+ def _make_targets_contiguous(data: typing.Dict[str, typing.Any]) -> typing.Dict[str, typing.Any]:
294
+ result = {}
295
+ for key, value in data.items():
296
+ if isinstance(value, np.ndarray):
297
+ value = np.ascontiguousarray(value)
298
+ result[key] = value
299
+ return result
300
+
301
+
302
+ class OneOf(BaseCompose):
303
+ """Select one of transforms to apply. Selected transform will be called with `force_apply=True`.
304
+ Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
305
+
306
+ Args:
307
+ transforms (list): list of transformations to compose.
308
+ p (float): probability of applying selected transform. Default: 0.5.
309
+ """
310
+
311
+ def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
312
+ super(OneOf, self).__init__(transforms, p)
313
+ transforms_ps = [t.p for t in self.transforms]
314
+ s = sum(transforms_ps)
315
+ self.transforms_ps = [t / s for t in transforms_ps]
316
+
317
+ def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
318
+ if self.replay_mode:
319
+ for t in self.transforms:
320
+ data = t(**data)
321
+ return data
322
+
323
+ if self.transforms_ps and (force_apply or random.random() < self.p):
324
+ idx: int = random_utils.choice(len(self.transforms), p=self.transforms_ps)
325
+ t = self.transforms[idx]
326
+ data = t(force_apply=True, **data)
327
+ return data
328
+
329
+
330
+ class SomeOf(BaseCompose):
331
+ """Select N transforms to apply. Selected transforms will be called with `force_apply=True`.
332
+ Transforms probabilities will be normalized to one 1, so in this case transforms probabilities works as weights.
333
+
334
+ Args:
335
+ transforms (list): list of transformations to compose.
336
+ n (int): number of transforms to apply.
337
+ replace (bool): Whether the sampled transforms are with or without replacement. Default: True.
338
+ p (float): probability of applying selected transform. Default: 1.
339
+ """
340
+
341
+ def __init__(self, transforms: TransformsSeqType, n: int, replace: bool = True, p: float = 1):
342
+ super(SomeOf, self).__init__(transforms, p)
343
+ self.n = n
344
+ self.replace = replace
345
+ transforms_ps = [t.p for t in self.transforms]
346
+ s = sum(transforms_ps)
347
+ self.transforms_ps = [t / s for t in transforms_ps]
348
+
349
+ def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
350
+ if self.replay_mode:
351
+ for t in self.transforms:
352
+ data = t(**data)
353
+ return data
354
+
355
+ if self.transforms_ps and (force_apply or random.random() < self.p):
356
+ idx = random_utils.choice(len(self.transforms), size=self.n, replace=self.replace, p=self.transforms_ps)
357
+ for i in idx: # type: ignore
358
+ t = self.transforms[i]
359
+ data = t(force_apply=True, **data)
360
+ return data
361
+
362
+ def _to_dict(self) -> typing.Dict[str, typing.Any]:
363
+ dictionary = super(SomeOf, self)._to_dict()
364
+ dictionary.update({"n": self.n, "replace": self.replace})
365
+ return dictionary
366
+
367
+
368
+ class OneOrOther(BaseCompose):
369
+ """Select one or another transform to apply. Selected transform will be called with `force_apply=True`."""
370
+
371
+ def __init__(
372
+ self,
373
+ first: typing.Optional[TransformType] = None,
374
+ second: typing.Optional[TransformType] = None,
375
+ transforms: typing.Optional[TransformsSeqType] = None,
376
+ p: float = 0.5,
377
+ ):
378
+ if transforms is None:
379
+ if first is None or second is None:
380
+ raise ValueError("You must set both first and second or set transforms argument.")
381
+ transforms = [first, second]
382
+ super(OneOrOther, self).__init__(transforms, p)
383
+ if len(self.transforms) != 2:
384
+ warnings.warn("Length of transforms is not equal to 2.")
385
+
386
+ def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
387
+ if self.replay_mode:
388
+ for t in self.transforms:
389
+ data = t(**data)
390
+ return data
391
+
392
+ if random.random() < self.p:
393
+ return self.transforms[0](force_apply=True, **data)
394
+
395
+ return self.transforms[-1](force_apply=True, **data)
396
+
397
+
398
+ class PerChannel(BaseCompose):
399
+ """Apply transformations per-channel
400
+
401
+ Args:
402
+ transforms (list): list of transformations to compose.
403
+ channels (sequence): channels to apply the transform to. Pass None to apply to all.
404
+ Default: None (apply to all)
405
+ p (float): probability of applying the transform. Default: 0.5.
406
+ """
407
+
408
+ def __init__(
409
+ self, transforms: TransformsSeqType, channels: typing.Optional[typing.Sequence[int]] = None, p: float = 0.5
410
+ ):
411
+ super(PerChannel, self).__init__(transforms, p)
412
+ self.channels = channels
413
+
414
+ def __call__(self, *args, force_apply: bool = False, **data) -> typing.Dict[str, typing.Any]:
415
+ if force_apply or random.random() < self.p:
416
+ image = data["image"]
417
+
418
+ # Expand mono images to have a single channel
419
+ if len(image.shape) == 2:
420
+ image = np.expand_dims(image, -1)
421
+
422
+ if self.channels is None:
423
+ self.channels = range(image.shape[2])
424
+
425
+ for c in self.channels:
426
+ for t in self.transforms:
427
+ image[:, :, c] = t(image=image[:, :, c])["image"]
428
+
429
+ data["image"] = image
430
+
431
+ return data
432
+
433
+
434
+ class ReplayCompose(Compose):
435
+ def __init__(
436
+ self,
437
+ transforms: TransformsSeqType,
438
+ bbox_params: typing.Optional[typing.Union[dict, "BboxParams"]] = None,
439
+ keypoint_params: typing.Optional[typing.Union[dict, "KeypointParams"]] = None,
440
+ additional_targets: typing.Optional[typing.Dict[str, str]] = None,
441
+ p: float = 1.0,
442
+ is_check_shapes: bool = True,
443
+ save_key: str = "replay",
444
+ ):
445
+ super(ReplayCompose, self).__init__(
446
+ transforms, bbox_params, keypoint_params, additional_targets, p, is_check_shapes
447
+ )
448
+ self.set_deterministic(True, save_key=save_key)
449
+ self.save_key = save_key
450
+
451
+ def __call__(self, *args, force_apply: bool = False, **kwargs) -> typing.Dict[str, typing.Any]:
452
+ kwargs[self.save_key] = defaultdict(dict)
453
+ result = super(ReplayCompose, self).__call__(force_apply=force_apply, **kwargs)
454
+ serialized = self.get_dict_with_id()
455
+ self.fill_with_params(serialized, result[self.save_key])
456
+ self.fill_applied(serialized)
457
+ result[self.save_key] = serialized
458
+ return result
459
+
460
+ @staticmethod
461
+ def replay(saved_augmentations: typing.Dict[str, typing.Any], **kwargs) -> typing.Dict[str, typing.Any]:
462
+ augs = ReplayCompose._restore_for_replay(saved_augmentations)
463
+ return augs(force_apply=True, **kwargs)
464
+
465
+ @staticmethod
466
+ def _restore_for_replay(
467
+ transform_dict: typing.Dict[str, typing.Any], lambda_transforms: typing.Optional[dict] = None
468
+ ) -> TransformType:
469
+ """
470
+ Args:
471
+ lambda_transforms (dict): A dictionary that contains lambda transforms, that
472
+ is instances of the Lambda class.
473
+ This dictionary is required when you are restoring a pipeline that contains lambda transforms. Keys
474
+ in that dictionary should be named same as `name` arguments in respective lambda transforms from
475
+ a serialized pipeline.
476
+ """
477
+ applied = transform_dict["applied"]
478
+ params = transform_dict["params"]
479
+ lmbd = instantiate_nonserializable(transform_dict, lambda_transforms)
480
+ if lmbd:
481
+ transform = lmbd
482
+ else:
483
+ name = transform_dict["__class_fullname__"]
484
+ args = {k: v for k, v in transform_dict.items() if k not in ["__class_fullname__", "applied", "params"]}
485
+ cls = SERIALIZABLE_REGISTRY[name]
486
+ if "transforms" in args:
487
+ args["transforms"] = [
488
+ ReplayCompose._restore_for_replay(t, lambda_transforms=lambda_transforms)
489
+ for t in args["transforms"]
490
+ ]
491
+ transform = cls(**args)
492
+
493
+ transform = typing.cast(BasicTransform, transform)
494
+ if isinstance(transform, BasicTransform):
495
+ transform.params = params
496
+ transform.replay_mode = True
497
+ transform.applied_in_replay = applied
498
+ return transform
499
+
500
+ def fill_with_params(self, serialized: dict, all_params: dict) -> None:
501
+ params = all_params.get(serialized.get("id"))
502
+ serialized["params"] = params
503
+ del serialized["id"]
504
+ for transform in serialized.get("transforms", []):
505
+ self.fill_with_params(transform, all_params)
506
+
507
+ def fill_applied(self, serialized: typing.Dict[str, typing.Any]) -> bool:
508
+ if "transforms" in serialized:
509
+ applied = [self.fill_applied(t) for t in serialized["transforms"]]
510
+ serialized["applied"] = any(applied)
511
+ else:
512
+ serialized["applied"] = serialized.get("params") is not None
513
+ return serialized["applied"]
514
+
515
+ def _to_dict(self) -> typing.Dict[str, typing.Any]:
516
+ dictionary = super(ReplayCompose, self)._to_dict()
517
+ dictionary.update({"save_key": self.save_key})
518
+ return dictionary
519
+
520
+
521
+ class Sequential(BaseCompose):
522
+ """Sequentially applies all transforms to targets.
523
+
524
+ Note:
525
+ This transform is not intended to be a replacement for `Compose`. Instead, it should be used inside `Compose`
526
+ the same way `OneOf` or `OneOrOther` are used. For instance, you can combine `OneOf` with `Sequential` to
527
+ create an augmentation pipeline that contains multiple sequences of augmentations and applies one randomly
528
+ chose sequence to input data (see the `Example` section for an example definition of such pipeline).
529
+
530
+ Example:
531
+ >>> import custom_albumentations as albumentations as A
532
+ >>> transform = A.Compose([
533
+ >>> A.OneOf([
534
+ >>> A.Sequential([
535
+ >>> A.HorizontalFlip(p=0.5),
536
+ >>> A.ShiftScaleRotate(p=0.5),
537
+ >>> ]),
538
+ >>> A.Sequential([
539
+ >>> A.VerticalFlip(p=0.5),
540
+ >>> A.RandomBrightnessContrast(p=0.5),
541
+ >>> ]),
542
+ >>> ], p=1)
543
+ >>> ])
544
+ """
545
+
546
+ def __init__(self, transforms: TransformsSeqType, p: float = 0.5):
547
+ super().__init__(transforms, p)
548
+
549
+ def __call__(self, *args, **data) -> typing.Dict[str, typing.Any]:
550
+ for t in self.transforms:
551
+ data = t(**data)
552
+ return data
custom_albumentations/core/keypoints_utils.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import math
4
+ import typing
5
+ import warnings
6
+ from typing import Any, Dict, List, Optional, Sequence, Tuple
7
+
8
+ from .utils import DataProcessor, Params
9
+
10
+ __all__ = [
11
+ "angle_to_2pi_range",
12
+ "check_keypoints",
13
+ "convert_keypoints_from_albumentations",
14
+ "convert_keypoints_to_albumentations",
15
+ "filter_keypoints",
16
+ "KeypointsProcessor",
17
+ "KeypointParams",
18
+ ]
19
+
20
+ keypoint_formats = {"xy", "yx", "xya", "xys", "xyas", "xysa"}
21
+
22
+
23
+ def angle_to_2pi_range(angle: float) -> float:
24
+ two_pi = 2 * math.pi
25
+ return angle % two_pi
26
+
27
+
28
+ class KeypointParams(Params):
29
+ """
30
+ Parameters of keypoints
31
+
32
+ Args:
33
+ format (str): format of keypoints. Should be 'xy', 'yx', 'xya', 'xys', 'xyas', 'xysa'.
34
+
35
+ x - X coordinate,
36
+
37
+ y - Y coordinate
38
+
39
+ s - Keypoint scale
40
+
41
+ a - Keypoint orientation in radians or degrees (depending on KeypointParams.angle_in_degrees)
42
+ label_fields (list): list of fields that are joined with keypoints, e.g labels.
43
+ Should be same type as keypoints.
44
+ remove_invisible (bool): to remove invisible points after transform or not
45
+ angle_in_degrees (bool): angle in degrees or radians in 'xya', 'xyas', 'xysa' keypoints
46
+ check_each_transform (bool): if `True`, then keypoints will be checked after each dual transform.
47
+ Default: `True`
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ format: str, # skipcq: PYL-W0622
53
+ label_fields: Optional[Sequence[str]] = None,
54
+ remove_invisible: bool = True,
55
+ angle_in_degrees: bool = True,
56
+ check_each_transform: bool = True,
57
+ ):
58
+ super(KeypointParams, self).__init__(format, label_fields)
59
+ self.remove_invisible = remove_invisible
60
+ self.angle_in_degrees = angle_in_degrees
61
+ self.check_each_transform = check_each_transform
62
+
63
+ def _to_dict(self) -> Dict[str, Any]:
64
+ data = super(KeypointParams, self)._to_dict()
65
+ data.update(
66
+ {
67
+ "remove_invisible": self.remove_invisible,
68
+ "angle_in_degrees": self.angle_in_degrees,
69
+ "check_each_transform": self.check_each_transform,
70
+ }
71
+ )
72
+ return data
73
+
74
+ @classmethod
75
+ def is_serializable(cls) -> bool:
76
+ return True
77
+
78
+ @classmethod
79
+ def get_class_fullname(cls) -> str:
80
+ return "KeypointParams"
81
+
82
+
83
+ class KeypointsProcessor(DataProcessor):
84
+ def __init__(self, params: KeypointParams, additional_targets: Optional[Dict[str, str]] = None):
85
+ super().__init__(params, additional_targets)
86
+
87
+ @property
88
+ def default_data_name(self) -> str:
89
+ return "keypoints"
90
+
91
+ def ensure_data_valid(self, data: Dict[str, Any]) -> None:
92
+ if self.params.label_fields:
93
+ if not all(i in data.keys() for i in self.params.label_fields):
94
+ raise ValueError(
95
+ "Your 'label_fields' are not valid - them must have same names as params in "
96
+ "'keypoint_params' dict"
97
+ )
98
+
99
+ def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
100
+ # IAA-based augmentations supports only transformation of xy keypoints.
101
+ # If your keypoints formats is other than 'xy' we emit warning to let user
102
+ # be aware that angle and size will not be modified.
103
+
104
+ try:
105
+ from custom_albumentations.imgaug.transforms import DualIAATransform
106
+ except ImportError:
107
+ # imgaug is not installed so we skip imgaug checks.
108
+ return
109
+
110
+ if self.params.format is not None and self.params.format != "xy":
111
+ for transform in transforms:
112
+ if isinstance(transform, DualIAATransform):
113
+ warnings.warn(
114
+ "{} transformation supports only 'xy' keypoints "
115
+ "augmentation. You have '{}' keypoints format. Scale "
116
+ "and angle WILL NOT BE transformed.".format(transform.__class__.__name__, self.params.format)
117
+ )
118
+ break
119
+
120
+ def filter(self, data: Sequence[Sequence], rows: int, cols: int) -> Sequence[Sequence]:
121
+ self.params: KeypointParams
122
+ return filter_keypoints(data, rows, cols, remove_invisible=self.params.remove_invisible)
123
+
124
+ def check(self, data: Sequence[Sequence], rows: int, cols: int) -> None:
125
+ check_keypoints(data, rows, cols)
126
+
127
+ def convert_from_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]:
128
+ params = self.params
129
+ return convert_keypoints_from_albumentations(
130
+ data,
131
+ params.format,
132
+ rows,
133
+ cols,
134
+ check_validity=params.remove_invisible,
135
+ angle_in_degrees=params.angle_in_degrees,
136
+ )
137
+
138
+ def convert_to_albumentations(self, data: Sequence[Sequence], rows: int, cols: int) -> List[Tuple]:
139
+ params = self.params
140
+ return convert_keypoints_to_albumentations(
141
+ data,
142
+ params.format,
143
+ rows,
144
+ cols,
145
+ check_validity=params.remove_invisible,
146
+ angle_in_degrees=params.angle_in_degrees,
147
+ )
148
+
149
+
150
+ def check_keypoint(kp: Sequence, rows: int, cols: int) -> None:
151
+ """Check if keypoint coordinates are less than image shapes"""
152
+ for name, value, size in zip(["x", "y"], kp[:2], [cols, rows]):
153
+ if not 0 <= value < size:
154
+ raise ValueError(
155
+ "Expected {name} for keypoint {kp} "
156
+ "to be in the range [0.0, {size}], got {value}.".format(kp=kp, name=name, value=value, size=size)
157
+ )
158
+
159
+ angle = kp[2]
160
+ if not (0 <= angle < 2 * math.pi):
161
+ raise ValueError("Keypoint angle must be in range [0, 2 * PI). Got: {angle}".format(angle=angle))
162
+
163
+
164
+ def check_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int) -> None:
165
+ """Check if keypoints boundaries are less than image shapes"""
166
+ for kp in keypoints:
167
+ check_keypoint(kp, rows, cols)
168
+
169
+
170
+ def filter_keypoints(keypoints: Sequence[Sequence], rows: int, cols: int, remove_invisible: bool) -> Sequence[Sequence]:
171
+ if not remove_invisible:
172
+ return keypoints
173
+
174
+ resulting_keypoints = []
175
+ for kp in keypoints:
176
+ x, y = kp[:2]
177
+ if x < 0 or x >= cols:
178
+ continue
179
+ if y < 0 or y >= rows:
180
+ continue
181
+ resulting_keypoints.append(kp)
182
+ return resulting_keypoints
183
+
184
+
185
+ def convert_keypoint_to_albumentations(
186
+ keypoint: Sequence,
187
+ source_format: str,
188
+ rows: int,
189
+ cols: int,
190
+ check_validity: bool = False,
191
+ angle_in_degrees: bool = True,
192
+ ) -> Tuple:
193
+ if source_format not in keypoint_formats:
194
+ raise ValueError("Unknown target_format {}. Supported formats are: {}".format(source_format, keypoint_formats))
195
+
196
+ if source_format == "xy":
197
+ (x, y), tail = keypoint[:2], tuple(keypoint[2:])
198
+ a, s = 0.0, 0.0
199
+ elif source_format == "yx":
200
+ (y, x), tail = keypoint[:2], tuple(keypoint[2:])
201
+ a, s = 0.0, 0.0
202
+ elif source_format == "xya":
203
+ (x, y, a), tail = keypoint[:3], tuple(keypoint[3:])
204
+ s = 0.0
205
+ elif source_format == "xys":
206
+ (x, y, s), tail = keypoint[:3], tuple(keypoint[3:])
207
+ a = 0.0
208
+ elif source_format == "xyas":
209
+ (x, y, a, s), tail = keypoint[:4], tuple(keypoint[4:])
210
+ elif source_format == "xysa":
211
+ (x, y, s, a), tail = keypoint[:4], tuple(keypoint[4:])
212
+ else:
213
+ raise ValueError(f"Unsupported source format. Got {source_format}")
214
+
215
+ if angle_in_degrees:
216
+ a = math.radians(a)
217
+
218
+ keypoint = (x, y, angle_to_2pi_range(a), s) + tail
219
+ if check_validity:
220
+ check_keypoint(keypoint, rows, cols)
221
+ return keypoint
222
+
223
+
224
+ def convert_keypoint_from_albumentations(
225
+ keypoint: Sequence,
226
+ target_format: str,
227
+ rows: int,
228
+ cols: int,
229
+ check_validity: bool = False,
230
+ angle_in_degrees: bool = True,
231
+ ) -> Tuple:
232
+ if target_format not in keypoint_formats:
233
+ raise ValueError("Unknown target_format {}. Supported formats are: {}".format(target_format, keypoint_formats))
234
+
235
+ (x, y, angle, scale), tail = keypoint[:4], tuple(keypoint[4:])
236
+ angle = angle_to_2pi_range(angle)
237
+ if check_validity:
238
+ check_keypoint((x, y, angle, scale), rows, cols)
239
+ if angle_in_degrees:
240
+ angle = math.degrees(angle)
241
+
242
+ kp: Tuple
243
+ if target_format == "xy":
244
+ kp = (x, y)
245
+ elif target_format == "yx":
246
+ kp = (y, x)
247
+ elif target_format == "xya":
248
+ kp = (x, y, angle)
249
+ elif target_format == "xys":
250
+ kp = (x, y, scale)
251
+ elif target_format == "xyas":
252
+ kp = (x, y, angle, scale)
253
+ elif target_format == "xysa":
254
+ kp = (x, y, scale, angle)
255
+ else:
256
+ raise ValueError(f"Invalid target format. Got: {target_format}")
257
+
258
+ return kp + tail
259
+
260
+
261
+ def convert_keypoints_to_albumentations(
262
+ keypoints: Sequence[Sequence],
263
+ source_format: str,
264
+ rows: int,
265
+ cols: int,
266
+ check_validity: bool = False,
267
+ angle_in_degrees: bool = True,
268
+ ) -> List[Tuple]:
269
+ return [
270
+ convert_keypoint_to_albumentations(kp, source_format, rows, cols, check_validity, angle_in_degrees)
271
+ for kp in keypoints
272
+ ]
273
+
274
+
275
+ def convert_keypoints_from_albumentations(
276
+ keypoints: Sequence[Sequence],
277
+ target_format: str,
278
+ rows: int,
279
+ cols: int,
280
+ check_validity: bool = False,
281
+ angle_in_degrees: bool = True,
282
+ ) -> List[Tuple]:
283
+ return [
284
+ convert_keypoint_from_albumentations(kp, target_format, rows, cols, check_validity, angle_in_degrees)
285
+ for kp in keypoints
286
+ ]
custom_albumentations/core/serialization.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import json
4
+ import typing
5
+ import warnings
6
+ from abc import ABC, ABCMeta, abstractmethod
7
+ from typing import IO, Any, Callable, Dict, Optional, Tuple, Type, Union
8
+
9
+ try:
10
+ import yaml
11
+
12
+ yaml_available = True
13
+ except ImportError:
14
+ yaml_available = False
15
+
16
+
17
+ from custom_albumentations import __version__
18
+
19
+ __all__ = ["to_dict", "from_dict", "save", "load"]
20
+
21
+
22
+ SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
23
+ NON_SERIALIZABLE_REGISTRY: Dict[str, "SerializableMeta"] = {}
24
+
25
+
26
+ def shorten_class_name(class_fullname: str) -> str:
27
+ splitted = class_fullname.split(".")
28
+ if len(splitted) == 1:
29
+ return class_fullname
30
+ top_module, *_, class_name = splitted
31
+ if top_module == "albumentations":
32
+ return class_name
33
+ return class_fullname
34
+
35
+
36
+ def get_shortest_class_fullname(cls: Type) -> str:
37
+ class_fullname = "{cls.__module__}.{cls.__name__}".format(cls=cls)
38
+ return shorten_class_name(class_fullname)
39
+
40
+
41
+ class SerializableMeta(ABCMeta):
42
+ """
43
+ A metaclass that is used to register classes in `SERIALIZABLE_REGISTRY` or `NON_SERIALIZABLE_REGISTRY`
44
+ so they can be found later while deserializing transformation pipeline using classes full names.
45
+ """
46
+
47
+ def __new__(mcs, name: str, bases: Tuple[type, ...], *args, **kwargs) -> "SerializableMeta":
48
+ cls_obj = super().__new__(mcs, name, bases, *args, **kwargs)
49
+ if name != "Serializable" and ABC not in bases:
50
+ if cls_obj.is_serializable():
51
+ SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
52
+ else:
53
+ NON_SERIALIZABLE_REGISTRY[cls_obj.get_class_fullname()] = cls_obj
54
+ return cls_obj
55
+
56
+ @classmethod
57
+ def is_serializable(mcs) -> bool:
58
+ return False
59
+
60
+ @classmethod
61
+ def get_class_fullname(mcs) -> str:
62
+ return get_shortest_class_fullname(mcs)
63
+
64
+ @classmethod
65
+ def _to_dict(mcs) -> Dict[str, Any]:
66
+ return {}
67
+
68
+
69
+ class Serializable(metaclass=SerializableMeta):
70
+ @classmethod
71
+ @abstractmethod
72
+ def is_serializable(cls) -> bool:
73
+ raise NotImplementedError
74
+
75
+ @classmethod
76
+ @abstractmethod
77
+ def get_class_fullname(cls) -> str:
78
+ raise NotImplementedError
79
+
80
+ @abstractmethod
81
+ def _to_dict(self) -> Dict[str, Any]:
82
+ raise NotImplementedError
83
+
84
+ def to_dict(self, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
85
+ """
86
+ Take a transform pipeline and convert it to a serializable representation that uses only standard
87
+ python data types: dictionaries, lists, strings, integers, and floats.
88
+
89
+ Args:
90
+ self: A transform that should be serialized. If the transform doesn't implement the `to_dict`
91
+ method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
92
+ If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
93
+ but no transform parameters will be serialized.
94
+ on_not_implemented_error (str): `raise` or `warn`.
95
+ """
96
+ if on_not_implemented_error not in {"raise", "warn"}:
97
+ raise ValueError(
98
+ "Unknown on_not_implemented_error value: {}. Supported values are: 'raise' and 'warn'".format(
99
+ on_not_implemented_error
100
+ )
101
+ )
102
+ try:
103
+ transform_dict = self._to_dict()
104
+ except NotImplementedError as e:
105
+ if on_not_implemented_error == "raise":
106
+ raise e
107
+
108
+ transform_dict = {}
109
+ warnings.warn(
110
+ "Got NotImplementedError while trying to serialize {obj}. Object arguments are not preserved. "
111
+ "Implement either '{cls_name}.get_transform_init_args_names' or '{cls_name}.get_transform_init_args' "
112
+ "method to make the transform serializable".format(obj=self, cls_name=self.__class__.__name__)
113
+ )
114
+ return {"__version__": __version__, "transform": transform_dict}
115
+
116
+
117
+ def to_dict(transform: Serializable, on_not_implemented_error: str = "raise") -> Dict[str, Any]:
118
+ """
119
+ Take a transform pipeline and convert it to a serializable representation that uses only standard
120
+ python data types: dictionaries, lists, strings, integers, and floats.
121
+
122
+ Args:
123
+ transform: A transform that should be serialized. If the transform doesn't implement the `to_dict`
124
+ method and `on_not_implemented_error` equals to 'raise' then `NotImplementedError` is raised.
125
+ If `on_not_implemented_error` equals to 'warn' then `NotImplementedError` will be ignored
126
+ but no transform parameters will be serialized.
127
+ on_not_implemented_error (str): `raise` or `warn`.
128
+ """
129
+ return transform.to_dict(on_not_implemented_error)
130
+
131
+
132
+ def instantiate_nonserializable(
133
+ transform: Dict[str, Any], nonserializable: Optional[Dict[str, Any]] = None
134
+ ) -> Optional[Serializable]:
135
+ if transform.get("__class_fullname__") in NON_SERIALIZABLE_REGISTRY:
136
+ name = transform["__name__"]
137
+ if nonserializable is None:
138
+ raise ValueError(
139
+ "To deserialize a non-serializable transform with name {name} you need to pass a dict with"
140
+ "this transform as the `lambda_transforms` argument".format(name=name)
141
+ )
142
+ result_transform = nonserializable.get(name)
143
+ if transform is None:
144
+ raise ValueError(
145
+ "Non-serializable transform with {name} was not found in `nonserializable`".format(name=name)
146
+ )
147
+ return result_transform
148
+ return None
149
+
150
+
151
+ def from_dict(
152
+ transform_dict: Dict[str, Any],
153
+ nonserializable: Optional[Dict[str, Any]] = None,
154
+ lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
155
+ ) -> Optional[Serializable]:
156
+ """
157
+ Args:
158
+ transform_dict (dict): A dictionary with serialized transform pipeline.
159
+ nonserializable (dict): A dictionary that contains non-serializable transforms.
160
+ This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
161
+ Keys in that dictionary should be named same as `name` arguments in respective transforms from
162
+ a serialized pipeline.
163
+ lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
164
+ """
165
+ if lambda_transforms != "deprecated":
166
+ warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
167
+ nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
168
+
169
+ register_additional_transforms()
170
+ transform = transform_dict["transform"]
171
+ lmbd = instantiate_nonserializable(transform, nonserializable)
172
+ if lmbd:
173
+ return lmbd
174
+ name = transform["__class_fullname__"]
175
+ args = {k: v for k, v in transform.items() if k != "__class_fullname__"}
176
+ cls = SERIALIZABLE_REGISTRY[shorten_class_name(name)]
177
+ if "transforms" in args:
178
+ args["transforms"] = [from_dict({"transform": t}, nonserializable=nonserializable) for t in args["transforms"]]
179
+ return cls(**args)
180
+
181
+
182
+ def check_data_format(data_format: str) -> None:
183
+ if data_format not in {"json", "yaml"}:
184
+ raise ValueError("Unknown data_format {}. Supported formats are: 'json' and 'yaml'".format(data_format))
185
+
186
+
187
+ def save(
188
+ transform: Serializable, filepath: str, data_format: str = "json", on_not_implemented_error: str = "raise"
189
+ ) -> None:
190
+ """
191
+ Take a transform pipeline, serialize it and save a serialized version to a file
192
+ using either json or yaml format.
193
+
194
+ Args:
195
+ transform (obj): Transform to serialize.
196
+ filepath (str): Filepath to write to.
197
+ data_format (str): Serialization format. Should be either `json` or 'yaml'.
198
+ on_not_implemented_error (str): Parameter that describes what to do if a transform doesn't implement
199
+ the `to_dict` method. If 'raise' then `NotImplementedError` is raised, if `warn` then the exception will be
200
+ ignored and no transform arguments will be saved.
201
+ """
202
+ check_data_format(data_format)
203
+ transform_dict = transform.to_dict(on_not_implemented_error=on_not_implemented_error)
204
+ dump_fn = json.dump if data_format == "json" else yaml.safe_dump
205
+ with open(filepath, "w") as f:
206
+ dump_fn(transform_dict, f) # type: ignore
207
+
208
+
209
+ def load(
210
+ filepath: str,
211
+ data_format: str = "json",
212
+ nonserializable: Optional[Dict[str, Any]] = None,
213
+ lambda_transforms: Union[Optional[Dict[str, Any]], str] = "deprecated",
214
+ ) -> object:
215
+ """
216
+ Load a serialized pipeline from a json or yaml file and construct a transform pipeline.
217
+
218
+ Args:
219
+ filepath (str): Filepath to read from.
220
+ data_format (str): Serialization format. Should be either `json` or 'yaml'.
221
+ nonserializable (dict): A dictionary that contains non-serializable transforms.
222
+ This dictionary is required when you are restoring a pipeline that contains non-serializable transforms.
223
+ Keys in that dictionary should be named same as `name` arguments in respective transforms from
224
+ a serialized pipeline.
225
+ lambda_transforms (dict): Deprecated. Use 'nonserizalizable' instead.
226
+ """
227
+ if lambda_transforms != "deprecated":
228
+ warnings.warn("lambda_transforms argument is deprecated, please use 'nonserializable'", DeprecationWarning)
229
+ nonserializable = typing.cast(Optional[Dict[str, Any]], lambda_transforms)
230
+
231
+ check_data_format(data_format)
232
+ load_fn = json.load if data_format == "json" else yaml.safe_load
233
+ with open(filepath) as f:
234
+ transform_dict = load_fn(f) # type: ignore
235
+
236
+ return from_dict(transform_dict, nonserializable=nonserializable)
237
+
238
+
239
+ def register_additional_transforms() -> None:
240
+ """
241
+ Register transforms that are not imported directly into the `albumentations` module.
242
+ """
243
+ try:
244
+ # This import will result in ImportError if `torch` is not installed
245
+ import custom_albumentations.pytorch
246
+ except ImportError:
247
+ pass
custom_albumentations/core/transforms_interface.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import random
4
+ from copy import deepcopy
5
+ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union, cast
6
+ from warnings import warn
7
+
8
+ import cv2
9
+ import numpy as np
10
+
11
+ from .serialization import Serializable, get_shortest_class_fullname
12
+ from .utils import format_args
13
+
14
+ __all__ = [
15
+ "to_tuple",
16
+ "BasicTransform",
17
+ "DualTransform",
18
+ "ImageOnlyTransform",
19
+ "NoOp",
20
+ "BoxType",
21
+ "KeypointType",
22
+ "ImageColorType",
23
+ "ScaleFloatType",
24
+ "ScaleIntType",
25
+ "ImageColorType",
26
+ ]
27
+
28
+ NumType = Union[int, float, np.ndarray]
29
+ BoxInternalType = Tuple[float, float, float, float]
30
+ BoxType = Union[BoxInternalType, Tuple[float, float, float, float, Any]]
31
+ KeypointInternalType = Tuple[float, float, float, float]
32
+ KeypointType = Union[KeypointInternalType, Tuple[float, float, float, float, Any]]
33
+ ImageColorType = Union[float, Sequence[float]]
34
+
35
+ ScaleFloatType = Union[float, Tuple[float, float]]
36
+ ScaleIntType = Union[int, Tuple[int, int]]
37
+
38
+ FillValueType = Optional[Union[int, float, Sequence[int], Sequence[float]]]
39
+
40
+
41
+ def to_tuple(param, low=None, bias=None):
42
+ """Convert input argument to min-max tuple
43
+ Args:
44
+ param (scalar, tuple or list of 2+ elements): Input value.
45
+ If value is scalar, return value would be (offset - value, offset + value).
46
+ If value is tuple, return value would be value + offset (broadcasted).
47
+ low: Second element of tuple can be passed as optional argument
48
+ bias: An offset factor added to each element
49
+ """
50
+ if low is not None and bias is not None:
51
+ raise ValueError("Arguments low and bias are mutually exclusive")
52
+
53
+ if param is None:
54
+ return param
55
+
56
+ if isinstance(param, (int, float)):
57
+ if low is None:
58
+ param = -param, +param
59
+ else:
60
+ param = (low, param) if low < param else (param, low)
61
+ elif isinstance(param, Sequence):
62
+ if len(param) != 2:
63
+ raise ValueError("to_tuple expects 1 or 2 values")
64
+ param = tuple(param)
65
+ else:
66
+ raise ValueError("Argument param must be either scalar (int, float) or tuple")
67
+
68
+ if bias is not None:
69
+ return tuple(bias + x for x in param)
70
+
71
+ return tuple(param)
72
+
73
+
74
+ class BasicTransform(Serializable):
75
+ call_backup = None
76
+ interpolation: Any
77
+ fill_value: Any
78
+ mask_fill_value: Any
79
+
80
+ def __init__(self, always_apply: bool = False, p: float = 0.5):
81
+ self.p = p
82
+ self.always_apply = always_apply
83
+ self._additional_targets: Dict[str, str] = {}
84
+
85
+ # replay mode params
86
+ self.deterministic = False
87
+ self.save_key = "replay"
88
+ self.params: Dict[Any, Any] = {}
89
+ self.replay_mode = False
90
+ self.applied_in_replay = False
91
+
92
+ def __call__(self, *args, force_apply: bool = False, **kwargs) -> Dict[str, Any]:
93
+ if args:
94
+ raise KeyError("You have to pass data to augmentations as named arguments, for example: aug(image=image)")
95
+ if self.replay_mode:
96
+ if self.applied_in_replay:
97
+ return self.apply_with_params(self.params, **kwargs)
98
+
99
+ return kwargs
100
+
101
+ if (random.random() < self.p) or self.always_apply or force_apply:
102
+ params = self.get_params()
103
+
104
+ if self.targets_as_params:
105
+ assert all(key in kwargs for key in self.targets_as_params), "{} requires {}".format(
106
+ self.__class__.__name__, self.targets_as_params
107
+ )
108
+ targets_as_params = {k: kwargs[k] for k in self.targets_as_params}
109
+ params_dependent_on_targets = self.get_params_dependent_on_targets(targets_as_params)
110
+ params.update(params_dependent_on_targets)
111
+ if self.deterministic:
112
+ if self.targets_as_params:
113
+ warn(
114
+ self.get_class_fullname() + " could work incorrectly in ReplayMode for other input data"
115
+ " because its' params depend on targets."
116
+ )
117
+ kwargs[self.save_key][id(self)] = deepcopy(params)
118
+ return self.apply_with_params(params, **kwargs)
119
+
120
+ return kwargs
121
+
122
+ def apply_with_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]: # skipcq: PYL-W0613
123
+ if params is None:
124
+ return kwargs
125
+ params = self.update_params(params, **kwargs)
126
+ res = {}
127
+ for key, arg in kwargs.items():
128
+ if arg is not None:
129
+ target_function = self._get_target_function(key)
130
+ target_dependencies = {k: kwargs[k] for k in self.target_dependence.get(key, [])}
131
+ res[key] = target_function(arg, **dict(params, **target_dependencies))
132
+ else:
133
+ res[key] = None
134
+ return res
135
+
136
+ def set_deterministic(self, flag: bool, save_key: str = "replay") -> "BasicTransform":
137
+ assert save_key != "params", "params save_key is reserved"
138
+ self.deterministic = flag
139
+ self.save_key = save_key
140
+ return self
141
+
142
+ def __repr__(self) -> str:
143
+ state = self.get_base_init_args()
144
+ state.update(self.get_transform_init_args())
145
+ return "{name}({args})".format(name=self.__class__.__name__, args=format_args(state))
146
+
147
+ def _get_target_function(self, key: str) -> Callable:
148
+ transform_key = key
149
+ if key in self._additional_targets:
150
+ transform_key = self._additional_targets.get(key, key)
151
+
152
+ target_function = self.targets.get(transform_key, lambda x, **p: x)
153
+ return target_function
154
+
155
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
156
+ raise NotImplementedError
157
+
158
+ def get_params(self) -> Dict:
159
+ return {}
160
+
161
+ @property
162
+ def targets(self) -> Dict[str, Callable]:
163
+ # you must specify targets in subclass
164
+ # for example: ('image', 'mask')
165
+ # ('image', 'boxes')
166
+ raise NotImplementedError
167
+
168
+ def update_params(self, params: Dict[str, Any], **kwargs) -> Dict[str, Any]:
169
+ if hasattr(self, "interpolation"):
170
+ params["interpolation"] = self.interpolation
171
+ if hasattr(self, "fill_value"):
172
+ params["fill_value"] = self.fill_value
173
+ if hasattr(self, "mask_fill_value"):
174
+ params["mask_fill_value"] = self.mask_fill_value
175
+ params.update({"cols": kwargs["image"].shape[1], "rows": kwargs["image"].shape[0]})
176
+ return params
177
+
178
+ @property
179
+ def target_dependence(self) -> Dict:
180
+ return {}
181
+
182
+ def add_targets(self, additional_targets: Dict[str, str]):
183
+ """Add targets to transform them the same way as one of existing targets
184
+ ex: {'target_image': 'image'}
185
+ ex: {'obj1_mask': 'mask', 'obj2_mask': 'mask'}
186
+ by the way you must have at least one object with key 'image'
187
+
188
+ Args:
189
+ additional_targets (dict): keys - new target name, values - old target name. ex: {'image2': 'image'}
190
+ """
191
+ self._additional_targets = additional_targets
192
+
193
+ @property
194
+ def targets_as_params(self) -> List[str]:
195
+ return []
196
+
197
+ def get_params_dependent_on_targets(self, params: Dict[str, Any]) -> Dict[str, Any]:
198
+ raise NotImplementedError(
199
+ "Method get_params_dependent_on_targets is not implemented in class " + self.__class__.__name__
200
+ )
201
+
202
+ @classmethod
203
+ def get_class_fullname(cls) -> str:
204
+ return get_shortest_class_fullname(cls)
205
+
206
+ @classmethod
207
+ def is_serializable(cls):
208
+ return True
209
+
210
+ def get_transform_init_args_names(self) -> Tuple[str, ...]:
211
+ raise NotImplementedError(
212
+ "Class {name} is not serializable because the `get_transform_init_args_names` method is not "
213
+ "implemented".format(name=self.get_class_fullname())
214
+ )
215
+
216
+ def get_base_init_args(self) -> Dict[str, Any]:
217
+ return {"always_apply": self.always_apply, "p": self.p}
218
+
219
+ def get_transform_init_args(self) -> Dict[str, Any]:
220
+ return {k: getattr(self, k) for k in self.get_transform_init_args_names()}
221
+
222
+ def _to_dict(self) -> Dict[str, Any]:
223
+ state = {"__class_fullname__": self.get_class_fullname()}
224
+ state.update(self.get_base_init_args())
225
+ state.update(self.get_transform_init_args())
226
+ return state
227
+
228
+ def get_dict_with_id(self) -> Dict[str, Any]:
229
+ d = self._to_dict()
230
+ d["id"] = id(self)
231
+ return d
232
+
233
+
234
+ class DualTransform(BasicTransform):
235
+ """Transform for segmentation task."""
236
+
237
+ @property
238
+ def targets(self) -> Dict[str, Callable]:
239
+ return {
240
+ "image": self.apply,
241
+ "mask": self.apply_to_mask,
242
+ "masks": self.apply_to_masks,
243
+ "bboxes": self.apply_to_bboxes,
244
+ "keypoints": self.apply_to_keypoints,
245
+ }
246
+
247
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
248
+ raise NotImplementedError("Method apply_to_bbox is not implemented in class " + self.__class__.__name__)
249
+
250
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
251
+ raise NotImplementedError("Method apply_to_keypoint is not implemented in class " + self.__class__.__name__)
252
+
253
+ def apply_to_bboxes(self, bboxes: Sequence[BoxType], **params) -> List[BoxType]:
254
+ return [self.apply_to_bbox(tuple(bbox[:4]), **params) + tuple(bbox[4:]) for bbox in bboxes] # type: ignore
255
+
256
+ def apply_to_keypoints(self, keypoints: Sequence[KeypointType], **params) -> List[KeypointType]:
257
+ return [ # type: ignore
258
+ self.apply_to_keypoint(tuple(keypoint[:4]), **params) + tuple(keypoint[4:]) # type: ignore
259
+ for keypoint in keypoints
260
+ ]
261
+
262
+ def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
263
+ return self.apply(img, **{k: cv2.INTER_NEAREST if k == "interpolation" else v for k, v in params.items()})
264
+
265
+ def apply_to_masks(self, masks: Sequence[np.ndarray], **params) -> List[np.ndarray]:
266
+ return [self.apply_to_mask(mask, **params) for mask in masks]
267
+
268
+
269
+ class ImageOnlyTransform(BasicTransform):
270
+ """Transform applied to image only."""
271
+
272
+ @property
273
+ def targets(self) -> Dict[str, Callable]:
274
+ return {"image": self.apply}
275
+
276
+
277
+ class NoOp(DualTransform):
278
+ """Does nothing"""
279
+
280
+ def apply_to_keypoint(self, keypoint: KeypointInternalType, **params) -> KeypointInternalType:
281
+ return keypoint
282
+
283
+ def apply_to_bbox(self, bbox: BoxInternalType, **params) -> BoxInternalType:
284
+ return bbox
285
+
286
+ def apply(self, img: np.ndarray, **params) -> np.ndarray:
287
+ return img
288
+
289
+ def apply_to_mask(self, img: np.ndarray, **params) -> np.ndarray:
290
+ return img
291
+
292
+ def get_transform_init_args_names(self) -> Tuple:
293
+ return ()
custom_albumentations/core/utils.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import Any, Dict, Optional, Sequence, Tuple
5
+
6
+ import numpy as np
7
+
8
+ from .serialization import Serializable
9
+
10
+
11
+ def get_shape(img: Any) -> Tuple[int, int]:
12
+ if isinstance(img, np.ndarray):
13
+ rows, cols = img.shape[:2]
14
+ return rows, cols
15
+
16
+ try:
17
+ import torch
18
+
19
+ if torch.is_tensor(img):
20
+ rows, cols = img.shape[-2:]
21
+ return rows, cols
22
+ except ImportError:
23
+ pass
24
+
25
+ raise RuntimeError(
26
+ f"Albumentations supports only numpy.ndarray and torch.Tensor data type for image. Got: {type(img)}"
27
+ )
28
+
29
+
30
+ def format_args(args_dict: Dict):
31
+ formatted_args = []
32
+ for k, v in args_dict.items():
33
+ if isinstance(v, str):
34
+ v = f"'{v}'"
35
+ formatted_args.append(f"{k}={v}")
36
+ return ", ".join(formatted_args)
37
+
38
+
39
+ class Params(Serializable, ABC):
40
+ def __init__(self, format: str, label_fields: Optional[Sequence[str]] = None):
41
+ self.format = format
42
+ self.label_fields = label_fields
43
+
44
+ def _to_dict(self) -> Dict[str, Any]:
45
+ return {"format": self.format, "label_fields": self.label_fields}
46
+
47
+
48
+ class DataProcessor(ABC):
49
+ def __init__(self, params: Params, additional_targets: Optional[Dict[str, str]] = None):
50
+ self.params = params
51
+ self.data_fields = [self.default_data_name]
52
+ if additional_targets is not None:
53
+ for k, v in additional_targets.items():
54
+ if v == self.default_data_name:
55
+ self.data_fields.append(k)
56
+
57
+ @property
58
+ @abstractmethod
59
+ def default_data_name(self) -> str:
60
+ raise NotImplementedError
61
+
62
+ def ensure_data_valid(self, data: Dict[str, Any]) -> None:
63
+ pass
64
+
65
+ def ensure_transforms_valid(self, transforms: Sequence[object]) -> None:
66
+ pass
67
+
68
+ def postprocess(self, data: Dict[str, Any]) -> Dict[str, Any]:
69
+ rows, cols = get_shape(data["image"])
70
+
71
+ for data_name in self.data_fields:
72
+ data[data_name] = self.filter(data[data_name], rows, cols)
73
+ data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="from")
74
+
75
+ data = self.remove_label_fields_from_data(data)
76
+ return data
77
+
78
+ def preprocess(self, data: Dict[str, Any]) -> None:
79
+ data = self.add_label_fields_to_data(data)
80
+
81
+ rows, cols = data["image"].shape[:2]
82
+ for data_name in self.data_fields:
83
+ data[data_name] = self.check_and_convert(data[data_name], rows, cols, direction="to")
84
+
85
+ def check_and_convert(self, data: Sequence, rows: int, cols: int, direction: str = "to") -> Sequence:
86
+ if self.params.format == "albumentations":
87
+ self.check(data, rows, cols)
88
+ return data
89
+
90
+ if direction == "to":
91
+ return self.convert_to_albumentations(data, rows, cols)
92
+ elif direction == "from":
93
+ return self.convert_from_albumentations(data, rows, cols)
94
+ else:
95
+ raise ValueError(f"Invalid direction. Must be `to` or `from`. Got `{direction}`")
96
+
97
+ @abstractmethod
98
+ def filter(self, data: Sequence, rows: int, cols: int) -> Sequence:
99
+ pass
100
+
101
+ @abstractmethod
102
+ def check(self, data: Sequence, rows: int, cols: int) -> None:
103
+ pass
104
+
105
+ @abstractmethod
106
+ def convert_to_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
107
+ pass
108
+
109
+ @abstractmethod
110
+ def convert_from_albumentations(self, data: Sequence, rows: int, cols: int) -> Sequence:
111
+ pass
112
+
113
+ def add_label_fields_to_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
114
+ if self.params.label_fields is None:
115
+ return data
116
+ for data_name in self.data_fields:
117
+ for field in self.params.label_fields:
118
+ assert len(data[data_name]) == len(data[field])
119
+ data_with_added_field = []
120
+ for d, field_value in zip(data[data_name], data[field]):
121
+ data_with_added_field.append(list(d) + [field_value])
122
+ data[data_name] = data_with_added_field
123
+ return data
124
+
125
+ def remove_label_fields_from_data(self, data: Dict[str, Any]) -> Dict[str, Any]:
126
+ if self.params.label_fields is None:
127
+ return data
128
+ for data_name in self.data_fields:
129
+ label_fields_len = len(self.params.label_fields)
130
+ for idx, field in enumerate(self.params.label_fields):
131
+ field_values = []
132
+ for bbox in data[data_name]:
133
+ field_values.append(bbox[-label_fields_len + idx])
134
+ data[field] = field_values
135
+ if label_fields_len:
136
+ data[data_name] = [d[:-label_fields_len] for d in data[data_name]]
137
+ return data
custom_albumentations/imgaug/__init__.py ADDED
File without changes
custom_albumentations/imgaug/stubs.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __all__ = [
2
+ "IAAEmboss",
3
+ "IAASuperpixels",
4
+ "IAASharpen",
5
+ "IAAAdditiveGaussianNoise",
6
+ "IAACropAndPad",
7
+ "IAAFliplr",
8
+ "IAAFlipud",
9
+ "IAAAffine",
10
+ "IAAPiecewiseAffine",
11
+ "IAAPerspective",
12
+ ]
13
+
14
+
15
+ class IAAStub:
16
+ def __init__(self, *args, **kwargs):
17
+ cls_name = self.__class__.__name__
18
+ doc_link = "https://albumentations.ai/docs/api_reference/augmentations" + self.doc_link
19
+ raise RuntimeError(
20
+ f"You are trying to use a deprecated augmentation '{cls_name}' which depends on the imgaug library, "
21
+ f"but imgaug is not installed.\n\n"
22
+ "There are two options to fix this error:\n"
23
+ "1. [Recommended]. Switch to the Albumentations' implementation of the augmentation with the same API: "
24
+ f"{self.alternative} - {doc_link}\n"
25
+ "2. Install a version of Albumentations that contains imgaug by running "
26
+ "'pip install -U albumentations[imgaug]'."
27
+ )
28
+
29
+
30
+ class IAACropAndPad(IAAStub):
31
+ alternative = "CropAndPad"
32
+ doc_link = "/crops/transforms/#albumentations.augmentations.crops.transforms.CropAndPad"
33
+
34
+
35
+ class IAAFliplr(IAAStub):
36
+ alternative = "HorizontalFlip"
37
+ doc_link = "/transforms/#albumentations.augmentations.transforms.HorizontalFlip"
38
+
39
+
40
+ class IAAFlipud(IAAStub):
41
+ alternative = "VerticalFlip"
42
+ doc_link = "/transforms/#albumentations.augmentations.transforms.VerticalFlip"
43
+
44
+
45
+ class IAAEmboss(IAAStub):
46
+ alternative = "Emboss"
47
+ doc_link = "/transforms/#albumentations.augmentations.transforms.Emboss"
48
+
49
+
50
+ class IAASuperpixels(IAAStub):
51
+ alternative = "Superpixels"
52
+ doc_link = "/transforms/#albumentations.augmentations.transforms.Superpixels"
53
+
54
+
55
+ class IAASharpen(IAAStub):
56
+ alternative = "Sharpen"
57
+ doc_link = "/transforms/#albumentations.augmentations.transforms.Sharpen"
58
+
59
+
60
+ class IAAAdditiveGaussianNoise(IAAStub):
61
+ alternative = "GaussNoise"
62
+ doc_link = "/transforms/#albumentations.augmentations.transforms.GaussNoise"
63
+
64
+
65
+ class IAAPiecewiseAffine(IAAStub):
66
+ alternative = "PiecewiseAffine"
67
+ doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.PiecewiseAffine"
68
+
69
+
70
+ class IAAAffine(IAAStub):
71
+ alternative = "Affine"
72
+ doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Affine"
73
+
74
+
75
+ class IAAPerspective(IAAStub):
76
+ alternative = "Perspective"
77
+ doc_link = "/geometric/transforms/#albumentations.augmentations.geometric.transforms.Perspective"
custom_albumentations/imgaug/transforms.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ try:
2
+ import imgaug as ia
3
+ except ImportError as e:
4
+ raise ImportError(
5
+ "You are trying to import an augmentation that depends on the imgaug library, but imgaug is not installed. To "
6
+ "install a version of Albumentations that contains imgaug please run 'pip install -U albumentations[imgaug]'"
7
+ ) from e
8
+
9
+ try:
10
+ from imgaug import augmenters as iaa
11
+ except ImportError:
12
+ import imgaug.imgaug.augmenters as iaa
13
+
14
+ import warnings
15
+
16
+ from custom_albumentations.core.bbox_utils import (
17
+ convert_bboxes_from_albumentations,
18
+ convert_bboxes_to_albumentations,
19
+ )
20
+ from custom_albumentations.core.keypoints_utils import (
21
+ convert_keypoints_from_albumentations,
22
+ convert_keypoints_to_albumentations,
23
+ )
24
+
25
+ from ..augmentations import Perspective
26
+ from ..core.transforms_interface import (
27
+ BasicTransform,
28
+ DualTransform,
29
+ ImageOnlyTransform,
30
+ to_tuple,
31
+ )
32
+
33
+ __all__ = [
34
+ "BasicIAATransform",
35
+ "DualIAATransform",
36
+ "ImageOnlyIAATransform",
37
+ "IAAEmboss",
38
+ "IAASuperpixels",
39
+ "IAASharpen",
40
+ "IAAAdditiveGaussianNoise",
41
+ "IAACropAndPad",
42
+ "IAAFliplr",
43
+ "IAAFlipud",
44
+ "IAAAffine",
45
+ "IAAPiecewiseAffine",
46
+ "IAAPerspective",
47
+ ]
48
+
49
+
50
+ class BasicIAATransform(BasicTransform):
51
+ def __init__(self, always_apply=False, p=0.5):
52
+ super(BasicIAATransform, self).__init__(always_apply, p)
53
+
54
+ @property
55
+ def processor(self):
56
+ return iaa.Noop()
57
+
58
+ def update_params(self, params, **kwargs):
59
+ params = super(BasicIAATransform, self).update_params(params, **kwargs)
60
+ params["deterministic_processor"] = self.processor.to_deterministic()
61
+ return params
62
+
63
+ def apply(self, img, deterministic_processor=None, **params):
64
+ return deterministic_processor.augment_image(img)
65
+
66
+
67
+ class DualIAATransform(DualTransform, BasicIAATransform):
68
+ def apply_to_bboxes(self, bboxes, deterministic_processor=None, rows=0, cols=0, **params):
69
+ if len(bboxes) > 0:
70
+ bboxes = convert_bboxes_from_albumentations(bboxes, "pascal_voc", rows=rows, cols=cols)
71
+
72
+ bboxes_t = ia.BoundingBoxesOnImage([ia.BoundingBox(*bbox[:4]) for bbox in bboxes], (rows, cols))
73
+ bboxes_t = deterministic_processor.augment_bounding_boxes([bboxes_t])[0].bounding_boxes
74
+ bboxes_t = [
75
+ [bbox.x1, bbox.y1, bbox.x2, bbox.y2] + list(bbox_orig[4:])
76
+ for (bbox, bbox_orig) in zip(bboxes_t, bboxes)
77
+ ]
78
+
79
+ bboxes = convert_bboxes_to_albumentations(bboxes_t, "pascal_voc", rows=rows, cols=cols)
80
+ return bboxes
81
+
82
+ """Applies transformation to keypoints.
83
+ Notes:
84
+ Since IAA supports only xy keypoints, scale and orientation will remain unchanged.
85
+ TODO:
86
+ Emit a warning message if child classes of DualIAATransform are instantiated
87
+ inside Compose with keypoints format other than 'xy'.
88
+ """
89
+
90
+ def apply_to_keypoints(self, keypoints, deterministic_processor=None, rows=0, cols=0, **params):
91
+ if len(keypoints) > 0:
92
+ keypoints = convert_keypoints_from_albumentations(keypoints, "xy", rows=rows, cols=cols)
93
+ keypoints_t = ia.KeypointsOnImage([ia.Keypoint(*kp[:2]) for kp in keypoints], (rows, cols))
94
+ keypoints_t = deterministic_processor.augment_keypoints([keypoints_t])[0].keypoints
95
+
96
+ bboxes_t = [[kp.x, kp.y] + list(kp_orig[2:]) for (kp, kp_orig) in zip(keypoints_t, keypoints)]
97
+
98
+ keypoints = convert_keypoints_to_albumentations(bboxes_t, "xy", rows=rows, cols=cols)
99
+ return keypoints
100
+
101
+
102
+ class ImageOnlyIAATransform(ImageOnlyTransform, BasicIAATransform):
103
+ pass
104
+
105
+
106
+ class IAACropAndPad(DualIAATransform):
107
+ """This augmentation is deprecated. Please use CropAndPad instead."""
108
+
109
+ def __init__(self, px=None, percent=None, pad_mode="constant", pad_cval=0, keep_size=True, always_apply=False, p=1):
110
+ super(IAACropAndPad, self).__init__(always_apply, p)
111
+ self.px = px
112
+ self.percent = percent
113
+ self.pad_mode = pad_mode
114
+ self.pad_cval = pad_cval
115
+ self.keep_size = keep_size
116
+ warnings.warn("IAACropAndPad is deprecated. Please use CropAndPad instead", FutureWarning)
117
+
118
+ @property
119
+ def processor(self):
120
+ return iaa.CropAndPad(self.px, self.percent, self.pad_mode, self.pad_cval, self.keep_size)
121
+
122
+ def get_transform_init_args_names(self):
123
+ return ("px", "percent", "pad_mode", "pad_cval", "keep_size")
124
+
125
+
126
+ class IAAFliplr(DualIAATransform):
127
+ """This augmentation is deprecated. Please use HorizontalFlip instead."""
128
+
129
+ def __init__(self, always_apply=False, p=0.5):
130
+ super().__init__(always_apply, p)
131
+ warnings.warn("IAAFliplr is deprecated. Please use HorizontalFlip instead.", FutureWarning)
132
+
133
+ @property
134
+ def processor(self):
135
+ return iaa.Fliplr(1)
136
+
137
+ def get_transform_init_args_names(self):
138
+ return ()
139
+
140
+
141
+ class IAAFlipud(DualIAATransform):
142
+ """This augmentation is deprecated. Please use VerticalFlip instead."""
143
+
144
+ def __init__(self, always_apply=False, p=0.5):
145
+ super().__init__(always_apply, p)
146
+ warnings.warn("IAAFlipud is deprecated. Please use VerticalFlip instead.", FutureWarning)
147
+
148
+ @property
149
+ def processor(self):
150
+ return iaa.Flipud(1)
151
+
152
+ def get_transform_init_args_names(self):
153
+ return ()
154
+
155
+
156
+ class IAAEmboss(ImageOnlyIAATransform):
157
+ """Emboss the input image and overlays the result with the original image.
158
+ This augmentation is deprecated. Please use Emboss instead.
159
+
160
+ Args:
161
+ alpha ((float, float)): range to choose the visibility of the embossed image. At 0, only the original image is
162
+ visible,at 1.0 only its embossed version is visible. Default: (0.2, 0.5).
163
+ strength ((float, float)): strength range of the embossing. Default: (0.2, 0.7).
164
+ p (float): probability of applying the transform. Default: 0.5.
165
+
166
+ Targets:
167
+ image
168
+ """
169
+
170
+ def __init__(self, alpha=(0.2, 0.5), strength=(0.2, 0.7), always_apply=False, p=0.5):
171
+ super(IAAEmboss, self).__init__(always_apply, p)
172
+ self.alpha = to_tuple(alpha, 0.0)
173
+ self.strength = to_tuple(strength, 0.0)
174
+ warnings.warn("This augmentation is deprecated. Please use Emboss instead", FutureWarning)
175
+
176
+ @property
177
+ def processor(self):
178
+ return iaa.Emboss(self.alpha, self.strength)
179
+
180
+ def get_transform_init_args_names(self):
181
+ return ("alpha", "strength")
182
+
183
+
184
+ class IAASuperpixels(ImageOnlyIAATransform):
185
+ """Completely or partially transform the input image to its superpixel representation. Uses skimage's version
186
+ of the SLIC algorithm. May be slow.
187
+
188
+ This augmentation is deprecated. Please use Superpixels instead.
189
+
190
+ Args:
191
+ p_replace (float): defines the probability of any superpixel area being replaced by the superpixel, i.e. by
192
+ the average pixel color within its area. Default: 0.1.
193
+ n_segments (int): target number of superpixels to generate. Default: 100.
194
+ p (float): probability of applying the transform. Default: 0.5.
195
+
196
+ Targets:
197
+ image
198
+ """
199
+
200
+ def __init__(self, p_replace=0.1, n_segments=100, always_apply=False, p=0.5):
201
+ super(IAASuperpixels, self).__init__(always_apply, p)
202
+ self.p_replace = p_replace
203
+ self.n_segments = n_segments
204
+ warnings.warn("IAASuperpixels is deprecated. Please use Superpixels instead.", FutureWarning)
205
+
206
+ @property
207
+ def processor(self):
208
+ return iaa.Superpixels(p_replace=self.p_replace, n_segments=self.n_segments)
209
+
210
+ def get_transform_init_args_names(self):
211
+ return ("p_replace", "n_segments")
212
+
213
+
214
+ class IAASharpen(ImageOnlyIAATransform):
215
+ """Sharpen the input image and overlays the result with the original image.
216
+ This augmentation is deprecated. Please use Sharpen instead
217
+ Args:
218
+ alpha ((float, float)): range to choose the visibility of the sharpened image. At 0, only the original image is
219
+ visible, at 1.0 only its sharpened version is visible. Default: (0.2, 0.5).
220
+ lightness ((float, float)): range to choose the lightness of the sharpened image. Default: (0.5, 1.0).
221
+ p (float): probability of applying the transform. Default: 0.5.
222
+
223
+ Targets:
224
+ image
225
+ """
226
+
227
+ def __init__(self, alpha=(0.2, 0.5), lightness=(0.5, 1.0), always_apply=False, p=0.5):
228
+ super(IAASharpen, self).__init__(always_apply, p)
229
+ self.alpha = to_tuple(alpha, 0)
230
+ self.lightness = to_tuple(lightness, 0)
231
+ warnings.warn("IAASharpen is deprecated. Please use Sharpen instead", FutureWarning)
232
+
233
+ @property
234
+ def processor(self):
235
+ return iaa.Sharpen(self.alpha, self.lightness)
236
+
237
+ def get_transform_init_args_names(self):
238
+ return ("alpha", "lightness")
239
+
240
+
241
+ class IAAAdditiveGaussianNoise(ImageOnlyIAATransform):
242
+ """Add gaussian noise to the input image.
243
+
244
+ This augmentation is deprecated. Please use GaussNoise instead.
245
+
246
+ Args:
247
+ loc (int): mean of the normal distribution that generates the noise. Default: 0.
248
+ scale ((float, float)): standard deviation of the normal distribution that generates the noise.
249
+ Default: (0.01 * 255, 0.05 * 255).
250
+ p (float): probability of applying the transform. Default: 0.5.
251
+
252
+ Targets:
253
+ image
254
+ """
255
+
256
+ def __init__(self, loc=0, scale=(0.01 * 255, 0.05 * 255), per_channel=False, always_apply=False, p=0.5):
257
+ super(IAAAdditiveGaussianNoise, self).__init__(always_apply, p)
258
+ self.loc = loc
259
+ self.scale = to_tuple(scale, 0.0)
260
+ self.per_channel = per_channel
261
+ warnings.warn("IAAAdditiveGaussianNoise is deprecated. Please use GaussNoise instead", FutureWarning)
262
+
263
+ @property
264
+ def processor(self):
265
+ return iaa.AdditiveGaussianNoise(self.loc, self.scale, self.per_channel)
266
+
267
+ def get_transform_init_args_names(self):
268
+ return ("loc", "scale", "per_channel")
269
+
270
+
271
+ class IAAPiecewiseAffine(DualIAATransform):
272
+ """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
273
+ via affine transformations.
274
+
275
+ This augmentation is deprecated. Please use PiecewiseAffine instead.
276
+
277
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
278
+
279
+ Args:
280
+ scale ((float, float): factor range that determines how far each point is moved. Default: (0.03, 0.05).
281
+ nb_rows (int): number of rows of points that the regular grid should have. Default: 4.
282
+ nb_cols (int): number of columns of points that the regular grid should have. Default: 4.
283
+ p (float): probability of applying the transform. Default: 0.5.
284
+
285
+ Targets:
286
+ image, mask
287
+ """
288
+
289
+ def __init__(
290
+ self, scale=(0.03, 0.05), nb_rows=4, nb_cols=4, order=1, cval=0, mode="constant", always_apply=False, p=0.5
291
+ ):
292
+ super(IAAPiecewiseAffine, self).__init__(always_apply, p)
293
+ self.scale = to_tuple(scale, 0.0)
294
+ self.nb_rows = nb_rows
295
+ self.nb_cols = nb_cols
296
+ self.order = order
297
+ self.cval = cval
298
+ self.mode = mode
299
+ warnings.warn("This IAAPiecewiseAffine is deprecated. Please use PiecewiseAffine instead", FutureWarning)
300
+
301
+ @property
302
+ def processor(self):
303
+ return iaa.PiecewiseAffine(self.scale, self.nb_rows, self.nb_cols, self.order, self.cval, self.mode)
304
+
305
+ def get_transform_init_args_names(self):
306
+ return ("scale", "nb_rows", "nb_cols", "order", "cval", "mode")
307
+
308
+
309
+ class IAAAffine(DualIAATransform):
310
+ """Place a regular grid of points on the input and randomly move the neighbourhood of these point around
311
+ via affine transformations.
312
+
313
+ This augmentation is deprecated. Please use Affine instead.
314
+
315
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
316
+
317
+ Args:
318
+ p (float): probability of applying the transform. Default: 0.5.
319
+
320
+ Targets:
321
+ image, mask
322
+ """
323
+
324
+ def __init__(
325
+ self,
326
+ scale=1.0,
327
+ translate_percent=None,
328
+ translate_px=None,
329
+ rotate=0.0,
330
+ shear=0.0,
331
+ order=1,
332
+ cval=0,
333
+ mode="reflect",
334
+ always_apply=False,
335
+ p=0.5,
336
+ ):
337
+ super(IAAAffine, self).__init__(always_apply, p)
338
+ self.scale = to_tuple(scale, 1.0)
339
+ self.translate_percent = to_tuple(translate_percent, 0)
340
+ self.translate_px = to_tuple(translate_px, 0)
341
+ self.rotate = to_tuple(rotate)
342
+ self.shear = to_tuple(shear)
343
+ self.order = order
344
+ self.cval = cval
345
+ self.mode = mode
346
+ warnings.warn("This IAAAffine is deprecated. Please use Affine instead", FutureWarning)
347
+
348
+ @property
349
+ def processor(self):
350
+ return iaa.Affine(
351
+ self.scale,
352
+ self.translate_percent,
353
+ self.translate_px,
354
+ self.rotate,
355
+ self.shear,
356
+ self.order,
357
+ self.cval,
358
+ self.mode,
359
+ )
360
+
361
+ def get_transform_init_args_names(self):
362
+ return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")
363
+
364
+
365
+ class IAAPerspective(Perspective):
366
+ """Perform a random four point perspective transform of the input.
367
+ This augmentation is deprecated. Please use Perspective instead.
368
+
369
+ Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
370
+
371
+ Args:
372
+ scale ((float, float): standard deviation of the normal distributions. These are used to sample
373
+ the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1).
374
+ p (float): probability of applying the transform. Default: 0.5.
375
+
376
+ Targets:
377
+ image, mask
378
+ """
379
+
380
+ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5):
381
+ super(IAAPerspective, self).__init__(always_apply, p)
382
+ self.scale = to_tuple(scale, 1.0)
383
+ self.keep_size = keep_size
384
+ warnings.warn("This IAAPerspective is deprecated. Please use Perspective instead", FutureWarning)
385
+
386
+ @property
387
+ def processor(self):
388
+ return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size)
389
+
390
+ def get_transform_init_args_names(self):
391
+ return ("scale", "keep_size")
custom_albumentations/pytorch/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ from .transforms import *
custom_albumentations/pytorch/functional.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torchvision.transforms.functional as F
6
+
7
+
8
+ def img_to_tensor(im, normalize=None):
9
+ tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
10
+ if normalize is not None:
11
+ return F.normalize(tensor, **normalize)
12
+ return tensor
13
+
14
+
15
+ def mask_to_tensor(mask, num_classes, sigmoid):
16
+ if num_classes > 1:
17
+ if not sigmoid:
18
+ # softmax
19
+ long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
20
+ if len(mask.shape) == 3:
21
+ for c in range(mask.shape[2]):
22
+ long_mask[mask[..., c] > 0] = c
23
+ else:
24
+ long_mask[mask > 127] = 1
25
+ long_mask[mask == 0] = 0
26
+ mask = long_mask
27
+ else:
28
+ mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
29
+ else:
30
+ mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
31
+ return torch.from_numpy(mask)
custom_albumentations/pytorch/transforms.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import absolute_import
2
+
3
+ import warnings
4
+
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms import functional as F
8
+
9
+ from ..core.transforms_interface import BasicTransform
10
+
11
+ __all__ = ["ToTensorV2"]
12
+
13
+
14
+ def img_to_tensor(im, normalize=None):
15
+ tensor = torch.from_numpy(np.moveaxis(im / (255.0 if im.dtype == np.uint8 else 1), -1, 0).astype(np.float32))
16
+ if normalize is not None:
17
+ return F.normalize(tensor, **normalize)
18
+ return tensor
19
+
20
+
21
+ def mask_to_tensor(mask, num_classes, sigmoid):
22
+ if num_classes > 1:
23
+ if not sigmoid:
24
+ # softmax
25
+ long_mask = np.zeros((mask.shape[:2]), dtype=np.int64)
26
+ if len(mask.shape) == 3:
27
+ for c in range(mask.shape[2]):
28
+ long_mask[mask[..., c] > 0] = c
29
+ else:
30
+ long_mask[mask > 127] = 1
31
+ long_mask[mask == 0] = 0
32
+ mask = long_mask
33
+ else:
34
+ mask = np.moveaxis(mask / (255.0 if mask.dtype == np.uint8 else 1), -1, 0).astype(np.float32)
35
+ else:
36
+ mask = np.expand_dims(mask / (255.0 if mask.dtype == np.uint8 else 1), 0).astype(np.float32)
37
+ return torch.from_numpy(mask)
38
+
39
+
40
+ class ToTensor(BasicTransform):
41
+ """Convert image and mask to `torch.Tensor` and divide by 255 if image or mask are `uint8` type.
42
+ This transform is now removed from custom_albumentations. If you need it downgrade the library to version 0.5.2.
43
+
44
+ Args:
45
+ num_classes (int): only for segmentation
46
+ sigmoid (bool, optional): only for segmentation, transform mask to LongTensor or not.
47
+ normalize (dict, optional): dict with keys [mean, std] to pass it into torchvision.normalize
48
+
49
+ """
50
+
51
+ def __init__(self, num_classes=1, sigmoid=True, normalize=None):
52
+ raise RuntimeError(
53
+ "`ToTensor` is obsolete and it was removed from custom_albumentations. Please use `ToTensorV2` instead - "
54
+ "https://albumentations.ai/docs/api_reference/pytorch/transforms/"
55
+ "#albumentations.pytorch.transforms.ToTensorV2. "
56
+ "\n\nIf you need `ToTensor` downgrade Albumentations to version 0.5.2."
57
+ )
58
+
59
+
60
+ class ToTensorV2(BasicTransform):
61
+ """Convert image and mask to `torch.Tensor`. The numpy `HWC` image is converted to pytorch `CHW` tensor.
62
+ If the image is in `HW` format (grayscale image), it will be converted to pytorch `HW` tensor.
63
+ This is a simplified and improved version of the old `ToTensor`
64
+ transform (`ToTensor` was deprecated, and now it is not present in Albumentations. You should use `ToTensorV2`
65
+ instead).
66
+
67
+ Args:
68
+ transpose_mask (bool): If True and an input mask has three dimensions, this transform will transpose dimensions
69
+ so the shape `[height, width, num_channels]` becomes `[num_channels, height, width]`. The latter format is a
70
+ standard format for PyTorch Tensors. Default: False.
71
+ always_apply (bool): Indicates whether this transformation should be always applied. Default: True.
72
+ p (float): Probability of applying the transform. Default: 1.0.
73
+ """
74
+
75
+ def __init__(self, transpose_mask=False, always_apply=True, p=1.0):
76
+ super(ToTensorV2, self).__init__(always_apply=always_apply, p=p)
77
+ self.transpose_mask = transpose_mask
78
+
79
+ @property
80
+ def targets(self):
81
+ return {"image": self.apply, "mask": self.apply_to_mask, "masks": self.apply_to_masks}
82
+
83
+ def apply(self, img, **params): # skipcq: PYL-W0613
84
+ if len(img.shape) not in [2, 3]:
85
+ raise ValueError("Albumentations only supports images in HW or HWC format")
86
+
87
+ if len(img.shape) == 2:
88
+ img = np.expand_dims(img, 2)
89
+
90
+ return torch.from_numpy(img.transpose(2, 0, 1))
91
+
92
+ def apply_to_mask(self, mask, **params): # skipcq: PYL-W0613
93
+ if self.transpose_mask and mask.ndim == 3:
94
+ mask = mask.transpose(2, 0, 1)
95
+ return torch.from_numpy(mask)
96
+
97
+ def apply_to_masks(self, masks, **params):
98
+ return [self.apply_to_mask(mask, **params) for mask in masks]
99
+
100
+ def get_transform_init_args_names(self):
101
+ return ("transpose_mask",)
102
+
103
+ def get_params_dependent_on_targets(self, params):
104
+ return {}
custom_albumentations/random_utils.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use `Any` as the return type to avoid mypy problems with Union data types,
2
+ # because numpy can return single number and ndarray
3
+
4
+ import random as py_random
5
+ from typing import Any, Optional, Sequence, Type, Union
6
+
7
+ import numpy as np
8
+
9
+ from .core.transforms_interface import NumType
10
+
11
+ IntNumType = Union[int, np.ndarray]
12
+ Size = Union[int, Sequence[int]]
13
+
14
+
15
+ def get_random_state() -> np.random.RandomState:
16
+ return np.random.RandomState(py_random.randint(0, (1 << 32) - 1))
17
+
18
+
19
+ def uniform(
20
+ low: NumType = 0.0,
21
+ high: NumType = 1.0,
22
+ size: Optional[Size] = None,
23
+ random_state: Optional[np.random.RandomState] = None,
24
+ ) -> Any:
25
+ if random_state is None:
26
+ random_state = get_random_state()
27
+ return random_state.uniform(low, high, size)
28
+
29
+
30
+ def rand(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any:
31
+ if random_state is None:
32
+ random_state = get_random_state()
33
+ return random_state.rand(d0, d1, *more, **kwargs) # type: ignore
34
+
35
+
36
+ def randn(d0: NumType, d1: NumType, *more, random_state: Optional[np.random.RandomState] = None, **kwargs) -> Any:
37
+ if random_state is None:
38
+ random_state = get_random_state()
39
+ return random_state.randn(d0, d1, *more, **kwargs) # type: ignore
40
+
41
+
42
+ def normal(
43
+ loc: NumType = 0.0,
44
+ scale: NumType = 1.0,
45
+ size: Optional[Size] = None,
46
+ random_state: Optional[np.random.RandomState] = None,
47
+ ) -> Any:
48
+ if random_state is None:
49
+ random_state = get_random_state()
50
+ return random_state.normal(loc, scale, size)
51
+
52
+
53
+ def poisson(
54
+ lam: NumType = 1.0, size: Optional[Size] = None, random_state: Optional[np.random.RandomState] = None
55
+ ) -> Any:
56
+ if random_state is None:
57
+ random_state = get_random_state()
58
+ return random_state.poisson(lam, size)
59
+
60
+
61
+ def permutation(
62
+ x: Union[int, Sequence[float], np.ndarray], random_state: Optional[np.random.RandomState] = None
63
+ ) -> Any:
64
+ if random_state is None:
65
+ random_state = get_random_state()
66
+ return random_state.permutation(x)
67
+
68
+
69
+ def randint(
70
+ low: IntNumType,
71
+ high: Optional[IntNumType] = None,
72
+ size: Optional[Size] = None,
73
+ dtype: Type = np.int32,
74
+ random_state: Optional[np.random.RandomState] = None,
75
+ ) -> Any:
76
+ if random_state is None:
77
+ random_state = get_random_state()
78
+ return random_state.randint(low, high, size, dtype)
79
+
80
+
81
+ def random(size: Optional[NumType] = None, random_state: Optional[np.random.RandomState] = None) -> Any:
82
+ if random_state is None:
83
+ random_state = get_random_state()
84
+ return random_state.random(size) # type: ignore
85
+
86
+
87
+ def choice(
88
+ a: NumType,
89
+ size: Optional[Size] = None,
90
+ replace: bool = True,
91
+ p: Optional[Union[Sequence[float], np.ndarray]] = None,
92
+ random_state: Optional[np.random.RandomState] = None,
93
+ ) -> Any:
94
+ if random_state is None:
95
+ random_state = get_random_state()
96
+ return random_state.choice(a, size, replace, p) # type: ignore
custom_controlnet_aux/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #Dummy file ensuring this package will be recognized
custom_controlnet_aux/anime_face_segment/__init__.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .network import UNet
2
+ from .util import seg2img
3
+ import torch
4
+ import os
5
+ import cv2
6
+ from custom_controlnet_aux.util import HWC3, resize_image_with_pad, common_input_validate, custom_hf_download, BDS_MODEL_NAME
7
+ from huggingface_hub import hf_hub_download
8
+ from PIL import Image
9
+ from einops import rearrange
10
+ from .anime_segmentation import AnimeSegmentation
11
+ import numpy as np
12
+
13
+ class AnimeFaceSegmentor:
14
+ def __init__(self, model, seg_model):
15
+ self.model = model
16
+ self.seg_model = seg_model
17
+ self.device = "cpu"
18
+
19
+ @classmethod
20
+ def from_pretrained(cls, pretrained_model_or_path=BDS_MODEL_NAME, filename="UNet.pth", seg_filename="isnetis.ckpt"):
21
+ model_path = custom_hf_download(pretrained_model_or_path, filename, subfolder="Annotators")
22
+ seg_model_path = custom_hf_download("skytnt/anime-seg", seg_filename)
23
+
24
+ model = UNet()
25
+ ckpt = torch.load(model_path, map_location="cpu")
26
+ model.load_state_dict(ckpt)
27
+ model.eval()
28
+
29
+ seg_model = AnimeSegmentation(seg_model_path)
30
+ seg_model.net.eval()
31
+ return cls(model, seg_model)
32
+
33
+ def to(self, device):
34
+ self.model.to(device)
35
+ self.seg_model.net.to(device)
36
+ self.device = device
37
+ return self
38
+
39
+ def __call__(self, input_image, detect_resolution=512, output_type="pil", upscale_method="INTER_CUBIC", remove_background=True, **kwargs):
40
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
41
+ input_image, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
42
+
43
+ with torch.no_grad():
44
+ if remove_background:
45
+ print(input_image.shape)
46
+ mask, input_image = self.seg_model(input_image, 0) #Don't resize image as it is resized
47
+ image_feed = torch.from_numpy(input_image).float().to(self.device)
48
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
49
+ image_feed = image_feed / 255
50
+ seg = self.model(image_feed).squeeze(dim=0)
51
+ result = seg2img(seg.cpu().detach().numpy())
52
+
53
+ detected_map = HWC3(result)
54
+ detected_map = remove_pad(detected_map)
55
+ if remove_background:
56
+ mask = remove_pad(mask)
57
+ H, W, C = detected_map.shape
58
+ tmp = np.zeros([H, W, C + 1])
59
+ tmp[:,:,:C] = detected_map
60
+ tmp[:,:,3:] = mask
61
+ detected_map = tmp
62
+
63
+ if output_type == "pil":
64
+ detected_map = Image.fromarray(detected_map[..., :3])
65
+
66
+ return detected_map
custom_controlnet_aux/anime_face_segment/anime_segmentation.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/SkyTNT/anime-segmentation/tree/main
2
+ #Only adapt isnet_is (https://huggingface.co/skytnt/anime-seg/blob/main/isnetis.ckpt)
3
+ import torch.nn as nn
4
+ import torch
5
+ from .isnet import ISNetDIS
6
+ import numpy as np
7
+ import cv2
8
+ from comfy.model_management import get_torch_device
9
+ DEVICE = get_torch_device()
10
+
11
+ class AnimeSegmentation:
12
+ def __init__(self, ckpt_path):
13
+ super(AnimeSegmentation).__init__()
14
+ sd = torch.load(ckpt_path, map_location="cpu")
15
+ self.net = ISNetDIS()
16
+ #gt_encoder isn't used during inference
17
+ self.net.load_state_dict({k.replace("net.", ''):v for k, v in sd.items() if k.startswith("net.")})
18
+ self.net = self.net.to(DEVICE)
19
+ self.net.eval()
20
+
21
+ def get_mask(self, input_img, s=640):
22
+ input_img = (input_img / 255).astype(np.float32)
23
+ if s == 0:
24
+ img_input = np.transpose(input_img, (2, 0, 1))
25
+ img_input = img_input[np.newaxis, :]
26
+ tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
27
+ with torch.no_grad():
28
+ pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
29
+ pred = pred.cpu().numpy()[0]
30
+ pred = np.transpose(pred, (1, 2, 0))
31
+ #pred = pred[:, :, np.newaxis]
32
+ return pred
33
+
34
+ h, w = h0, w0 = input_img.shape[:-1]
35
+ h, w = (s, int(s * w / h)) if h > w else (int(s * h / w), s)
36
+ ph, pw = s - h, s - w
37
+ img_input = np.zeros([s, s, 3], dtype=np.float32)
38
+ img_input[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w] = cv2.resize(input_img, (w, h))
39
+ img_input = np.transpose(img_input, (2, 0, 1))
40
+ img_input = img_input[np.newaxis, :]
41
+ tmpImg = torch.from_numpy(img_input).float().to(DEVICE)
42
+ with torch.no_grad():
43
+ pred = self.net(tmpImg)[0][0].sigmoid() #https://github.com/SkyTNT/anime-segmentation/blob/main/train.py#L92C20-L92C47
44
+ pred = pred.cpu().numpy()[0]
45
+ pred = np.transpose(pred, (1, 2, 0))
46
+ pred = pred[ph // 2:ph // 2 + h, pw // 2:pw // 2 + w]
47
+ #pred = cv2.resize(pred, (w0, h0))[:, :, np.newaxis]
48
+ pred = cv2.resize(pred, (w0, h0))
49
+ return pred
50
+
51
+ def __call__(self, np_img, img_size):
52
+ mask = self.get_mask(np_img, int(img_size))
53
+ np_img = (mask * np_img + 255 * (1 - mask)).astype(np.uint8)
54
+ mask = (mask * 255).astype(np.uint8)
55
+ #np_img = np.concatenate([np_img, mask], axis=2, dtype=np.uint8)
56
+ #mask = mask.repeat(3, axis=2)
57
+ return mask, np_img
58
+
custom_controlnet_aux/anime_face_segment/isnet.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Codes are borrowed from
2
+ # https://github.com/xuebinqin/DIS/blob/main/IS-Net/models/isnet.py
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torchvision import models
8
+
9
+ bce_loss = nn.BCEWithLogitsLoss(reduction="mean")
10
+
11
+
12
+ def muti_loss_fusion(preds, target):
13
+ loss0 = 0.0
14
+ loss = 0.0
15
+
16
+ for i in range(0, len(preds)):
17
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
18
+ tmp_target = F.interpolate(
19
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
20
+ )
21
+ loss = loss + bce_loss(preds[i], tmp_target)
22
+ else:
23
+ loss = loss + bce_loss(preds[i], target)
24
+ if i == 0:
25
+ loss0 = loss
26
+ return loss0, loss
27
+
28
+
29
+ fea_loss = nn.MSELoss(reduction="mean")
30
+ kl_loss = nn.KLDivLoss(reduction="mean")
31
+ l1_loss = nn.L1Loss(reduction="mean")
32
+ smooth_l1_loss = nn.SmoothL1Loss(reduction="mean")
33
+
34
+
35
+ def muti_loss_fusion_kl(preds, target, dfs, fs, mode="MSE"):
36
+ loss0 = 0.0
37
+ loss = 0.0
38
+
39
+ for i in range(0, len(preds)):
40
+ if preds[i].shape[2] != target.shape[2] or preds[i].shape[3] != target.shape[3]:
41
+ tmp_target = F.interpolate(
42
+ target, size=preds[i].size()[2:], mode="bilinear", align_corners=True
43
+ )
44
+ loss = loss + bce_loss(preds[i], tmp_target)
45
+ else:
46
+ loss = loss + bce_loss(preds[i], target)
47
+ if i == 0:
48
+ loss0 = loss
49
+
50
+ for i in range(0, len(dfs)):
51
+ df = dfs[i]
52
+ fs_i = fs[i]
53
+ if mode == "MSE":
54
+ loss = loss + fea_loss(
55
+ df, fs_i
56
+ ) ### add the mse loss of features as additional constraints
57
+ elif mode == "KL":
58
+ loss = loss + kl_loss(F.log_softmax(df, dim=1), F.softmax(fs_i, dim=1))
59
+ elif mode == "MAE":
60
+ loss = loss + l1_loss(df, fs_i)
61
+ elif mode == "SmoothL1":
62
+ loss = loss + smooth_l1_loss(df, fs_i)
63
+
64
+ return loss0, loss
65
+
66
+
67
+ class REBNCONV(nn.Module):
68
+ def __init__(self, in_ch=3, out_ch=3, dirate=1, stride=1):
69
+ super(REBNCONV, self).__init__()
70
+
71
+ self.conv_s1 = nn.Conv2d(
72
+ in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate, stride=stride
73
+ )
74
+ self.bn_s1 = nn.BatchNorm2d(out_ch)
75
+ self.relu_s1 = nn.ReLU(inplace=True)
76
+
77
+ def forward(self, x):
78
+ hx = x
79
+ xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
80
+
81
+ return xout
82
+
83
+
84
+ ## upsample tensor 'src' to have the same spatial size with tensor 'tar'
85
+ def _upsample_like(src, tar):
86
+ src = F.interpolate(src, size=tar.shape[2:], mode="bilinear", align_corners=False)
87
+
88
+ return src
89
+
90
+
91
+ ### RSU-7 ###
92
+ class RSU7(nn.Module):
93
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512):
94
+ super(RSU7, self).__init__()
95
+
96
+ self.in_ch = in_ch
97
+ self.mid_ch = mid_ch
98
+ self.out_ch = out_ch
99
+
100
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1) ## 1 -> 1/2
101
+
102
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
103
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
104
+
105
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
106
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
107
+
108
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
109
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
110
+
111
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
112
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
113
+
114
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
115
+ self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
116
+
117
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
118
+
119
+ self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
120
+
121
+ self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
122
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
123
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
124
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
125
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
126
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
127
+
128
+ def forward(self, x):
129
+ b, c, h, w = x.shape
130
+
131
+ hx = x
132
+ hxin = self.rebnconvin(hx)
133
+
134
+ hx1 = self.rebnconv1(hxin)
135
+ hx = self.pool1(hx1)
136
+
137
+ hx2 = self.rebnconv2(hx)
138
+ hx = self.pool2(hx2)
139
+
140
+ hx3 = self.rebnconv3(hx)
141
+ hx = self.pool3(hx3)
142
+
143
+ hx4 = self.rebnconv4(hx)
144
+ hx = self.pool4(hx4)
145
+
146
+ hx5 = self.rebnconv5(hx)
147
+ hx = self.pool5(hx5)
148
+
149
+ hx6 = self.rebnconv6(hx)
150
+
151
+ hx7 = self.rebnconv7(hx6)
152
+
153
+ hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
154
+ hx6dup = _upsample_like(hx6d, hx5)
155
+
156
+ hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
157
+ hx5dup = _upsample_like(hx5d, hx4)
158
+
159
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
160
+ hx4dup = _upsample_like(hx4d, hx3)
161
+
162
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
163
+ hx3dup = _upsample_like(hx3d, hx2)
164
+
165
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
166
+ hx2dup = _upsample_like(hx2d, hx1)
167
+
168
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
169
+
170
+ return hx1d + hxin
171
+
172
+
173
+ ### RSU-6 ###
174
+ class RSU6(nn.Module):
175
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
176
+ super(RSU6, self).__init__()
177
+
178
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
179
+
180
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
181
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
182
+
183
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
184
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
185
+
186
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
187
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
188
+
189
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
190
+ self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
191
+
192
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
193
+
194
+ self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
195
+
196
+ self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
197
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
198
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
199
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
200
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
201
+
202
+ def forward(self, x):
203
+ hx = x
204
+
205
+ hxin = self.rebnconvin(hx)
206
+
207
+ hx1 = self.rebnconv1(hxin)
208
+ hx = self.pool1(hx1)
209
+
210
+ hx2 = self.rebnconv2(hx)
211
+ hx = self.pool2(hx2)
212
+
213
+ hx3 = self.rebnconv3(hx)
214
+ hx = self.pool3(hx3)
215
+
216
+ hx4 = self.rebnconv4(hx)
217
+ hx = self.pool4(hx4)
218
+
219
+ hx5 = self.rebnconv5(hx)
220
+
221
+ hx6 = self.rebnconv6(hx5)
222
+
223
+ hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
224
+ hx5dup = _upsample_like(hx5d, hx4)
225
+
226
+ hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
227
+ hx4dup = _upsample_like(hx4d, hx3)
228
+
229
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
230
+ hx3dup = _upsample_like(hx3d, hx2)
231
+
232
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
233
+ hx2dup = _upsample_like(hx2d, hx1)
234
+
235
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
236
+
237
+ return hx1d + hxin
238
+
239
+
240
+ ### RSU-5 ###
241
+ class RSU5(nn.Module):
242
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
243
+ super(RSU5, self).__init__()
244
+
245
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
246
+
247
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
248
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
249
+
250
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
251
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
252
+
253
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
254
+ self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
255
+
256
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
257
+
258
+ self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
259
+
260
+ self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
261
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
262
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
263
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
264
+
265
+ def forward(self, x):
266
+ hx = x
267
+
268
+ hxin = self.rebnconvin(hx)
269
+
270
+ hx1 = self.rebnconv1(hxin)
271
+ hx = self.pool1(hx1)
272
+
273
+ hx2 = self.rebnconv2(hx)
274
+ hx = self.pool2(hx2)
275
+
276
+ hx3 = self.rebnconv3(hx)
277
+ hx = self.pool3(hx3)
278
+
279
+ hx4 = self.rebnconv4(hx)
280
+
281
+ hx5 = self.rebnconv5(hx4)
282
+
283
+ hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
284
+ hx4dup = _upsample_like(hx4d, hx3)
285
+
286
+ hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
287
+ hx3dup = _upsample_like(hx3d, hx2)
288
+
289
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
290
+ hx2dup = _upsample_like(hx2d, hx1)
291
+
292
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
293
+
294
+ return hx1d + hxin
295
+
296
+
297
+ ### RSU-4 ###
298
+ class RSU4(nn.Module):
299
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
300
+ super(RSU4, self).__init__()
301
+
302
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
303
+
304
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
305
+ self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
306
+
307
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
308
+ self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
309
+
310
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
311
+
312
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
313
+
314
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
315
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
316
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
317
+
318
+ def forward(self, x):
319
+ hx = x
320
+
321
+ hxin = self.rebnconvin(hx)
322
+
323
+ hx1 = self.rebnconv1(hxin)
324
+ hx = self.pool1(hx1)
325
+
326
+ hx2 = self.rebnconv2(hx)
327
+ hx = self.pool2(hx2)
328
+
329
+ hx3 = self.rebnconv3(hx)
330
+
331
+ hx4 = self.rebnconv4(hx3)
332
+
333
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
334
+ hx3dup = _upsample_like(hx3d, hx2)
335
+
336
+ hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
337
+ hx2dup = _upsample_like(hx2d, hx1)
338
+
339
+ hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
340
+
341
+ return hx1d + hxin
342
+
343
+
344
+ ### RSU-4F ###
345
+ class RSU4F(nn.Module):
346
+ def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
347
+ super(RSU4F, self).__init__()
348
+
349
+ self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
350
+
351
+ self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
352
+ self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
353
+ self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
354
+
355
+ self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
356
+
357
+ self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
358
+ self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
359
+ self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
360
+
361
+ def forward(self, x):
362
+ hx = x
363
+
364
+ hxin = self.rebnconvin(hx)
365
+
366
+ hx1 = self.rebnconv1(hxin)
367
+ hx2 = self.rebnconv2(hx1)
368
+ hx3 = self.rebnconv3(hx2)
369
+
370
+ hx4 = self.rebnconv4(hx3)
371
+
372
+ hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
373
+ hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
374
+ hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
375
+
376
+ return hx1d + hxin
377
+
378
+
379
+ class myrebnconv(nn.Module):
380
+ def __init__(
381
+ self,
382
+ in_ch=3,
383
+ out_ch=1,
384
+ kernel_size=3,
385
+ stride=1,
386
+ padding=1,
387
+ dilation=1,
388
+ groups=1,
389
+ ):
390
+ super(myrebnconv, self).__init__()
391
+
392
+ self.conv = nn.Conv2d(
393
+ in_ch,
394
+ out_ch,
395
+ kernel_size=kernel_size,
396
+ stride=stride,
397
+ padding=padding,
398
+ dilation=dilation,
399
+ groups=groups,
400
+ )
401
+ self.bn = nn.BatchNorm2d(out_ch)
402
+ self.rl = nn.ReLU(inplace=True)
403
+
404
+ def forward(self, x):
405
+ return self.rl(self.bn(self.conv(x)))
406
+
407
+
408
+ class ISNetGTEncoder(nn.Module):
409
+ def __init__(self, in_ch=1, out_ch=1):
410
+ super(ISNetGTEncoder, self).__init__()
411
+
412
+ self.conv_in = myrebnconv(
413
+ in_ch, 16, 3, stride=2, padding=1
414
+ ) # nn.Conv2d(in_ch,64,3,stride=2,padding=1)
415
+
416
+ self.stage1 = RSU7(16, 16, 64)
417
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
418
+
419
+ self.stage2 = RSU6(64, 16, 64)
420
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
421
+
422
+ self.stage3 = RSU5(64, 32, 128)
423
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
424
+
425
+ self.stage4 = RSU4(128, 32, 256)
426
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
427
+
428
+ self.stage5 = RSU4F(256, 64, 512)
429
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
430
+
431
+ self.stage6 = RSU4F(512, 64, 512)
432
+
433
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
434
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
435
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
436
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
437
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
438
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
439
+
440
+ @staticmethod
441
+ def compute_loss(args):
442
+ preds, targets = args
443
+ return muti_loss_fusion(preds, targets)
444
+
445
+ def forward(self, x):
446
+ hx = x
447
+
448
+ hxin = self.conv_in(hx)
449
+ # hx = self.pool_in(hxin)
450
+
451
+ # stage 1
452
+ hx1 = self.stage1(hxin)
453
+ hx = self.pool12(hx1)
454
+
455
+ # stage 2
456
+ hx2 = self.stage2(hx)
457
+ hx = self.pool23(hx2)
458
+
459
+ # stage 3
460
+ hx3 = self.stage3(hx)
461
+ hx = self.pool34(hx3)
462
+
463
+ # stage 4
464
+ hx4 = self.stage4(hx)
465
+ hx = self.pool45(hx4)
466
+
467
+ # stage 5
468
+ hx5 = self.stage5(hx)
469
+ hx = self.pool56(hx5)
470
+
471
+ # stage 6
472
+ hx6 = self.stage6(hx)
473
+
474
+ # side output
475
+ d1 = self.side1(hx1)
476
+ d1 = _upsample_like(d1, x)
477
+
478
+ d2 = self.side2(hx2)
479
+ d2 = _upsample_like(d2, x)
480
+
481
+ d3 = self.side3(hx3)
482
+ d3 = _upsample_like(d3, x)
483
+
484
+ d4 = self.side4(hx4)
485
+ d4 = _upsample_like(d4, x)
486
+
487
+ d5 = self.side5(hx5)
488
+ d5 = _upsample_like(d5, x)
489
+
490
+ d6 = self.side6(hx6)
491
+ d6 = _upsample_like(d6, x)
492
+
493
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
494
+
495
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1, hx2, hx3, hx4, hx5, hx6]
496
+ return [d1, d2, d3, d4, d5, d6], [hx1, hx2, hx3, hx4, hx5, hx6]
497
+
498
+
499
+ class ISNetDIS(nn.Module):
500
+ def __init__(self, in_ch=3, out_ch=1):
501
+ super(ISNetDIS, self).__init__()
502
+
503
+ self.conv_in = nn.Conv2d(in_ch, 64, 3, stride=2, padding=1)
504
+ self.pool_in = nn.MaxPool2d(2, stride=2, ceil_mode=True)
505
+
506
+ self.stage1 = RSU7(64, 32, 64)
507
+ self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
508
+
509
+ self.stage2 = RSU6(64, 32, 128)
510
+ self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
511
+
512
+ self.stage3 = RSU5(128, 64, 256)
513
+ self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
514
+
515
+ self.stage4 = RSU4(256, 128, 512)
516
+ self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
517
+
518
+ self.stage5 = RSU4F(512, 256, 512)
519
+ self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
520
+
521
+ self.stage6 = RSU4F(512, 256, 512)
522
+
523
+ # decoder
524
+ self.stage5d = RSU4F(1024, 256, 512)
525
+ self.stage4d = RSU4(1024, 128, 256)
526
+ self.stage3d = RSU5(512, 64, 128)
527
+ self.stage2d = RSU6(256, 32, 64)
528
+ self.stage1d = RSU7(128, 16, 64)
529
+
530
+ self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
531
+ self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
532
+ self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
533
+ self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
534
+ self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
535
+ self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
536
+
537
+ # self.outconv = nn.Conv2d(6*out_ch,out_ch,1)
538
+
539
+ @staticmethod
540
+ def compute_loss_kl(preds, targets, dfs, fs, mode="MSE"):
541
+ return muti_loss_fusion_kl(preds, targets, dfs, fs, mode=mode)
542
+
543
+ @staticmethod
544
+ def compute_loss(args):
545
+ if len(args) == 3:
546
+ ds, dfs, labels = args
547
+ return muti_loss_fusion(ds, labels)
548
+ else:
549
+ ds, dfs, labels, fs = args
550
+ return muti_loss_fusion_kl(ds, labels, dfs, fs, mode="MSE")
551
+
552
+ def forward(self, x):
553
+ hx = x
554
+
555
+ hxin = self.conv_in(hx)
556
+ hx = self.pool_in(hxin)
557
+
558
+ # stage 1
559
+ hx1 = self.stage1(hxin)
560
+ hx = self.pool12(hx1)
561
+
562
+ # stage 2
563
+ hx2 = self.stage2(hx)
564
+ hx = self.pool23(hx2)
565
+
566
+ # stage 3
567
+ hx3 = self.stage3(hx)
568
+ hx = self.pool34(hx3)
569
+
570
+ # stage 4
571
+ hx4 = self.stage4(hx)
572
+ hx = self.pool45(hx4)
573
+
574
+ # stage 5
575
+ hx5 = self.stage5(hx)
576
+ hx = self.pool56(hx5)
577
+
578
+ # stage 6
579
+ hx6 = self.stage6(hx)
580
+ hx6up = _upsample_like(hx6, hx5)
581
+
582
+ # -------------------- decoder --------------------
583
+ hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
584
+ hx5dup = _upsample_like(hx5d, hx4)
585
+
586
+ hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
587
+ hx4dup = _upsample_like(hx4d, hx3)
588
+
589
+ hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
590
+ hx3dup = _upsample_like(hx3d, hx2)
591
+
592
+ hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
593
+ hx2dup = _upsample_like(hx2d, hx1)
594
+
595
+ hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
596
+
597
+ # side output
598
+ d1 = self.side1(hx1d)
599
+ d1 = _upsample_like(d1, x)
600
+
601
+ d2 = self.side2(hx2d)
602
+ d2 = _upsample_like(d2, x)
603
+
604
+ d3 = self.side3(hx3d)
605
+ d3 = _upsample_like(d3, x)
606
+
607
+ d4 = self.side4(hx4d)
608
+ d4 = _upsample_like(d4, x)
609
+
610
+ d5 = self.side5(hx5d)
611
+ d5 = _upsample_like(d5, x)
612
+
613
+ d6 = self.side6(hx6)
614
+ d6 = _upsample_like(d6, x)
615
+
616
+ # d0 = self.outconv(torch.cat((d1,d2,d3,d4,d5,d6),1))
617
+
618
+ # return [torch.sigmoid(d1), torch.sigmoid(d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(d5), torch.sigmoid(d6)], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
619
+ return [d1, d2, d3, d4, d5, d6], [hx1d, hx2d, hx3d, hx4d, hx5d, hx6]
custom_controlnet_aux/anime_face_segment/network.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/network.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torchvision
6
+
7
+ from custom_controlnet_aux.util import custom_torch_download
8
+
9
+ class UNet(nn.Module):
10
+ def __init__(self):
11
+ super(UNet, self).__init__()
12
+ self.NUM_SEG_CLASSES = 7 # Background, hair, face, eye, mouth, skin, clothes
13
+
14
+ mobilenet_v2 = torchvision.models.mobilenet_v2(pretrained=False)
15
+ mobilenet_v2.load_state_dict(torch.load(custom_torch_download(filename="mobilenet_v2-b0353104.pth")), strict=True)
16
+ mob_blocks = mobilenet_v2.features
17
+
18
+ # Encoder
19
+ self.en_block0 = nn.Sequential( # in_ch=3 out_ch=16
20
+ mob_blocks[0],
21
+ mob_blocks[1]
22
+ )
23
+ self.en_block1 = nn.Sequential( # in_ch=16 out_ch=24
24
+ mob_blocks[2],
25
+ mob_blocks[3],
26
+ )
27
+ self.en_block2 = nn.Sequential( # in_ch=24 out_ch=32
28
+ mob_blocks[4],
29
+ mob_blocks[5],
30
+ mob_blocks[6],
31
+ )
32
+ self.en_block3 = nn.Sequential( # in_ch=32 out_ch=96
33
+ mob_blocks[7],
34
+ mob_blocks[8],
35
+ mob_blocks[9],
36
+ mob_blocks[10],
37
+ mob_blocks[11],
38
+ mob_blocks[12],
39
+ mob_blocks[13],
40
+ )
41
+ self.en_block4 = nn.Sequential( # in_ch=96 out_ch=160
42
+ mob_blocks[14],
43
+ mob_blocks[15],
44
+ mob_blocks[16],
45
+ )
46
+
47
+ # Decoder
48
+ self.de_block4 = nn.Sequential( # in_ch=160 out_ch=96
49
+ nn.UpsamplingNearest2d(scale_factor=2),
50
+ nn.Conv2d(160, 96, kernel_size=3, padding=1),
51
+ nn.InstanceNorm2d(96),
52
+ nn.LeakyReLU(0.1),
53
+ nn.Dropout(p=0.2)
54
+ )
55
+ self.de_block3 = nn.Sequential( # in_ch=96x2 out_ch=32
56
+ nn.UpsamplingNearest2d(scale_factor=2),
57
+ nn.Conv2d(96*2, 32, kernel_size=3, padding=1),
58
+ nn.InstanceNorm2d(32),
59
+ nn.LeakyReLU(0.1),
60
+ nn.Dropout(p=0.2)
61
+ )
62
+ self.de_block2 = nn.Sequential( # in_ch=32x2 out_ch=24
63
+ nn.UpsamplingNearest2d(scale_factor=2),
64
+ nn.Conv2d(32*2, 24, kernel_size=3, padding=1),
65
+ nn.InstanceNorm2d(24),
66
+ nn.LeakyReLU(0.1),
67
+ nn.Dropout(p=0.2)
68
+ )
69
+ self.de_block1 = nn.Sequential( # in_ch=24x2 out_ch=16
70
+ nn.UpsamplingNearest2d(scale_factor=2),
71
+ nn.Conv2d(24*2, 16, kernel_size=3, padding=1),
72
+ nn.InstanceNorm2d(16),
73
+ nn.LeakyReLU(0.1),
74
+ nn.Dropout(p=0.2)
75
+ )
76
+
77
+ self.de_block0 = nn.Sequential( # in_ch=16x2 out_ch=7
78
+ nn.UpsamplingNearest2d(scale_factor=2),
79
+ nn.Conv2d(16*2, self.NUM_SEG_CLASSES, kernel_size=3, padding=1),
80
+ nn.Softmax2d()
81
+ )
82
+
83
+ def forward(self, x):
84
+ e0 = self.en_block0(x)
85
+ e1 = self.en_block1(e0)
86
+ e2 = self.en_block2(e1)
87
+ e3 = self.en_block3(e2)
88
+ e4 = self.en_block4(e3)
89
+
90
+ d4 = self.de_block4(e4)
91
+ c4 = torch.cat((d4,e3),1)
92
+ d3 = self.de_block3(c4)
93
+ c3 = torch.cat((d3,e2),1)
94
+ d2 = self.de_block2(c3)
95
+ c2 =torch.cat((d2,e1),1)
96
+ d1 = self.de_block1(c2)
97
+ c1 = torch.cat((d1,e0),1)
98
+ y = self.de_block0(c1)
99
+
100
+ return y
custom_controlnet_aux/anime_face_segment/util.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/siyeong0/Anime-Face-Segmentation/blob/main/util.py
2
+ #The color palette is changed according to https://github.com/Mikubill/sd-webui-controlnet/blob/91f67ddcc7bc47537a6285864abfc12590f46c3f/annotator/anime_face_segment/__init__.py
3
+ import cv2 as cv
4
+ import glob
5
+ import numpy as np
6
+ import os
7
+
8
+ """
9
+ COLOR_BACKGROUND = (0,255,255)
10
+ COLOR_HAIR = (255,0,0)
11
+ COLOR_EYE = (0,0,255)
12
+ COLOR_MOUTH = (255,255,255)
13
+ COLOR_FACE = (0,255,0)
14
+ COLOR_SKIN = (255,255,0)
15
+ COLOR_CLOTHES = (255,0,255)
16
+ """
17
+ COLOR_BACKGROUND = (255,255,0)
18
+ COLOR_HAIR = (0,0,255)
19
+ COLOR_EYE = (255,0,0)
20
+ COLOR_MOUTH = (255,255,255)
21
+ COLOR_FACE = (0,255,0)
22
+ COLOR_SKIN = (0,255,255)
23
+ COLOR_CLOTHES = (255,0,255)
24
+ PALETTE = [COLOR_BACKGROUND,COLOR_HAIR,COLOR_EYE,COLOR_MOUTH,COLOR_FACE,COLOR_SKIN,COLOR_CLOTHES]
25
+
26
+ def img2seg(path):
27
+ src = cv.imread(path)
28
+ src = src.reshape(-1, 3)
29
+ seg_list = []
30
+ for color in PALETTE:
31
+ seg_list.append(np.where(np.all(src==color, axis=1), 1.0, 0.0))
32
+ dst = np.stack(seg_list,axis=1).reshape(512,512,7)
33
+
34
+ return dst.astype(np.float32)
35
+
36
+ def seg2img(src):
37
+ src = np.moveaxis(src,0,2)
38
+ dst = [[PALETTE[np.argmax(val)] for val in buf]for buf in src]
39
+
40
+ return np.array(dst).astype(np.uint8)
custom_controlnet_aux/binary/__init__.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from custom_controlnet_aux.util import HWC3, resize_image_with_pad
6
+
7
+ class BinaryDetector:
8
+ def __call__(self, input_image=None, bin_threshold=0, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
9
+ if "img" in kwargs:
10
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
11
+ input_image = kwargs.pop("img")
12
+
13
+ if input_image is None:
14
+ raise ValueError("input_image must be defined.")
15
+
16
+ if not isinstance(input_image, np.ndarray):
17
+ input_image = np.array(input_image, dtype=np.uint8)
18
+ output_type = output_type or "pil"
19
+ else:
20
+ output_type = output_type or "np"
21
+
22
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
23
+
24
+ img_gray = cv2.cvtColor(detected_map, cv2.COLOR_RGB2GRAY)
25
+ if bin_threshold == 0 or bin_threshold == 255:
26
+ # Otsu's threshold
27
+ otsu_threshold, img_bin = cv2.threshold(img_gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
28
+ print("Otsu threshold:", otsu_threshold)
29
+ else:
30
+ _, img_bin = cv2.threshold(img_gray, bin_threshold, 255, cv2.THRESH_BINARY_INV)
31
+
32
+ detected_map = cv2.cvtColor(img_bin, cv2.COLOR_GRAY2RGB)
33
+ detected_map = HWC3(remove_pad(255 - detected_map))
34
+
35
+ if output_type == "pil":
36
+ detected_map = Image.fromarray(detected_map)
37
+
38
+ return detected_map
custom_controlnet_aux/canny/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import cv2
3
+ import numpy as np
4
+ from PIL import Image
5
+ from custom_controlnet_aux.util import resize_image_with_pad, common_input_validate, HWC3
6
+
7
+ class CannyDetector:
8
+ def __call__(self, input_image=None, low_threshold=100, high_threshold=200, detect_resolution=512, output_type=None, upscale_method="INTER_CUBIC", **kwargs):
9
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
10
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
11
+ detected_map = cv2.Canny(detected_map, low_threshold, high_threshold)
12
+ detected_map = HWC3(remove_pad(detected_map))
13
+
14
+ if output_type == "pil":
15
+ detected_map = Image.fromarray(detected_map)
16
+
17
+ return detected_map
custom_controlnet_aux/color/__init__.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import warnings
3
+ import cv2
4
+ import numpy as np
5
+ from PIL import Image
6
+ from custom_controlnet_aux.util import HWC3, safer_memory, common_input_validate
7
+
8
+ def cv2_resize_shortest_edge(image, size):
9
+ h, w = image.shape[:2]
10
+ if h < w:
11
+ new_h = size
12
+ new_w = int(round(w / h * size))
13
+ else:
14
+ new_w = size
15
+ new_h = int(round(h / w * size))
16
+ resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
17
+ return resized_image
18
+
19
+ def apply_color(img, res=512):
20
+ img = cv2_resize_shortest_edge(img, res)
21
+ h, w = img.shape[:2]
22
+
23
+ input_img_color = cv2.resize(img, (w//64, h//64), interpolation=cv2.INTER_CUBIC)
24
+ input_img_color = cv2.resize(input_img_color, (w, h), interpolation=cv2.INTER_NEAREST)
25
+ return input_img_color
26
+
27
+ #Color T2I like multiples-of-64, upscale methods are fixed.
28
+ class ColorDetector:
29
+ def __call__(self, input_image=None, detect_resolution=512, output_type=None, **kwargs):
30
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
31
+ input_image = HWC3(input_image)
32
+ detected_map = HWC3(apply_color(input_image, detect_resolution))
33
+
34
+ if output_type == "pil":
35
+ detected_map = Image.fromarray(detected_map)
36
+
37
+ return detected_map