File size: 3,498 Bytes
baa8e90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import torch
import copy
from comfy.sd1_clip import gen_empty_tokens

class Negpip:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "model": ("MODEL", ),
                "clip": ("CLIP", ),
            },
        }

    RETURN_TYPES = ("MODEL", "CLIP")
    FUNCTION = "apply"
    CATEGORY = "loaders"

    def apply(self, model, clip):
        new_clip = copy.copy(clip)
        if hasattr(new_clip.cond_stage_model, "clip_g"):
            new_clip.cond_stage_model.clip_g.encode_token_weights = hook_clip_encode_token_weights(new_clip.cond_stage_model.clip_g)
        if hasattr(new_clip.cond_stage_model, "clip_h"):
            new_clip.cond_stage_model.clip_h.encode_token_weights = hook_clip_encode_token_weights(new_clip.cond_stage_model.clip_h)
        if hasattr(new_clip.cond_stage_model, "clip_l"):
            new_clip.cond_stage_model.clip_l.encode_token_weights = hook_clip_encode_token_weights(new_clip.cond_stage_model.clip_l)
        new_model = model.clone()

        def negpip_apply(q, k, v, extra_options):
            new_k = k[:, 0::2]
            new_v = v[:, 1::2]
            return q, new_k, new_v

        new_model.set_model_attn2_patch(negpip_apply)

        return new_model, new_clip

# prompt weightingの計算で無理やりk,vを計算
# k,vはattn2_patchで分離する
# weightがマイナスのときはvだけマイナスにする
def hook_clip_encode_token_weights(self):
    
    def encode_token_weights(token_weight_pairs):
        to_encode = list()
        max_token_len = 0
        has_weights = False
        for x in token_weight_pairs:
            tokens = list(map(lambda a: a[0], x))
            max_token_len = max(len(tokens), max_token_len)
            has_weights = has_weights or not all(map(lambda a: a[1] == 1.0, x))
            to_encode.append(tokens)

        sections = len(to_encode)
        if has_weights or sections == 0:
            to_encode.append(gen_empty_tokens(self.special_tokens, max_token_len))

        out, pooled = self.encode(to_encode)
        if pooled is not None:
            first_pooled = pooled[0:1].cpu()
        else:
            first_pooled = pooled

        output = []
        for k in range(0, sections):
            zk = out[k:k+1].clone()
            zv = out[k:k+1].clone()
            if has_weights:
                z_empty = out[-1]
                for i in range(len(zk)):
                    for j in range(len(zk[i])):
                        weight = token_weight_pairs[k][j][1]
                        if weight < 0:
                            weight = -weight
                            sign = -1
                        else:
                            sign = 1
                        zk[i][j] = (zk[i][j] - z_empty[0][j]) * weight + z_empty[0][j]
                        zv[i][j] = sign * ((zv[i][j] - z_empty[0][j]) * weight + z_empty[0][j])

            z = torch.zeros_like(zk).repeat(1, 2, 1)
            for i in range(zk.shape[1]):  # 頭悪いのでfor文
                z[:, 2*i, :] += zk[:, i, :]
                z[:, 2*i+1, :] += zv[:, i, :]
            output.append(z)

        if (len(output) == 0):
            return out[-1:].cpu(), first_pooled
        return torch.cat(output, dim=-2).cpu(), first_pooled

    return encode_token_weights


NODE_CLASS_MAPPINGS = {
    "Negpip": Negpip,
}

NODE_DISPLAY_NAME_MAPPINGS = {
    "Negpip": "Apply Negpip",
}

__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"]