Dreamspire's picture
custom_nodes
f2dbf59
from .imagefunc import *
NODE_NAME = 'BiRefNetUltra'
class BiRefNetUltra:
@classmethod
def INPUT_TYPES(cls):
method_list = ['VITMatte', 'VITMatte(local)', 'PyMatting', 'GuidedFilter', ]
device_list = ['cuda','cpu']
return {
"required": {
"image": ("IMAGE",),
"detail_method": (method_list,),
"detail_erode": ("INT", {"default": 6, "min": 1, "max": 255, "step": 1}),
"detail_dilate": ("INT", {"default": 6, "min": 1, "max": 255, "step": 1}),
"black_point": ("FLOAT", {"default": 0.01, "min": 0.01, "max": 0.98, "step": 0.01, "display": "slider"}),
"white_point": ("FLOAT", {"default": 0.99, "min": 0.02, "max": 0.99, "step": 0.01, "display": "slider"}),
"process_detail": ("BOOLEAN", {"default": True}),
"device": (device_list,),
"max_megapixels": ("FLOAT", {"default": 2.0, "min": 1, "max": 999, "step": 0.1}),
},
"optional": {
}
}
RETURN_TYPES = ("IMAGE", "MASK", )
RETURN_NAMES = ("image", "mask", )
FUNCTION = "birefnet_ultra"
CATEGORY = '😺dzNodes/LayerMask'
def birefnet_ultra(self, image, detail_method, detail_erode, detail_dilate,
black_point, white_point, process_detail, device, max_megapixels):
ret_images = []
ret_masks = []
if detail_method == 'VITMatte(local)':
local_files_only = True
else:
local_files_only = False
from .birefnet_legacy import BiRefNetRemoveBackground
birefnetrmbg = BiRefNetRemoveBackground()
for i in image:
i = torch.unsqueeze(i, 0)
orig_image = tensor2pil(i).convert('RGB')
_mask = birefnetrmbg.generate_mask(orig_image)
_mask = image2mask(_mask)
detail_range = detail_erode + detail_dilate
if process_detail:
if detail_method == 'GuidedFilter':
_mask = guided_filter_alpha(i, _mask, detail_range // 6 + 1)
_mask = tensor2pil(histogram_remap(_mask, black_point, white_point))
elif detail_method == 'PyMatting':
_mask = tensor2pil(mask_edge_detail(i, _mask, detail_range // 8 + 1, black_point, white_point))
else:
_trimap = generate_VITMatte_trimap(_mask, detail_erode, detail_dilate)
_mask = generate_VITMatte(orig_image, _trimap, local_files_only=local_files_only, device=device, max_megapixels=max_megapixels)
_mask = tensor2pil(histogram_remap(pil2tensor(_mask), black_point, white_point))
else:
_mask = tensor2pil(_mask)
ret_image = RGB2RGBA(orig_image, _mask.convert('L'))
ret_images.append(pil2tensor(ret_image))
ret_masks.append(image2mask(_mask))
log(f"{NODE_NAME} Processed {len(ret_masks)} image(s).", message_type='finish')
return (torch.cat(ret_images, dim=0), torch.cat(ret_masks, dim=0),)
NODE_CLASS_MAPPINGS = {
"LayerMask: BiRefNetUltra": BiRefNetUltra,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LayerMask: BiRefNetUltra": "LayerMask: BiRefNetUltra",
}