File size: 4,278 Bytes
f2dbf59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import scipy.ndimage
import torch
from PIL import ImageFilter
from scipy.ndimage import binary_closing, binary_fill_holes

from .image_convert import pil2tensor, tensor2pil


def combine_mask(destination, source, x, y):
    output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone()
    source = source.reshape((-1, source.shape[-2], source.shape[-1]))

    left, top = (
        x,
        y,
    )
    right, bottom = (
        min(left + source.shape[-1], destination.shape[-1]),
        min(top + source.shape[-2], destination.shape[-2]),
    )
    visible_width, visible_height = (
        right - left,
        bottom - top,
    )

    source_portion = source[:, :visible_height, :visible_width]
    destination_portion = destination[:, top:bottom, left:right]

    # operation == "subtract":
    output[:, top:bottom, left:right] = destination_portion - source_portion

    output = torch.clamp(output, 0.0, 1.0)

    return output


def grow_mask(mask, expand, tapered_corners):
    if expand == 0:
        return mask

    device = mask.device
    mask = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1]))  # 添加通道维度

    # 创建卷积核
    c = 0.0 if tapered_corners else 1.0
    kernel = torch.tensor([[c, 1.0, c], [1.0, 1.0, 1.0], [c, 1.0, c]], device=device).unsqueeze(0).unsqueeze(0)

    # 计算填充
    padding = abs(expand)

    if expand > 0:
        # 膨胀操作
        mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=0)
        mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=expand)
    else:
        # 腐蚀操作
        mask = 1 - mask
        mask = torch.nn.functional.pad(mask, (padding, padding, padding, padding), mode='constant', value=1)
        mask = torch.nn.functional.conv2d(mask, kernel, padding=1, dilation=-expand)
        mask = 1 - mask

    # 移除额外的padding
    if padding > 0:
        mask = mask[:, :, padding:-padding, padding:-padding]

    # 将结果转回原始形状
    output = mask.squeeze(1)
    return torch.clamp(output, 0.0, 1.0)


def fill_holes(mask):
    holemask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
    out = []
    for m in holemask:
        mask_np = m.numpy()
        binary_mask = mask_np > 0
        struct = np.ones((5, 5))
        closed_mask = binary_closing(binary_mask, structure=struct, border_value=1)
        filled_mask = binary_fill_holes(closed_mask)
        output = filled_mask.astype(np.float32) * 255  # type: ignore
        output = torch.from_numpy(output)
        out.append(output)
    mask = torch.stack(out, dim=0)
    mask = torch.clamp(mask, 0.0, 1.0)
    return mask


def invert_mask(mask):
    return 1.0 - mask


def expand_mask(mask, expand, tapered_corners):
    c = 0 if tapered_corners else 1
    kernel = np.array([[c, 1, c], [1, 1, 1], [c, 1, c]])
    device = mask.device
    mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).cpu()
    out = []
    for m in mask:
        output = m.numpy()
        for _ in range(abs(expand)):
            if expand < 0:
                output = scipy.ndimage.grey_erosion(output, footprint=kernel)
            else:
                output = scipy.ndimage.grey_dilation(output, footprint=kernel)
        output = torch.from_numpy(output)
        out.append(output)

    return torch.stack(out, dim=0).to(device)


def blur_mask(mask, radius):
    pil_image = tensor2pil(mask)
    return pil2tensor(pil_image.filter(ImageFilter.GaussianBlur(radius)))


def solid_mask(width, height, value=1):
    return torch.full((1, height, width), value, dtype=torch.float32, device='cpu')


def mask_floor(mask, threshold: float = 0.99):
    # 将遮罩二值化,大于等于阈值的设为1,小于阈值的设为0
    return (mask >= threshold).to(mask.dtype)


def mask_unsqueeze(mask):
    # 调整遮罩的维度,确保输出的遮罩形状为 B1HW
    if len(mask.shape) == 3:  # BHW -> B1HW
        mask = mask.unsqueeze(1)
    elif len(mask.shape) == 2:  # HW -> B1HW
        mask = mask.unsqueeze(0).unsqueeze(0)
    return mask