Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -68,35 +68,89 @@ def build_model(hypar,device):
|
|
68 |
net.eval()
|
69 |
return net
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
else:
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
inputs_val_v = Variable(inputs_val, requires_grad=False).to(device) # wrap inputs in Variable
|
85 |
-
|
86 |
-
ds_val = net(inputs_val_v)[0] # list of 6 results
|
87 |
-
|
88 |
-
pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W # we want the first one which is the most accurate prediction
|
89 |
-
|
90 |
-
## recover the prediction spatial size to the orignal image size
|
91 |
-
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val,0),(shapes_val[0][0],shapes_val[0][1]),mode='bilinear'))
|
92 |
-
|
93 |
-
ma = torch.max(pred_val)
|
94 |
-
mi = torch.min(pred_val)
|
95 |
-
pred_val = (pred_val-mi)/(ma-mi) # max = 1
|
96 |
|
97 |
-
|
98 |
-
return (pred_val.detach().cpu().numpy()*255).astype(np.uint8) # it is the mask we need
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
# Set Parameters
|
101 |
hypar = {} # paramters for inferencing
|
102 |
|
@@ -124,16 +178,11 @@ net = build_model(hypar, device)
|
|
124 |
def inference(image):
|
125 |
image_path = image
|
126 |
|
127 |
-
image_tensor, orig_size =
|
128 |
-
mask = predict(net, image_tensor
|
129 |
-
|
130 |
-
pil_mask = Image.fromarray(mask).convert('L')
|
131 |
-
im_rgb = Image.open(image).convert("RGB")
|
132 |
|
133 |
-
im_rgba = im_rgb.copy()
|
134 |
-
im_rgba.putalpha(pil_mask)
|
135 |
|
136 |
-
return [
|
137 |
|
138 |
|
139 |
title = "Highly Accurate Dichotomous Image Segmentation"
|
|
|
68 |
net.eval()
|
69 |
return net
|
70 |
|
71 |
+
def resize_image(image, size=1024):
|
72 |
+
|
73 |
+
height, width = image.shape[:2]
|
74 |
+
|
75 |
+
# Check if either dimension is greater than 1120
|
76 |
+
if height > size or width > size:
|
77 |
+
# Calculate the scale factor
|
78 |
+
if height > width:
|
79 |
+
scale_factor = size / height
|
80 |
+
else:
|
81 |
+
scale_factor = size / width
|
82 |
+
|
83 |
+
# Resize the image
|
84 |
+
new_dimensions = (int(width * scale_factor), int(height * scale_factor))
|
85 |
+
resized_image = cv2.resize(image, new_dimensions, interpolation=cv2.INTER_AREA)
|
86 |
+
|
87 |
+
# Save the resized image
|
88 |
+
|
89 |
+
print(f"Image resized to {new_dimensions}")
|
90 |
+
return resized_image
|
91 |
else:
|
92 |
+
print("Image is already within the desired size.")
|
93 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
+
def predict(net, im):
|
|
|
96 |
|
97 |
+
im = resize_image(im)
|
98 |
+
temp = np.ones((1024,1024,3))
|
99 |
+
h, w = im.shape[0],im.shape[1]
|
100 |
+
temp[:h,:w] = im
|
101 |
+
im = temp
|
102 |
+
#show_pic(im)
|
103 |
+
input_size = [1024,1024]
|
104 |
+
if len(im.shape) < 3:
|
105 |
+
im = np.stack([im] * 3, axis=-1) # Convert grayscale to RGB
|
106 |
+
im_shp = im.shape[0:2]
|
107 |
+
im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1)
|
108 |
+
im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), input_size, mode="bilinear").type(torch.uint8)
|
109 |
+
image = torch.divide(im_tensor, 255.0)
|
110 |
+
image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0])
|
111 |
+
|
112 |
+
result = net(image)
|
113 |
+
result = torch.squeeze(F.upsample(result[0][0], im_shp, mode='bilinear'), 0)
|
114 |
+
ma = torch.max(result)
|
115 |
+
mi = torch.min(result)
|
116 |
+
result = (result - mi) / (ma - mi)
|
117 |
+
result = result.unsqueeze(0) if result.dim() == 2 else result # Ensure result has 3 channels
|
118 |
+
result = result.repeat(3, 1, 1) if result.shape[0] == 1 else result
|
119 |
+
result = 1 - result # Invert the mask here
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
#im_name = im_path.split('\\')[-1].split('.')[0]
|
124 |
+
|
125 |
+
# Resize the image to match result dimensions
|
126 |
+
image_resized = F.upsample(image, size=result.shape[1:], mode='bilinear')
|
127 |
+
|
128 |
+
# Ensure both tensors are 3D
|
129 |
+
image_resized = image_resized.squeeze(0) if image_resized.dim() == 4 else image_resized
|
130 |
+
result = result.squeeze(0) if result.dim() == 4 else result
|
131 |
+
|
132 |
+
# Apply threshold to result to ensure only pure black or white pixels
|
133 |
+
threshold = 0.50 # Adjust as needed
|
134 |
+
result[result < threshold] = 0
|
135 |
+
result[result >= threshold] = 1
|
136 |
+
|
137 |
+
distance = np.sqrt(np.sum((im - [255, 255, 255]) ** 2, axis=-1))
|
138 |
+
|
139 |
+
# Create a mask where the distance is less than the threshold
|
140 |
+
mask = distance < 200
|
141 |
+
|
142 |
+
# Convert mask to uint8
|
143 |
+
mask = mask.astype(np.uint8) * 255
|
144 |
+
|
145 |
+
mask = np.stack([mask] * 3, axis=-1)
|
146 |
+
|
147 |
+
result = (result.permute(1, 2, 0) * 255).cpu().numpy().astype(np.uint8)
|
148 |
+
# result=result.cpu().numpy().astype(np.uint8)
|
149 |
+
# io.imsave(result_path + im_name + "_foreground.png", foreground)
|
150 |
+
wite = np.ones_like(im) * 255
|
151 |
+
cropped = np.where(result == 0, wite, mask)
|
152 |
+
#cv2.imwrite(result_path + f, cropped)
|
153 |
+
return cropped[:h,:w]
|
154 |
# Set Parameters
|
155 |
hypar = {} # paramters for inferencing
|
156 |
|
|
|
178 |
def inference(image):
|
179 |
image_path = image
|
180 |
|
181 |
+
image_tensor, orig_size = cv2.imread(image_path)
|
182 |
+
mask = predict(net, image_tensor)
|
|
|
|
|
|
|
183 |
|
|
|
|
|
184 |
|
185 |
+
return [mask,mask]
|
186 |
|
187 |
|
188 |
title = "Highly Accurate Dichotomous Image Segmentation"
|