Spaces:
Runtime error
Runtime error
File size: 2,669 Bytes
f2dbf59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
import numpy as np
class MaskSubNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
"optional": {
"src1": ("MASK",),
"src2": ("MASK",),
"src3": ("MASK",),
"src4": ("MASK",),
"src5": ("MASK",),
"src6": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "sub"
CATEGORY = "tbox/Mask"
def sub_mask(self, dst, src):
if src != None:
mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
return dst - mask
return dst
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
print(f'mask shape: {mask.shape}')
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
output[:, :, :] = self.sub_mask(output, src1)
output[:, :, :] = self.sub_mask(output, src2)
output[:, :, :] = self.sub_mask(output, src3)
output[:, :, :] = self.sub_mask(output, src4)
output[:, :, :] = self.sub_mask(output, src5)
output[:, :, :] = self.sub_mask(output, src6)
output = torch.clamp(output, 0.0, 1.0)
return (output, )
class MaskAddNode:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"mask": ("MASK",),
},
"optional": {
"src1": ("MASK",),
"src2": ("MASK",),
"src3": ("MASK",),
"src4": ("MASK",),
"src5": ("MASK",),
"src6": ("MASK",),
}
}
CATEGORY = "mask"
RETURN_TYPES = ("MASK",)
FUNCTION = "add"
CATEGORY = "tbox/Mask"
def add_mask(self, dst, src):
if src != None:
mask = src.reshape((-1, src.shape[-2], src.shape[-1]))
return dst + mask
return dst
def add(self, mask, src1=None, src2=None, src3=None, src4=None, src5=None, src6=None):
print(f'mask shape: {mask.shape}')
output = mask.reshape((-1, mask.shape[-2], mask.shape[-1])).clone()
output[:, :, :] = self.add_mask(output, src1)
output[:, :, :] = self.add_mask(output, src2)
output[:, :, :] = self.add_mask(output, src3)
output[:, :, :] = self.add_mask(output, src4)
output[:, :, :] = self.add_mask(output, src5)
output[:, :, :] = self.add_mask(output, src6)
output = torch.clamp(output, 0.0, 1.0)
return (output, )
|