File size: 5,281 Bytes
cc69848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import torch
import torch.nn as nn

from .base import LycorisBaseModule
from ..logging import warning_once


class NormModule(LycorisBaseModule):
    name = "norm"
    support_module = {
        "layernorm",
        "groupnorm",
    }
    weight_list = ["w_norm", "b_norm"]
    weight_list_det = ["w_norm"]

    def __init__(
        self,
        lora_name,
        org_module: nn.Module,
        multiplier=1.0,
        rank_dropout=0.0,
        module_dropout=0.0,
        rank_dropout_scale=False,
        **kwargs,
    ):
        """if alpha == 0 or None, alpha is rank (no scaling)."""
        super().__init__(
            lora_name=lora_name,
            org_module=org_module,
            multiplier=multiplier,
            rank_dropout=rank_dropout,
            module_dropout=module_dropout,
            rank_dropout_scale=rank_dropout_scale,
            **kwargs,
        )
        if self.module_type == "unknown":
            if not hasattr(org_module, "weight") or not hasattr(org_module, "_norm"):
                warning_once(f"{type(org_module)} is not supported in Norm algo.")
                self.not_supported = True
                return
            else:
                self.dim = org_module.weight.numel()
                self.not_supported = False
        elif self.module_type not in self.support_module:
            warning_once(f"{self.module_type} is not supported in Norm algo.")
            self.not_supported = True
            return

        self.w_norm = nn.Parameter(torch.zeros(self.dim))
        if hasattr(org_module, "bias"):
            self.b_norm = nn.Parameter(torch.zeros(self.dim))
        if hasattr(org_module, "_norm"):
            self.org_norm = org_module._norm
        else:
            self.org_norm = None

    @classmethod
    def make_module_from_state_dict(cls, lora_name, orig_module, w_norm, b_norm):
        module = cls(
            lora_name,
            orig_module,
            1,
        )
        module.w_norm.copy_(w_norm)
        if b_norm is not None:
            module.b_norm.copy_(b_norm)
        return module

    def make_weight(self, scale=1, device=None):
        org_weight = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype)
        if hasattr(self.org_module[0], "bias"):
            org_bias = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype)
        else:
            org_bias = None
        if self.rank_dropout and self.training:
            drop = (torch.rand(self.dim, device=device) < self.rank_dropout).to(
                self.w_norm.device
            )
            if self.rank_dropout_scale:
                drop /= drop.mean()
        else:
            drop = 1
        drop = (
            torch.rand(self.dim, device=device) < self.rank_dropout
            if self.rank_dropout and self.training
            else 1
        )
        weight = self.w_norm.to(device) * drop * scale
        if org_bias is not None:
            bias = self.b_norm.to(device) * drop * scale
        return org_weight + weight, org_bias + bias if org_bias is not None else None

    def get_diff_weight(self, multiplier=1, shape=None, device=None):
        if self.not_supported:
            return 0, 0
        w = self.w_norm * multiplier
        if device is not None:
            w = w.to(device)
        if shape is not None:
            w = w.view(shape)
        if self.b_norm is not None:
            b = self.b_norm * multiplier
            if device is not None:
                b = b.to(device)
            if shape is not None:
                b = b.view(shape)
        else:
            b = None
        return w, b

    def get_merged_weight(self, multiplier=1, shape=None, device=None):
        if self.not_supported:
            return None, None
        diff_w, diff_b = self.get_diff_weight(multiplier, shape, device)
        org_w = self.org_module[0].weight.to(device, dtype=self.w_norm.dtype)
        weight = org_w + diff_w
        if diff_b is not None:
            org_b = self.org_module[0].bias.to(device, dtype=self.b_norm.dtype)
            bias = org_b + diff_b
        else:
            bias = None
        return weight, bias

    def forward(self, x):
        if self.not_supported or (
            self.module_dropout
            and self.training
            and torch.rand(1) < self.module_dropout
        ):
            return self.org_forward(x)
        scale = self.multiplier

        w, b = self.make_weight(scale, x.device)
        if self.org_norm is not None:
            normed = self.org_norm(x)
            scaled = normed * w
            if b is not None:
                scaled += b
            return scaled

        kw_dict = self.kw_dict | {"weight": w, "bias": b}
        return self.op(x, **kw_dict)


if __name__ == "__main__":
    base = nn.LayerNorm(128).cuda()
    norm = NormModule("test", base, 1).cuda()
    print(norm)
    test_input = torch.randn(1, 128).cuda()
    test_output = norm(test_input)
    torch.sum(test_output).backward()
    print(test_output.shape)

    base = nn.GroupNorm(4, 128).cuda()
    norm = NormModule("test", base, 1).cuda()
    print(norm)
    test_input = torch.randn(1, 128, 3, 3).cuda()
    test_output = norm(test_input)
    torch.sum(test_output).backward()
    print(test_output.shape)