leonelhs commited on
Commit
1082d5a
·
verified ·
1 Parent(s): 89392fc

Update app.py

Browse files

adding image slicer

Files changed (1) hide show
  1. app.py +24 -16
app.py CHANGED
@@ -25,11 +25,13 @@
25
  #######################################################################################
26
 
27
  # This file implements an API endpoint for DIS background image removal system.
 
28
  #
29
  # Source code is based on or inspired by several projects.
30
  # For more details and proper attribution, please refer to the following resources:
31
  #
32
  # - [DIS] - [https://github.com/xuebinqin/DIS]
 
33
 
34
  import gradio as gr
35
  import numpy as np
@@ -39,6 +41,7 @@ from PIL import Image
39
  from huggingface_hub import hf_hub_download
40
  from torch.autograd import Variable
41
  from torchvision.transforms.functional import normalize
 
42
 
43
  # project imports
44
  from models.isnet import ISNetDIS
@@ -54,11 +57,13 @@ net.load_state_dict(torch.load(model_path, map_location=device))
54
  net.to(device)
55
  net.eval()
56
 
57
- def im_preprocess(im,size):
58
  if len(im.shape) < 3:
59
  im = im[:, :, np.newaxis]
60
  if im.shape[2] == 1:
61
  im = np.repeat(im, 3, axis=2)
 
 
62
  im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
63
  im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
64
  if len(size)<2:
@@ -80,9 +85,8 @@ def predict(image):
80
  Parameters:
81
  image (string): File path to the input image.
82
  Returns:
83
- paths (tuple): paths for background-removed image and cutting mask.
84
  """
85
-
86
  im_tensor, shapes = im_preprocess(image, [1024, 1024])
87
  shapes = torch.from_numpy(np.array(shapes)).unsqueeze(0)
88
 
@@ -101,30 +105,34 @@ def predict(image):
101
  prediction = (prediction - mi) / (ma - mi) # max = 1
102
 
103
  torch.cuda.empty_cache()
104
- mask = (prediction.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
105
 
 
 
106
  mask = Image.fromarray(mask).convert('L')
107
- image_rgb = Image.fromarray(image).convert("RGB")
108
- image_rgb.putalpha(mask)
109
- return image_rgb, mask
110
 
111
- article = "<div><center>Unofficial demo from:<a href='https://github.com/xuebinqin/DIS'>DIS</<></center></div>"
112
 
113
  with gr.Blocks(title="DIS") as app:
 
114
  gr.Markdown("## Dichotomous Image Segmentation")
115
  with gr.Row():
116
  with gr.Column(scale=1):
117
- inp = gr.Image(type="numpy", label="Upload Image")
118
- btn_predict = gr.Button("Remove background")
119
  with gr.Column(scale=2):
120
  with gr.Row():
121
- with gr.Column(scale=1):
122
- out = gr.Image(type="filepath", label="Output image")
123
- with gr.Accordion("See intermediates", open=False):
124
- out_mask = gr.Image(type="filepath", label="Mask")
 
 
 
 
125
 
126
- btn_predict.click(predict, inputs=inp, outputs=[out, out_mask])
127
- gr.HTML(article)
128
 
129
  app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
130
  app.queue()
 
25
  #######################################################################################
26
 
27
  # This file implements an API endpoint for DIS background image removal system.
28
+ # [Self space] - [https://huggingface.co/spaces/leonelhs/removebg]
29
  #
30
  # Source code is based on or inspired by several projects.
31
  # For more details and proper attribution, please refer to the following resources:
32
  #
33
  # - [DIS] - [https://github.com/xuebinqin/DIS]
34
+ # - [removebg] - [https://huggingface.co/spaces/gaviego/removebg]
35
 
36
  import gradio as gr
37
  import numpy as np
 
41
  from huggingface_hub import hf_hub_download
42
  from torch.autograd import Variable
43
  from torchvision.transforms.functional import normalize
44
+ from itertools import islice
45
 
46
  # project imports
47
  from models.isnet import ISNetDIS
 
57
  net.to(device)
58
  net.eval()
59
 
60
+ def im_preprocess(im, size):
61
  if len(im.shape) < 3:
62
  im = im[:, :, np.newaxis]
63
  if im.shape[2] == 1:
64
  im = np.repeat(im, 3, axis=2)
65
+ if im.shape[2] == 4:
66
+ im = im[:, :, :3]
67
  im_tensor = torch.tensor(im.copy(), dtype=torch.float32)
68
  im_tensor = torch.transpose(torch.transpose(im_tensor,1,2),0,1)
69
  if len(size)<2:
 
85
  Parameters:
86
  image (string): File path to the input image.
87
  Returns:
88
+ image (string): paths for image cutting mask.
89
  """
 
90
  im_tensor, shapes = im_preprocess(image, [1024, 1024])
91
  shapes = torch.from_numpy(np.array(shapes)).unsqueeze(0)
92
 
 
105
  prediction = (prediction - mi) / (ma - mi) # max = 1
106
 
107
  torch.cuda.empty_cache()
108
+ return (prediction.detach().cpu().numpy() * 255).astype(np.uint8) # it is the mask we need
109
 
110
+ def cuts(image):
111
+ mask = predict(image)
112
  mask = Image.fromarray(mask).convert('L')
113
+ cutted = Image.fromarray(image).convert("RGB")
114
+ cutted.putalpha(mask)
115
+ return [image, cutted], mask
116
 
 
117
 
118
  with gr.Blocks(title="DIS") as app:
119
+ navbar = gr.Navbar(visible=True, main_page_name="Workspace")
120
  gr.Markdown("## Dichotomous Image Segmentation")
121
  with gr.Row():
122
  with gr.Column(scale=1):
123
+ inp_image = gr.Image(type="numpy", label="Upload Image")
124
+ btn_predict = gr.Button(variant="primary", value="Remove background")
125
  with gr.Column(scale=2):
126
  with gr.Row():
127
+ preview = gr.ImageSlider(type="filepath", label="Comparer")
128
+
129
+ btn_predict.click(cuts, inputs=[inp_image], outputs=[preview, inp_image])
130
+
131
+ with app.route("Readme", "/readme"):
132
+ with open("README.md") as f:
133
+ for line in islice(f, 12, None):
134
+ gr.Markdown(line.strip())
135
 
 
 
136
 
137
  app.launch(share=False, debug=True, show_error=True, mcp_server=True, pwa=True)
138
  app.queue()