Spaces:
Running
on
Zero
Running
on
Zero
Add tab of batch inference with saving function.
Browse files
app.py
CHANGED
|
@@ -57,15 +57,37 @@ birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7',
|
|
| 57 |
birefnet.to(device)
|
| 58 |
birefnet.eval()
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
@spaces.GPU
|
| 62 |
-
def predict(
|
| 63 |
-
assert (
|
| 64 |
-
|
| 65 |
-
if isinstance(image, str):
|
| 66 |
-
response = requests.get(image)
|
| 67 |
-
image_data = BytesIO(response.content)
|
| 68 |
-
image = np.array(Image.open(image_data))
|
| 69 |
global birefnet
|
| 70 |
# Load BiRefNet with chosen weights
|
| 71 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
|
@@ -74,33 +96,63 @@ def predict(image, resolution, weights_file):
|
|
| 74 |
birefnet.to(device)
|
| 75 |
birefnet.eval()
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
# Preprocess the image
|
| 84 |
-
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 85 |
-
image_proc = image_preprocessor.proc(image_pil)
|
| 86 |
-
image_proc = image_proc.unsqueeze(0)
|
| 87 |
-
|
| 88 |
-
# Perform the prediction
|
| 89 |
-
with torch.no_grad():
|
| 90 |
-
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
| 91 |
-
|
| 92 |
-
if device == 'cuda':
|
| 93 |
-
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
| 94 |
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
return image, image_pred
|
| 106 |
|
|
@@ -118,6 +170,11 @@ examples_url = [
|
|
| 118 |
for idx_example_url, example_url in enumerate(examples_url):
|
| 119 |
examples_url[idx_example_url].append('1024x1024')
|
| 120 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
tab_image = gr.Interface(
|
| 122 |
fn=predict,
|
| 123 |
inputs=[
|
|
@@ -128,10 +185,7 @@ tab_image = gr.Interface(
|
|
| 128 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
| 129 |
examples=examples,
|
| 130 |
api_name="image",
|
| 131 |
-
description=
|
| 132 |
-
' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
|
| 133 |
-
' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
|
| 134 |
-
' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.'),
|
| 135 |
)
|
| 136 |
|
| 137 |
tab_text = gr.Interface(
|
|
@@ -144,15 +198,20 @@ tab_text = gr.Interface(
|
|
| 144 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
| 145 |
examples=examples_url,
|
| 146 |
api_name="text",
|
| 147 |
-
description=
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
)
|
| 152 |
|
| 153 |
demo = gr.TabbedInterface(
|
| 154 |
-
[tab_image, tab_text],
|
| 155 |
-
[
|
| 156 |
title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
|
| 157 |
)
|
| 158 |
|
|
|
|
| 57 |
birefnet.to(device)
|
| 58 |
birefnet.eval()
|
| 59 |
|
| 60 |
+
# for idx, image_path in enumerate(images):
|
| 61 |
+
# im = load_img(image_path, output_type="pil")
|
| 62 |
+
# if im is None:
|
| 63 |
+
# continue
|
| 64 |
+
|
| 65 |
+
# im = im.convert("RGB")
|
| 66 |
+
# image_size = im.size
|
| 67 |
+
# input_images = transform_image(im).unsqueeze(0).to("cpu")
|
| 68 |
+
|
| 69 |
+
# with torch.no_grad():
|
| 70 |
+
# preds = birefnet(input_images)[-1].sigmoid().cpu()
|
| 71 |
+
# pred = preds[0].squeeze()
|
| 72 |
+
# pred_pil = transforms.ToPILImage()(pred)
|
| 73 |
+
# mask = pred_pil.resize(image_size)
|
| 74 |
+
|
| 75 |
+
# im.putalpha(mask)
|
| 76 |
+
# output_file_path = os.path.join(save_dir, f"output_image_batch_{idx + 1}.png")
|
| 77 |
+
# im.save(output_file_path)
|
| 78 |
+
# output_paths.append(output_file_path)
|
| 79 |
+
|
| 80 |
+
# zip_file_path = os.path.join(save_dir, "processed_images.zip")
|
| 81 |
+
# with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 82 |
+
# for file in output_paths:
|
| 83 |
+
# zipf.write(file, os.path.basename(file))
|
| 84 |
+
|
| 85 |
+
# return output_paths, zip_file_path
|
| 86 |
|
| 87 |
@spaces.GPU
|
| 88 |
+
def predict(images, resolution, weights_file):
|
| 89 |
+
assert (images is not None), 'AssertionError: images cannot be None.'
|
| 90 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
global birefnet
|
| 92 |
# Load BiRefNet with chosen weights
|
| 93 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
|
|
|
| 96 |
birefnet.to(device)
|
| 97 |
birefnet.eval()
|
| 98 |
|
| 99 |
+
try:
|
| 100 |
+
resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
|
| 101 |
+
except:
|
| 102 |
+
resolution = [1024, 1024]
|
| 103 |
+
print('Invalid resolution input. Automatically changed to 1024x1024.')
|
| 104 |
+
|
| 105 |
+
if isinstance(images, list):
|
| 106 |
+
save_dir = 'preds-BiRefNet'
|
| 107 |
+
if not os.path.exists(save_dir):
|
| 108 |
+
os.makedirs(save_dir)
|
| 109 |
+
else:
|
| 110 |
+
# For tab_batch
|
| 111 |
+
save_paths = []
|
| 112 |
+
images = [images]
|
| 113 |
+
|
| 114 |
+
for idx_image, image_src in enumerate(images):
|
| 115 |
+
if isinstance(image_src, str):
|
| 116 |
+
response = requests.get(image_src)
|
| 117 |
+
image_data = BytesIO(response.content)
|
| 118 |
+
image = np.array(Image.open(image_data))
|
| 119 |
+
else:
|
| 120 |
+
image = image_src
|
| 121 |
|
| 122 |
+
image_shape = image.shape[:2]
|
| 123 |
+
image_pil = array_to_pil_image(image, tuple(resolution))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
+
# Preprocess the image
|
| 126 |
+
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
| 127 |
+
image_proc = image_preprocessor.proc(image_pil)
|
| 128 |
+
image_proc = image_proc.unsqueeze(0)
|
| 129 |
+
|
| 130 |
+
# Perform the prediction
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
scaled_pred_tensor = birefnet(image_proc.to(device))[-1].sigmoid()
|
| 133 |
+
|
| 134 |
+
if device == 'cuda':
|
| 135 |
+
scaled_pred_tensor = scaled_pred_tensor.cpu()
|
| 136 |
+
|
| 137 |
+
# Resize the prediction to match the original image shape
|
| 138 |
+
pred = torch.nn.functional.interpolate(scaled_pred_tensor, size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy()
|
| 139 |
+
|
| 140 |
+
# Apply the prediction mask to the original image
|
| 141 |
+
image_pil = image_pil.resize(pred.shape[::-1])
|
| 142 |
+
pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
|
| 143 |
+
image_pred = (pred * np.array(image_pil)).astype(np.uint8)
|
| 144 |
+
|
| 145 |
+
torch.cuda.empty_cache()
|
| 146 |
+
|
| 147 |
+
save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
|
| 148 |
+
cv2.imwrite(save_file_path)
|
| 149 |
+
save_paths.append(save_file_path)
|
| 150 |
|
| 151 |
+
if len(images) > 1:
|
| 152 |
+
zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
|
| 153 |
+
with zipfile.ZipFile(zip_file_path, 'w') as zipf:
|
| 154 |
+
for file in save_paths:
|
| 155 |
+
zipf.write(file, os.path.basename(file))
|
| 156 |
|
| 157 |
return image, image_pred
|
| 158 |
|
|
|
|
| 170 |
for idx_example_url, example_url in enumerate(examples_url):
|
| 171 |
examples_url[idx_example_url].append('1024x1024')
|
| 172 |
|
| 173 |
+
descriptions = ('Upload a picture, our model will extract a highly accurate segmentation of the subject in it.\n)'
|
| 174 |
+
' The resolution used in our training was `1024x1024`, thus the suggested resolution to obtain good results!\n'
|
| 175 |
+
' Our codes can be found at https://github.com/ZhengPeng7/BiRefNet.\n'
|
| 176 |
+
' We also maintain the HF model of BiRefNet at https://huggingface.co/ZhengPeng7/BiRefNet for easier access.')
|
| 177 |
+
|
| 178 |
tab_image = gr.Interface(
|
| 179 |
fn=predict,
|
| 180 |
inputs=[
|
|
|
|
| 185 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
| 186 |
examples=examples,
|
| 187 |
api_name="image",
|
| 188 |
+
description=descriptions,
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
tab_text = gr.Interface(
|
|
|
|
| 198 |
outputs=ImageSlider(label="BiRefNet's prediction", type="pil"),
|
| 199 |
examples=examples_url,
|
| 200 |
api_name="text",
|
| 201 |
+
description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
tab_batch = gr.Interface(
|
| 205 |
+
fn=predict,
|
| 206 |
+
inputs=gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
|
| 207 |
+
outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
|
| 208 |
+
api_name="batch",
|
| 209 |
+
description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
|
| 210 |
)
|
| 211 |
|
| 212 |
demo = gr.TabbedInterface(
|
| 213 |
+
[tab_image, tab_text, tab_batch],
|
| 214 |
+
['image', 'text', 'batch'],
|
| 215 |
title="BiRefNet demo for subject extraction (general / salient / camouflaged / portrait).",
|
| 216 |
)
|
| 217 |
|