File size: 6,198 Bytes
81b1a0e
 
 
 
 
e797135
6284dc0
4a3bbdd
6284dc0
 
 
e797135
54516d1
 
5d10050
54516d1
736404c
f41a09e
217b29d
34b5702
 
 
6be00d8
e797135
81b1a0e
53ff575
81b1a0e
621c740
 
 
 
 
 
 
 
 
 
 
017cd38
621c740
 
 
 
 
 
 
 
 
 
 
 
017cd38
621c740
 
 
81b1a0e
6284dc0
81b1a0e
a0f35d6
81b1a0e
 
 
 
6284dc0
017cd38
81b1a0e
736404c
 
017cd38
736404c
017cd38
d967d62
33f3505
d967d62
e797135
736404c
017cd38
741bf59
736404c
4bb8a82
b59df1c
741bf59
 
 
736404c
 
741bf59
 
5023a18
741bf59
 
 
5023a18
741bf59
 
5023a18
741bf59
 
 
85f9120
53ff575
85f9120
 
 
53ff575
741bf59
53ff575
5023a18
53ff575
a0f35d6
 
736404c
741bf59
017cd38
5023a18
741bf59
33f3505
eefba1b
5023a18
621c740
53ff575
 
5023a18
741bf59
5023a18
 
6f7ea50
621c740
6f7ea50
84abebf
5023a18
741bf59
 
 
 
8a9ec25
 
017cd38
4c18769
3855ec6
 
 
 
 
1acca69
81b1a0e
b59df1c
1acca69
1592dab
b59df1c
8d130d7
9f09c5a
741bf59
81b1a0e
1acca69
 
 
9430ab7
 
1592dab
2ef1d69
8d130d7
9e4d313
736404c
741bf59
 
 
1592dab
 
 
 
8d130d7
741bf59
8d130d7
1acca69
 
741bf59
9e4d313
0b98784
1acca69
 
 
c389a57
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import os
import cv2
import numpy as np
import torch
import gradio as gr
import spaces

from typing import Tuple

from PIL import Image
from torchvision import transforms

import requests
from io import BytesIO
import zipfile

# Fix the HF space permission error
os.environ["HF_MODULES_CACHE"] = os.path.join("/tmp/hf_cache", "modules")

import transformers
transformers.utils.move_cache()

torch.set_float32_matmul_precision('high')
torch.jit.script = lambda f: f

device = "cuda" if torch.cuda.is_available() else "cpu"

def refine_foreground(image, mask, r=90):
    if mask.size != image.size:
        mask = mask.resize(image.size)
    image = np.array(image) / 255.0
    mask = np.array(mask) / 255.0
    estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
    image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
    return image_masked

def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
    alpha = alpha[:, :, None]
    F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
    return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]

def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
    if isinstance(image, Image.Image):
        image = np.array(image) / 255.0
    blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]

    blurred_FA = cv2.blur(F * alpha, (r, r))
    blurred_F = blurred_FA / (blurred_alpha + 1e-5)

    blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
    blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
    F = blurred_F + alpha * (image - alpha * blurred_F - (1 - alpha) * blurred_B)
    F = np.clip(F, 0, 1)
    return F, blurred_B

class ImagePreprocessor():
    def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
        self.transform_image = transforms.Compose([
            transforms.Resize(resolution[::-1]),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def proc(self, image: Image.Image) -> torch.Tensor:
        return self.transform_image(image)

# Fixed weights
weights_file = 'BiRefNet'
birefnet = transformers.AutoModelForImageSegmentation.from_pretrained(
    '/'.join(('zhengpeng7', weights_file)), trust_remote_code=True
)
birefnet.to(device)
birefnet.eval(); birefnet.half()

@spaces.GPU
def predict(images, resolution):
    assert images is not None, 'AssertionError: images cannot be None.'

    _weights_file = '/'.join(('zhengpeng7', weights_file))
    print('Using weights: {}.'.format(_weights_file))

    try:
        resolution = [int(int(reso)//32*32) for reso in resolution.strip().split('x')]
    except:
        resolution = (1024, 1024)
        print('Invalid resolution input. Automatically changed to 1024x1024.')

    if isinstance(images, list):
        save_paths = []
        save_dir = 'preds-BiRefNet'
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        tab_is_batch = True
    else:
        images = [images]
        tab_is_batch = False

    for idx_image, image_src in enumerate(images):
        if isinstance(image_src, str):
            if os.path.isfile(image_src):
                image_ori = Image.open(image_src)
            else:
                response = requests.get(image_src)
                image_data = BytesIO(response.content)
                image_ori = Image.open(image_data)
        else:
            image_ori = Image.fromarray(image_src)

        image = image_ori.convert('RGB')
        if resolution is None:
            resolution_div_by_32 = [int(int(reso)//32*32) for reso in image.size]
            resolution = resolution_div_by_32
        image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
        image_proc = image_preprocessor.proc(image).unsqueeze(0)

        with torch.no_grad():
            preds = birefnet(image_proc.to(device).half())[-1].sigmoid().cpu()
        pred = preds[0].squeeze()

        pred_pil = transforms.ToPILImage()(pred)
        image_masked = refine_foreground(image, pred_pil)
        image_masked.putalpha(pred_pil.resize(image.size))

        torch.cuda.empty_cache()

        if tab_is_batch:
            save_file_path = os.path.join(save_dir, "{}.png".format(os.path.splitext(os.path.basename(image_src))[0]))
            image_masked.save(save_file_path)
            save_paths.append(save_file_path)

    if tab_is_batch:
        zip_file_path = os.path.join(save_dir, "{}.zip".format(save_dir))
        with zipfile.ZipFile(zip_file_path, 'w') as zipf:
            for file in save_paths:
                zipf.write(file, os.path.basename(file))
        return save_paths, zip_file_path
    else:
        return image_masked, image_ori

descriptions = (
    "Upload a picture, and we'll remove the background!\n"
    "The resolution used is `1024x1024`\n"
)

tab_image = gr.Interface(
    fn=predict,
    inputs=[
        gr.Image(label='Upload an image'),
        gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
    ],
    outputs=gr.ImageSlider(label="Lot Lingo's prediction", type="pil", format='png'),
    api_name="image",
    description=descriptions,
)

tab_text = gr.Interface(
    fn=predict,
    inputs=[
        gr.Textbox(label="Paste an image URL"),
        gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
    ],
    outputs=gr.ImageSlider(label="Lot Lingo's prediction", type="pil", format='png'),
    api_name="URL",
)

tab_batch = gr.Interface(
    fn=predict,
    inputs=[
        gr.File(label="Upload multiple images", type="filepath", file_count="multiple"),
        gr.Textbox(lines=1, placeholder="Type the resolution (`WxH`) you want, e.g., `1024x1024`.", label="Resolution"),
    ],
    outputs=[gr.Gallery(label="Lot Lingo's predictions"), gr.File(label="Download masked images.")],
    api_name="batch",
    )

demo = gr.TabbedInterface(
    [tab_image, tab_text, tab_batch],
    ['image', 'URL', 'batch'],
    title="Lot Lingo Background Removal Demo",
)

if __name__ == "__main__":
    demo.launch(debug=True)