File size: 3,042 Bytes
f2dbf59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from .imagefunc import *

import torch.nn as nn
from torchvision import transforms
from .BiRefNet_legacy.baseline import BiRefNet
from .BiRefNet_legacy.config import Config

class BiRefNet_img_processor:
    def __init__(self, config):
        self.config = config
        self.data_size = (config.size, config.size)
        self.transform_image = transforms.Compose([
            transforms.Resize(self.data_size),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def __call__(self, _image: np.array):
        _image_rs = cv2.resize(_image, (self.config.size, self.config.size), interpolation=cv2.INTER_LINEAR)
        _image_rs = Image.fromarray(np.uint8(_image_rs*255)).convert('RGB')
        image = self.transform_image(_image_rs)
        return image

class BiRefNetRemoveBackground:
    def __init__(self):
        self.ready = False

    def load(self, weight_path, device):
        # load model
        self.model = BiRefNet()
        state_dict = torch.load(weight_path, map_location='cpu')
        unwanted_prefix = '_orig_mod.'
        for k, v in list(state_dict.items()):
            if k.startswith(unwanted_prefix):
                state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
        self.model.load_state_dict(state_dict)
        self.model = self.model.to(device)
        self.model.eval()
        # load processor
        self.processor = BiRefNet_img_processor(Config())
        self.ready = True

  
    def generate_mask(self, image:Image) -> Image:

        if torch.backends.mps.is_available():
            device = "mps"
        elif torch.cuda.is_available():
            device = "cuda"
        else:
            device = "cpu"

        if not self.ready:
            model_folder_name = 'BiRefNet'
            model_name = 'BiRefNet-ep480.pth'
            model_file_path = ""
            try:
                model_file_path = os.path.join(
                    os.path.normpath(folder_paths.folder_names_and_paths[model_folder_name][0][0]), model_name)
            except:
                pass
            if not os.path.exists(model_file_path):
                model_file_path = os.path.join(folder_paths.models_dir, model_folder_name, model_name)
            self.load(model_file_path, device=device)

        i = pil2tensor(image)
        orig_image = image.convert('RGB')
        np_image = i.squeeze().numpy()
        img = self.processor(np_image)
        inputs = img[None, ...].to(device)
        with torch.no_grad():
            scaled_preds = self.model(inputs)[-1].sigmoid()
        _mask = nn.functional.interpolate(scaled_preds[0].unsqueeze(0),
                                          size=np_image.shape[:2],
                                          mode='bilinear',
                                          align_corners=True
                                          )[0]

        brightness_image = ImageEnhance.Brightness(tensor2pil(_mask))

        return brightness_image.enhance(factor=1.01)