NewbloomAI commited on
Commit
017cd38
·
verified ·
1 Parent(s): 3855ec6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -47
app.py CHANGED
@@ -21,7 +21,6 @@ os.environ["HF_MODULES_CACHE"] = os.path.join("/tmp/hf_cache", "modules")
21
  import transformers
22
  transformers.utils.move_cache()
23
 
24
-
25
  torch.set_float32_matmul_precision('high')
26
  torch.jit.script = lambda f: f
27
 
@@ -37,15 +36,11 @@ def refine_foreground(image, mask, r=90):
37
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
38
  return image_masked
39
 
40
-
41
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
42
- # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
43
  alpha = alpha[:, :, None]
44
- F, blur_B = FB_blur_fusion_foreground_estimator(
45
- image, image, image, alpha, r)
46
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
47
 
48
-
49
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
50
  if isinstance(image, Image.Image):
51
  image = np.array(image) / 255.0
@@ -56,15 +51,12 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
56
 
57
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
58
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
59
- F = blurred_F + alpha * \
60
- (image - alpha * blurred_F - (1 - alpha) * blurred_B)
61
  F = np.clip(F, 0, 1)
62
  return F, blurred_B
63
 
64
-
65
  class ImagePreprocessor():
66
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
67
- # Input resolution is on WxH.
68
  self.transform_image = transforms.Compose([
69
  transforms.Resize(resolution[::-1]),
70
  transforms.ToTensor(),
@@ -72,9 +64,7 @@ class ImagePreprocessor():
72
  ])
73
 
74
  def proc(self, image: Image.Image) -> torch.Tensor:
75
- image = self.transform_image(image)
76
- return image
77
-
78
 
