import torch import numpy as np from munch import munchify from scipy.ndimage import distance_transform_edt from src.flair.functions.degradation import get_degradation import torchvision class BaseDegradation(torch.nn.Module): def __init__(self, noise_std=0.0): super().__init__() self.noise_std = noise_std def forward(self, x): x = x + self.noise_std * torch.randn_like(x) return x def pseudo_inv(self, y): return y def zero_filler(x, scale): B, C, H, W = x.shape scale = int(scale) H_new, W_new = H * scale, W * scale out = torch.zeros(B, C, H_new, W_new, dtype=x.dtype, device=x.device) out[:, :, ::scale, ::scale] = x return out class SuperRes(BaseDegradation): def __init__(self, scale, noise_std=0.0, img_size=256): super().__init__(noise_std=noise_std) self.scale = scale deg_config = munchify({ 'channels': 3, 'image_size': img_size, 'deg_scale': scale }) self.img_size = img_size self.deg = get_degradation("sr_bicubic", deg_config, device="cuda") def forward(self, x, noise=True): dtype = x.dtype y = self.deg.A(x.float()) # add noise if noise: y = super().forward(y) return y.to(dtype) def pseudo_inv(self, y): x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size)* self.scale**2 return x.to(y.dtype) def nn(self, y): x = torch.nn.functional.interpolate( y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest" ) return x.to(y.dtype) class SuperResGradio(BaseDegradation): def __init__(self, scale, noise_std=0.0, img_size=256): super().__init__(noise_std=noise_std) self.scale = scale self.downscaler = lambda x: torch.nn.functional.interpolate( x.float(), scale_factor=1/self.scale, mode="bilinear", align_corners=False, antialias=True ) self.upscaler = lambda x: torch.nn.functional.interpolate( x.float(), scale_factor=self.scale, mode="bilinear", align_corners=False, antialias=True ) self.img_size = img_size def forward(self, x, noise=True): dtype = x.dtype y = self.downscaler(x.float()) # add noise if noise: y = super().forward(y) return y.to(dtype) def pseudo_inv(self, y): x = self.upscaler(y.float()) return x.to(y.dtype) def nn(self, y): x = torch.nn.functional.interpolate( y.reshape(1,3,self.img_size//self.scale, self.img_size//self.scale), scale_factor=self.scale, mode="nearest-exact" ) return x.to(y.dtype) class Inpainting(BaseDegradation): def __init__(self, mask, H, W, noise_std=0.0): """ mask: torch.Tensor, shape (H, W), dtype bool function assumes 3 channels """ super().__init__(noise_std=noise_std) if isinstance(mask, list): # generate box from left, right, lower upper list # observed region is True mask_ = torch.ones(H, W, dtype=torch.bool) mask_[slice(*mask[0:2]), slice(*mask[2:])] = False # repeat for 3 channels mask_ = mask_.repeat(3, 1, 1) elif isinstance(mask, str): # load mask file mask_ = torch.tensor(np.load(mask), dtype=torch.bool) mask_ = mask_.repeat(3, 1, 1) elif isinstance(mask, torch.Tensor): if mask.ndim == 2: # assume mask is for one channel, repeat for 3 channels mask_ = mask[None].repeat(3, 1, 1) elif mask.ndim == 3 and mask.shape[0] == 1: # assume mask is for one channel, repeat for 3 channels mask_ = mask.repeat(3, 1, 1) else: mask_ = mask else: raise ValueError("Mask must be a list, string (file path), or torch.Tensor.") self.mask = mask_ self.H, self.W = H, W def forward(self, x, noise=True): B = x.shape[0] y = x[self.mask[None]].view(B, -1) # add noise if noise: y = super().forward(y) return y def pseudo_inv(self, y): x = torch.zeros(y.shape[0], 3 * self.H * self.W, dtype=y.dtype, device=y.device) x[:, self.mask.view(-1)] = y x = x.view(y.shape[0], 3, self.H, self.W) # x = inpaint_nearest(x[0], self.mask[0])[None] return x def inpaint_nearest(image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """ Fill missing pixels in an image using the nearest observed pixel value. Args: image: A tensor of shape [C, H, W] representing the image. mask: A tensor of shape [H, W] with 1 for observed pixels and 0 for missing pixels. Returns: A tensor of shape [C, H, W] where missing pixels have been filled. """ # Move tensors to CPU and convert to numpy arrays. image_np = image.cpu().float().numpy() # Convert mask to boolean: True for observed, False for missing. mask_np = mask.cpu().numpy().astype(bool) # Compute the distance transform of the inverse mask (~mask_np). # The function returns: # - distances: distance to the nearest True pixel in mask_np # - indices: the indices of that nearest True pixel for each pixel. # indices has shape (2, H, W): first row is the row index, second row is the column index. _, indices = distance_transform_edt(~mask_np, return_indices=True) # Create a copy of the image to hold the filled values. filled_image_np = np.empty_like(image_np) # For each channel, replace every pixel with the value of the nearest observed pixel. for c in range(image_np.shape[0]): filled_image_np[c] = image_np[c, indices[0], indices[1]] # Convert back to a torch tensor and send to the original device. return torch.from_numpy(filled_image_np).to(image.device).to(image.dtype) class MotionBlur(BaseDegradation): def __init__(self, kernel_size=5, img_size=256, noise_std=0.0): super().__init__(noise_std=noise_std) deg_config = munchify({ 'channels': 3, 'image_size': img_size, 'deg_scale': kernel_size }) self.img_size = img_size self.deg = get_degradation("deblur_motion", deg_config, device="cuda") def forward(self, x, noise=True): dtype = x.dtype y = self.deg.A(x.float()) # add noise if noise: y = super().forward(y) return y.to(dtype) def pseudo_inv(self, y): dtype = y.dtype x = self.deg.At(y.float()).reshape(1,3,self.img_size, self.img_size) return x.to(dtype)