File size: 4,385 Bytes
baa8e90 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
class CDTuner:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"detail_1": ("FLOAT", {
"default": 0,
"min": -10,
"max": 10,
"step": 0.1
}),
"detail_2": ("FLOAT", {
"default": 0,
"min": -10,
"max": 10,
"step": 0.1
}),
"contrast_1": ("FLOAT", {
"default": 0,
"min": -20,
"max": 20,
"step": 0.1
}),
"start": ("INT", {
"default": 0,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
"end": ("INT", {
"default": 1000,
"min": 0,
"max": 1000,
"step": 1,
"display": "number"
}),
},
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "apply"
CATEGORY = "loaders"
def apply(self, model, detail_1, detail_2, contrast_1, start, end):
'''
detail_1: 最初のConv層のweightを減らしbiasを増やすことで、detailを増やす・・?
detail_2: 最後のConv層前のGroupNormの以下略
contrast_1: 最後のConv層のbiasの0チャンネル目を増やすことでコントラストを増やす・・・?
'''
new_model = model.clone()
ratios = fineman([detail_1, detail_2, contrast_1])
self.storedweights = {}
self.start = start
self.end = end
# unet計算前後のパッチ
def apply_cdtuner(model_function, kwargs):
if kwargs["timestep"][0] < (1000 - self.end) or kwargs["timestep"][0] > (1000 - self.start):
return model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
for i, name in enumerate(ADJUSTS):
# 元の重みをロード
self.storedweights[name] = getset_nested_module_tensor(True, new_model, name).clone()
if 4 > i:
new_weight = self.storedweights[name] * ratios[i]
else:
device = self.storedweights[name].device
dtype = self.storedweights[name].dtype
new_weight = self.storedweights[name] + torch.tensor(ratios[i], device=device, dtype=dtype)
# 重みを書き換え
getset_nested_module_tensor(False, new_model, name, new_tensor=new_weight)
retval = model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
# 重みを元に戻す
for name in ADJUSTS:
getset_nested_module_tensor(False, new_model, name, new_tensor=self.storedweights[name])
return retval
new_model.set_model_unet_function_wrapper(apply_cdtuner)
return (new_model, )
def getset_nested_module_tensor(clone, model, tensor_path, new_tensor=None):
sdmodules = tensor_path.split('.')
target_module = model
last_attr = None
for module_name in sdmodules if clone else sdmodules[:-1]:
if module_name.isdigit():
target_module = target_module[int(module_name)]
else:
target_module = getattr(target_module, module_name)
if clone:
return target_module
last_attr = sdmodules[-1]
setattr(target_module, last_attr, torch.nn.Parameter(new_tensor))
# なんでfineman?
def fineman(fine):
fine = [
1 - fine[0] * 0.01,
1 + fine[0] * 0.02,
1 - fine[1] * 0.01,
1 + fine[1] * 0.02,
[fine[2] * 0.02, 0, 0, 0]
]
return fine
ADJUSTS = [
"model.diffusion_model.input_blocks.0.0.weight",
"model.diffusion_model.input_blocks.0.0.bias",
"model.diffusion_model.out.0.weight",
"model.diffusion_model.out.0.bias",
"model.diffusion_model.out.2.bias",
]
NODE_CLASS_MAPPINGS = {
"CDTuner": CDTuner,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"CDTuner": "Apply CDTuner",
}
__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]
|