|  | import numpy as np | 
					
						
						|  | import scipy.ndimage | 
					
						
						|  | import torch | 
					
						
						|  | import comfy.utils | 
					
						
						|  | import node_helpers | 
					
						
						|  | import folder_paths | 
					
						
						|  | import random | 
					
						
						|  |  | 
					
						
						|  | import nodes | 
					
						
						|  | from nodes import MAX_RESOLUTION | 
					
						
						|  |  | 
					
						
						|  | def composite(destination, source, x, y, mask = None, multiplier = 8, resize_source = False): | 
					
						
						|  | source = source.to(destination.device) | 
					
						
						|  | if resize_source: | 
					
						
						|  | source = torch.nn.functional.interpolate(source, size=(destination.shape[2], destination.shape[3]), mode="bilinear") | 
					
						
						|  |  | 
					
						
						|  | source = comfy.utils.repeat_to_batch_size(source, destination.shape[0]) | 
					
						
						|  |  | 
					
						
						|  | x = max(-source.shape[3] * multiplier, min(x, destination.shape[3] * multiplier)) | 
					
						
						|  | y = max(-source.shape[2] * multiplier, min(y, destination.shape[2] * multiplier)) | 
					
						
						|  |  | 
					
						
						|  | left, top = (x // multiplier, y // multiplier) | 
					
						
						|  | right, bottom = (left + source.shape[3], top + source.shape[2],) | 
					
						
						|  |  | 
					
						
						|  | if mask is None: | 
					
						
						|  | mask = torch.ones_like(source) | 
					
						
						|  | else: | 
					
						
						|  | mask = mask.to(destination.device, copy=True) | 
					
						
						|  | mask = torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(source.shape[2], source.shape[3]), mode="bilinear") | 
					
						
						|  | mask = comfy.utils.repeat_to_batch_size(mask, source.shape[0]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | visible_width, visible_height = (destination.shape[3] - left + min(0, x), destination.shape[2] - top + min(0, y),) | 
					
						
						|  |  | 
					
						
						|  | mask = mask[:, :, :visible_height, :visible_width] | 
					
						
						|  | inverse_mask = torch.ones_like(mask) - mask | 
					
						
						|  |  | 
					
						
						|  | source_portion = mask * source[:, :, :visible_height, :visible_width] | 
					
						
						|  | destination_portion = inverse_mask  * destination[:, :, top:bottom, left:right] | 
					
						
						|  |  | 
					
						
						|  | destination[:, :, top:bottom, left:right] = source_portion + destination_portion | 
					
						
						|  | return destination | 
					
						
						|  |  | 
					
						
						|  | class LatentCompositeMasked: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "destination": ("LATENT",), | 
					
						
						|  | "source": ("LATENT",), | 
					
						
						|  | "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), | 
					
						
						|  | "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 8}), | 
					
						
						|  | "resize_source": ("BOOLEAN", {"default": False}), | 
					
						
						|  | }, | 
					
						
						|  | "optional": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | RETURN_TYPES = ("LATENT",) | 
					
						
						|  | FUNCTION = "composite" | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "latent" | 
					
						
						|  |  | 
					
						
						|  | def composite(self, destination, source, x, y, resize_source, mask = None): | 
					
						
						|  | output = destination.copy() | 
					
						
						|  | destination = destination["samples"].clone() | 
					
						
						|  | source = source["samples"] | 
					
						
						|  | output["samples"] = composite(destination, source, x, y, mask, 8, resize_source) | 
					
						
						|  | return (output,) | 
					
						
						|  |  | 
					
						
						|  | class ImageCompositeMasked: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "destination": ("IMAGE",), | 
					
						
						|  | "source": ("IMAGE",), | 
					
						
						|  | "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "resize_source": ("BOOLEAN", {"default": False}), | 
					
						
						|  | }, | 
					
						
						|  | "optional": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  | RETURN_TYPES = ("IMAGE",) | 
					
						
						|  | FUNCTION = "composite" | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "image" | 
					
						
						|  |  | 
					
						
						|  | def composite(self, destination, source, x, y, resize_source, mask = None): | 
					
						
						|  | destination, source = node_helpers.image_alpha_fix(destination, source) | 
					
						
						|  | destination = destination.clone().movedim(-1, 1) | 
					
						
						|  | output = composite(destination, source.movedim(-1, 1), x, y, mask, 1, resize_source).movedim(1, -1) | 
					
						
						|  | return (output,) | 
					
						
						|  |  | 
					
						
						|  | class MaskToImage: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("IMAGE",) | 
					
						
						|  | FUNCTION = "mask_to_image" | 
					
						
						|  |  | 
					
						
						|  | def mask_to_image(self, mask): | 
					
						
						|  | result = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) | 
					
						
						|  | return (result,) | 
					
						
						|  |  | 
					
						
						|  | class ImageToMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "image": ("IMAGE",), | 
					
						
						|  | "channel": (["red", "green", "blue", "alpha"],), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  | FUNCTION = "image_to_mask" | 
					
						
						|  |  | 
					
						
						|  | def image_to_mask(self, image, channel): | 
					
						
						|  | channels = ["red", "green", "blue", "alpha"] | 
					
						
						|  | mask = image[:, :, :, channels.index(channel)] | 
					
						
						|  | return (mask,) | 
					
						
						|  |  | 
					
						
						|  | class ImageColorToMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "image": ("IMAGE",), | 
					
						
						|  | "color": ("INT", {"default": 0, "min": 0, "max": 0xFFFFFF, "step": 1, "display": "color"}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  | FUNCTION = "image_to_mask" | 
					
						
						|  |  | 
					
						
						|  | def image_to_mask(self, image, color): | 
					
						
						|  | temp = (torch.clamp(image, 0, 1.0) * 255.0).round().to(torch.int) | 
					
						
						|  | temp = torch.bitwise_left_shift(temp[:,:,:,0], 16) + torch.bitwise_left_shift(temp[:,:,:,1], 8) + temp[:,:,:,2] | 
					
						
						|  | mask = torch.where(temp == color, 1.0, 0).float() | 
					
						
						|  | return (mask,) | 
					
						
						|  |  | 
					
						
						|  | class SolidMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "value": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}), | 
					
						
						|  | "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "solid" | 
					
						
						|  |  | 
					
						
						|  | def solid(self, value, width, height): | 
					
						
						|  | out = torch.full((1, height, width), value, dtype=torch.float32, device="cpu") | 
					
						
						|  | return (out,) | 
					
						
						|  |  | 
					
						
						|  | class InvertMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "invert" | 
					
						
						|  |  | 
					
						
						|  | def invert(self, mask): | 
					
						
						|  | out = 1.0 - mask | 
					
						
						|  | return (out,) | 
					
						
						|  |  | 
					
						
						|  | class CropMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "width": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "height": ("INT", {"default": 512, "min": 1, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "crop" | 
					
						
						|  |  | 
					
						
						|  | def crop(self, mask, x, y, width, height): | 
					
						
						|  | mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) | 
					
						
						|  | out = mask[:, y:y + height, x:x + width] | 
					
						
						|  | return (out,) | 
					
						
						|  |  | 
					
						
						|  | class MaskComposite: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "destination": ("MASK",), | 
					
						
						|  | "source": ("MASK",), | 
					
						
						|  | "x": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "y": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "operation": (["multiply", "add", "subtract", "and", "or", "xor"],), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "combine" | 
					
						
						|  |  | 
					
						
						|  | def combine(self, destination, source, x, y, operation): | 
					
						
						|  | output = destination.reshape((-1, destination.shape[-2], destination.shape[-1])).clone() | 
					
						
						|  | source = source.reshape((-1, source.shape[-2], source.shape[-1])) | 
					
						
						|  |  | 
					
						
						|  | left, top = (x, y,) | 
					
						
						|  | right, bottom = (min(left + source.shape[-1], destination.shape[-1]), min(top + source.shape[-2], destination.shape[-2])) | 
					
						
						|  | visible_width, visible_height = (right - left, bottom - top,) | 
					
						
						|  |  | 
					
						
						|  | source_portion = source[:, :visible_height, :visible_width] | 
					
						
						|  | destination_portion = output[:, top:bottom, left:right] | 
					
						
						|  |  | 
					
						
						|  | if operation == "multiply": | 
					
						
						|  | output[:, top:bottom, left:right] = destination_portion * source_portion | 
					
						
						|  | elif operation == "add": | 
					
						
						|  | output[:, top:bottom, left:right] = destination_portion + source_portion | 
					
						
						|  | elif operation == "subtract": | 
					
						
						|  | output[:, top:bottom, left:right] = destination_portion - source_portion | 
					
						
						|  | elif operation == "and": | 
					
						
						|  | output[:, top:bottom, left:right] = torch.bitwise_and(destination_portion.round().bool(), source_portion.round().bool()).float() | 
					
						
						|  | elif operation == "or": | 
					
						
						|  | output[:, top:bottom, left:right] = torch.bitwise_or(destination_portion.round().bool(), source_portion.round().bool()).float() | 
					
						
						|  | elif operation == "xor": | 
					
						
						|  | output[:, top:bottom, left:right] = torch.bitwise_xor(destination_portion.round().bool(), source_portion.round().bool()).float() | 
					
						
						|  |  | 
					
						
						|  | output = torch.clamp(output, 0.0, 1.0) | 
					
						
						|  |  | 
					
						
						|  | return (output,) | 
					
						
						|  |  | 
					
						
						|  | class FeatherMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | "left": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "top": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "right": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "bottom": ("INT", {"default": 0, "min": 0, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "feather" | 
					
						
						|  |  | 
					
						
						|  | def feather(self, mask, left, top, right, bottom): | 
					
						
						|  | output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() | 
					
						
						|  |  | 
					
						
						|  | left = min(left, output.shape[-1]) | 
					
						
						|  | right = min(right, output.shape[-1]) | 
					
						
						|  | top = min(top, output.shape[-2]) | 
					
						
						|  | bottom = min(bottom, output.shape[-2]) | 
					
						
						|  |  | 
					
						
						|  | for x in range(left): | 
					
						
						|  | feather_rate = (x + 1.0) / left | 
					
						
						|  | output[:, :, x] *= feather_rate | 
					
						
						|  |  | 
					
						
						|  | for x in range(right): | 
					
						
						|  | feather_rate = (x + 1) / right | 
					
						
						|  | output[:, :, -x] *= feather_rate | 
					
						
						|  |  | 
					
						
						|  | for y in range(top): | 
					
						
						|  | feather_rate = (y + 1) / top | 
					
						
						|  | output[:, y, :] *= feather_rate | 
					
						
						|  |  | 
					
						
						|  | for y in range(bottom): | 
					
						
						|  | feather_rate = (y + 1) / bottom | 
					
						
						|  | output[:, -y, :] *= feather_rate | 
					
						
						|  |  | 
					
						
						|  | return (output,) | 
					
						
						|  |  | 
					
						
						|  | class GrowMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(cls): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | "expand": ("INT", {"default": 0, "min": -MAX_RESOLUTION, "max": MAX_RESOLUTION, "step": 1}), | 
					
						
						|  | "tapered_corners": ("BOOLEAN", {"default": True}), | 
					
						
						|  | }, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "expand_mask" | 
					
						
						|  |  | 
					
						
						|  | def expand_mask(self, mask, expand, tapered_corners): | 
					
						
						|  | c = 0 if tapered_corners else 1 | 
					
						
						|  | kernel = np.array([[c, 1, c], | 
					
						
						|  | [1, 1, 1], | 
					
						
						|  | [c, 1, c]]) | 
					
						
						|  | mask = mask.reshape((-1, mask.shape[-2], mask.shape[-1])) | 
					
						
						|  | out = [] | 
					
						
						|  | for m in mask: | 
					
						
						|  | output = m.numpy() | 
					
						
						|  | for _ in range(abs(expand)): | 
					
						
						|  | if expand < 0: | 
					
						
						|  | output = scipy.ndimage.grey_erosion(output, footprint=kernel) | 
					
						
						|  | else: | 
					
						
						|  | output = scipy.ndimage.grey_dilation(output, footprint=kernel) | 
					
						
						|  | output = torch.from_numpy(output) | 
					
						
						|  | out.append(output) | 
					
						
						|  | return (torch.stack(out, dim=0),) | 
					
						
						|  |  | 
					
						
						|  | class ThresholdMask: | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": { | 
					
						
						|  | "mask": ("MASK",), | 
					
						
						|  | "value": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | 
					
						
						|  | } | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | RETURN_TYPES = ("MASK",) | 
					
						
						|  | FUNCTION = "image_to_mask" | 
					
						
						|  |  | 
					
						
						|  | def image_to_mask(self, mask, value): | 
					
						
						|  | mask = (mask > value).float() | 
					
						
						|  | return (mask,) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class MaskPreview(nodes.SaveImage): | 
					
						
						|  | def __init__(self): | 
					
						
						|  | self.output_dir = folder_paths.get_temp_directory() | 
					
						
						|  | self.type = "temp" | 
					
						
						|  | self.prefix_append = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstupvxyz") for x in range(5)) | 
					
						
						|  | self.compress_level = 4 | 
					
						
						|  |  | 
					
						
						|  | @classmethod | 
					
						
						|  | def INPUT_TYPES(s): | 
					
						
						|  | return { | 
					
						
						|  | "required": {"mask": ("MASK",), }, | 
					
						
						|  | "hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"}, | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | FUNCTION = "execute" | 
					
						
						|  | CATEGORY = "mask" | 
					
						
						|  |  | 
					
						
						|  | def execute(self, mask, filename_prefix="ComfyUI", prompt=None, extra_pnginfo=None): | 
					
						
						|  | preview = mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])).movedim(1, -1).expand(-1, -1, -1, 3) | 
					
						
						|  | return self.save_images(preview, filename_prefix, prompt, extra_pnginfo) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | NODE_CLASS_MAPPINGS = { | 
					
						
						|  | "LatentCompositeMasked": LatentCompositeMasked, | 
					
						
						|  | "ImageCompositeMasked": ImageCompositeMasked, | 
					
						
						|  | "MaskToImage": MaskToImage, | 
					
						
						|  | "ImageToMask": ImageToMask, | 
					
						
						|  | "ImageColorToMask": ImageColorToMask, | 
					
						
						|  | "SolidMask": SolidMask, | 
					
						
						|  | "InvertMask": InvertMask, | 
					
						
						|  | "CropMask": CropMask, | 
					
						
						|  | "MaskComposite": MaskComposite, | 
					
						
						|  | "FeatherMask": FeatherMask, | 
					
						
						|  | "GrowMask": GrowMask, | 
					
						
						|  | "ThresholdMask": ThresholdMask, | 
					
						
						|  | "MaskPreview": MaskPreview | 
					
						
						|  | } | 
					
						
						|  |  | 
					
						
						|  | NODE_DISPLAY_NAME_MAPPINGS = { | 
					
						
						|  | "ImageToMask": "Convert Image to Mask", | 
					
						
						|  | "MaskToImage": "Convert Mask to Image", | 
					
						
						|  | } | 
					
						
						|  |  |