File size: 4,385 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch


class CDTuner:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL", ),
                "detail_1": ("FLOAT", {
                    "default": 0,
                    "min": -10,
                    "max": 10,
                    "step": 0.1
                }),
                "detail_2": ("FLOAT", {
                    "default": 0,
                    "min": -10,
                    "max": 10,
                    "step": 0.1
                }),
                "contrast_1": ("FLOAT", {
                    "default": 0,
                    "min": -20,
                    "max": 20,
                    "step": 0.1
                }),
                "start": ("INT", {
                    "default": 0, 
                    "min": 0,
                    "max": 1000,
                    "step": 1,
                    "display": "number"
                }),
                "end": ("INT", {
                    "default": 1000, 
                    "min": 0,
                    "max": 1000,
                    "step": 1,
                    "display": "number"
                }),
            },
        }

    RETURN_TYPES = ("MODEL", )
    FUNCTION = "apply"
    CATEGORY = "loaders"

    def apply(self, model, detail_1, detail_2, contrast_1, start, end):
        '''
        detail_1: 最初のConv層のweightを減らしbiasを増やすことで、detailを増やす・・?
        detail_2: 最後のConv層前のGroupNormの以下略
        contrast_1: 最後のConv層のbiasの0チャンネル目を増やすことでコントラストを増やす・・・?
        '''
        new_model = model.clone()
        ratios = fineman([detail_1, detail_2, contrast_1])
        self.storedweights = {}
        self.start = start
        self.end = end

        # unet計算前後のパッチ
        def apply_cdtuner(model_function, kwargs):
            if kwargs["timestep"][0] < (1000 - self.end) or kwargs["timestep"][0] > (1000 - self.start):
                return model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])
            for i, name in enumerate(ADJUSTS):
                # 元の重みをロード
                self.storedweights[name] = getset_nested_module_tensor(True, new_model, name).clone()
                if 4 > i:
                    new_weight = self.storedweights[name] * ratios[i]
                else:
                    device = self.storedweights[name].device
                    dtype = self.storedweights[name].dtype
                    new_weight = self.storedweights[name] + torch.tensor(ratios[i], device=device, dtype=dtype)
                # 重みを書き換え
                getset_nested_module_tensor(False, new_model, name, new_tensor=new_weight)
            retval = model_function(kwargs["input"], kwargs["timestep"], **kwargs["c"])

            # 重みを元に戻す
            for name in ADJUSTS:
                getset_nested_module_tensor(False, new_model, name, new_tensor=self.storedweights[name])

            return retval

        new_model.set_model_unet_function_wrapper(apply_cdtuner)

        return (new_model, )


def getset_nested_module_tensor(clone, model, tensor_path, new_tensor=None):
    sdmodules = tensor_path.split('.')
    target_module = model
    last_attr = None

    for module_name in sdmodules if clone else sdmodules[:-1]:
        if module_name.isdigit():
            target_module = target_module[int(module_name)]
        else:
            target_module = getattr(target_module, module_name)

    if clone:
        return target_module

    last_attr = sdmodules[-1]
    setattr(target_module, last_attr, torch.nn.Parameter(new_tensor))

# なんでfineman?
def fineman(fine):
    fine = [
        1 - fine[0] * 0.01,
        1 + fine[0] * 0.02,
        1 - fine[1] * 0.01,
        1 + fine[1] * 0.02,
        [fine[2] * 0.02, 0, 0, 0]
    ]
    return fine


ADJUSTS = [
    "model.diffusion_model.input_blocks.0.0.weight",
    "model.diffusion_model.input_blocks.0.0.bias",
    "model.diffusion_model.out.0.weight",
    "model.diffusion_model.out.0.bias",
    "model.diffusion_model.out.2.bias",
]

NODE_CLASS_MAPPINGS = {
    "CDTuner": CDTuner,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "CDTuner": "Apply CDTuner",
}

__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]