File size: 5,179 Bytes
cc69848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import math
import random

import torch
import torch.nn as nn

from .base import LycorisBaseModule
from ..utils import product


class DyLoraModule(LycorisBaseModule):
    support_module = {
        "linear",
        "conv1d",
        "conv2d",
        "conv3d",
    }

    def __init__(
        self,
        lora_name,
        org_module: nn.Module,
        multiplier=1.0,
        lora_dim=4,
        alpha=1,
        dropout=0.0,
        rank_dropout=0.0,
        module_dropout=0.0,
        use_tucker=False,
        block_size=4,
        use_scalar=False,
        rank_dropout_scale=False,
        weight_decompose=False,
        bypass_mode=None,
        rs_lora=False,
        train_on_input=False,
        **kwargs,
    ):
        """if alpha == 0 or None, alpha is rank (no scaling)."""
        super().__init__(
            lora_name,
            org_module,
            multiplier,
            dropout,
            rank_dropout,
            module_dropout,
            rank_dropout_scale,
            bypass_mode,
        )
        if self.module_type not in self.support_module:
            raise ValueError(f"{self.module_type} is not supported in IA^3 algo.")
        assert lora_dim % block_size == 0, "lora_dim must be a multiple of block_size"
        self.block_count = lora_dim // block_size
        self.block_size = block_size

        shape = (
            self.shape[0],
            product(self.shape[1:]),
        )

        self.lora_dim = lora_dim
        self.up_list = nn.ParameterList(
            [torch.empty(shape[0], self.block_size) for i in range(self.block_count)]
        )
        self.down_list = nn.ParameterList(
            [torch.empty(self.block_size, shape[1]) for i in range(self.block_count)]
        )

        if type(alpha) == torch.Tensor:
            alpha = alpha.detach().float().numpy()  # without casting, bf16 causes error
        alpha = lora_dim if alpha is None or alpha == 0 else alpha
        self.scale = alpha / self.lora_dim
        self.register_buffer("alpha", torch.tensor(alpha))  # 定数として扱える

        # Need more experiences on init method
        for v in self.down_list:
            torch.nn.init.kaiming_uniform_(v, a=math.sqrt(5))
        for v in self.up_list:
            torch.nn.init.zeros_(v)

    def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
        return

    def custom_state_dict(self):
        destination = {}
        destination["alpha"] = self.alpha
        destination["lora_up.weight"] = nn.Parameter(
            torch.concat(list(self.up_list), dim=1)
        )
        destination["lora_down.weight"] = nn.Parameter(
            torch.concat(list(self.down_list)).reshape(
                self.lora_dim, -1, *self.shape[2:]
            )
        )
        return destination

    def get_weight(self, rank):
        b = math.ceil(rank / self.block_size)
        down = torch.concat(
            list(i.data for i in self.down_list[:b]) + list(self.down_list[b : (b + 1)])
        )
        up = torch.concat(
            list(i.data for i in self.up_list[:b]) + list(self.up_list[b : (b + 1)]),
            dim=1,
        )
        return down, up, self.alpha / (b + 1)

    def get_random_rank_weight(self):
        b = random.randint(0, self.block_count - 1)
        return self.get_weight(b * self.block_size)

    def get_diff_weight(self, multiplier=1, shape=None, device=None, rank=None):
        if rank is None:
            down, up, scale = self.get_random_rank_weight()
        else:
            down, up, scale = self.get_weight(rank)
        w = up @ (down * (scale * multiplier))
        if device is not None:
            w = w.to(device)
        if shape is not None:
            w = w.view(shape)
        else:
            w = w.view(self.shape)
        return w, None

    def get_merged_weight(self, multiplier=1, shape=None, device=None, rank=None):
        diff, _ = self.get_diff_weight(multiplier, shape, device, rank)
        return diff + self.org_weight, None

    def bypass_forward_diff(self, x, scale=1, rank=None):
        if rank is None:
            down, up, gamma = self.get_random_rank_weight()
        else:
            down, up, scale = self.get_weight(rank)
        down = down.view(self.lora_dim, -1, *self.shape[2:])
        up = up.view(-1, self.lora_dim, *(1 for _ in self.shape[2:]))
        scale = scale * gamma
        return self.op(self.op(x, down, **self.kw_dict), up)

    def bypass_forward(self, x, scale=1, rank=None):
        return self.org_forward(x) + self.bypass_forward_diff(x, scale, rank)

    def forward(self, x, *args, **kwargs):
        if self.module_dropout and self.training:
            if torch.rand(1) < self.module_dropout:
                return self.org_forward(x)
        if self.bypass_mode:
            return self.bypass_forward(x, self.multiplier)
        else:
            weight = self.get_merged_weight(multiplier=self.multiplier)[0]
            bias = (
                None
                if self.org_module[0].bias is None
                else self.org_module[0].bias.data
            )
            return self.op(x, weight, bias, **self.kw_dict)