NewbloomAI commited on
Commit
736404c
·
verified ·
1 Parent(s): 962aa2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -48
app.py CHANGED
@@ -5,7 +5,6 @@ import torch
5
  import gradio as gr
6
  import spaces
7
 
8
- from glob import glob
9
  from typing import Tuple
10
 
11
  from PIL import Image
@@ -15,7 +14,7 @@ import requests
15
  from io import BytesIO
16
  import zipfile
17
 
18
- # Fix the HF space permission error when using from_pretrained(..., trust_remote_code=True)
19
  os.environ["HF_MODULES_CACHE"] = os.path.join("/tmp/hf_cache", "modules")
20
 
21
  import transformers
@@ -26,7 +25,6 @@ torch.jit.script = lambda f: f
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
- ### image_proc.py
30
  def refine_foreground(image, mask, r=90):
31
  if mask.size != image.size:
32
  mask = mask.resize(image.size)
@@ -66,58 +64,26 @@ class ImagePreprocessor():
66
  def proc(self, image: Image.Image) -> torch.Tensor:
67
  return self.transform_image(image)
68
 
69
- usage_to_weights_file = {
70
- 'General': 'BiRefNet',
71
- 'General-HR': 'BiRefNet_HR',
72
- 'Matting-HR': 'BiRefNet_HR-matting',
73
- 'Matting': 'BiRefNet-matting',
74
- 'Portrait': 'BiRefNet-portrait',
75
- 'General-reso_512': 'BiRefNet_512x512',
76
- 'General-Lite': 'BiRefNet_lite',
77
- 'General-Lite-2K': 'BiRefNet_lite-2K',
78
- 'Anime-Lite': 'BiRefNet_lite-Anime',
79
- 'DIS': 'BiRefNet-DIS5K',
80
- 'HRSOD': 'BiRefNet-HRSOD',
81
- 'COD': 'BiRefNet-COD',
82
- 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
83
- 'General-legacy': 'BiRefNet-legacy',
84
- 'General-dynamic': 'BiRefNet_dynamic',
85
- }
86
-
87
  birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(
88
- '/'.join(('zhengpeng7', usage_to_weights_file['General'])),
89
- trust_remote_code=True
90
  )
91
  birefnet.to(device)
92
  birefnet.eval(); birefnet.half()
93
 
94
  @spaces.GPU
95
- def predict(images, resolution, weights_file):
96
  assert images is not None, 'AssertionError: images cannot be None.'
97
 
98
- global birefnet
99
- _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
100
  print('Using weights: {}.'.format(_weights_file))
101
- birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
102
- birefnet.to(device)
103
- birefnet.eval(); birefnet.half()
104
 
105
  try:
106
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
107
  except:
108
- if weights_file in ['General-HR', 'Matting-HR']:
109
- resolution = (2048, 2048)
110
- elif weights_file in ['General-Lite-2K']:
111
- resolution = (2560, 1440)
112
- elif weights_file in ['General-reso_512']:
113
- resolution = (512, 512)
114
- else:
115
- if weights_file in ['General-dynamic']:
116
- resolution = None
117
- print('Using the original size (div by 32) for inference.')
118
- else:
119
- resolution = (1024, 1024)
120
- print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
121
 
122
  if isinstance(images, list):
123
  save_paths = []
@@ -143,8 +109,7 @@ def predict(images, resolution, weights_file):
143
  image = image_ori.convert('RGB')
144
  if resolution is None:
145
  resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
146
- if resolution_div_by_32 != resolution:
147
- resolution = resolution_div_by_32
148
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
149
  image_proc = image_preprocessor.proc(image).unsqueeze(0)
150
 
@@ -182,7 +147,6 @@ tab_image = gr.Interface(
182
  inputs=[
183
  gr.Image(label='Upload an image'),
184
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
185
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
186
  ],
187
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
188
  api_name="image",
@@ -194,21 +158,21 @@ tab_text = gr.Interface(
194
  inputs=[
195
  gr.Textbox(label="Paste an image URL"),
196
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
197
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
198
  ],
199
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
200
  api_name="URL",
201
- )
 
202
 
203
  tab_batch = gr.Interface(
204
  fn=predict,
205
  inputs=[
206
  gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
207
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
208
- gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
209
  ],
210
  outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
211
  api_name="batch",
 
212
  )
213
 
214
  demo = gr.TabbedInterface(
 
5
  import gradio as gr
6
  import spaces
7
 
 
8
  from typing import Tuple
9
 
10
  from PIL import Image
 
14
  from io import BytesIO
15
  import zipfile
16
 
17
+ # Fix the HF space permission error
18
  os.environ["HF_MODULES_CACHE"] = os.path.join("/tmp/hf_cache", "modules")
19
 
20
  import transformers
 
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
 
 
28
  def refine_foreground(image, mask, r=90):
29
  if mask.size != image.size:
30
  mask = mask.resize(image.size)
 
64
  def proc(self, image: Image.Image) -> torch.Tensor:
65
  return self.transform_image(image)
66
 
67
+ # Fixed weights
68
+ weights_file = 'BiRefNet'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(
70
+ '/'.join(('zhengpeng7', weights_file)), trust_remote_code=True
 
71
  )
72
  birefnet.to(device)
73
  birefnet.eval(); birefnet.half()
74
 
75
  @spaces.GPU
76
+ def predict(images, resolution):
77
  assert images is not None, 'AssertionError: images cannot be None.'
78
 
79
+ _weights_file = '/'.join(('zhengpeng7', weights_file))
 
80
  print('Using weights: {}.'.format(_weights_file))
 
 
 
81
 
82
  try:
83
  resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
84
  except:
85
+ resolution = (1024, 1024)
86
+ print('Invalid resolution input. Automatically changed to 1024x1024.')
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  if isinstance(images, list):
89
  save_paths = []
 
109
  image = image_ori.convert('RGB')
110
  if resolution is None:
111
  resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
112
+ resolution = resolution_div_by_32
 
113
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
114
  image_proc = image_preprocessor.proc(image).unsqueeze(0)
115
 
 
147
  inputs=[
148
  gr.Image(label='Upload an image'),
149
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
 
150
  ],
151
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
152
  api_name="image",
 
158
  inputs=[
159
  gr.Textbox(label="Paste an image URL"),
160
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
 
161
  ],
162
  outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
163
  api_name="URL",
164
+ description=descriptions + '\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
165
+ )
166
 
167
  tab_batch = gr.Interface(
168
  fn=predict,
169
  inputs=[
170
  gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
171
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
 
172
  ],
173
  outputs=[gr.Gallery(label="BiRefNet's predictions"), gr.File(label="Download masked images.")],
174
  api_name="batch",
175
+ description=descriptions + '\nTab-batch is partially modified from https://huggingface.co/spaces/NegiTurkey/Multi_Birefnetfor_Background_Removal, thanks to this great work!',
176
  )
177
 
178
  demo = gr.TabbedInterface(