ret45 commited on
Commit
08d0710
·
verified ·
1 Parent(s): 2774d83

src_inference/lora_helper.py

Browse files
Files changed (1) hide show
  1. lora_helper.py +194 -0
lora_helper.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import FluxAttnProcessor2_0
2
+ from safetensors import safe_open
3
+ import re
4
+ import torch
5
+ from .layers_cache import MultiDoubleStreamBlockLoraProcessor, MultiSingleStreamBlockLoraProcessor
6
+
7
+ device = "cuda"
8
+
9
+ def load_safetensors(path):
10
+ tensors = {}
11
+ with safe_open(path, framework="pt", device="cpu") as f:
12
+ for key in f.keys():
13
+ tensors[key] = f.get_tensor(key)
14
+ return tensors
15
+
16
+ def get_lora_rank(checkpoint):
17
+ for k in checkpoint.keys():
18
+ if k.endswith(".down.weight"):
19
+ return checkpoint[k].shape[0]
20
+
21
+ def load_checkpoint(local_path):
22
+ if local_path is not None:
23
+ if '.safetensors' in local_path:
24
+ print(f"Loading .safetensors checkpoint from {local_path}")
25
+ checkpoint = load_safetensors(local_path)
26
+ else:
27
+ print(f"Loading checkpoint from {local_path}")
28
+ checkpoint = torch.load(local_path, map_location='cpu')
29
+ return checkpoint
30
+
31
+ def update_model_with_lora(checkpoint, lora_weights, transformer, cond_size):
32
+ number = len(lora_weights)
33
+ ranks = [get_lora_rank(checkpoint) for _ in range(number)]
34
+ lora_attn_procs = {}
35
+ double_blocks_idx = list(range(19))
36
+ single_blocks_idx = list(range(38))
37
+ for name, attn_processor in transformer.attn_processors.items():
38
+ match = re.search(r'\.(\d+)\.', name)
39
+ if match:
40
+ layer_index = int(match.group(1))
41
+
42
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
43
+
44
+ lora_state_dicts = {}
45
+ for key, value in checkpoint.items():
46
+ # Match based on the layer index in the key (assuming the key contains layer index)
47
+ if re.search(r'\.(\d+)\.', key):
48
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
49
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
50
+ lora_state_dicts[key] = value
51
+
52
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
53
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
54
+ )
55
+
56
+ # Load the weights from the checkpoint dictionary into the corresponding layers
57
+ for n in range(number):
58
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
59
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
60
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
61
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
62
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
63
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
64
+ lora_attn_procs[name].proj_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.down.weight', None)
65
+ lora_attn_procs[name].proj_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.proj_loras.{n}.up.weight', None)
66
+ lora_attn_procs[name].to(device)
67
+
68
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
69
+
70
+ lora_state_dicts = {}
71
+ for key, value in checkpoint.items():
72
+ if re.search(r'\.(\d+)\.', key):
73
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
74
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
75
+ lora_state_dicts[key] = value
76
+
77
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
78
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=lora_weights, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=number
79
+ )
80
+ for n in range(number):
81
+ lora_attn_procs[name].q_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.down.weight', None)
82
+ lora_attn_procs[name].q_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.q_loras.{n}.up.weight', None)
83
+ lora_attn_procs[name].k_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.down.weight', None)
84
+ lora_attn_procs[name].k_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.k_loras.{n}.up.weight', None)
85
+ lora_attn_procs[name].v_loras[n].down.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.down.weight', None)
86
+ lora_attn_procs[name].v_loras[n].up.weight.data = lora_state_dicts.get(f'{name}.v_loras.{n}.up.weight', None)
87
+ lora_attn_procs[name].to(device)
88
+ else:
89
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
90
+
91
+ transformer.set_attn_processor(lora_attn_procs)
92
+
93
+
94
+ def update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size):
95
+ ck_number = len(checkpoints)
96
+ cond_lora_number = [len(ls) for ls in lora_weights]
97
+ cond_number = sum(cond_lora_number)
98
+ ranks = [get_lora_rank(checkpoint) for checkpoint in checkpoints]
99
+ multi_lora_weight = []
100
+ for ls in lora_weights:
101
+ for n in ls:
102
+ multi_lora_weight.append(n)
103
+
104
+ lora_attn_procs = {}
105
+ double_blocks_idx = list(range(19))
106
+ single_blocks_idx = list(range(38))
107
+ for name, attn_processor in transformer.attn_processors.items():
108
+ match = re.search(r'\.(\d+)\.', name)
109
+ if match:
110
+ layer_index = int(match.group(1))
111
+
112
+ if name.startswith("transformer_blocks") and layer_index in double_blocks_idx:
113
+ lora_state_dicts = [{} for _ in range(ck_number)]
114
+ for idx, checkpoint in enumerate(checkpoints):
115
+ for key, value in checkpoint.items():
116
+ # Match based on the layer index in the key (assuming the key contains layer index)
117
+ if re.search(r'\.(\d+)\.', key):
118
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
119
+ if checkpoint_layer_index == layer_index and key.startswith("transformer_blocks"):
120
+ lora_state_dicts[idx][key] = value
121
+
122
+ lora_attn_procs[name] = MultiDoubleStreamBlockLoraProcessor(
123
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
124
+ )
125
+
126
+ # Load the weights from the checkpoint dictionary into the corresponding layers
127
+ num = 0
128
+ for idx in range(ck_number):
129
+ for n in range(cond_lora_number[idx]):
130
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
131
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
132
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
133
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
134
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
135
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
136
+ lora_attn_procs[name].proj_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.down.weight', None)
137
+ lora_attn_procs[name].proj_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.proj_loras.{n}.up.weight', None)
138
+ lora_attn_procs[name].to(device)
139
+ num += 1
140
+
141
+ elif name.startswith("single_transformer_blocks") and layer_index in single_blocks_idx:
142
+
143
+ lora_state_dicts = [{} for _ in range(ck_number)]
144
+ for idx, checkpoint in enumerate(checkpoints):
145
+ for key, value in checkpoint.items():
146
+ # Match based on the layer index in the key (assuming the key contains layer index)
147
+ if re.search(r'\.(\d+)\.', key):
148
+ checkpoint_layer_index = int(re.search(r'\.(\d+)\.', key).group(1))
149
+ if checkpoint_layer_index == layer_index and key.startswith("single_transformer_blocks"):
150
+ lora_state_dicts[idx][key] = value
151
+
152
+ lora_attn_procs[name] = MultiSingleStreamBlockLoraProcessor(
153
+ dim=3072, ranks=ranks, network_alphas=ranks, lora_weights=multi_lora_weight, device=device, dtype=torch.bfloat16, cond_width=cond_size, cond_height=cond_size, n_loras=cond_number
154
+ )
155
+ # Load the weights from the checkpoint dictionary into the corresponding layers
156
+ num = 0
157
+ for idx in range(ck_number):
158
+ for n in range(cond_lora_number[idx]):
159
+ lora_attn_procs[name].q_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.down.weight', None)
160
+ lora_attn_procs[name].q_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.q_loras.{n}.up.weight', None)
161
+ lora_attn_procs[name].k_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.down.weight', None)
162
+ lora_attn_procs[name].k_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.k_loras.{n}.up.weight', None)
163
+ lora_attn_procs[name].v_loras[num].down.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.down.weight', None)
164
+ lora_attn_procs[name].v_loras[num].up.weight.data = lora_state_dicts[idx].get(f'{name}.v_loras.{n}.up.weight', None)
165
+ lora_attn_procs[name].to(device)
166
+ num += 1
167
+
168
+ else:
169
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
170
+
171
+ transformer.set_attn_processor(lora_attn_procs)
172
+
173
+
174
+ def set_single_lora(transformer, local_path, lora_weights=[], cond_size=512):
175
+ checkpoint = load_checkpoint(local_path)
176
+ update_model_with_lora(checkpoint, lora_weights, transformer, cond_size)
177
+
178
+ def set_multi_lora(transformer, local_paths, lora_weights=[[]], cond_size=512):
179
+ checkpoints = [load_checkpoint(local_path) for local_path in local_paths]
180
+ update_model_with_multi_lora(checkpoints, lora_weights, transformer, cond_size)
181
+
182
+ def unset_lora(transformer):
183
+ lora_attn_procs = {}
184
+ for name, attn_processor in transformer.attn_processors.items():
185
+ lora_attn_procs[name] = FluxAttnProcessor2_0()
186
+ transformer.set_attn_processor(lora_attn_procs)
187
+
188
+
189
+ '''
190
+ unset_lora(pipe.transformer)
191
+ lora_path = "./lora.safetensors"
192
+ lora_weights = [1, 1]
193
+ set_lora(pipe.transformer, local_path=lora_path, lora_weights=lora_weights, cond_size=512)
194
+ '''