Dreamspire's picture
custom_nodes
f2dbf59
import os
import sys
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import tqdm
from .imagefunc import *
from comfy.utils import ProgressBar
sys.path.append(os.path.join(os.path.dirname(__file__), 'BiRefNet_v2'))
def get_models():
model_path = os.path.join(folder_paths.models_dir, 'BiRefNet', 'pth')
model_ext = [".pth"]
model_dict = get_files(model_path, model_ext)
return model_dict
class LS_LoadBiRefNetModel:
def __init__(self):
self.birefnet = None
self.state_dict = None
@classmethod
def INPUT_TYPES(s):
tmp_list = list(get_models().keys())
model_list = []
if 'BiRefNet-general-epoch_244.pth' in tmp_list:
model_list.append('BiRefNet-general-epoch_244.pth')
tmp_list.remove('BiRefNet-general-epoch_244.pth')
model_list.extend(tmp_list)
return {
"required": {
"model": (model_list,),
},
}
RETURN_TYPES = ("BIREFNET_MODEL",)
RETURN_NAMES = ("birefnet_model",)
FUNCTION = "load_birefnet_model"
CATEGORY = '😺dzNodes/LayerMask'
def load_birefnet_model(self, model):
from .BiRefNet_v2.models.birefnet import BiRefNet
from .BiRefNet_v2.utils import check_state_dict
model_dict = get_models()
self.birefnet = BiRefNet(bb_pretrained=False)
self.state_dict = torch.load(model_dict[model], map_location='cpu', weights_only=True)
self.state_dict = check_state_dict(self.state_dict)
self.birefnet.load_state_dict(self.state_dict)
return (self.birefnet,)
class LS_LoadBiRefNetModelV2:
def __init__(self):
self.model = None
@classmethod
def INPUT_TYPES(s):
model_list = list(s.birefnet_model_repos.keys())
return {
"required": {
"version": (model_list,{"default": model_list[0]}),
},
}
RETURN_TYPES = ("BIREFNET_MODEL",)
RETURN_NAMES = ("birefnet_model",)
FUNCTION = "load_birefnet_model"
CATEGORY = '😺dzNodes/LayerMask'
birefnet_model_repos = {
"BiRefNet-General": "ZhengPeng7/BiRefNet",
"RMBG-2.0": "briaai/RMBG-2.0"
}
def load_birefnet_model(self, version):
birefnet_path = os.path.join(folder_paths.models_dir, 'BiRefNet')
os.makedirs(birefnet_path, exist_ok=True)
model_path = os.path.join(birefnet_path, version)
if version == "BiRefNet-General":
old_birefnet_path = os.path.join(birefnet_path, 'pth')
old_model = "BiRefNet-general-epoch_244.pth"
old_model_path = os.path.join(old_birefnet_path, old_model)
if os.path.exists(old_model_path):
from .BiRefNet_v2.models.birefnet import BiRefNet
from .BiRefNet_v2.utils import check_state_dict
self.birefnet = BiRefNet(bb_pretrained=False)
self.state_dict = torch.load(old_model_path, map_location='cpu', weights_only=True)
self.state_dict = check_state_dict(self.state_dict)
self.birefnet.load_state_dict(self.state_dict)
return (self.birefnet,)
elif not os.path.exists(model_path):
log(f"Downloading {version} model...")
repo_id = self.birefnet_model_repos[version]
from huggingface_hub import snapshot_download
snapshot_download(repo_id=repo_id, local_dir=model_path, ignore_patterns=["*.md", "*.txt"])
self.model = AutoModelForImageSegmentation.from_pretrained(model_path, trust_remote_code=True)
return (self.model,)
class LS_BiRefNetUltraV2:
def __init__(self):
self.NODE_NAME = 'BiRefNetUltraV2'
@classmethod
def INPUT_TYPES(cls):
method_list = ['VITMatte', 'VITMatte(local)', 'PyMatting', 'GuidedFilter', ]
device_list = ['cuda', 'cpu']
return {
"required": {
"image": ("IMAGE",),
"birefnet_model": ("BIREFNET_MODEL",),
"detail_method": (method_list,),
"detail_erode": ("INT", {"default": 4, "min": 1, "max": 255, "step": 1}),
"detail_dilate": ("INT", {"default": 2, "min": 1, "max": 255, "step": 1}),
"black_point": ("FLOAT", {"default": 0.01, "min": 0.01, "max": 0.98, "step": 0.01, "display": "slider"}),
"white_point": ("FLOAT", {"default": 0.99, "min": 0.02, "max": 0.99, "step": 0.01, "display": "slider"}),
"process_detail": ("BOOLEAN", {"default": False}),
"device": (device_list,),
"max_megapixels": ("FLOAT", {"default": 2.0, "min": 1, "max": 999, "step": 0.1}),
},
"optional": {
}
}
RETURN_TYPES = ("IMAGE", "MASK", )
RETURN_NAMES = ("image", "mask", )
FUNCTION = "birefnet_ultra_v2"
CATEGORY = '😺dzNodes/LayerMask'
def birefnet_ultra_v2(self, image, birefnet_model, detail_method, detail_erode, detail_dilate,
black_point, white_point, process_detail, device, max_megapixels):
ret_images = []
ret_masks = []
inference_image_size = (1024, 1024)
if detail_method == 'VITMatte(local)':
local_files_only = True
else:
local_files_only = False
torch.set_float32_matmul_precision(['high', 'highest'][0])
birefnet_model.to(device)
birefnet_model.eval()
comfy_pbar = ProgressBar(len(image))
tqdm_pbar = tqdm(total=len(image), desc="Processing BiRefNet")
for i in image:
i = torch.unsqueeze(i, 0)
orig_image = tensor2pil(i).convert('RGB')
transform_image = transforms.Compose([
transforms.Resize(inference_image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
inference_image = transform_image(orig_image).unsqueeze(0).to(device)
# Prediction
with torch.no_grad():
preds = birefnet_model(inference_image)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
_mask = pred_pil.resize(inference_image_size)
resize_sampler = Image.BILINEAR
_mask = _mask.resize(orig_image.size, resize_sampler)
brightness_image = ImageEnhance.Brightness(_mask)
_mask = brightness_image.enhance(factor=1.08)
_mask = image2mask(_mask)
detail_range = detail_erode + detail_dilate
if process_detail:
if detail_method == 'GuidedFilter':
_mask = guided_filter_alpha(i, _mask, detail_range // 6 + 1)
_mask = tensor2pil(histogram_remap(_mask, black_point, white_point))
elif detail_method == 'PyMatting':
_mask = tensor2pil(mask_edge_detail(i, _mask, detail_range // 8 + 1, black_point, white_point))
else:
_trimap = generate_VITMatte_trimap(_mask, detail_erode, detail_dilate)
_mask = generate_VITMatte(orig_image, _trimap, local_files_only=local_files_only, device=device, max_megapixels=max_megapixels)
_mask = tensor2pil(histogram_remap(pil2tensor(_mask), black_point, white_point))
else:
_mask = tensor2pil(_mask)
ret_image = RGB2RGBA(orig_image, _mask.convert('L'))
ret_images.append(pil2tensor(ret_image))
ret_masks.append(image2mask(_mask))
comfy_pbar.update(1)
tqdm_pbar.update(1)
log(f"{self.NODE_NAME} Processed {len(ret_masks)} image(s).", message_type='finish')
return (torch.cat(ret_images, dim=0), torch.cat(ret_masks, dim=0),)
NODE_CLASS_MAPPINGS = {
"LayerMask: BiRefNetUltraV2": LS_BiRefNetUltraV2,
"LayerMask: LoadBiRefNetModel": LS_LoadBiRefNetModel,
"LayerMask: LoadBiRefNetModelV2": LS_LoadBiRefNetModelV2
}
NODE_DISPLAY_NAME_MAPPINGS = {
"LayerMask: BiRefNetUltraV2": "LayerMask: BiRefNet Ultra V2",
"LayerMask: LoadBiRefNetModel": "LayerMask: Load BiRefNet Model",
"LayerMask: LoadBiRefNetModelV2": "LayerMask: Load BiRefNet Model V2"
}