Update app.py
Browse filesadding image slicer
app.py
CHANGED
@@ -25,11 +25,13 @@
|
|
25 |
#######################################################################################
|
26 |
|
27 |
# This file implements an API endpoint for DIS background image removal system.
|
|
|
28 |
#
|
29 |
# Source code is based on or inspired by several projects.
|
30 |
# For more details and proper attribution, please refer to the following resources:
|
31 |
#
|
32 |
# - [DIS] - [https://github.com/xuebinqin/DIS]
|
|
|
33 |
|
34 |
import gradio as gr
|
35 |
import numpy as np
|
@@ -39,6 +41,7 @@ from PIL import Image
|
|
39 |
from huggingface_hub import hf_hub_download
|
40 |
from torch.autograd import Variable
|
41 |
from torchvision.transforms.functional import normalize
|
|
|
42 |
|
43 |
# project imports
|
44 |
from models.isnet import ISNetDIS
|
@@ -54,11 +57,13 @@ net.load_state_dict(torch.load(model_path, map_location=device))
|
|
54 |
net.to(device)
|
55 |
net.eval()
|
56 |
|
57 |
-
def im_preprocess(im,size):
|
58 |
if len(im.shape) < 3:
|
59 |
im = im[:, :, np.newaxis]
|
60 |
if im.shape[2] == 1:
|
61 |
im = np.repeat(im, 3, axis=2)
|
|
|
|
|
62 |
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
63 |
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
64 |
if len(size)<2:
|
@@ -80,9 +85,8 @@ def predict(image):
|
|
80 |
Parameters:
|
81 |
image (string): File path to the input image.
|
82 |
Returns:
|
83 |
-
|
84 |
"""
|
85 |
-
|
86 |
im_tensor, shapes = im_preprocess(image, [1024, 1024])
|
87 |
shapes = torch.from_numpy(np.array(shapes)).unsqueeze(0)
|
88 |
|
@@ -101,30 +105,34 @@ def predict(image):
|
|
101 |
prediction = (prediction - mi) / (ma - mi) # max = 1
|
102 |
|
103 |
torch.cuda.empty_cache()
|
104 |
-
|
105 |
|
|
|
|
|
106 |
mask = Image.fromarray(mask).convert('L')
|
107 |
-
|
108 |
-
|
109 |
-
return
|
110 |
|
111 |
-
article = "<div><center>Unofficial demo from:<a href='https://github.com/xuebinqin/DIS'>DIS</<></center></div>"
|
112 |
|
113 |
with gr.Blocks(title="DIS") as app:
|
|
|
114 |
gr.Markdown("## Dichotomous Image Segmentation")
|
115 |
with gr.Row():
|
116 |
with gr.Column(scale=1):
|
117 |
-
|
118 |
-
btn_predict = gr.Button("Remove background")
|
119 |
with gr.Column(scale=2):
|
120 |
with gr.Row():
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
|
|
125 |
|
126 |
-
btn_predict.click(predict, inputs=inp, outputs=[out, out_mask])
|
127 |
-
gr.HTML(article)
|
128 |
|
129 |
app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
|
130 |
app.queue()
|
|
|
25 |
#######################################################################################
|
26 |
|
27 |
# This file implements an API endpoint for DIS background image removal system.
|
28 |
+
# [Self space] - [https://huggingface.co/spaces/leonelhs/removebg]
|
29 |
#
|
30 |
# Source code is based on or inspired by several projects.
|
31 |
# For more details and proper attribution, please refer to the following resources:
|
32 |
#
|
33 |
# - [DIS] - [https://github.com/xuebinqin/DIS]
|
34 |
+
# - [removebg] - [https://huggingface.co/spaces/gaviego/removebg]
|
35 |
|
36 |
import gradio as gr
|
37 |
import numpy as np
|
|
|
41 |
from huggingface_hub import hf_hub_download
|
42 |
from torch.autograd import Variable
|
43 |
from torchvision.transforms.functional import normalize
|
44 |
+
from itertools import islice
|
45 |
|
46 |
# project imports
|
47 |
from models.isnet import ISNetDIS
|
|
|
57 |
net.to(device)
|
58 |
net.eval()
|
59 |
|
60 |
+
def im_preprocess(im, size):
|
61 |
if len(im.shape) < 3:
|
62 |
im = im[:, :, np.newaxis]
|
63 |
if im.shape[2] == 1:
|
64 |
im = np.repeat(im, 3, axis=2)
|
65 |
+
if im.shape[2] == 4:
|
66 |
+
im = im[:, :, :3]
|
67 |
im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
|
68 |
im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
|
69 |
if len(size)<2:
|
|
|
85 |
Parameters:
|
86 |
image (string): File path to the input image.
|
87 |
Returns:
|
88 |
+
image (string): paths for image cutting mask.
|
89 |
"""
|
|
|
90 |
im_tensor, shapes = im_preprocess(image, [1024, 1024])
|
91 |
shapes = torch.from_numpy(np.array(shapes)).unsqueeze(0)
|
92 |
|
|
|
105 |
prediction = (prediction - mi) / (ma - mi) # max = 1
|
106 |
|
107 |
torch.cuda.empty_cache()
|
108 |
+
return (prediction.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
|
109 |
|
110 |
+
def cuts(image):
|
111 |
+
mask = predict(image)
|
112 |
mask = Image.fromarray(mask).convert('L')
|
113 |
+
cutted = Image.fromarray(image).convert("RGB")
|
114 |
+
cutted.putalpha(mask)
|
115 |
+
return [image, cutted], mask
|
116 |
|
|
|
117 |
|
118 |
with gr.Blocks(title="DIS") as app:
|
119 |
+
navbar = gr.Navbar(visible=True, main_page_name="Workspace")
|
120 |
gr.Markdown("## Dichotomous Image Segmentation")
|
121 |
with gr.Row():
|
122 |
with gr.Column(scale=1):
|
123 |
+
inp_image = gr.Image(type="numpy", label="Upload Image")
|
124 |
+
btn_predict = gr.Button(variant="primary", value="Remove background")
|
125 |
with gr.Column(scale=2):
|
126 |
with gr.Row():
|
127 |
+
preview = gr.ImageSlider(type="filepath", label="Comparer")
|
128 |
+
|
129 |
+
btn_predict.click(cuts, inputs=[inp_image], outputs=[preview, inp_image])
|
130 |
+
|
131 |
+
with app.route("Readme", "/readme"):
|
132 |
+
with open("README.md") as f:
|
133 |
+
for line in islice(f, 12, None):
|
134 |
+
gr.Markdown(line.strip())
|
135 |
|
|
|
|
|
136 |
|
137 |
app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
|
138 |
app.queue()
|