import numpy as np import scipy.ndimage import torch from PIL import ImageFilter from scipy.ndimage import binary_closing, binary_fill_holes from .image_convert import pil2tensor, tensor2pil def combine_mask(destination, source, x, y): output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() source = source.reshape((-1, source.shape[-2], source.shape[-1])) left, top = ( x, y, ) right, bottom = ( min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2]), ) visible_width, visible_height = ( right - left, bottom - top, ) source_portion = source[:, :visible_height, :visible_width] destination_portion = destination[:, top:bottom, left:right] # operation == "subtract": output[:, top:bottom, left:right] = destination_portion - source_portion output = torch.clamp(output, 0.0, 1.0) return output def grow_mask(mask, expand, tapered_corners): if expand == 0: return mask device = mask.device mask = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) # 添加通道维度 # 创建卷积核 c = 0.0 if tapered_corners else 1.0 kernel = torch.tensor([[c, 1.0, c], [1.0, 1.0, 1.0], [c, 1.0, c]], device=device).unsqueeze(0).unsqueeze(0) # 计算填充 padding = abs(expand) if expand > 0: # 膨胀操作 mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=0) mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=expand) else: # 腐蚀操作 mask = 1 - mask mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=1) mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=-expand) mask = 1 - mask # 移除额外的padding if padding > 0: mask = mask[:, :, padding:-padding, padding:-padding] # 将结果转回原始形状 output = mask.squeeze(1) return torch.clamp(output, 0.0, 1.0) def fill_holes(mask): holemask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu() out = [] for m in holemask: mask_np = m.numpy() binary_mask = mask_np > 0 struct = np.ones((5, 5)) closed_mask = binary_closing(binary_mask, structure=struct, border_value=1) filled_mask = binary_fill_holes(closed_mask) output = filled_mask.astype(np.float32) * 255 # type: ignore output = torch.from_numpy(output) out.append(output) mask = torch.stack(out, dim=0) mask = torch.clamp(mask, 0.0, 1.0) return mask def invert_mask(mask): return 1.0 - mask def expand_mask(mask, expand, tapered_corners): c = 0 if tapered_corners else 1 kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]]) device = mask.device mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu() out = [] for m in mask: output = m.numpy() for _ in range(abs(expand)): if expand < 0: output = scipy.ndimage.grey_erosion(output, footprint=kernel) else: output = scipy.ndimage.grey_dilation(output, footprint=kernel) output = torch.from_numpy(output) out.append(output) return torch.stack(out, dim=0).to(device) def blur_mask(mask, radius): pil_image = tensor2pil(mask) return pil2tensor(pil_image.filter(ImageFilter.GaussianBlur(radius))) def solid_mask(width, height, value=1): return torch.full((1, height, width), value, dtype=torch.float32, device='cpu') def mask_floor(mask, threshold: float = 0.99): # 将遮罩二值化,大于等于阈值的设为1,小于阈值的设为0 return (mask >= threshold).to(mask.dtype) def mask_unsqueeze(mask): # 调整遮罩的维度,确保输出的遮罩形状为 B1HW if len(mask.shape) == 3: # BHW -> B1HW mask = mask.unsqueeze(1) elif len(mask.shape) == 2: # HW -> B1HW mask = mask.unsqueeze(0).unsqueeze(0) return mask