from .imagefunc import * import torch.nn as nn from torchvision import transforms from .BiRefNet_legacy.baseline import BiRefNet from .BiRefNet_legacy.config import Config class BiRefNet_img_processor: def __init__(self, config): self.config = config self.data_size = (config.size, config.size) self.transform_image = transforms.Compose([ transforms.Resize(self.data_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ]) def __call__(self, _image: np.array): _image_rs = cv2.resize(_image, (self.config.size, self.config.size), interpolation=cv2.INTER_LINEAR) _image_rs = Image.fromarray(np.uint8(_image_rs*255)).convert('RGB') image = self.transform_image(_image_rs) return image class BiRefNetRemoveBackground: def __init__(self): self.ready = False def load(self, weight_path, device): # load model self.model = BiRefNet() state_dict = torch.load(weight_path, map_location='cpu') unwanted_prefix = '_orig_mod.' for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) self.model.load_state_dict(state_dict) self.model = self.model.to(device) self.model.eval() # load processor self.processor = BiRefNet_img_processor(Config()) self.ready = True def generate_mask(self, image:Image) -> Image: if torch.backends.mps.is_available(): device = "mps" elif torch.cuda.is_available(): device = "cuda" else: device = "cpu" if not self.ready: model_folder_name = 'BiRefNet' model_name = 'BiRefNet-ep480.pth' model_file_path = "" try: model_file_path = os.path.join( os.path.normpath(folder_paths.folder_names_and_paths[model_folder_name][0][0]), model_name) except: pass if not os.path.exists(model_file_path): model_file_path = os.path.join(folder_paths.models_dir, model_folder_name, model_name) self.load(model_file_path, device=device) i = pil2tensor(image) orig_image = image.convert('RGB') np_image = i.squeeze().numpy() img = self.processor(np_image) inputs = img[None, ...].to(device) with torch.no_grad(): scaled_preds = self.model(inputs)[-1].sigmoid() _mask = nn.functional.interpolate(scaled_preds[0].unsqueeze(0), size=np_image.shape[:2], mode='bilinear', align_corners=True )[0] brightness_image = ImageEnhance.Brightness(tensor2pil(_mask)) return brightness_image.enhance(factor=1.01)