79
  usage_to_weights_file = {
80
  'General': 'BiRefNet',
@@ -94,17 +84,18 @@ usage_to_weights_file = {
94
  'General-dynamic': 'BiRefNet_dynamic',
95
  }
96
 
97
- birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
 
 
 
98
  birefnet.to(device)
99
  birefnet.eval(); birefnet.half()
100
 
101
-
102
  @spaces.GPU
103
  def predict(images, resolution, weights_file):
104
- assert (images is not None), 'AssertionError: images cannot be None.'
105
 
106
  global birefnet
107
- # Load BiRefNet with chosen weights
108
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
109
  print('Using weights: {}.'.format(_weights_file))
110
  birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
@@ -129,7 +120,6 @@ def predict(images, resolution, weights_file):
129
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
130
 
131
  if isinstance(images, list):
132
- # For tab_batch
133
  save_paths = []
134
  save_dir = 'preds-BiRefNet'
135
  if not os.path.exists(save_dir):
@@ -151,21 +141,17 @@ def predict(images, resolution, weights_file):
151
  image_ori = Image.fromarray(image_src)
152
 
153
  image = image_ori.convert('RGB')
154
- # Preprocess the image
155
  if resolution is None:
156
  resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
157
  if resolution_div_by_32 != resolution:
158
  resolution = resolution_div_by_32
159
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
160
- image_proc = image_preprocessor.proc(image)
161
- image_proc = image_proc.unsqueeze(0)
162
 
163
- # Prediction
164
  with torch.no_grad():
165
  preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
166
  pred = preds[0].squeeze()
167
 
168
- # Show Results
169
  pred_pil = transforms.ToPILImage()(pred)
170
  image_masked = refine_foreground(image, pred_pil)
171
  image_masked.putalpha(pred_pil.resize(image.size))
@@ -184,32 +170,13 @@ def predict(images, resolution, weights_file):
184
  zipf.write(file, os.path.basename(file))
185
  return save_paths, zip_file_path
186
  else:
187
- return (image_masked, image_ori)
188
-
189
-
190
- examples = [[_] for _ in glob('examples/*')][:]
191
- # Add the option of resolution in a text box.
192
- for idx_example, example in enumerate(examples):
193
- if 'My_' in example[0]:
194
- example_resolution = '2048x2048'
195
- else:
196
- example_resolution = '1024x1024'
197
- examples[idx_example].append(example_resolution)
198
- examples.append(examples[-1].copy())
199
- examples[-1][1] = '512x512'
200
-
201
- examples_url = [
202
- ['https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg'],
203
- ]
204
- for idx_example_url, example_url in enumerate(examples_url):
205
- examples_url[idx_example_url].append('1024x1024')
206
 
207
  descriptions = (
208
  "Upload a picture, and we'll remove the background!\n"
209
  "The resolution used is `1024x1024`\n"
210
  )
211
 
212
-
213
  tab_image = gr.Interface(
214
  fn=predict,
215
  inputs=[
@@ -218,7 +185,6 @@ tab_image = gr.Interface(
218
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
219
  ],
220
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
221
- examples=examples,
222
  api_name="image",
223
  description=descriptions,
224
  )
@@ -231,9 +197,8 @@ tab_text = gr.Interface(
231
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
232
  ],
233
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
234
- examples=examples_url,
235
  api_name="URL",
236
- description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
237
  )
238
 
239
  tab_batch = gr.Interface(
@@ -245,7 +210,7 @@ tab_batch = gr.Interface(
245
  ],
246
  outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
247
  api_name="batch",
248
- description=descriptions+'\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
249
  )
250
 
251
  demo = gr.TabbedInterface(
 
21
  import transformers
22
  transformers.utils.move_cache()
23
 
 
24
  torch.set_float32_matmul_precision('high')
25
  torch.jit.script = lambda f: f
26
 
 
36
  image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
37
  return image_masked
38
 
 
39
  def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
 
40
  alpha = alpha[:, :, None]
41
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
 
42
  return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
43
 
 
44
  def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
45
  if isinstance(image, Image.Image):
46
  image = np.array(image) / 255.0
 
51
 
52
  blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
53
  blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
54
+ F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
 
55
  F = np.clip(F, 0, 1)
56
  return F, blurred_B
57
 
 
58
  class ImagePreprocessor():
59
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
 
60
  self.transform_image = transforms.Compose([
61
  transforms.Resize(resolution[::-1]),
62
  transforms.ToTensor(),
 
64
  ])
65
 
66
  def proc(self, image: Image.Image) -> torch.Tensor:
67
+ return self.transform_image(image)
 
 
68
 
69
  usage_to_weights_file = {
70
  'General': 'BiRefNet',
 
84
  'General-dynamic': 'BiRefNet_dynamic',
85
  }
86
 
87
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(
88
+ '/'.join(('zhengpeng7', usage_to_weights_file['General'])),
89
+ trust_remote_code=True
90
+ )
91
  birefnet.to(device)
92
  birefnet.eval(); birefnet.half()
93
 
 
94
  @spaces.GPU
95
  def predict(images, resolution, weights_file):
96
+ assert images is not None, 'AssertionError: images cannot be None.'
97
 
98
  global birefnet
 
99
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
100
  print('Using weights: {}.'.format(_weights_file))
101
  birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
 
120
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
121
 
122
  if isinstance(images, list):
 
123
  save_paths = []
124
  save_dir = 'preds-BiRefNet'
125
  if not os.path.exists(save_dir):
 
141
  image_ori = Image.fromarray(image_src)
142
 
143
  image = image_ori.convert('RGB')
 
144
  if resolution is None:
145
  resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
146
  if resolution_div_by_32 != resolution:
147
  resolution = resolution_div_by_32
148
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
149
+ image_proc = image_preprocessor.proc(image).unsqueeze(0)
 
150
 
 
151
  with torch.no_grad():
152
  preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
153
  pred = preds[0].squeeze()
154
 
 
155
  pred_pil = transforms.ToPILImage()(pred)
156
  image_masked = refine_foreground(image, pred_pil)
157
  image_masked.putalpha(pred_pil.resize(image.size))
 
170
  zipf.write(file, os.path.basename(file))
171
  return save_paths, zip_file_path
172
  else:
173
+ return image_masked, image_ori
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
  descriptions = (
176
  "Upload a picture, and we'll remove the background!\n"
177
  "The resolution used is `1024x1024`\n"
178
  )
179
 
 
180
  tab_image = gr.Interface(
181
  fn=predict,
182
  inputs=[
 
185
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
186
  ],
187
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
 
188
  api_name="image",
189
  description=descriptions,
190
  )
 
197
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
198
  ],
199
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
 
200
  api_name="URL",
201
+ description=descriptions + '\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
202
  )
203
 
204
  tab_batch = gr.Interface(
 
210
  ],
211
  outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
212
  api_name="batch",
213
+ description=descriptions + '\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
214
  )
215
 
216
  demo = gr.TabbedInterface(