Shivdutta commited on
Commit
d8c0b8f
·
verified ·
1 Parent(s): 7a64e24

Delete utils_tools_gradio.py

Browse files
Files changed (1) hide show
  1. utils_tools_gradio.py +0 -175
utils_tools_gradio.py DELETED
@@ -1,175 +0,0 @@
1
- import numpy as np
2
- from PIL import Image
3
- import matplotlib.pyplot as plt
4
- import cv2
5
- import torch
6
-
7
-
8
- def fast_process(
9
- annotations,
10
- image,
11
- device,
12
- scale,
13
- better_quality=False,
14
- mask_random_color=True,
15
- bbox=None,
16
- use_retina=True,
17
- withContours=True,
18
- ):
19
- if isinstance(annotations[0], dict):
20
- annotations = [annotation['segmentation'] for annotation in annotations]
21
-
22
- original_h = image.height
23
- original_w = image.width
24
- if better_quality:
25
- if isinstance(annotations[0], torch.Tensor):
26
- annotations = np.array(annotations.cpu())
27
- for i, mask in enumerate(annotations):
28
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
29
- annotations[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
30
- if device == 'cpu':
31
- annotations = np.array(annotations)
32
- inner_mask = fast_show_mask(
33
- annotations,
34
- plt.gca(),
35
- random_color=mask_random_color,
36
- bbox=bbox,
37
- retinamask=use_retina,
38
- target_height=original_h,
39
- target_width=original_w,
40
- )
41
- else:
42
- if isinstance(annotations[0], np.ndarray):
43
- annotations = torch.from_numpy(annotations)
44
- inner_mask = fast_show_mask_gpu(
45
- annotations,
46
- plt.gca(),
47
- random_color=mask_random_color,
48
- bbox=bbox,
49
- retinamask=use_retina,
50
- target_height=original_h,
51
- target_width=original_w,
52
- )
53
- if isinstance(annotations, torch.Tensor):
54
- annotations = annotations.cpu().numpy()
55
-
56
- if withContours:
57
- contour_all = []
58
- temp = np.zeros((original_h, original_w, 1))
59
- for i, mask in enumerate(annotations):
60
- if type(mask) == dict:
61
- mask = mask['segmentation']
62
- annotation = mask.astype(np.uint8)
63
- if use_retina == False:
64
- annotation = cv2.resize(
65
- annotation,
66
- (original_w, original_h),
67
- interpolation=cv2.INTER_NEAREST,
68
- )
69
- contours, _ = cv2.findContours(annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
70
- for contour in contours:
71
- contour_all.append(contour)
72
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
73
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
74
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
75
-
76
- image = image.convert('RGBA')
77
- overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), 'RGBA')
78
- image.paste(overlay_inner, (0, 0), overlay_inner)
79
-
80
- if withContours:
81
- overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), 'RGBA')
82
- image.paste(overlay_contour, (0, 0), overlay_contour)
83
-
84
- return image
85
-
86
-
87
- # CPU post process
88
- def fast_show_mask(
89
- annotation,
90
- ax,
91
- random_color=False,
92
- bbox=None,
93
- retinamask=True,
94
- target_height=960,
95
- target_width=960,
96
- ):
97
- mask_sum = annotation.shape[0]
98
- height = annotation.shape[1]
99
- weight = annotation.shape[2]
100
- # 将annotation 按照面积 排序
101
- areas = np.sum(annotation, axis=(1, 2))
102
- sorted_indices = np.argsort(areas)[::1]
103
- annotation = annotation[sorted_indices]
104
-
105
- index = (annotation != 0).argmax(axis=0)
106
- if random_color:
107
- color = np.random.random((mask_sum, 1, 1, 3))
108
- else:
109
- color = np.ones((mask_sum, 1, 1, 3)) * np.array([30 / 255, 144 / 255, 255 / 255])
110
- transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
111
- visual = np.concatenate([color, transparency], axis=-1)
112
- mask_image = np.expand_dims(annotation, -1) * visual
113
-
114
- mask = np.zeros((height, weight, 4))
115
-
116
- h_indices, w_indices = np.meshgrid(np.arange(height), np.arange(weight), indexing='ij')
117
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
118
-
119
- mask[h_indices, w_indices, :] = mask_image[indices]
120
- if bbox is not None:
121
- x1, y1, x2, y2 = bbox
122
- ax.add_patch(plt.Rectangle((x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor='b', linewidth=1))
123
-
124
- if not retinamask:
125
- mask = cv2.resize(mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST)
126
-
127
- return mask
128
-
129
-
130
- def fast_show_mask_gpu(
131
- annotation,
132
- ax,
133
- random_color=False,
134
- bbox=None,
135
- retinamask=True,
136
- target_height=960,
137
- target_width=960,
138
- ):
139
- device = annotation.device
140
- mask_sum = annotation.shape[0]
141
- height = annotation.shape[1]
142
- weight = annotation.shape[2]
143
- areas = torch.sum(annotation, dim=(1, 2))
144
- sorted_indices = torch.argsort(areas, descending=False)
145
- annotation = annotation[sorted_indices]
146
- # 找每个位置第一个非零值下标
147
- index = (annotation != 0).to(torch.long).argmax(dim=0)
148
- if random_color:
149
- color = torch.rand((mask_sum, 1, 1, 3)).to(device)
150
- else:
151
- color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
152
- [30 / 255, 144 / 255, 255 / 255]
153
- ).to(device)
154
- transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
155
- visual = torch.cat([color, transparency], dim=-1)
156
- mask_image = torch.unsqueeze(annotation, -1) * visual
157
- # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
158
- mask = torch.zeros((height, weight, 4)).to(device)
159
- h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
160
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
161
- # 使用向量化索引更新show的值
162
- mask[h_indices, w_indices, :] = mask_image[indices]
163
- mask_cpu = mask.cpu().numpy()
164
- if bbox is not None:
165
- x1, y1, x2, y2 = bbox
166
- ax.add_patch(
167
- plt.Rectangle(
168
- (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
169
- )
170
- )
171
- if not retinamask:
172
- mask_cpu = cv2.resize(
173
- mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
174
- )
175
- return mask_cpu