ZhengPeng7 commited on
Commit
e391bd4
·
1 Parent(s): c389a57

Update and sync the codes for running locally.

Browse files
Files changed (1) hide show
  1. app_local.py +22 -11
app_local.py CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
10
 
11
  from PIL import Image
12
  # from gradio_imageslider import ImageSlider
13
- from transformers import AutoModelForImageSegmentation
14
  from torchvision import transforms
15
 
16
  import requests
@@ -60,8 +60,9 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
 
63
  self.transform_image = transforms.Compose([
64
- transforms.Resize(resolution),
65
  transforms.ToTensor(),
66
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
67
  ])
@@ -77,17 +78,19 @@ usage_to_weights_file = {
77
  'Matting-HR': 'BiRefNet_HR-matting',
78
  'Matting': 'BiRefNet-matting',
79
  'Portrait': 'BiRefNet-portrait',
80
- 'General-reso_512': 'BiRefNet-reso_512',
81
  'General-Lite': 'BiRefNet_lite',
82
  'General-Lite-2K': 'BiRefNet_lite-2K',
 
83
  'DIS': 'BiRefNet-DIS5K',
84
  'HRSOD': 'BiRefNet-HRSOD',
85
  'COD': 'BiRefNet-COD',
86
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
87
- 'General-legacy': 'BiRefNet-legacy'
 
88
  }
89
 
90
- birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
91
  birefnet.to(device)
92
  birefnet.eval(); birefnet.half()
93
 
@@ -100,7 +103,7 @@ def predict(images, resolution, weights_file):
100
  # Load BiRefNet with chosen weights
101
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
102
  print('Using weights: {}.'.format(_weights_file))
103
- birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
104
  birefnet.to(device)
105
  birefnet.eval(); birefnet.half()
106
 
@@ -114,7 +117,11 @@ def predict(images, resolution, weights_file):
114
  elif weights_file in ['General-reso_512']:
115
  resolution = (512, 512)
116
  else:
117
- resolution = (1024, 1024)
 
 
 
 
118
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
119
 
120
  if isinstance(images, list):
@@ -141,6 +148,10 @@ def predict(images, resolution, weights_file):
141
 
142
  image = image_ori.convert('RGB')
143
  # Preprocess the image
 
 
 
 
144
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
145
  image_proc = image_preprocessor.proc(image)
146
  image_proc = image_proc.unsqueeze(0)
@@ -169,7 +180,7 @@ def predict(images, resolution, weights_file):
169
  zipf.write(file, os.path.basename(file))
170
  return save_paths, zip_file_path
171
  else:
172
- return (image_masked, image_ori)[0]
173
 
174
 
175
  examples = [[_] for _ in glob('examples/*')][:]
@@ -201,7 +212,7 @@ tab_image = gr.Interface(
201
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
202
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
203
  ],
204
- outputs=gr.Image(label="BiRefNet's prediction", type="pil", format='png'),
205
  examples=examples,
206
  api_name="image",
207
  description=descriptions,
@@ -214,7 +225,7 @@ tab_text = gr.Interface(
214
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
215
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
216
  ],
217
- outputs=gr.Image(label="BiRefNet's prediction", type="pil", format='png'),
218
  examples=examples_url,
219
  api_name="URL",
220
  description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
@@ -235,7 +246,7 @@ tab_batch = gr.Interface(
235
  demo = gr.TabbedInterface(
236
  [tab_image, tab_text, tab_batch],
237
  ['image', 'URL', 'batch'],
238
- title="BiRefNet demo for subject extraction and background removal ([CAAI AIR'24] Bilateral Reference for High-Resolution Dichotomous Image Segmentation).",
239
  )
240
 
241
  if __name__ == "__main__":
 
10
 
11
  from PIL import Image
12
  # from gradio_imageslider import ImageSlider
13
+ import transformers
14
  from torchvision import transforms
15
 
16
  import requests
 
60
 
61
  class ImagePreprocessor():
62
  def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
63
+ # Input resolution is on WxH.
64
  self.transform_image = transforms.Compose([
65
+ transforms.Resize(resolution[::-1]),
66
  transforms.ToTensor(),
67
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
68
  ])
 
78
  'Matting-HR': 'BiRefNet_HR-matting',
79
  'Matting': 'BiRefNet-matting',
80
  'Portrait': 'BiRefNet-portrait',
81
+ 'General-reso_512': 'BiRefNet_512x512',
82
  'General-Lite': 'BiRefNet_lite',
83
  'General-Lite-2K': 'BiRefNet_lite-2K',
84
+ 'Anime-Lite': 'BiRefNet_lite-Anime',
85
  'DIS': 'BiRefNet-DIS5K',
86
  'HRSOD': 'BiRefNet-HRSOD',
87
  'COD': 'BiRefNet-COD',
88
  'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
89
+ 'General-legacy': 'BiRefNet-legacy',
90
+ 'General-dynamic': 'BiRefNet_dynamic',
91
  }
92
 
93
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
94
  birefnet.to(device)
95
  birefnet.eval(); birefnet.half()
96
 
 
103
  # Load BiRefNet with chosen weights
104
  _weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
105
  print('Using weights: {}.'.format(_weights_file))
106
+ birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
107
  birefnet.to(device)
108
  birefnet.eval(); birefnet.half()
109
 
 
117
  elif weights_file in ['General-reso_512']:
118
  resolution = (512, 512)
119
  else:
120
+ if weights_file in ['General-dynamic']:
121
+ resolution = None
122
+ print('Using the original size (div by 32) for inference.')
123
+ else:
124
+ resolution = (1024, 1024)
125
  print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
126
 
127
  if isinstance(images, list):
 
148
 
149
  image = image_ori.convert('RGB')
150
  # Preprocess the image
151
+ if resolution is None:
152
+ resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
153
+ if resolution_div_by_32 != resolution:
154
+ resolution = resolution_div_by_32
155
  image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
156
  image_proc = image_preprocessor.proc(image)
157
  image_proc = image_proc.unsqueeze(0)
 
180
  zipf.write(file, os.path.basename(file))
181
  return save_paths, zip_file_path
182
  else:
183
+ return (image_masked, image_ori)
184
 
185
 
186
  examples = [[_] for _ in glob('examples/*')][:]
 
212
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
213
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
214
  ],
215
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
216
  examples=examples,
217
  api_name="image",
218
  description=descriptions,
 
225
  gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
226
  gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
227
  ],
228
+ outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
229
  examples=examples_url,
230
  api_name="URL",
231
  description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
 
246
  demo = gr.TabbedInterface(
247
  [tab_image, tab_text, tab_batch],
248
  ['image', 'URL', 'batch'],
249
+ title="Official Online Demo of BiRefNet",
250
  )
251
 
252
  if __name__ == "__main__":