Spaces:
Runtime error
Runtime error
import torch | |
import numpy as np | |
class MaskSubNode: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
}, | |
"optional": { | |
"src1": ("MASK",), | |
"src2": ("MASK",), | |
"src3": ("MASK",), | |
"src4": ("MASK",), | |
"src5": ("MASK",), | |
"src6": ("MASK",), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "sub" | |
CATEGORY = "tbox/Mask" | |
def sub_mask(self, dst, src): | |
if src != None: | |
mask = src.reshape((-1, src.shape[-2], src.shape[-1])) | |
return dst - mask | |
return dst | |
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None): | |
print(f'mask shape: {mask.shape}') | |
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() | |
output[:, :, :] = self.sub_mask(output, src1) | |
output[:, :, :] = self.sub_mask(output, src2) | |
output[:, :, :] = self.sub_mask(output, src3) | |
output[:, :, :] = self.sub_mask(output, src4) | |
output[:, :, :] = self.sub_mask(output, src5) | |
output[:, :, :] = self.sub_mask(output, src6) | |
output = torch.clamp(output, 0.0, 1.0) | |
return (output, ) | |
class MaskAddNode: | |
def INPUT_TYPES(cls): | |
return { | |
"required": { | |
"mask": ("MASK",), | |
}, | |
"optional": { | |
"src1": ("MASK",), | |
"src2": ("MASK",), | |
"src3": ("MASK",), | |
"src4": ("MASK",), | |
"src5": ("MASK",), | |
"src6": ("MASK",), | |
} | |
} | |
CATEGORY = "mask" | |
RETURN_TYPES = ("MASK",) | |
FUNCTION = "add" | |
CATEGORY = "tbox/Mask" | |
def add_mask(self, dst, src): | |
if src != None: | |
mask = src.reshape((-1, src.shape[-2], src.shape[-1])) | |
return dst + mask | |
return dst | |
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None): | |
print(f'mask shape: {mask.shape}') | |
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone() | |
output[:, :, :] = self.add_mask(output, src1) | |
output[:, :, :] = self.add_mask(output, src2) | |
output[:, :, :] = self.add_mask(output, src3) | |
output[:, :, :] = self.add_mask(output, src4) | |
output[:, :, :] = self.add_mask(output, src5) | |
output[:, :, :] = self.add_mask(output, src6) | |
output = torch.clamp(output, 0.0, 1.0) | |
return (output, ) | |