Spaces:
Running
on
Zero
Running
on
Zero
Fix the bug in identifying the type of image input.
Browse files
app.py
CHANGED
|
@@ -102,15 +102,16 @@ def predict(images, resolution, weights_file):
|
|
| 102 |
resolution = [1024, 1024]
|
| 103 |
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
| 104 |
|
| 105 |
-
print('type(images):', type(images))
|
| 106 |
if isinstance(images, list):
|
|
|
|
|
|
|
| 107 |
save_dir = 'preds-BiRefNet'
|
| 108 |
if not os.path.exists(save_dir):
|
| 109 |
os.makedirs(save_dir)
|
|
|
|
| 110 |
else:
|
| 111 |
-
# For tab_batch
|
| 112 |
-
save_paths = []
|
| 113 |
images = [images]
|
|
|
|
| 114 |
|
| 115 |
for idx_image, image_src in enumerate(images):
|
| 116 |
if isinstance(image_src, str):
|
|
@@ -119,38 +120,38 @@ def predict(images, resolution, weights_file):
|
|
| 119 |
image = np.array(Image.open(image_data))
|
| 120 |
else:
|
| 121 |
image = image_src
|
| 122 |
-
|
| 123 |
image_shape = image.shape[:2]
|
| 124 |
image_pil = array_to_pil_image(image, tuple(resolution))
|
| 125 |
-
|
| 126 |
# Preprocess the image
|
| 127 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 128 |
image_proc = image_preprocessor.proc(image_pil)
|
| 129 |
image_proc = image_proc.unsqueeze(0)
|
| 130 |
-
|
| 131 |
# Perform the prediction
|
| 132 |
with torch.no_grad():
|
| 133 |
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
| 134 |
-
|
| 135 |
if device == 'cuda':
|
| 136 |
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
| 137 |
-
|
| 138 |
# Resize the prediction to match the original image shape
|
| 139 |
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
| 140 |
-
|
| 141 |
# Apply the prediction mask to the original image
|
| 142 |
image_pil = image_pil.resize(pred.shape[::-1])
|
| 143 |
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
| 144 |
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
| 145 |
-
|
| 146 |
torch.cuda.empty_cache()
|
| 147 |
-
|
| 148 |
-
if
|
| 149 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
| 150 |
cv2.imwrite(save_file_path)
|
| 151 |
save_paths.append(save_file_path)
|
| 152 |
|
| 153 |
-
if
|
| 154 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 155 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 156 |
for file in save_paths:
|
|
|
|
| 102 |
resolution = [1024, 1024]
|
| 103 |
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
| 104 |
|
|
|
|
| 105 |
if isinstance(images, list):
|
| 106 |
+
# For tab_batch
|
| 107 |
+
save_paths = []
|
| 108 |
save_dir = 'preds-BiRefNet'
|
| 109 |
if not os.path.exists(save_dir):
|
| 110 |
os.makedirs(save_dir)
|
| 111 |
+
tab_is_batch = True
|
| 112 |
else:
|
|
|
|
|
|
|
| 113 |
images = [images]
|
| 114 |
+
tab_is_batch = False
|
| 115 |
|
| 116 |
for idx_image, image_src in enumerate(images):
|
| 117 |
if isinstance(image_src, str):
|
|
|
|
| 120 |
image = np.array(Image.open(image_data))
|
| 121 |
else:
|
| 122 |
image = image_src
|
| 123 |
+
|
| 124 |
image_shape = image.shape[:2]
|
| 125 |
image_pil = array_to_pil_image(image, tuple(resolution))
|
| 126 |
+
|
| 127 |
# Preprocess the image
|
| 128 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 129 |
image_proc = image_preprocessor.proc(image_pil)
|
| 130 |
image_proc = image_proc.unsqueeze(0)
|
| 131 |
+
|
| 132 |
# Perform the prediction
|
| 133 |
with torch.no_grad():
|
| 134 |
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
| 135 |
+
|
| 136 |
if device == 'cuda':
|
| 137 |
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
| 138 |
+
|
| 139 |
# Resize the prediction to match the original image shape
|
| 140 |
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
| 141 |
+
|
| 142 |
# Apply the prediction mask to the original image
|
| 143 |
image_pil = image_pil.resize(pred.shape[::-1])
|
| 144 |
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
| 145 |
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
| 146 |
+
|
| 147 |
torch.cuda.empty_cache()
|
| 148 |
+
|
| 149 |
+
if tab_is_batch:
|
| 150 |
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
| 151 |
cv2.imwrite(save_file_path)
|
| 152 |
save_paths.append(save_file_path)
|
| 153 |
|
| 154 |
+
if tab_is_batch:
|
| 155 |
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 156 |
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 157 |
for file in save_paths:
|