ret45 commited on
Commit
2774d83
·
verified ·
1 Parent(s): a37e2ab

src_inference/layers_cache.py

Browse files
Files changed (1) hide show
  1. layers_cache.py +366 -0
layers_cache.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import math
3
+ from typing import Callable, List, Optional, Tuple, Union
4
+ from einops import rearrange
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from torch import Tensor
9
+ from diffusers.models.attention_processor import Attention
10
+
11
+ class LoRALinearLayer(nn.Module):
12
+ def __init__(
13
+ self,
14
+ in_features: int,
15
+ out_features: int,
16
+ rank: int = 4,
17
+ network_alpha: Optional[float] = None,
18
+ device: Optional[Union[torch.device, str]] = None,
19
+ dtype: Optional[torch.dtype] = None,
20
+ cond_width=512,
21
+ cond_height=512,
22
+ number=0,
23
+ n_loras=1
24
+ ):
25
+ super().__init__()
26
+ self.down = nn.Linear(in_features, rank, bias=False, device=device, dtype=dtype)
27
+ self.up = nn.Linear(rank, out_features, bias=False, device=device, dtype=dtype)
28
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
29
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
30
+ self.network_alpha = network_alpha
31
+ self.rank = rank
32
+ self.out_features = out_features
33
+ self.in_features = in_features
34
+
35
+ nn.init.normal_(self.down.weight, std=1 / rank)
36
+ nn.init.zeros_(self.up.weight)
37
+
38
+ self.cond_height = cond_height
39
+ self.cond_width = cond_width
40
+ self.number = number
41
+ self.n_loras = n_loras
42
+
43
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
44
+ orig_dtype = hidden_states.dtype
45
+ dtype = self.down.weight.dtype
46
+
47
+ ####
48
+ batch_size = hidden_states.shape[0]
49
+ cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
50
+ block_size = hidden_states.shape[1] - cond_size * self.n_loras
51
+ shape = (batch_size, hidden_states.shape[1], 3072)
52
+ mask = torch.ones(shape, device=hidden_states.device, dtype=dtype)
53
+ mask[:, :block_size+self.number*cond_size, :] = 0
54
+ mask[:, block_size+(self.number+1)*cond_size:, :] = 0
55
+ hidden_states = mask * hidden_states
56
+ ####
57
+
58
+ down_hidden_states = self.down(hidden_states.to(dtype))
59
+ up_hidden_states = self.up(down_hidden_states)
60
+
61
+ if self.network_alpha is not None:
62
+ up_hidden_states *= self.network_alpha / self.rank
63
+
64
+ return up_hidden_states.to(orig_dtype)
65
+
66
+
67
+ class MultiSingleStreamBlockLoraProcessor(nn.Module):
68
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
69
+ super().__init__()
70
+ # Initialize a list to store the LoRA layers
71
+ self.n_loras = n_loras
72
+ self.cond_width = cond_width
73
+ self.cond_height = cond_height
74
+
75
+ self.q_loras = nn.ModuleList([
76
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
77
+ for i in range(n_loras)
78
+ ])
79
+ self.k_loras = nn.ModuleList([
80
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
81
+ for i in range(n_loras)
82
+ ])
83
+ self.v_loras = nn.ModuleList([
84
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
85
+ for i in range(n_loras)
86
+ ])
87
+ self.lora_weights = lora_weights
88
+ self.bank_attn = None
89
+ self.bank_kv = []
90
+
91
+
92
+ def __call__(self,
93
+ attn: Attention,
94
+ hidden_states: torch.FloatTensor,
95
+ encoder_hidden_states: torch.FloatTensor = None,
96
+ attention_mask: Optional[torch.FloatTensor] = None,
97
+ image_rotary_emb: Optional[torch.Tensor] = None,
98
+ use_cond = False,
99
+ image_emb: torch.FloatTensor = None
100
+ ) -> torch.FloatTensor:
101
+
102
+ scaled_cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
103
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
104
+ scaled_seq_len = hidden_states.shape[1]
105
+ block_size = scaled_seq_len - scaled_cond_size * self.n_loras
106
+
107
+ if len(self.bank_kv)== 0:
108
+ cache = True
109
+ else:
110
+ cache = False
111
+
112
+ if cache:
113
+ query = attn.to_q(hidden_states)
114
+ key = attn.to_k(hidden_states)
115
+ value = attn.to_v(hidden_states)
116
+ for i in range(self.n_loras):
117
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
118
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
119
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
120
+
121
+ inner_dim = key.shape[-1]
122
+ head_dim = inner_dim // attn.heads
123
+
124
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
125
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
126
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
127
+
128
+
129
+ self.bank_kv.append(key[:, :, block_size:, :])
130
+ self.bank_kv.append(value[:, :, block_size:, :])
131
+
132
+ if attn.norm_q is not None:
133
+ query = attn.norm_q(query)
134
+ if attn.norm_k is not None:
135
+ key = attn.norm_k(key)
136
+
137
+ if image_rotary_emb is not None:
138
+ from diffusers.models.embeddings import apply_rotary_emb
139
+ query = apply_rotary_emb(query, image_rotary_emb)
140
+ key = apply_rotary_emb(key, image_rotary_emb)
141
+
142
+ num_cond_blocks = self.n_loras
143
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
144
+ mask[ :block_size, :] = 0 # First block_size row
145
+ for i in range(num_cond_blocks):
146
+ start = i * scaled_cond_size + block_size
147
+ end = (i + 1) * scaled_cond_size + block_size
148
+ mask[start:end, start:end] = 0 # Diagonal blocks
149
+ mask = mask * -1e20
150
+ mask = mask.to(query.dtype)
151
+
152
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
153
+ else:
154
+ query = attn.to_q(hidden_states)
155
+ key = attn.to_k(hidden_states)
156
+ value = attn.to_v(hidden_states)
157
+
158
+ inner_dim = query.shape[-1]
159
+ head_dim = inner_dim // attn.heads
160
+
161
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
162
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
163
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
164
+
165
+ zero_pad = torch.zeros_like(self.bank_kv[0], dtype=query.dtype, device=query.device)
166
+
167
+
168
+ key = torch.concat([key[:, :, :scaled_seq_len, :], self.bank_kv[0]], dim=-2)
169
+ value = torch.concat([value[:, :, :scaled_seq_len, :], self.bank_kv[1]], dim=-2)
170
+
171
+ if attn.norm_q is not None:
172
+ query = attn.norm_q(query)
173
+ if attn.norm_k is not None:
174
+ key = attn.norm_k(key)
175
+
176
+ query = torch.concat([query[:, :, :scaled_seq_len, :], zero_pad], dim=-2)
177
+
178
+ if image_rotary_emb is not None:
179
+ from diffusers.models.embeddings import apply_rotary_emb
180
+ query = apply_rotary_emb(query, image_rotary_emb)
181
+ key = apply_rotary_emb(key, image_rotary_emb)
182
+
183
+ query = query[:, :, :scaled_seq_len, :]
184
+
185
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
186
+
187
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
188
+ hidden_states = hidden_states.to(query.dtype)
189
+
190
+ hidden_states = hidden_states[:, : scaled_seq_len,:]
191
+
192
+ return hidden_states
193
+
194
+
195
+ class MultiDoubleStreamBlockLoraProcessor(nn.Module):
196
+ def __init__(self, dim: int, ranks=[], lora_weights=[], network_alphas=[], device=None, dtype=None, cond_width=512, cond_height=512, n_loras=1):
197
+ super().__init__()
198
+
199
+ # Initialize a list to store the LoRA layers
200
+ self.n_loras = n_loras
201
+ self.cond_width = cond_width
202
+ self.cond_height = cond_height
203
+ self.q_loras = nn.ModuleList([
204
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
205
+ for i in range(n_loras)
206
+ ])
207
+ self.k_loras = nn.ModuleList([
208
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
209
+ for i in range(n_loras)
210
+ ])
211
+ self.v_loras = nn.ModuleList([
212
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
213
+ for i in range(n_loras)
214
+ ])
215
+ self.proj_loras = nn.ModuleList([
216
+ LoRALinearLayer(dim, dim, ranks[i],network_alphas[i], device=device, dtype=dtype, cond_width=cond_width, cond_height=cond_height, number=i, n_loras=n_loras)
217
+ for i in range(n_loras)
218
+ ])
219
+ self.lora_weights = lora_weights
220
+ self.bank_attn = None
221
+ self.bank_kv = []
222
+
223
+
224
+ def __call__(self,
225
+ attn: Attention,
226
+ hidden_states: torch.FloatTensor,
227
+ encoder_hidden_states: torch.FloatTensor = None,
228
+ attention_mask: Optional[torch.FloatTensor] = None,
229
+ image_rotary_emb: Optional[torch.Tensor] = None,
230
+ use_cond=False,
231
+ image_emb: torch.FloatTensor = None
232
+ ) -> torch.FloatTensor:
233
+
234
+ scaled_cond_size = self.cond_width // 8 * self.cond_height // 8 * 16 // 64
235
+ batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
236
+ block_size = hidden_states.shape[1]
237
+ scaled_seq_len = encoder_hidden_states.shape[1] + hidden_states.shape[1]
238
+ scaled_block_size = scaled_seq_len
239
+
240
+ # `context` projections.
241
+ inner_dim = 3072
242
+ head_dim = inner_dim // attn.heads
243
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
244
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
245
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
246
+
247
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
248
+ batch_size, -1, attn.heads, head_dim
249
+ ).transpose(1, 2)
250
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
251
+ batch_size, -1, attn.heads, head_dim
252
+ ).transpose(1, 2)
253
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
254
+ batch_size, -1, attn.heads, head_dim
255
+ ).transpose(1, 2)
256
+
257
+ if attn.norm_added_q is not None:
258
+ encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
259
+ if attn.norm_added_k is not None:
260
+ encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
261
+
262
+ if len(self.bank_kv)== 0:
263
+ cache = True
264
+ else:
265
+ cache = False
266
+
267
+ if cache:
268
+
269
+ query = attn.to_q(hidden_states)
270
+ key = attn.to_k(hidden_states)
271
+ value = attn.to_v(hidden_states)
272
+ for i in range(self.n_loras):
273
+ query = query + self.lora_weights[i] * self.q_loras[i](hidden_states)
274
+ key = key + self.lora_weights[i] * self.k_loras[i](hidden_states)
275
+ value = value + self.lora_weights[i] * self.v_loras[i](hidden_states)
276
+
277
+ inner_dim = key.shape[-1]
278
+ head_dim = inner_dim // attn.heads
279
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
280
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
281
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
282
+
283
+
284
+ self.bank_kv.append(key)
285
+ self.bank_kv.append(value)
286
+
287
+ if attn.norm_q is not None:
288
+ query = attn.norm_q(query)
289
+ if attn.norm_k is not None:
290
+ key = attn.norm_k(key)
291
+
292
+ # attention
293
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
294
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
295
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
296
+
297
+ if image_rotary_emb is not None:
298
+ from diffusers.models.embeddings import apply_rotary_emb
299
+ query = apply_rotary_emb(query, image_rotary_emb)
300
+ key = apply_rotary_emb(key, image_rotary_emb)
301
+
302
+ num_cond_blocks = self.n_loras
303
+ mask = torch.ones((scaled_seq_len, scaled_seq_len), device=hidden_states.device)
304
+ mask[ :scaled_block_size-block_size, :] = 0 # First block_size row
305
+ for i in range(num_cond_blocks):
306
+ start = i * scaled_cond_size + scaled_block_size-block_size
307
+ end = (i + 1) * scaled_cond_size + scaled_block_size-block_size
308
+ mask[start:end, start:end] = 0 # Diagonal blocks
309
+ mask = mask * -1e20
310
+ mask = mask.to(query.dtype)
311
+
312
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=mask)
313
+
314
+ else:
315
+ query = attn.to_q(hidden_states)
316
+ key = attn.to_k(hidden_states)
317
+ value = attn.to_v(hidden_states)
318
+
319
+ inner_dim = query.shape[-1]
320
+ head_dim = inner_dim // attn.heads
321
+
322
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
323
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
324
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
325
+
326
+ zero_pad = torch.zeros_like(self.bank_kv[0], dtype=query.dtype, device=query.device)
327
+
328
+ key = torch.concat([key[:, :, :block_size, :], self.bank_kv[0]], dim=-2)
329
+ value = torch.concat([value[:, :, :block_size, :], self.bank_kv[1]], dim=-2)
330
+
331
+ if attn.norm_q is not None:
332
+ query = attn.norm_q(query)
333
+ if attn.norm_k is not None:
334
+ key = attn.norm_k(key)
335
+
336
+ query = torch.concat([query[:, :, :block_size, :], zero_pad], dim=-2)
337
+
338
+ # attention
339
+ query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
340
+ key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
341
+ value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
342
+
343
+ if image_rotary_emb is not None:
344
+ from diffusers.models.embeddings import apply_rotary_emb
345
+ query = apply_rotary_emb(query, image_rotary_emb)
346
+ key = apply_rotary_emb(key, image_rotary_emb)
347
+
348
+ query = query[:, :, :scaled_block_size, :]
349
+
350
+ hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False, attn_mask=None)
351
+
352
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
353
+ hidden_states = hidden_states.to(query.dtype)
354
+
355
+ encoder_hidden_states, hidden_states = (
356
+ hidden_states[:, : encoder_hidden_states.shape[1]],
357
+ hidden_states[:, encoder_hidden_states.shape[1] :],
358
+ )
359
+
360
+ # Linear projection (with LoRA weight applied to each proj layer)
361
+ hidden_states = attn.to_out[0](hidden_states)
362
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
363
+
364
+ hidden_states = hidden_states[:, :block_size,:]
365
+
366
+ return hidden_states, encoder_hidden_states