Spaces:
Configuration error
Configuration error
| import random | |
| from typing import Any, Callable, Literal, Sequence, Tuple | |
| import cv2 | |
| import numpy as np | |
| from custom_qudida import DomainAdapter | |
| from skimage.exposure import match_histograms | |
| from sklearn.decomposition import PCA | |
| from sklearn.preprocessing import MinMaxScaler, StandardScaler | |
| from custom_albumentations.augmentations.utils import ( | |
| clipped, | |
| get_opencv_dtype_from_numpy, | |
| is_grayscale_image, | |
| is_multispectral_image, | |
| preserve_shape, | |
| read_rgb_image, | |
| ) | |
| from ..core.transforms_interface import ImageOnlyTransform, ScaleFloatType, to_tuple | |
| __all__ = [ | |
| "HistogramMatching", | |
| "FDA", | |
| "PixelDistributionAdaptation", | |
| "fourier_domain_adaptation", | |
| "apply_histogram", | |
| "adapt_pixel_distribution", | |
| ] | |
| def fourier_domain_adaptation(img: np.ndarray, target_img: np.ndarray, beta: float) -> np.ndarray: | |
| """ | |
| Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA | |
| Args: | |
| img: source image | |
| target_img: target image for domain adaptation | |
| beta: coefficient from source paper | |
| Returns: | |
| transformed image | |
| """ | |
| img = np.squeeze(img) | |
| target_img = np.squeeze(target_img) | |
| if target_img.shape != img.shape: | |
| raise ValueError( | |
| "The source and target images must have the same shape," | |
| " but got {} and {} respectively.".format(img.shape, target_img.shape) | |
| ) | |
| # get fft of both source and target | |
| fft_src = np.fft.fft2(img.astype(np.float32), axes=(0, 1)) | |
| fft_trg = np.fft.fft2(target_img.astype(np.float32), axes=(0, 1)) | |
| # extract amplitude and phase of both fft-s | |
| amplitude_src, phase_src = np.abs(fft_src), np.angle(fft_src) | |
| amplitude_trg = np.abs(fft_trg) | |
| # mutate the amplitude part of source with target | |
| amplitude_src = np.fft.fftshift(amplitude_src, axes=(0, 1)) | |
| amplitude_trg = np.fft.fftshift(amplitude_trg, axes=(0, 1)) | |
| height, width = amplitude_src.shape[:2] | |
| border = np.floor(min(height, width) * beta).astype(int) | |
| center_y, center_x = np.floor([height / 2.0, width / 2.0]).astype(int) | |
| y1, y2 = center_y - border, center_y + border + 1 | |
| x1, x2 = center_x - border, center_x + border + 1 | |
| amplitude_src[y1:y2, x1:x2] = amplitude_trg[y1:y2, x1:x2] | |
| amplitude_src = np.fft.ifftshift(amplitude_src, axes=(0, 1)) | |
| # get mutated image | |
| src_image_transformed = np.fft.ifft2(amplitude_src * np.exp(1j * phase_src), axes=(0, 1)) | |
| src_image_transformed = np.real(src_image_transformed) | |
| return src_image_transformed | |
| def apply_histogram(img: np.ndarray, reference_image: np.ndarray, blend_ratio: float) -> np.ndarray: | |
| if img.dtype != reference_image.dtype: | |
| raise RuntimeError( | |
| f"Dtype of image and reference image must be the same. Got {img.dtype} and {reference_image.dtype}" | |
| ) | |
| if img.shape[:2] != reference_image.shape[:2]: | |
| reference_image = cv2.resize(reference_image, dsize=(img.shape[1], img.shape[0])) | |
| img, reference_image = np.squeeze(img), np.squeeze(reference_image) | |
| try: | |
| matched = match_histograms(img, reference_image, channel_axis=2 if len(img.shape) == 3 else None) | |
| except TypeError: | |
| matched = match_histograms(img, reference_image, multichannel=True) # case for scikit-image<0.19.1 | |
| img = cv2.addWeighted( | |
| matched, | |
| blend_ratio, | |
| img, | |
| 1 - blend_ratio, | |
| 0, | |
| dtype=get_opencv_dtype_from_numpy(img.dtype), | |
| ) | |
| return img | |
| def adapt_pixel_distribution( | |
| img: np.ndarray, ref: np.ndarray, transform_type: str = "pca", weight: float = 0.5 | |
| ) -> np.ndarray: | |
| initial_type = img.dtype | |
| transformer = {"pca": PCA, "standard": StandardScaler, "minmax": MinMaxScaler}[transform_type]() | |
| adapter = DomainAdapter(transformer=transformer, ref_img=ref) | |
| result = adapter(img).astype("float32") | |
| blended = (img.astype("float32") * (1 - weight) + result * weight).astype(initial_type) | |
| return blended | |
| class HistogramMatching(ImageOnlyTransform): | |
| """ | |
| Apply histogram matching. It manipulates the pixels of an input image so that its histogram matches | |
| the histogram of the reference image. If the images have multiple channels, the matching is done independently | |
| for each channel, as long as the number of channels is equal in the input image and the reference. | |
| Histogram matching can be used as a lightweight normalisation for image processing, | |
| such as feature matching, especially in circumstances where the images have been taken from different | |
| sources or in different conditions (i.e. lighting). | |
| See: | |
| https://scikit-image.org/docs/dev/auto_examples/color_exposure/plot_histogram_matching.html | |
| Args: | |
| reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, | |
| it expects a sequence of paths to images. | |
| blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original | |
| with random blend factor for increased diversity of generated images. | |
| read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` | |
| and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. | |
| p (float): probability of applying the transform. Default: 1.0. | |
| Targets: | |
| image | |
| Image types: | |
| uint8, uint16, float32 | |
| """ | |
| def __init__( | |
| self, | |
| reference_images: Sequence[Any], | |
| blend_ratio: Tuple[float, float] = (0.5, 1.0), | |
| read_fn: Callable[[Any], np.ndarray] = read_rgb_image, | |
| always_apply: bool = False, | |
| p: float = 0.5, | |
| ): | |
| super().__init__(always_apply=always_apply, p=p) | |
| self.reference_images = reference_images | |
| self.read_fn = read_fn | |
| self.blend_ratio = blend_ratio | |
| def apply(self, img, reference_image=None, blend_ratio=0.5, **params): | |
| return apply_histogram(img, reference_image, blend_ratio) | |
| def get_params(self): | |
| return { | |
| "reference_image": self.read_fn(random.choice(self.reference_images)), | |
| "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]), | |
| } | |
| def get_transform_init_args_names(self): | |
| return ("reference_images", "blend_ratio", "read_fn") | |
| def _to_dict(self): | |
| raise NotImplementedError("HistogramMatching can not be serialized.") | |
| class FDA(ImageOnlyTransform): | |
| """ | |
| Fourier Domain Adaptation from https://github.com/YanchaoYang/FDA | |
| Simple "style transfer". | |
| Args: | |
| reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, | |
| it expects a sequence of paths to images. | |
| beta_limit (float or tuple of float): coefficient beta from paper. Recommended less 0.3. | |
| read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` | |
| and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. | |
| Targets: | |
| image | |
| Image types: | |
| uint8, float32 | |
| Reference: | |
| https://github.com/YanchaoYang/FDA | |
| https://openaccess.thecvf.com/content_CVPR_2020/papers/Yang_FDA_Fourier_Domain_Adaptation_for_Semantic_Segmentation_CVPR_2020_paper.pdf | |
| Example: | |
| >>> import numpy as np | |
| >>> import custom_albumentations as albumentations as A | |
| >>> image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8) | |
| >>> target_image = np.random.randint(0, 256, [100, 100, 3], dtype=np.uint8) | |
| >>> aug = A.Compose([A.FDA([target_image], p=1, read_fn=lambda x: x)]) | |
| >>> result = aug(image=image) | |
| """ | |
| def __init__( | |
| self, | |
| reference_images: Sequence[Any], | |
| beta_limit: ScaleFloatType = 0.1, | |
| read_fn: Callable[[Any], np.ndarray] = read_rgb_image, | |
| always_apply: bool = False, | |
| p: float = 0.5, | |
| ): | |
| super(FDA, self).__init__(always_apply=always_apply, p=p) | |
| self.reference_images = reference_images | |
| self.read_fn = read_fn | |
| self.beta_limit = to_tuple(beta_limit, low=0) | |
| def apply(self, img, target_image=None, beta=0.1, **params): | |
| return fourier_domain_adaptation(img=img, target_img=target_image, beta=beta) | |
| def get_params_dependent_on_targets(self, params): | |
| img = params["image"] | |
| target_img = self.read_fn(random.choice(self.reference_images)) | |
| target_img = cv2.resize(target_img, dsize=(img.shape[1], img.shape[0])) | |
| return {"target_image": target_img} | |
| def get_params(self): | |
| return {"beta": random.uniform(self.beta_limit[0], self.beta_limit[1])} | |
| def targets_as_params(self): | |
| return ["image"] | |
| def get_transform_init_args_names(self): | |
| return ("reference_images", "beta_limit", "read_fn") | |
| def _to_dict(self): | |
| raise NotImplementedError("FDA can not be serialized.") | |
| class PixelDistributionAdaptation(ImageOnlyTransform): | |
| """ | |
| Another naive and quick pixel-level domain adaptation. It fits a simple transform (such as PCA, StandardScaler | |
| or MinMaxScaler) on both original and reference image, transforms original image with transform trained on this | |
| image and then performs inverse transformation using transform fitted on reference image. | |
| Args: | |
| reference_images (Sequence[Any]): Sequence of objects that will be converted to images by `read_fn`. By default, | |
| it expects a sequence of paths to images. | |
| blend_ratio (float, float): Tuple of min and max blend ratio. Matched image will be blended with original | |
| with random blend factor for increased diversity of generated images. | |
| read_fn (Callable): Used-defined function to read image. Function should get an element of `reference_images` | |
| and return numpy array of image pixels. Default: takes as input a path to an image and returns a numpy array. | |
| transform_type (str): type of transform; "pca", "standard", "minmax" are allowed. | |
| p (float): probability of applying the transform. Default: 1.0. | |
| Targets: | |
| image | |
| Image types: | |
| uint8, float32 | |
| See also: https://github.com/arsenyinfo/qudida | |
| """ | |
| def __init__( | |
| self, | |
| reference_images: Sequence[Any], | |
| blend_ratio: Tuple[float, float] = (0.25, 1.0), | |
| read_fn: Callable[[Any], np.ndarray] = read_rgb_image, | |
| transform_type: Literal["pca", "standard", "minmax"] = "pca", | |
| always_apply: bool = False, | |
| p: float = 0.5, | |
| ): | |
| super().__init__(always_apply=always_apply, p=p) | |
| self.reference_images = reference_images | |
| self.read_fn = read_fn | |
| self.blend_ratio = blend_ratio | |
| expected_transformers = ("pca", "standard", "minmax") | |
| if transform_type not in expected_transformers: | |
| raise ValueError(f"Got unexpected transform_type {transform_type}. Expected one of {expected_transformers}") | |
| self.transform_type = transform_type | |
| def _validate_shape(img: np.ndarray): | |
| if is_grayscale_image(img) or is_multispectral_image(img): | |
| raise ValueError( | |
| f"Unexpected image shape: expected 3 dimensions, got {len(img.shape)}." | |
| f"Is it a grayscale or multispectral image? It's not supported for now." | |
| ) | |
| def ensure_uint8(self, img: np.ndarray) -> Tuple[np.ndarray, bool]: | |
| if img.dtype == np.float32: | |
| if img.min() < 0 or img.max() > 1: | |
| message = ( | |
| "PixelDistributionAdaptation uses uint8 under the hood, so float32 should be converted," | |
| "Can not do it automatically when the image is out of [0..1] range." | |
| ) | |
| raise TypeError(message) | |
| return (img * 255).astype("uint8"), True | |
| return img, False | |
| def apply(self, img, reference_image, blend_ratio, **params): | |
| self._validate_shape(img) | |
| reference_image, _ = self.ensure_uint8(reference_image) | |
| img, needs_reconvert = self.ensure_uint8(img) | |
| adapted = adapt_pixel_distribution( | |
| img=img, | |
| ref=reference_image, | |
| weight=blend_ratio, | |
| transform_type=self.transform_type, | |
| ) | |
| if needs_reconvert: | |
| adapted = adapted.astype("float32") * (1 / 255) | |
| return adapted | |
| def get_params(self): | |
| return { | |
| "reference_image": self.read_fn(random.choice(self.reference_images)), | |
| "blend_ratio": random.uniform(self.blend_ratio[0], self.blend_ratio[1]), | |
| } | |
| def get_transform_init_args_names(self): | |
| return ("reference_images", "blend_ratio", "read_fn", "transform_type") | |
| def _to_dict(self): | |
| raise NotImplementedError("PixelDistributionAdaptation can not be serialized.") | |