Shivdutta commited on
Commit
7a64e24
·
verified ·
1 Parent(s): 1f515b4

Delete utils_tools.py

Browse files
Files changed (1) hide show
  1. utils_tools.py +0 -442
utils_tools.py DELETED
@@ -1,442 +0,0 @@
1
- import numpy as np
2
- from PIL import Image
3
- import matplotlib.pyplot as plt
4
- import cv2
5
- import torch
6
- import os
7
- import sys
8
- import clip
9
-
10
-
11
- def convert_box_xywh_to_xyxy(box):
12
- if len(box) == 4:
13
- return [box[0], box[1], box[0] + box[2], box[1] + box[3]]
14
- else:
15
- result = []
16
- for b in box:
17
- b = convert_box_xywh_to_xyxy(b)
18
- result.append(b)
19
- return result
20
-
21
-
22
- def segment_image(image, bbox):
23
- image_array = np.array(image)
24
- segmented_image_array = np.zeros_like(image_array)
25
- x1, y1, x2, y2 = bbox
26
- segmented_image_array[y1:y2, x1:x2] = image_array[y1:y2, x1:x2]
27
- segmented_image = Image.fromarray(segmented_image_array)
28
- black_image = Image.new("RGB", image.size, (255, 255, 255))
29
- # transparency_mask = np.zeros_like((), dtype=np.uint8)
30
- transparency_mask = np.zeros(
31
- (image_array.shape[0], image_array.shape[1]), dtype=np.uint8
32
- )
33
- transparency_mask[y1:y2, x1:x2] = 255
34
- transparency_mask_image = Image.fromarray(transparency_mask, mode="L")
35
- black_image.paste(segmented_image, mask=transparency_mask_image)
36
- return black_image
37
-
38
-
39
- def format_results(result, filter=0):
40
- annotations = []
41
- n = len(result.masks.data)
42
- for i in range(n):
43
- annotation = {}
44
- mask = result.masks.data[i] == 1.0
45
-
46
- if torch.sum(mask) < filter:
47
- continue
48
- annotation["id"] = i
49
- annotation["segmentation"] = mask.cpu().numpy()
50
- annotation["bbox"] = result.boxes.data[i]
51
- annotation["score"] = result.boxes.conf[i]
52
- annotation["area"] = annotation["segmentation"].sum()
53
- annotations.append(annotation)
54
- return annotations
55
-
56
-
57
- def filter_masks(annotations): # filter the overlap mask
58
- annotations.sort(key=lambda x: x["area"], reverse=True)
59
- to_remove = set()
60
- for i in range(0, len(annotations)):
61
- a = annotations[i]
62
- for j in range(i + 1, len(annotations)):
63
- b = annotations[j]
64
- if i != j and j not in to_remove:
65
- # check if
66
- if b["area"] < a["area"]:
67
- if (a["segmentation"] & b["segmentation"]).sum() / b[
68
- "segmentation"
69
- ].sum() > 0.8:
70
- to_remove.add(j)
71
-
72
- return [a for i, a in enumerate(annotations) if i not in to_remove], to_remove
73
-
74
-
75
- def get_bbox_from_mask(mask):
76
- mask = mask.astype(np.uint8)
77
- contours, hierarchy = cv2.findContours(
78
- mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
79
- )
80
- x1, y1, w, h = cv2.boundingRect(contours[0])
81
- x2, y2 = x1 + w, y1 + h
82
- if len(contours) > 1:
83
- for b in contours:
84
- x_t, y_t, w_t, h_t = cv2.boundingRect(b)
85
- # 将多个bbox合并成一个
86
- x1 = min(x1, x_t)
87
- y1 = min(y1, y_t)
88
- x2 = max(x2, x_t + w_t)
89
- y2 = max(y2, y_t + h_t)
90
- h = y2 - y1
91
- w = x2 - x1
92
- return [x1, y1, x2, y2]
93
-
94
-
95
- def fast_process(
96
- annotations, args, mask_random_color, bbox=None, points=None, edges=False
97
- ):
98
- if isinstance(annotations[0], dict):
99
- annotations = [annotation["segmentation"] for annotation in annotations]
100
- result_name = os.path.basename(args.img_path)
101
- image = cv2.imread(args.img_path)
102
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
103
- original_h = image.shape[0]
104
- original_w = image.shape[1]
105
- if sys.platform == "darwin":
106
- plt.switch_backend("TkAgg")
107
- plt.figure(figsize=(original_w / 100, original_h / 100))
108
- # Add subplot with no margin.
109
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
110
- plt.margins(0, 0)
111
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
112
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
113
- plt.imshow(image)
114
- if args.better_quality == True:
115
- if isinstance(annotations[0], torch.Tensor):
116
- annotations = np.array(annotations.cpu())
117
- for i, mask in enumerate(annotations):
118
- mask = cv2.morphologyEx(
119
- mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
120
- )
121
- annotations[i] = cv2.morphologyEx(
122
- mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
123
- )
124
- if args.device == "cpu":
125
- annotations = np.array(annotations)
126
- fast_show_mask(
127
- annotations,
128
- plt.gca(),
129
- random_color=mask_random_color,
130
- bbox=bbox,
131
- points=points,
132
- point_label=args.point_label,
133
- retinamask=args.retina,
134
- target_height=original_h,
135
- target_width=original_w,
136
- )
137
- else:
138
- if isinstance(annotations[0], np.ndarray):
139
- annotations = torch.from_numpy(annotations)
140
- fast_show_mask_gpu(
141
- annotations,
142
- plt.gca(),
143
- random_color=args.randomcolor,
144
- bbox=bbox,
145
- points=points,
146
- point_label=args.point_label,
147
- retinamask=args.retina,
148
- target_height=original_h,
149
- target_width=original_w,
150
- )
151
- if isinstance(annotations, torch.Tensor):
152
- annotations = annotations.cpu().numpy()
153
- if args.withContours == True:
154
- contour_all = []
155
- temp = np.zeros((original_h, original_w, 1))
156
- for i, mask in enumerate(annotations):
157
- if type(mask) == dict:
158
- mask = mask["segmentation"]
159
- annotation = mask.astype(np.uint8)
160
- if args.retina == False:
161
- annotation = cv2.resize(
162
- annotation,
163
- (original_w, original_h),
164
- interpolation=cv2.INTER_NEAREST,
165
- )
166
- contours, hierarchy = cv2.findContours(
167
- annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
168
- )
169
- for contour in contours:
170
- contour_all.append(contour)
171
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
172
- color = np.array([0 / 255, 0 / 255, 255 / 255, 0.8])
173
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
174
- plt.imshow(contour_mask)
175
-
176
- save_path = args.output
177
- if not os.path.exists(save_path):
178
- os.makedirs(save_path)
179
- plt.axis("off")
180
- fig = plt.gcf()
181
- plt.draw()
182
-
183
- try:
184
- buf = fig.canvas.tostring_rgb()
185
- except AttributeError:
186
- fig.canvas.draw()
187
- buf = fig.canvas.tostring_rgb()
188
-
189
- cols, rows = fig.canvas.get_width_height()
190
- img_array = np.fromstring(buf, dtype=np.uint8).reshape(rows, cols, 3)
191
- cv2.imwrite(os.path.join(save_path, result_name), cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR))
192
-
193
-
194
- # CPU post process
195
- def fast_show_mask(
196
- annotation,
197
- ax,
198
- random_color=False,
199
- bbox=None,
200
- points=None,
201
- point_label=None,
202
- retinamask=True,
203
- target_height=960,
204
- target_width=960,
205
- ):
206
- msak_sum = annotation.shape[0]
207
- height = annotation.shape[1]
208
- weight = annotation.shape[2]
209
- # 将annotation 按照面积 排序
210
- areas = np.sum(annotation, axis=(1, 2))
211
- sorted_indices = np.argsort(areas)
212
- annotation = annotation[sorted_indices]
213
-
214
- index = (annotation != 0).argmax(axis=0)
215
- if random_color == True:
216
- color = np.random.random((msak_sum, 1, 1, 3))
217
- else:
218
- color = np.ones((msak_sum, 1, 1, 3)) * np.array(
219
- [30 / 255, 144 / 255, 255 / 255]
220
- )
221
- transparency = np.ones((msak_sum, 1, 1, 1)) * 0.6
222
- visual = np.concatenate([color, transparency], axis=-1)
223
- mask_image = np.expand_dims(annotation, -1) * visual
224
-
225
- show = np.zeros((height, weight, 4))
226
- h_indices, w_indices = np.meshgrid(
227
- np.arange(height), np.arange(weight), indexing="ij"
228
- )
229
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
230
- # 使用向量化索引更新show的值
231
- show[h_indices, w_indices, :] = mask_image[indices]
232
- if bbox is not None:
233
- x1, y1, x2, y2 = bbox
234
- ax.add_patch(
235
- plt.Rectangle(
236
- (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
237
- )
238
- )
239
- # draw point
240
- if points is not None:
241
- plt.scatter(
242
- [point[0] for i, point in enumerate(points) if point_label[i] == 1],
243
- [point[1] for i, point in enumerate(points) if point_label[i] == 1],
244
- s=20,
245
- c="y",
246
- )
247
- plt.scatter(
248
- [point[0] for i, point in enumerate(points) if point_label[i] == 0],
249
- [point[1] for i, point in enumerate(points) if point_label[i] == 0],
250
- s=20,
251
- c="m",
252
- )
253
-
254
- if retinamask == False:
255
- show = cv2.resize(
256
- show, (target_width, target_height), interpolation=cv2.INTER_NEAREST
257
- )
258
- ax.imshow(show)
259
-
260
-
261
- def fast_show_mask_gpu(
262
- annotation,
263
- ax,
264
- random_color=False,
265
- bbox=None,
266
- points=None,
267
- point_label=None,
268
- retinamask=True,
269
- target_height=960,
270
- target_width=960,
271
- ):
272
- msak_sum = annotation.shape[0]
273
- height = annotation.shape[1]
274
- weight = annotation.shape[2]
275
- areas = torch.sum(annotation, dim=(1, 2))
276
- sorted_indices = torch.argsort(areas, descending=False)
277
- annotation = annotation[sorted_indices]
278
- # 找每个位置第一个非零值下标
279
- index = (annotation != 0).to(torch.long).argmax(dim=0)
280
- if random_color == True:
281
- color = torch.rand((msak_sum, 1, 1, 3)).to(annotation.device)
282
- else:
283
- color = torch.ones((msak_sum, 1, 1, 3)).to(annotation.device) * torch.tensor(
284
- [30 / 255, 144 / 255, 255 / 255]
285
- ).to(annotation.device)
286
- transparency = torch.ones((msak_sum, 1, 1, 1)).to(annotation.device) * 0.6
287
- visual = torch.cat([color, transparency], dim=-1)
288
- mask_image = torch.unsqueeze(annotation, -1) * visual
289
- # 按index取数,index指每个位置选哪个batch的数,把mask_image转成一个batch的形式
290
- show = torch.zeros((height, weight, 4)).to(annotation.device)
291
- h_indices, w_indices = torch.meshgrid(
292
- torch.arange(height), torch.arange(weight), indexing="ij"
293
- )
294
- indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
295
- # 使用向量化索引更新show的值
296
- show[h_indices, w_indices, :] = mask_image[indices]
297
- show_cpu = show.cpu().numpy()
298
- if bbox is not None:
299
- x1, y1, x2, y2 = bbox
300
- ax.add_patch(
301
- plt.Rectangle(
302
- (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
303
- )
304
- )
305
- # draw point
306
- if points is not None:
307
- plt.scatter(
308
- [point[0] for i, point in enumerate(points) if point_label[i] == 1],
309
- [point[1] for i, point in enumerate(points) if point_label[i] == 1],
310
- s=20,
311
- c="y",
312
- )
313
- plt.scatter(
314
- [point[0] for i, point in enumerate(points) if point_label[i] == 0],
315
- [point[1] for i, point in enumerate(points) if point_label[i] == 0],
316
- s=20,
317
- c="m",
318
- )
319
- if retinamask == False:
320
- show_cpu = cv2.resize(
321
- show_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
322
- )
323
- ax.imshow(show_cpu)
324
-
325
-
326
- # clip
327
- @torch.no_grad()
328
- def retriev(
329
- model, preprocess, elements: [Image.Image], search_text: str, device
330
- ):
331
- preprocessed_images = [preprocess(image).to(device) for image in elements]
332
- tokenized_text = clip.tokenize([search_text]).to(device)
333
- stacked_images = torch.stack(preprocessed_images)
334
- image_features = model.encode_image(stacked_images)
335
- text_features = model.encode_text(tokenized_text)
336
- image_features /= image_features.norm(dim=-1, keepdim=True)
337
- text_features /= text_features.norm(dim=-1, keepdim=True)
338
- probs = 100.0 * image_features @ text_features.T
339
- return probs[:, 0].softmax(dim=0)
340
-
341
-
342
- def crop_image(annotations, image_like):
343
- if isinstance(image_like, str):
344
- image = Image.open(image_like)
345
- else:
346
- image = image_like
347
- ori_w, ori_h = image.size
348
- mask_h, mask_w = annotations[0]["segmentation"].shape
349
- if ori_w != mask_w or ori_h != mask_h:
350
- image = image.resize((mask_w, mask_h))
351
- cropped_boxes = []
352
- cropped_images = []
353
- not_crop = []
354
- origin_id = []
355
- for _, mask in enumerate(annotations):
356
- if np.sum(mask["segmentation"]) <= 100:
357
- continue
358
- origin_id.append(_)
359
- bbox = get_bbox_from_mask(mask["segmentation"]) # mask 的 bbox
360
- cropped_boxes.append(segment_image(image, bbox)) # 保存裁剪的图片
361
- # cropped_boxes.append(segment_image(image,mask["segmentation"]))
362
- cropped_images.append(bbox) # 保存裁剪的图片的bbox
363
- return cropped_boxes, cropped_images, not_crop, origin_id, annotations
364
-
365
-
366
- def box_prompt(masks, bbox, target_height, target_width):
367
- h = masks.shape[1]
368
- w = masks.shape[2]
369
- if h != target_height or w != target_width:
370
- bbox = [
371
- int(bbox[0] * w / target_width),
372
- int(bbox[1] * h / target_height),
373
- int(bbox[2] * w / target_width),
374
- int(bbox[3] * h / target_height),
375
- ]
376
- bbox[0] = round(bbox[0]) if round(bbox[0]) > 0 else 0
377
- bbox[1] = round(bbox[1]) if round(bbox[1]) > 0 else 0
378
- bbox[2] = round(bbox[2]) if round(bbox[2]) < w else w
379
- bbox[3] = round(bbox[3]) if round(bbox[3]) < h else h
380
-
381
- # IoUs = torch.zeros(len(masks), dtype=torch.float32)
382
- bbox_area = (bbox[3] - bbox[1]) * (bbox[2] - bbox[0])
383
-
384
- masks_area = torch.sum(masks[:, bbox[1]: bbox[3], bbox[0]: bbox[2]], dim=(1, 2))
385
- orig_masks_area = torch.sum(masks, dim=(1, 2))
386
-
387
- union = bbox_area + orig_masks_area - masks_area
388
- IoUs = masks_area / union
389
- max_iou_index = torch.argmax(IoUs)
390
-
391
- return masks[max_iou_index].cpu().numpy(), max_iou_index
392
-
393
-
394
- def point_prompt(masks, points, point_label, target_height, target_width): # numpy 处理
395
- h = masks[0]["segmentation"].shape[0]
396
- w = masks[0]["segmentation"].shape[1]
397
- if h != target_height or w != target_width:
398
- points = [
399
- [int(point[0] * w / target_width), int(point[1] * h / target_height)]
400
- for point in points
401
- ]
402
- onemask = np.zeros((h, w))
403
- masks = sorted(masks, key=lambda x: x['area'], reverse=True)
404
- for i, annotation in enumerate(masks):
405
- if type(annotation) == dict:
406
- mask = annotation['segmentation']
407
- else:
408
- mask = annotation
409
- for i, point in enumerate(points):
410
- if mask[point[1], point[0]] == 1 and point_label[i] == 1:
411
- onemask[mask] = 1
412
- if mask[point[1], point[0]] == 1 and point_label[i] == 0:
413
- onemask[mask] = 0
414
- onemask = onemask >= 1
415
- return onemask, 0
416
-
417
-
418
- def text_prompt(annotations, text, img_path, device, wider=False, threshold=0.9):
419
- cropped_boxes, cropped_images, not_crop, origin_id, annotations_ = crop_image(
420
- annotations, img_path
421
- )
422
- clip_model, preprocess = clip.load("./weights/CLIP_ViT_B_32.pt", device=device)
423
- scores = retriev(
424
- clip_model, preprocess, cropped_boxes, text, device=device
425
- )
426
- max_idx = scores.argsort()
427
- max_idx = max_idx[-1]
428
- max_idx = origin_id[int(max_idx)]
429
-
430
- # find the biggest mask which contains the mask with max score
431
- if wider:
432
- mask0 = annotations_[max_idx]["segmentation"]
433
- area0 = np.sum(mask0)
434
- areas = [(i, np.sum(mask["segmentation"])) for i, mask in enumerate(annotations_) if i in origin_id]
435
- areas = sorted(areas, key=lambda area: area[1], reverse=True)
436
- indices = [area[0] for area in areas]
437
- for index in indices:
438
- if index == max_idx or np.sum(annotations_[index]["segmentation"] & mask0) / area0 > threshold:
439
- max_idx = index
440
- break
441
-
442
- return annotations_[max_idx]["segmentation"], max_idx