Dreamspire's picture
custom_nodes
f2dbf59
import numpy as np
import scipy.ndimage
import torch
from PIL import ImageFilter
from scipy.ndimage import binary_closing, binary_fill_holes
from .image_convert import pil2tensor, tensor2pil
def combine_mask(destination, source, x, y):
output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
source = source.reshape((-1, source.shape[-2], source.shape[-1]))
left, top = (
x,
y,
)
right, bottom = (
min(left + source.shape[-1], destination.shape[-1]),
min(top + source.shape[-2], destination.shape[-2]),
)
visible_width, visible_height = (
right - left,
bottom - top,
)
source_portion = source[:, :visible_height, :visible_width]
destination_portion = destination[:, top:bottom, left:right]
# operation == "subtract":
output[:, top:bottom, left:right] = destination_portion - source_portion
output = torch.clamp(output, 0.0, 1.0)
return output
def grow_mask(mask, expand, tapered_corners):
if expand == 0:
return mask
device = mask.device
mask = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])) # 添加通道维度
# 创建卷积核
c = 0.0 if tapered_corners else 1.0
kernel = torch.tensor([[c, 1.0, c], [1.0, 1.0, 1.0], [c, 1.0, c]], device=device).unsqueeze(0).unsqueeze(0)
# 计算填充
padding = abs(expand)
if expand > 0:
# 膨胀操作
mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=0)
mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=expand)
else:
# 腐蚀操作
mask = 1 - mask
mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=1)
mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=-expand)
mask = 1 - mask
# 移除额外的padding
if padding > 0:
mask = mask[:, :, padding:-padding, padding:-padding]
# 将结果转回原始形状
output = mask.squeeze(1)
return torch.clamp(output, 0.0, 1.0)
def fill_holes(mask):
holemask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
out = []
for m in holemask:
mask_np = m.numpy()
binary_mask = mask_np > 0
struct = np.ones((5, 5))
closed_mask = binary_closing(binary_mask, structure=struct, border_value=1)
filled_mask = binary_fill_holes(closed_mask)
output = filled_mask.astype(np.float32) * 255 # type: ignore
output = torch.from_numpy(output)
out.append(output)
mask = torch.stack(out, dim=0)
mask = torch.clamp(mask, 0.0, 1.0)
return mask
def invert_mask(mask):
return 1.0 - mask
def expand_mask(mask, expand, tapered_corners):
c = 0 if tapered_corners else 1
kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]])
device = mask.device
mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
out = []
for m in mask:
output = m.numpy()
for _ in range(abs(expand)):
if expand < 0:
output = scipy.ndimage.grey_erosion(output, footprint=kernel)
else:
output = scipy.ndimage.grey_dilation(output, footprint=kernel)
output = torch.from_numpy(output)
out.append(output)
return torch.stack(out, dim=0).to(device)
def blur_mask(mask, radius):
pil_image = tensor2pil(mask)
return pil2tensor(pil_image.filter(ImageFilter.GaussianBlur(radius)))
def solid_mask(width, height, value=1):
return torch.full((1, height, width), value, dtype=torch.float32, device='cpu')
def mask_floor(mask, threshold: float = 0.99):
# 将遮罩二值化,大于等于阈值的设为1,小于阈值的设为0
return (mask >= threshold).to(mask.dtype)
def mask_unsqueeze(mask):
# 调整遮罩的维度,确保输出的遮罩形状为 B1HW
if len(mask.shape) == 3: # BHW -> B1HW
mask = mask.unsqueeze(1)
elif len(mask.shape) == 2: # HW -> B1HW
mask = mask.unsqueeze(0).unsqueeze(0)
return mask