Inoob commited on
Commit
274ab09
·
verified ·
1 Parent(s): 55d15ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -34
app.py CHANGED
@@ -68,35 +68,89 @@ def build_model(hypar,device):
68
  net.eval()
69
  return net
70
 
71
-
72
- def predict(net, inputs_val, shapes_val, hypar, device):
73
- '''
74
- Given an Image, predict the mask
75
- '''
76
- net.eval()
77
-
78
- if(hypar["model_digit"]=="full"):
79
- inputs_val = inputs_val.type(torch.FloatTensor)
 
 
 
 
 
 
 
 
 
 
 
80
  else:
81
- inputs_val = inputs_val.type(torch.HalfTensor)
82
-
83
-
84
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
85
-
86
- ds_val = net(inputs_val_v)[0] # list of 6 results
87
-
88
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
89
-
90
- ## recover the prediction spatial size to the orignal image size
91
- pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
92
-
93
- ma = torch.max(pred_val)
94
- mi = torch.min(pred_val)
95
- pred_val = (pred_val-mi)/(ma-mi) # max = 1
96
 
97
- if device == 'cuda': torch.cuda.empty_cache()
98
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # Set Parameters
101
  hypar = {} # paramters for inferencing
102
 
@@ -124,16 +178,11 @@ net = build_model(hypar, device)
124
  def inference(image):
125
  image_path = image
126
 
127
- image_tensor, orig_size = load_image(image_path, hypar)
128
- mask = predict(net, image_tensor, orig_size, hypar, device)
129
-
130
- pil_mask = Image.fromarray(mask).convert('L')
131
- im_rgb = Image.open(image).convert("RGB")
132
 
133
- im_rgba = im_rgb.copy()
134
- im_rgba.putalpha(pil_mask)
135
 
136
- return [im_rgba, pil_mask]
137
 
138
 
139
  title = "Highly Accurate Dichotomous Image Segmentation"
 
68
  net.eval()
69
  return net
70
 
71
+ def resize_image(image, size=1024):
72
+
73
+ height, width = image.shape[:2]
74
+
75
+ # Check if either dimension is greater than 1120
76
+ if height > size or width > size:
77
+ # Calculate the scale factor
78
+ if height > width:
79
+ scale_factor = size / height
80
+ else:
81
+ scale_factor = size / width
82
+
83
+ # Resize the image
84
+ new_dimensions = (int(width * scale_factor), int(height * scale_factor))
85
+ resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA)
86
+
87
+ # Save the resized image
88
+
89
+ print(f"Image resized to {new_dimensions}")
90
+ return resized_image
91
  else:
92
+ print("Image is already within the desired size.")
93
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ def predict(net, im):
 
96
 
97
+ im = resize_image(im)
98
+ temp = np.ones((1024,1024,3))
99
+ h, w = im.shape[0],im.shape[1]
100
+ temp[:h,:w] = im
101
+ im = temp
102
+ #show_pic(im)
103
+ input_size = [1024,1024]
104
+ if len(im.shape) < 3:
105
+ im = np.stack([im] * 3, axis=-1) # Convert grayscale to RGB
106
+ im_shp = im.shape[0:2]
107
+ im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
108
+ im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
109
+ image = torch.divide(im_tensor, 255.0)
110
+ image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
111
+
112
+ result = net(image)
113
+ result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
114
+ ma = torch.max(result)
115
+ mi = torch.min(result)
116
+ result = (result - mi) / (ma - mi)
117
+ result = result.unsqueeze(0) if result.dim() == 2 else result # Ensure result has 3 channels
118
+ result = result.repeat(3, 1, 1) if result.shape[0] == 1 else result
119
+ result = 1 - result # Invert the mask here
120
+
121
+
122
+
123
+ #im_name = im_path.split('\\')[-1].split('.')[0]
124
+
125
+ # Resize the image to match result dimensions
126
+ image_resized = F.upsample(image, size=result.shape[1:], mode='bilinear')
127
+
128
+ # Ensure both tensors are 3D
129
+ image_resized = image_resized.squeeze(0) if image_resized.dim() == 4 else image_resized
130
+ result = result.squeeze(0) if result.dim() == 4 else result
131
+
132
+ # Apply threshold to result to ensure only pure black or white pixels
133
+ threshold = 0.50 # Adjust as needed
134
+ result[result < threshold] = 0
135
+ result[result >= threshold] = 1
136
+
137
+ distance = np.sqrt(np.sum((im - [255, 255, 255]) ** 2, axis=-1))
138
+
139
+ # Create a mask where the distance is less than the threshold
140
+ mask = distance < 200
141
+
142
+ # Convert mask to uint8
143
+ mask = mask.astype(np.uint8) * 255
144
+
145
+ mask = np.stack([mask] * 3, axis=-1)
146
+
147
+ result = (result.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
148
+ # result=result.cpu().numpy().astype(np.uint8)
149
+ # io.imsave(result_path + im_name + "_foreground.png", foreground)
150
+ wite = np.ones_like(im) * 255
151
+ cropped = np.where(result == 0, wite, mask)
152
+ #cv2.imwrite(result_path + f, cropped)
153
+ return cropped[:h,:w]
154
  # Set Parameters
155
  hypar = {} # paramters for inferencing
156
 
 
178
  def inference(image):
179
  image_path = image
180
 
181
+ image_tensor, orig_size = cv2.imread(image_path)
182
+ mask = predict(net, image_tensor)
 
 
 
183
 
 
 
184
 
185
+ return [mask,mask]
186
 
187
 
188
  title = "Highly Accurate Dichotomous Image Segmentation"