Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
e391bd4
1
Parent(s):
c389a57
Update and sync the codes for running locally.
Browse files- app_local.py +22 -11
app_local.py
CHANGED
@@ -10,7 +10,7 @@ from typing import Tuple
|
|
10 |
|
11 |
from PIL import Image
|
12 |
# from gradio_imageslider import ImageSlider
|
13 |
-
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
@@ -60,8 +60,9 @@ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
|
|
60 |
|
61 |
class ImagePreprocessor():
|
62 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
|
|
63 |
self.transform_image = transforms.Compose([
|
64 |
-
transforms.Resize(resolution),
|
65 |
transforms.ToTensor(),
|
66 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
67 |
])
|
@@ -77,17 +78,19 @@ usage_to_weights_file = {
|
|
77 |
'Matting-HR': 'BiRefNet_HR-matting',
|
78 |
'Matting': 'BiRefNet-matting',
|
79 |
'Portrait': 'BiRefNet-portrait',
|
80 |
-
'General-reso_512': '
|
81 |
'General-Lite': 'BiRefNet_lite',
|
82 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
|
|
83 |
'DIS': 'BiRefNet-DIS5K',
|
84 |
'HRSOD': 'BiRefNet-HRSOD',
|
85 |
'COD': 'BiRefNet-COD',
|
86 |
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
|
87 |
-
'General-legacy': 'BiRefNet-legacy'
|
|
|
88 |
}
|
89 |
|
90 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
91 |
birefnet.to(device)
|
92 |
birefnet.eval(); birefnet.half()
|
93 |
|
@@ -100,7 +103,7 @@ def predict(images, resolution, weights_file):
|
|
100 |
# Load BiRefNet with chosen weights
|
101 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
102 |
print('Using weights: {}.'.format(_weights_file))
|
103 |
-
birefnet = AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
104 |
birefnet.to(device)
|
105 |
birefnet.eval(); birefnet.half()
|
106 |
|
@@ -114,7 +117,11 @@ def predict(images, resolution, weights_file):
|
|
114 |
elif weights_file in ['General-reso_512']:
|
115 |
resolution = (512, 512)
|
116 |
else:
|
117 |
-
|
|
|
|
|
|
|
|
|
118 |
print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
|
119 |
|
120 |
if isinstance(images, list):
|
@@ -141,6 +148,10 @@ def predict(images, resolution, weights_file):
|
|
141 |
|
142 |
image = image_ori.convert('RGB')
|
143 |
# Preprocess the image
|
|
|
|
|
|
|
|
|
144 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
145 |
image_proc = image_preprocessor.proc(image)
|
146 |
image_proc = image_proc.unsqueeze(0)
|
@@ -169,7 +180,7 @@ def predict(images, resolution, weights_file):
|
|
169 |
zipf.write(file, os.path.basename(file))
|
170 |
return save_paths, zip_file_path
|
171 |
else:
|
172 |
-
return (image_masked, image_ori)
|
173 |
|
174 |
|
175 |
examples = [[_] for _ in glob('examples/*')][:]
|
@@ -201,7 +212,7 @@ tab_image = gr.Interface(
|
|
201 |
gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
|
202 |
gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
|
203 |
],
|
204 |
-
outputs=gr.
|
205 |
examples=examples,
|
206 |
api_name="image",
|
207 |
description=descriptions,
|
@@ -214,7 +225,7 @@ tab_text = gr.Interface(
|
|
214 |
gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
|
215 |
gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
|
216 |
],
|
217 |
-
outputs=gr.
|
218 |
examples=examples_url,
|
219 |
api_name="URL",
|
220 |
description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
|
@@ -235,7 +246,7 @@ tab_batch = gr.Interface(
|
|
235 |
demo = gr.TabbedInterface(
|
236 |
[tab_image, tab_text, tab_batch],
|
237 |
['image', 'URL', 'batch'],
|
238 |
-
title="
|
239 |
)
|
240 |
|
241 |
if __name__ == "__main__":
|
|
|
10 |
|
11 |
from PIL import Image
|
12 |
# from gradio_imageslider import ImageSlider
|
13 |
+
import transformers
|
14 |
from torchvision import transforms
|
15 |
|
16 |
import requests
|
|
|
60 |
|
61 |
class ImagePreprocessor():
|
62 |
def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
|
63 |
+
# Input resolution is on WxH.
|
64 |
self.transform_image = transforms.Compose([
|
65 |
+
transforms.Resize(resolution[::-1]),
|
66 |
transforms.ToTensor(),
|
67 |
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
|
68 |
])
|
|
|
78 |
'Matting-HR': 'BiRefNet_HR-matting',
|
79 |
'Matting': 'BiRefNet-matting',
|
80 |
'Portrait': 'BiRefNet-portrait',
|
81 |
+
'General-reso_512': 'BiRefNet_512x512',
|
82 |
'General-Lite': 'BiRefNet_lite',
|
83 |
'General-Lite-2K': 'BiRefNet_lite-2K',
|
84 |
+
'Anime-Lite': 'BiRefNet_lite-Anime',
|
85 |
'DIS': 'BiRefNet-DIS5K',
|
86 |
'HRSOD': 'BiRefNet-HRSOD',
|
87 |
'COD': 'BiRefNet-COD',
|
88 |
'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
|
89 |
+
'General-legacy': 'BiRefNet-legacy',
|
90 |
+
'General-dynamic': 'BiRefNet_dynamic',
|
91 |
}
|
92 |
|
93 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained('/'.join(('zhengpeng7', usage_to_weights_file['General'])), trust_remote_code=True)
|
94 |
birefnet.to(device)
|
95 |
birefnet.eval(); birefnet.half()
|
96 |
|
|
|
103 |
# Load BiRefNet with chosen weights
|
104 |
_weights_file = '/'.join(('zhengpeng7', usage_to_weights_file[weights_file] if weights_file is not None else usage_to_weights_file['General']))
|
105 |
print('Using weights: {}.'.format(_weights_file))
|
106 |
+
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(_weights_file, trust_remote_code=True)
|
107 |
birefnet.to(device)
|
108 |
birefnet.eval(); birefnet.half()
|
109 |
|
|
|
117 |
elif weights_file in ['General-reso_512']:
|
118 |
resolution = (512, 512)
|
119 |
else:
|
120 |
+
if weights_file in ['General-dynamic']:
|
121 |
+
resolution = None
|
122 |
+
print('Using the original size (div by 32) for inference.')
|
123 |
+
else:
|
124 |
+
resolution = (1024, 1024)
|
125 |
print('Invalid resolution input. Automatically changed to 1024x1024 / 2048x2048 / 2560x1440.')
|
126 |
|
127 |
if isinstance(images, list):
|
|
|
148 |
|
149 |
image = image_ori.convert('RGB')
|
150 |
# Preprocess the image
|
151 |
+
if resolution is None:
|
152 |
+
resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
|
153 |
+
if resolution_div_by_32 != resolution:
|
154 |
+
resolution = resolution_div_by_32
|
155 |
image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
|
156 |
image_proc = image_preprocessor.proc(image)
|
157 |
image_proc = image_proc.unsqueeze(0)
|
|
|
180 |
zipf.write(file, os.path.basename(file))
|
181 |
return save_paths, zip_file_path
|
182 |
else:
|
183 |
+
return (image_masked, image_ori)
|
184 |
|
185 |
|
186 |
examples = [[_] for _ in glob('examples/*')][:]
|
|
|
212 |
gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
|
213 |
gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
|
214 |
],
|
215 |
+
outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
|
216 |
examples=examples,
|
217 |
api_name="image",
|
218 |
description=descriptions,
|
|
|
225 |
gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
|
226 |
gr.Radio(list(usage_to_weights_file.keys()), value='General', label="Weights", info="Choose the weights you want.")
|
227 |
],
|
228 |
+
outputs=gr.ImageSlider(label="BiRefNet's prediction", type="pil", format='png'),
|
229 |
examples=examples_url,
|
230 |
api_name="URL",
|
231 |
description=descriptions+'\nTab-URL is partially modified from https://huggingface.co/spaces/not-lain/background-removal, thanks to this great work!',
|
|
|
246 |
demo = gr.TabbedInterface(
|
247 |
[tab_image, tab_text, tab_batch],
|
248 |
['image', 'URL', 'batch'],
|
249 |
+
title="Official Online Demo of BiRefNet",
|
250 |
)
|
251 |
|
252 |
if __name__ == "__main__":
|