ONNX
Zhang-Yang-Sustech commited on
Commit
bf92df0
·
1 Parent(s): 6dfe46a

Add multi-points input, foreground/background points input and box input to EfficientSAM model (#291)

Browse files

* a

* add efficientsam model and basic demo

* update license

* remove example images

* update readme

* update readme

* update demo

* update demo

* update readme

* update SAM and __init__

* update demo and sam

* update label

* add present gif

* update readme

* add efficientSAM gif to readme of opencvzoo

* cv version 4.10.0, remove camera branch

* 1. add multipoints infering(max: 6)
2. add box prompt(drag), add background point(long press)
3. model fix to 1024*1024
4. label padding -1
5. update demo

* replace the model by new model support mutil-points input, update demo

* update readme

* update readme

* change window size to (800*600), pictures be put in can not exceed it

* add int8 model

* update demo

* update README

* check OpenCV version

* update model name in demo

* update model name in demo

* Add a key to exit ('q' and 'Q'); When clicks reach maximum, no box shows; comment useless print, delete useless whitespace

* update demo with some ASCII

Files changed (3) hide show
  1. README.md +13 -5
  2. demo.py +152 -42
  3. efficientSAM.py +91 -28
README.md CHANGED
@@ -3,9 +3,16 @@
3
  EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
4
 
5
  Notes:
6
- - The current implementation of the EfficientSAM demo uses the EfficientSAM-Ti model, which is specifically tailored for scenarios requiring higher speed and lightweight.
7
- - MD5 value of "efficient_sam_vitt.pt" is 7A804DA508F30EFC59EC06711C8DCD62
8
- - SHA-256 value of "efficient_sam_vitt.pt" is DFF858B19600A46461CBB7DE98F796B23A7A888D9F5E34C0B033F7D6EB9E4E6A
 
 
 
 
 
 
 
9
 
10
 
11
  ## Demo
@@ -17,7 +24,7 @@ Run the following command to try the demo:
17
  python demo.py --input /path/to/image
18
  ```
19
 
20
- Click only **once** on the object you wish to segment in the displayed image. After the click, the segmentation result will be shown in a new window.
21
 
22
  ## Result
23
 
@@ -41,4 +48,5 @@ All files in this directory are licensed under [Apache 2.0 License](./LICENSE).
41
  ## Reference
42
 
43
  - https://arxiv.org/abs/2312.00863
44
- - https://github.com/yformer/EfficientSAM
 
 
3
  EfficientSAM: Leveraged Masked Image Pretraining for Efficient Segment Anything
4
 
5
  Notes:
6
+ - The current implementation of the EfficientSAM demo uses the EfficientSAM-Ti model, which is specifically tailored for scenarios requiring higher speed and lightweight.
7
+ - image_segmentation_efficientsam_ti_2024may.onnx(supports only single point infering)
8
+ - MD5 value: 117d6a6cac60039a20b399cc133c2a60
9
+ - SHA-256 value: e3957d2cd1422855f350aa7b044f47f5b3eafada64b5904ed330b696229e2943
10
+ - image_segmentation_efficientsam_ti_2025april.onnx
11
+ - MD5 value: f23cecbb344547c960c933ff454536a3
12
+ - SHA-256 value: 4eb496e0a7259d435b49b66faf1754aa45a5c382a34558ddda9a8c6fe5915d77
13
+ - image_segmentation_efficientsam_ti_2025april_int8.onnx
14
+ - MD5 value: a1164f44b0495b82e9807c7256e95a50
15
+ - SHA-256 value: 5ecc8d59a2802c32246e68553e1cf8ce74cf74ba707b84f206eb9181ff774b4e
16
 
17
 
18
  ## Demo
 
24
  python demo.py --input /path/to/image
25
  ```
26
 
27
+ **Click** to select foreground points, **drag** to use box to select and **long press** to select background points on the object you wish to segment in the displayed image. After clicking the **Enter**, the segmentation result will be shown in a new window. Clicking the **Backspace** to clear all the prompts.
28
 
29
  ## Result
30
 
 
48
  ## Reference
49
 
50
  - https://arxiv.org/abs/2312.00863
51
+ - https://github.com/yformer/EfficientSAM
52
+ - https://github.com/facebookresearch/segment-anything
demo.py CHANGED
@@ -20,8 +20,8 @@ backend_target_pairs = [
20
  parser = argparse.ArgumentParser(description='EfficientSAM Demo')
21
  parser.add_argument('--input', '-i', type=str,
22
  help='Set input path to a certain image.')
23
- parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2024may.onnx',
24
- help='Set model path, defaults to image_segmentation_efficientsam_ti_2024may.onnx.')
25
  parser.add_argument('--backend_target', '-bt', type=int, default=0,
26
  help='''Choose one of the backend-target pair to run this demo:
27
  {:d}: (default) OpenCV implementation + CPU,
@@ -34,10 +34,14 @@ parser.add_argument('--save', '-s', action='store_true',
34
  help='Specify to save a file with results. Invalid in case of camera input.')
35
  args = parser.parse_args()
36
 
37
- #global click listener
38
- clicked_left = False
39
- #global point record in the window
40
- point = []
 
 
 
 
41
 
42
  def visualize(image, result):
43
  """
@@ -55,26 +59,88 @@ def visualize(image, result):
55
  mask = np.copy(result)
56
  # change mask to binary image
57
  t, binary = cv.threshold(mask, 127, 255, cv.THRESH_BINARY)
58
- assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image"
59
  # enhance red channel to make the segmentation more obviously
60
  enhancement_factor = 1.8
61
- red_channel = vis_result[:, :, 2]
62
  # update the channel
63
  red_channel = np.where(binary == 255, np.minimum(red_channel * enhancement_factor, 255), red_channel)
64
- vis_result[:, :, 2] = red_channel
65
-
66
  # draw borders
67
  contours, hierarchy = cv.findContours(binary, cv.RETR_LIST, cv.CHAIN_APPROX_TC89_L1)
68
  cv.drawContours(vis_result, contours, contourIdx = -1, color = (255,255,255), thickness=2)
69
  return vis_result
70
 
71
  def select(event, x, y, flags, param):
72
- global clicked_left
73
- # When the left mouse button is pressed, record the coordinates of the point where it is pressed
74
- if event == cv.EVENT_LBUTTONUP:
75
- point.append([x,y])
76
- print("point:",point[0])
77
- clicked_left = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
  if __name__ == '__main__':
80
  backend_id = backend_target_pairs[args.backend_target][0]
@@ -89,49 +155,93 @@ if __name__ == '__main__':
89
  print('Could not open or find the image:', args.input)
90
  exit(0)
91
  # create window
92
- image_window = "image: click on the thing whick you want to segment!"
93
  cv.namedWindow(image_window, cv.WINDOW_NORMAL)
94
  # change window size
95
- cv.resizeWindow(image_window, 800 if image.shape[0] > 800 else image.shape[0], 600 if image.shape[1] > 600 else image.shape[1])
 
 
 
 
 
 
 
 
 
 
96
  # put the window on the left of the screen
97
  cv.moveWindow(image_window, 50, 100)
98
  # set listener to record user's click point
99
- cv.setMouseCallback(image_window, select)
 
 
 
 
 
100
  # tips in the terminal
101
- print("click the picture on the LEFT and see the result on the RIGHT!")
 
 
 
 
 
102
  # show image
103
  cv.imshow(image_window, image)
 
 
 
 
 
 
 
 
104
  # waiting for click
105
- while cv.waitKey(1) == -1 or clicked_left:
106
- # receive click
107
- if clicked_left:
108
- # put the click point (x,y) into the model to predict
109
- result = model.infer(image=image, points=point, labels=[1])
110
- # get the visualized result
111
- vis_result = visualize(image, result)
112
- # create window to show visualized result
113
- cv.namedWindow("vis_result", cv.WINDOW_NORMAL)
114
- cv.resizeWindow("vis_result", 800 if vis_result.shape[0] > 800 else vis_result.shape[0], 600 if vis_result.shape[1] > 600 else vis_result.shape[1])
115
- cv.moveWindow("vis_result", 851, 100)
116
- cv.imshow("vis_result", vis_result)
117
- # set click false to listen another click
118
- clicked_left = False
119
- elif cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1:
120
- # if click × to close the image window then ending
121
  break
122
- else:
123
- # when not clicked, set point to empty
124
- point = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  cv.destroyAllWindows()
126
-
127
  # Save results if save is true
128
  if args.save:
129
  cv.imwrite('./example_outputs/vis_result.jpg', vis_result)
130
  cv.imwrite("./example_outputs/mask.jpg", result)
131
  print('vis_result.jpg and mask.jpg are saved to ./example_outputs/')
132
 
133
-
134
  else:
135
  print('Set input path to a certain image.')
136
  pass
137
-
 
20
  parser = argparse.ArgumentParser(description='EfficientSAM Demo')
21
  parser.add_argument('--input', '-i', type=str,
22
  help='Set input path to a certain image.')
23
+ parser.add_argument('--model', '-m', type=str, default='image_segmentation_efficientsam_ti_2025april.onnx',
24
+ help='Set model path, defaults to image_segmentation_efficientsam_ti_2025april.onnx.')
25
  parser.add_argument('--backend_target', '-bt', type=int, default=0,
26
  help='''Choose one of the backend-target pair to run this demo:
27
  {:d}: (default) OpenCV implementation + CPU,
 
34
  help='Specify to save a file with results. Invalid in case of camera input.')
35
  args = parser.parse_args()
36
 
37
+ # Global configuration
38
+ WINDOW_SIZE = (800, 600) # Fixed window size (width, height)
39
+ MAX_POINTS = 6 # Maximum allowed points
40
+ points = [] # Store clicked coordinates (original image scale)
41
+ labels = [] # Point labels (-1: useless, 0: background, 1: foreground, 2: top-left, 3: bottom right)
42
+ backend_point = []
43
+ rectangle = False
44
+ current_img = None
45
 
46
  def visualize(image, result):
47
  """
 
59
  mask = np.copy(result)
60
  # change mask to binary image
61
  t, binary = cv.threshold(mask, 127, 255, cv.THRESH_BINARY)
62
+ assert set(np.unique(binary)) <= {0, 255}, "The mask must be a binary image."
63
  # enhance red channel to make the segmentation more obviously
64
  enhancement_factor = 1.8
65
+ red_channel = vis_result[:, :, 2]
66
  # update the channel
67
  red_channel = np.where(binary == 255, np.minimum(red_channel * enhancement_factor, 255), red_channel)
68
+ vis_result[:, :, 2] = red_channel
69
+
70
  # draw borders
71
  contours, hierarchy = cv.findContours(binary, cv.RETR_LIST, cv.CHAIN_APPROX_TC89_L1)
72
  cv.drawContours(vis_result, contours, contourIdx = -1, color = (255,255,255), thickness=2)
73
  return vis_result
74
 
75
  def select(event, x, y, flags, param):
76
+ """Handle mouse events with coordinate conversion"""
77
+ global points, labels, backend_point, rectangle, current_img
78
+ orig_img = param['original_img']
79
+ image_window = param['image_window']
80
+
81
+ if event == cv.EVENT_LBUTTONDOWN:
82
+ param['mouse_down_time'] = cv.getTickCount()
83
+ backend_point = [x, y]
84
+
85
+ elif event == cv.EVENT_MOUSEMOVE:
86
+ if rectangle == True:
87
+ rectangle_change_img = current_img.copy()
88
+ cv.rectangle(rectangle_change_img, (backend_point[0], backend_point[1]), (x, y), (255,0,0) , 2)
89
+ cv.imshow(image_window, rectangle_change_img)
90
+ elif len(backend_point) != 0 and len(points) < MAX_POINTS:
91
+ rectangle = True
92
+
93
+
94
+ elif event == cv.EVENT_LBUTTONUP:
95
+ if len(points) >= MAX_POINTS:
96
+ print(f"Maximum points reached {MAX_POINTS}.")
97
+ return
98
+
99
+ if rectangle == False:
100
+ duration = (cv.getTickCount() - param['mouse_down_time'])/cv.getTickFrequency()
101
+ label = -1 if duration > 0.5 else 1 # Long press = background
102
+
103
+ points.append([backend_point[0], backend_point[1]])
104
+ labels.append(label)
105
+ print(f"Added {['background','foreground','background'][label]} point {backend_point}.")
106
+ else:
107
+ if len(points) + 1 >= MAX_POINTS:
108
+ rectangle = False
109
+ backend_point.clear()
110
+ cv.imshow(image_window, current_img)
111
+ print(f"Points reached {MAX_POINTS}, could not add box.")
112
+ return
113
+ point_leftup = []
114
+ point_rightdown = []
115
+ if x > backend_point[0] or y > backend_point[1]:
116
+ point_leftup.extend(backend_point)
117
+ point_rightdown.extend([x,y])
118
+ else:
119
+ point_leftup.extend([x,y])
120
+ point_rightdown.extend(backend_point)
121
+ points.append(point_leftup)
122
+ points.append(point_rightdown)
123
+ print(f"Added box from {point_leftup} to {point_rightdown}.")
124
+ labels.append(2)
125
+ labels.append(3)
126
+ rectangle = False
127
+ backend_point.clear()
128
+
129
+ marked_img = orig_img.copy()
130
+ top_left = None
131
+ for (px, py), lbl in zip(points, labels):
132
+ if lbl == -1:
133
+ cv.circle(marked_img, (px, py), 5, (0, 0, 255), -1)
134
+ elif lbl == 1:
135
+ cv.circle(marked_img, (px, py), 5, (0, 255, 0), -1)
136
+ elif lbl == 2:
137
+ top_left = (px, py)
138
+ elif lbl == 3:
139
+ bottom_right = (px, py)
140
+ cv.rectangle(marked_img, top_left, bottom_right, (255,0,0) , 2)
141
+ cv.imshow(image_window, marked_img)
142
+ current_img = marked_img.copy()
143
+
144
 
145
  if __name__ == '__main__':
146
  backend_id = backend_target_pairs[args.backend_target][0]
 
155
  print('Could not open or find the image:', args.input)
156
  exit(0)
157
  # create window
158
+ image_window = "Origin image"
159
  cv.namedWindow(image_window, cv.WINDOW_NORMAL)
160
  # change window size
161
+ rate = 1
162
+ rate1 = 1
163
+ rate2 = 1
164
+ if(image.shape[1]>WINDOW_SIZE[0]):
165
+ rate1 = WINDOW_SIZE[0]/image.shape[1]
166
+ if(image.shape[0]>WINDOW_SIZE[1]):
167
+ rate2 = WINDOW_SIZE[1]/image.shape[0]
168
+ rate = min(rate1, rate2)
169
+ # width, height
170
+ WINDOW_SIZE = (int(image.shape[1] * rate), int(image.shape[0] * rate))
171
+ cv.resizeWindow(image_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
172
  # put the window on the left of the screen
173
  cv.moveWindow(image_window, 50, 100)
174
  # set listener to record user's click point
175
+ param = {
176
+ 'original_img': image,
177
+ 'mouse_down_time': 0,
178
+ 'image_window' : image_window
179
+ }
180
+ cv.setMouseCallback(image_window, select, param)
181
  # tips in the terminal
182
+ print("Click Select foreground point\n"
183
+ "Long press — Select background point\n"
184
+ "Drag — Create selection box\n"
185
+ "Enter — Infer\n"
186
+ "Backspace — Clear the prompts\n"
187
+ "Q - Quit")
188
  # show image
189
  cv.imshow(image_window, image)
190
+ current_img = image.copy()
191
+ # create window to show visualized result
192
+ vis_image = image.copy()
193
+ segmentation_window = "Segment result"
194
+ cv.namedWindow(segmentation_window, cv.WINDOW_NORMAL)
195
+ cv.resizeWindow(segmentation_window, WINDOW_SIZE[0], WINDOW_SIZE[1])
196
+ cv.moveWindow(segmentation_window, WINDOW_SIZE[0]+51, 100)
197
+ cv.imshow(segmentation_window, vis_image)
198
  # waiting for click
199
+ while True:
200
+ # Check window status
201
+ # if click × to close the image window then ending
202
+ if (cv.getWindowProperty(image_window, cv.WND_PROP_VISIBLE) < 1 or
203
+ cv.getWindowProperty(segmentation_window, cv.WND_PROP_VISIBLE) < 1):
 
 
 
 
 
 
 
 
 
 
 
204
  break
205
+
206
+ # Handle keyboard input
207
+ key = cv.waitKey(1)
208
+
209
+ # receive enter
210
+ if key == 13:
211
+
212
+ vis_image = image.copy()
213
+ cv.putText(vis_image, "infering...",
214
+ (50, vis_image.shape[0]//2),
215
+ cv.FONT_HERSHEY_SIMPLEX, 10, (255,255,255), 5)
216
+ cv.imshow(segmentation_window, vis_image)
217
+
218
+ result = model.infer(image=image, points=points, labels=labels)
219
+ if len(result) == 0:
220
+ print("clear and select points again!")
221
+ else:
222
+ vis_result = visualize(image, result)
223
+
224
+ cv.imshow(segmentation_window, vis_result)
225
+ elif key == 8 or key == 127: # ASCII for Backspace or Delete
226
+ points.clear()
227
+ labels.clear()
228
+ backend_point = []
229
+ rectangle = False
230
+ current_img = image
231
+ print("Points are cleared.")
232
+ cv.imshow(image_window, image)
233
+ elif key == ord('q') or key == ord('Q'):
234
+ break
235
+
236
  cv.destroyAllWindows()
237
+
238
  # Save results if save is true
239
  if args.save:
240
  cv.imwrite('./example_outputs/vis_result.jpg', vis_result)
241
  cv.imwrite("./example_outputs/mask.jpg", result)
242
  print('vis_result.jpg and mask.jpg are saved to ./example_outputs/')
243
 
 
244
  else:
245
  print('Set input path to a certain image.')
246
  pass
247
+
efficientSAM.py CHANGED
@@ -11,11 +11,15 @@ class EfficientSAM:
11
  self._model.setPreferableBackend(self._backendId)
12
  self._model.setPreferableTarget(self._targetId)
13
  # 3 inputs
14
- self._inputNames = ["batched_images", "batched_point_coords", "batched_point_labels"]
15
-
16
- self._outputNames = ['output_masks'] # actual output layer name
17
  self._currentInputSize = None
18
- self._inputSize = [640, 640] # input size for the model
 
 
 
 
19
 
20
  @property
21
  def name(self):
@@ -28,26 +32,54 @@ class EfficientSAM:
28
  self._model.setPreferableTarget(self._targetId)
29
 
30
  def _preprocess(self, image, points, labels):
31
-
32
  image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
33
  # record the input image size, (width, height)
34
  self._currentInputSize = (image.shape[1], image.shape[0])
35
-
36
  image = cv.resize(image, self._inputSize)
37
-
38
  image = image.astype(np.float32, copy=False) / 255.0
39
-
40
- # convert points to (640*640) size space
41
- for p in points:
42
- p[0] = int(p[0] * self._inputSize[0]/self._currentInputSize[0])
43
- p[1] = int(p[1]* self._inputSize[1]/self._currentInputSize[1])
44
-
45
  image_blob = cv.dnn.blobFromImage(image)
46
-
47
- points_blob = np.array([[points]], dtype=np.float32)
48
-
49
- labels_blob = np.array([[[labels]]])
50
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  return image_blob, points_blob, labels_blob
52
 
53
  def infer(self, image, points, labels):
@@ -57,17 +89,48 @@ class EfficientSAM:
57
  self._model.setInput(imageBlob, self._inputNames[0])
58
  self._model.setInput(pointsBlob, self._inputNames[1])
59
  self._model.setInput(labelsBlob, self._inputNames[2])
60
- outputBlob = self._model.forward()
 
 
61
  # Postprocess
62
- results = self._postprocess(outputBlob)
63
-
64
  return results
65
 
66
- def _postprocess(self, outputBlob):
67
- mask = outputBlob[0, 0, 0, :, :] >= 0
68
-
69
- mask_uint8 = (mask * 255).astype(np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  # change to real image size
71
- mask_uint8 = cv.resize(mask_uint8, dsize=self._currentInputSize, interpolation=2)
72
-
73
- return mask_uint8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  self._model.setPreferableBackend(self._backendId)
12
  self._model.setPreferableTarget(self._targetId)
13
  # 3 inputs
14
+ self._inputNames = ["batched_images", "batched_point_coords", "batched_point_labels"]
15
+
16
+ self._outputNames = ['output_masks', 'iou_predictions'] # actual output layer name
17
  self._currentInputSize = None
18
+ self._inputSize = [1024, 1024] # input size for the model
19
+ self._maxPointNums = 6
20
+ self._frontGroundPoints = []
21
+ self._backGroundPoints = []
22
+ self._labels = []
23
 
24
  @property
25
  def name(self):
 
32
  self._model.setPreferableTarget(self._targetId)
33
 
34
  def _preprocess(self, image, points, labels):
35
+
36
  image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
37
  # record the input image size, (width, height)
38
  self._currentInputSize = (image.shape[1], image.shape[0])
39
+
40
  image = cv.resize(image, self._inputSize)
41
+
42
  image = image.astype(np.float32, copy=False) / 255.0
43
+
 
 
 
 
 
44
  image_blob = cv.dnn.blobFromImage(image)
45
+
46
+ points = np.array(points, dtype=np.float32)
47
+ labels = np.array(labels, dtype=np.float32)
48
+ assert points.shape[0] <= self._maxPointNums, f"Max input points number: {self._maxPointNums}"
49
+ assert points.shape[0] == labels.shape[0]
50
+
51
+ frontGroundPoints = []
52
+ backGroundPoints = []
53
+ inputLabels = []
54
+ for i in range(len(points)):
55
+ if labels[i] == -1:
56
+ backGroundPoints.append(points[i])
57
+ else:
58
+ frontGroundPoints.append(points[i])
59
+ inputLabels.append(labels[i])
60
+ self._backGroundPoints = np.uint32(backGroundPoints)
61
+ # print("input:")
62
+ # print(" back: ", self._backGroundPoints)
63
+ # print(" front: ", frontGroundPoints)
64
+ # print(" label: ", inputLabels)
65
+
66
+ # convert points to (1024*1024) size space
67
+ for p in frontGroundPoints:
68
+ p[0] = np.float32(p[0] * self._inputSize[0]/self._currentInputSize[0])
69
+ p[1] = np.float32(p[1] * self._inputSize[1]/self._currentInputSize[1])
70
+
71
+ if len(frontGroundPoints) > self._maxPointNums:
72
+ return "no"
73
+
74
+ pad_num = self._maxPointNums - len(frontGroundPoints)
75
+ self._frontGroundPoints = np.vstack([frontGroundPoints, np.zeros((pad_num, 2), dtype=np.float32)])
76
+ inputLabels_arr = np.array(inputLabels, dtype=np.float32).reshape(-1, 1)
77
+ self._labels = np.vstack([inputLabels_arr, np.full((pad_num, 1), -1, dtype=np.float32)])
78
+
79
+ points_blob = np.array([[self._frontGroundPoints]])
80
+
81
+ labels_blob = np.array([[self._labels]])
82
+
83
  return image_blob, points_blob, labels_blob
84
 
85
  def infer(self, image, points, labels):
 
89
  self._model.setInput(imageBlob, self._inputNames[0])
90
  self._model.setInput(pointsBlob, self._inputNames[1])
91
  self._model.setInput(labelsBlob, self._inputNames[2])
92
+ # print("infering...")
93
+ outputs = self._model.forward(self._outputNames)
94
+ outputBlob, outputIou = outputs[0], outputs[1]
95
  # Postprocess
96
+ results = self._postprocess(outputBlob, outputIou)
97
+ # print("done")
98
  return results
99
 
100
+ def _postprocess(self, outputBlob, outputIou):
101
+ # The masks are already sorted by their predicted IOUs.
102
+ # The first dimension is the batch size (we have a single image. so it is 1).
103
+ # The second dimension is the number of masks we want to generate
104
+ # The third dimension is the number of candidate masks output by the model.
105
+ masks = outputBlob[0, 0, :, :, :] >= 0
106
+ ious = outputIou[0, 0, :]
107
+
108
+ # sorted by ious
109
+ sorted_indices = np.argsort(ious)[::-1]
110
+ sorted_masks = masks[sorted_indices]
111
+
112
+ # sorted by area
113
+ # mask_areas = np.sum(masks, axis=(1, 2))
114
+ # sorted_indices = np.argsort(mask_areas)
115
+ # sorted_masks = masks[sorted_indices]
116
+
117
+ masks_uint8 = (sorted_masks * 255).astype(np.uint8)
118
+
119
  # change to real image size
120
+ resized_masks = [
121
+ cv.resize(mask, dsize=self._currentInputSize,
122
+ interpolation=cv.INTER_NEAREST)
123
+ for mask in masks_uint8
124
+ ]
125
+
126
+ # background mask don't need
127
+ for mask in resized_masks:
128
+ contains_bg = any(
129
+ mask[y, x] if (0 <= x < mask.shape[1] and 0 <= y < mask.shape[0])
130
+ else False
131
+ for (x, y) in self._backGroundPoints
132
+ )
133
+ if not contains_bg:
134
+ return mask
135
+
136
+ return resized_masks[0]