_Noxty commited on
Commit
b9508e3
·
verified ·
1 Parent(s): ab4641a

Update libs/infer_packs/attentions_onnx.py

Browse files
Files changed (1) hide show
  1. libs/infer_packs/attentions_onnx.py +459 -459
libs/infer_packs/attentions_onnx.py CHANGED
@@ -1,459 +1,459 @@
1
- ############################## Warning! ##############################
2
- # #
3
- # Onnx Export Not Support All Of Non-Torch Types #
4
- # Include Python Built-in Types!!!!!!!!!!!!!!!!! #
5
- # If You Want TO Change This File #
6
- # Do Not Use All Of Non-Torch Types! #
7
- # #
8
- ############################## Warning! ##############################
9
- import copy
10
- import math
11
- from typing import Optional
12
-
13
- import numpy as np
14
- import torch
15
- from torch import nn
16
- from torch.nn import functional as F
17
-
18
- from infer.lib.infer_pack import commons, modules
19
- from infer.lib.infer_pack.modules import LayerNorm
20
-
21
-
22
- class Encoder(nn.Module):
23
- def __init__(
24
- self,
25
- hidden_channels,
26
- filter_channels,
27
- n_heads,
28
- n_layers,
29
- kernel_size=1,
30
- p_dropout=0.0,
31
- window_size=10,
32
- **kwargs
33
- ):
34
- super(Encoder, self).__init__()
35
- self.hidden_channels = hidden_channels
36
- self.filter_channels = filter_channels
37
- self.n_heads = n_heads
38
- self.n_layers = int(n_layers)
39
- self.kernel_size = kernel_size
40
- self.p_dropout = p_dropout
41
- self.window_size = window_size
42
-
43
- self.drop = nn.Dropout(p_dropout)
44
- self.attn_layers = nn.ModuleList()
45
- self.norm_layers_1 = nn.ModuleList()
46
- self.ffn_layers = nn.ModuleList()
47
- self.norm_layers_2 = nn.ModuleList()
48
- for i in range(self.n_layers):
49
- self.attn_layers.append(
50
- MultiHeadAttention(
51
- hidden_channels,
52
- hidden_channels,
53
- n_heads,
54
- p_dropout=p_dropout,
55
- window_size=window_size,
56
- )
57
- )
58
- self.norm_layers_1.append(LayerNorm(hidden_channels))
59
- self.ffn_layers.append(
60
- FFN(
61
- hidden_channels,
62
- hidden_channels,
63
- filter_channels,
64
- kernel_size,
65
- p_dropout=p_dropout,
66
- )
67
- )
68
- self.norm_layers_2.append(LayerNorm(hidden_channels))
69
-
70
- def forward(self, x, x_mask):
71
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
72
- x = x * x_mask
73
- zippep = zip(
74
- self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
75
- )
76
- for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
77
- y = attn_layers(x, x, attn_mask)
78
- y = self.drop(y)
79
- x = norm_layers_1(x + y)
80
-
81
- y = ffn_layers(x, x_mask)
82
- y = self.drop(y)
83
- x = norm_layers_2(x + y)
84
- x = x * x_mask
85
- return x
86
-
87
-
88
- class Decoder(nn.Module):
89
- def __init__(
90
- self,
91
- hidden_channels,
92
- filter_channels,
93
- n_heads,
94
- n_layers,
95
- kernel_size=1,
96
- p_dropout=0.0,
97
- proximal_bias=False,
98
- proximal_init=True,
99
- **kwargs
100
- ):
101
- super(Decoder, self).__init__()
102
- self.hidden_channels = hidden_channels
103
- self.filter_channels = filter_channels
104
- self.n_heads = n_heads
105
- self.n_layers = n_layers
106
- self.kernel_size = kernel_size
107
- self.p_dropout = p_dropout
108
- self.proximal_bias = proximal_bias
109
- self.proximal_init = proximal_init
110
-
111
- self.drop = nn.Dropout(p_dropout)
112
- self.self_attn_layers = nn.ModuleList()
113
- self.norm_layers_0 = nn.ModuleList()
114
- self.encdec_attn_layers = nn.ModuleList()
115
- self.norm_layers_1 = nn.ModuleList()
116
- self.ffn_layers = nn.ModuleList()
117
- self.norm_layers_2 = nn.ModuleList()
118
- for i in range(self.n_layers):
119
- self.self_attn_layers.append(
120
- MultiHeadAttention(
121
- hidden_channels,
122
- hidden_channels,
123
- n_heads,
124
- p_dropout=p_dropout,
125
- proximal_bias=proximal_bias,
126
- proximal_init=proximal_init,
127
- )
128
- )
129
- self.norm_layers_0.append(LayerNorm(hidden_channels))
130
- self.encdec_attn_layers.append(
131
- MultiHeadAttention(
132
- hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
133
- )
134
- )
135
- self.norm_layers_1.append(LayerNorm(hidden_channels))
136
- self.ffn_layers.append(
137
- FFN(
138
- hidden_channels,
139
- hidden_channels,
140
- filter_channels,
141
- kernel_size,
142
- p_dropout=p_dropout,
143
- causal=True,
144
- )
145
- )
146
- self.norm_layers_2.append(LayerNorm(hidden_channels))
147
-
148
- def forward(self, x, x_mask, h, h_mask):
149
- """
150
- x: decoder input
151
- h: encoder output
152
- """
153
- self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
154
- device=x.device, dtype=x.dtype
155
- )
156
- encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
157
- x = x * x_mask
158
- for i in range(self.n_layers):
159
- y = self.self_attn_layers[i](x, x, self_attn_mask)
160
- y = self.drop(y)
161
- x = self.norm_layers_0[i](x + y)
162
-
163
- y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
164
- y = self.drop(y)
165
- x = self.norm_layers_1[i](x + y)
166
-
167
- y = self.ffn_layers[i](x, x_mask)
168
- y = self.drop(y)
169
- x = self.norm_layers_2[i](x + y)
170
- x = x * x_mask
171
- return x
172
-
173
-
174
- class MultiHeadAttention(nn.Module):
175
- def __init__(
176
- self,
177
- channels,
178
- out_channels,
179
- n_heads,
180
- p_dropout=0.0,
181
- window_size=None,
182
- heads_share=True,
183
- block_length=None,
184
- proximal_bias=False,
185
- proximal_init=False,
186
- ):
187
- super(MultiHeadAttention, self).__init__()
188
- assert channels % n_heads == 0
189
-
190
- self.channels = channels
191
- self.out_channels = out_channels
192
- self.n_heads = n_heads
193
- self.p_dropout = p_dropout
194
- self.window_size = window_size
195
- self.heads_share = heads_share
196
- self.block_length = block_length
197
- self.proximal_bias = proximal_bias
198
- self.proximal_init = proximal_init
199
- self.attn = None
200
-
201
- self.k_channels = channels // n_heads
202
- self.conv_q = nn.Conv1d(channels, channels, 1)
203
- self.conv_k = nn.Conv1d(channels, channels, 1)
204
- self.conv_v = nn.Conv1d(channels, channels, 1)
205
- self.conv_o = nn.Conv1d(channels, out_channels, 1)
206
- self.drop = nn.Dropout(p_dropout)
207
-
208
- if window_size is not None:
209
- n_heads_rel = 1 if heads_share else n_heads
210
- rel_stddev = self.k_channels**-0.5
211
- self.emb_rel_k = nn.Parameter(
212
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
213
- * rel_stddev
214
- )
215
- self.emb_rel_v = nn.Parameter(
216
- torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
217
- * rel_stddev
218
- )
219
-
220
- nn.init.xavier_uniform_(self.conv_q.weight)
221
- nn.init.xavier_uniform_(self.conv_k.weight)
222
- nn.init.xavier_uniform_(self.conv_v.weight)
223
- if proximal_init:
224
- with torch.no_grad():
225
- self.conv_k.weight.copy_(self.conv_q.weight)
226
- self.conv_k.bias.copy_(self.conv_q.bias)
227
-
228
- def forward(
229
- self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
230
- ):
231
- q = self.conv_q(x)
232
- k = self.conv_k(c)
233
- v = self.conv_v(c)
234
-
235
- x, _ = self.attention(q, k, v, mask=attn_mask)
236
-
237
- x = self.conv_o(x)
238
- return x
239
-
240
- def attention(
241
- self,
242
- query: torch.Tensor,
243
- key: torch.Tensor,
244
- value: torch.Tensor,
245
- mask: Optional[torch.Tensor] = None,
246
- ):
247
- # reshape [b, d, t] -> [b, n_h, t, d_k]
248
- b, d, t_s = key.size()
249
- t_t = query.size(2)
250
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
251
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
252
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
253
-
254
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
255
- if self.window_size is not None:
256
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
257
- rel_logits = self._matmul_with_relative_keys(
258
- query / math.sqrt(self.k_channels), key_relative_embeddings
259
- )
260
- scores_local = self._relative_position_to_absolute_position(rel_logits)
261
- scores = scores + scores_local
262
- if self.proximal_bias:
263
- assert t_s == t_t, "Proximal bias is only available for self-attention."
264
- scores = scores + self._attention_bias_proximal(t_s).to(
265
- device=scores.device, dtype=scores.dtype
266
- )
267
- if mask is not None:
268
- scores = scores.masked_fill(mask == 0, -1e4)
269
- if self.block_length is not None:
270
- assert (
271
- t_s == t_t
272
- ), "Local attention is only available for self-attention."
273
- block_mask = (
274
- torch.ones_like(scores)
275
- .triu(-self.block_length)
276
- .tril(self.block_length)
277
- )
278
- scores = scores.masked_fill(block_mask == 0, -1e4)
279
- p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
280
- p_attn = self.drop(p_attn)
281
- output = torch.matmul(p_attn, value)
282
- if self.window_size is not None:
283
- relative_weights = self._absolute_position_to_relative_position(p_attn)
284
- value_relative_embeddings = self._get_relative_embeddings(
285
- self.emb_rel_v, t_s
286
- )
287
- output = output + self._matmul_with_relative_values(
288
- relative_weights, value_relative_embeddings
289
- )
290
- output = (
291
- output.transpose(2, 3).contiguous().view(b, d, t_t)
292
- ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
293
- return output, p_attn
294
-
295
- def _matmul_with_relative_values(self, x, y):
296
- """
297
- x: [b, h, l, m]
298
- y: [h or 1, m, d]
299
- ret: [b, h, l, d]
300
- """
301
- ret = torch.matmul(x, y.unsqueeze(0))
302
- return ret
303
-
304
- def _matmul_with_relative_keys(self, x, y):
305
- """
306
- x: [b, h, l, d]
307
- y: [h or 1, m, d]
308
- ret: [b, h, l, m]
309
- """
310
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
311
- return ret
312
-
313
- def _get_relative_embeddings(self, relative_embeddings, length):
314
- max_relative_position = 2 * self.window_size + 1
315
- # Pad first before slice to avoid using cond ops.
316
-
317
- pad_length = torch.clamp(length - (self.window_size + 1), min=0)
318
- slice_start_position = torch.clamp((self.window_size + 1) - length, min=0)
319
- slice_end_position = slice_start_position + 2 * length - 1
320
- padded_relative_embeddings = F.pad(
321
- relative_embeddings,
322
- # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
323
- [0, 0, pad_length, pad_length, 0, 0],
324
- )
325
- used_relative_embeddings = padded_relative_embeddings[
326
- :, slice_start_position:slice_end_position
327
- ]
328
- return used_relative_embeddings
329
-
330
- def _relative_position_to_absolute_position(self, x):
331
- """
332
- x: [b, h, l, 2*l-1]
333
- ret: [b, h, l, l]
334
- """
335
- batch, heads, length, _ = x.size()
336
- # Concat columns of pad to shift from relative to absolute indexing.
337
- x = F.pad(
338
- x,
339
- # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
340
- [0, 1, 0, 0, 0, 0, 0, 0],
341
- )
342
-
343
- # Concat extra elements so to add up to shape (len+1, 2*len-1).
344
- x_flat = x.view([batch, heads, length * 2 * length])
345
- x_flat = F.pad(
346
- x_flat,
347
- [0, length - 1, 0, 0, 0, 0],
348
- )
349
-
350
- # Reshape and slice out the padded elements.
351
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
352
- :, :, :length, length - 1 :
353
- ]
354
- return x_final
355
-
356
- def _absolute_position_to_relative_position(self, x):
357
- """
358
- x: [b, h, l, l]
359
- ret: [b, h, l, 2*l-1]
360
- """
361
- batch, heads, length, _ = x.size()
362
- # padd along column
363
- x = F.pad(
364
- x,
365
- [0, length - 1, 0, 0, 0, 0, 0, 0],
366
- )
367
- x_flat = x.view([batch, heads, length*length + length * (length - 1)])
368
- # add 0's in the beginning that will skew the elements after reshape
369
- x_flat = F.pad(
370
- x_flat,
371
- [length, 0, 0, 0, 0, 0],
372
- )
373
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
374
- return x_final
375
-
376
- def _attention_bias_proximal(self, length):
377
- """Bias for self-attention to encourage attention to close positions.
378
- Args:
379
- length: an integer scalar.
380
- Returns:
381
- a Tensor with shape [1, 1, length, length]
382
- """
383
- r = torch.arange(length, dtype=torch.float32)
384
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
385
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
386
-
387
-
388
- class FFN(nn.Module):
389
- def __init__(
390
- self,
391
- in_channels,
392
- out_channels,
393
- filter_channels,
394
- kernel_size,
395
- p_dropout=0.0,
396
- activation: str = None,
397
- causal=False,
398
- ):
399
- super(FFN, self).__init__()
400
- self.in_channels = in_channels
401
- self.out_channels = out_channels
402
- self.filter_channels = filter_channels
403
- self.kernel_size = kernel_size
404
- self.p_dropout = p_dropout
405
- self.activation = activation
406
- self.causal = causal
407
- self.is_activation = True if activation == "gelu" else False
408
- # if causal:
409
- # self.padding = self._causal_padding
410
- # else:
411
- # self.padding = self._same_padding
412
-
413
- self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
414
- self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
415
- self.drop = nn.Dropout(p_dropout)
416
-
417
- def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
418
- if self.causal:
419
- padding = self._causal_padding(x * x_mask)
420
- else:
421
- padding = self._same_padding(x * x_mask)
422
- return padding
423
-
424
- def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
425
- x = self.conv_1(self.padding(x, x_mask))
426
- if self.is_activation:
427
- x = x * torch.sigmoid(1.702 * x)
428
- else:
429
- x = torch.relu(x)
430
- x = self.drop(x)
431
-
432
- x = self.conv_2(self.padding(x, x_mask))
433
- return x * x_mask
434
-
435
- def _causal_padding(self, x):
436
- if self.kernel_size == 1:
437
- return x
438
- pad_l = self.kernel_size - 1
439
- pad_r = 0
440
- # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
441
- x = F.pad(
442
- x,
443
- # commons.convert_pad_shape(padding)
444
- [pad_l, pad_r, 0, 0, 0, 0],
445
- )
446
- return x
447
-
448
- def _same_padding(self, x):
449
- if self.kernel_size == 1:
450
- return x
451
- pad_l = (self.kernel_size - 1) // 2
452
- pad_r = self.kernel_size // 2
453
- # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
- x = F.pad(
455
- x,
456
- # commons.convert_pad_shape(padding)
457
- [pad_l, pad_r, 0, 0, 0, 0],
458
- )
459
- return x
 
1
+ ############################## Warning! ##############################
2
+ # #
3
+ # Onnx Export Not Support All Of Non-Torch Types #
4
+ # Include Python Built-in Types!!!!!!!!!!!!!!!!! #
5
+ # If You Want TO Change This File #
6
+ # Do Not Use All Of Non-Torch Types! #
7
+ # #
8
+ ############################## Warning! ##############################
9
+ import copy
10
+ import math
11
+ from typing import Optional
12
+
13
+ import numpy as np
14
+ import torch
15
+ from torch import nn
16
+ from torch.nn import functional as F
17
+
18
+ from libs.infer_pack import commons, modules
19
+ from libs.infer_pack.modules import LayerNorm
20
+
21
+
22
+ class Encoder(nn.Module):
23
+ def __init__(
24
+ self,
25
+ hidden_channels,
26
+ filter_channels,
27
+ n_heads,
28
+ n_layers,
29
+ kernel_size=1,
30
+ p_dropout=0.0,
31
+ window_size=10,
32
+ **kwargs
33
+ ):
34
+ super(Encoder, self).__init__()
35
+ self.hidden_channels = hidden_channels
36
+ self.filter_channels = filter_channels
37
+ self.n_heads = n_heads
38
+ self.n_layers = int(n_layers)
39
+ self.kernel_size = kernel_size
40
+ self.p_dropout = p_dropout
41
+ self.window_size = window_size
42
+
43
+ self.drop = nn.Dropout(p_dropout)
44
+ self.attn_layers = nn.ModuleList()
45
+ self.norm_layers_1 = nn.ModuleList()
46
+ self.ffn_layers = nn.ModuleList()
47
+ self.norm_layers_2 = nn.ModuleList()
48
+ for i in range(self.n_layers):
49
+ self.attn_layers.append(
50
+ MultiHeadAttention(
51
+ hidden_channels,
52
+ hidden_channels,
53
+ n_heads,
54
+ p_dropout=p_dropout,
55
+ window_size=window_size,
56
+ )
57
+ )
58
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
59
+ self.ffn_layers.append(
60
+ FFN(
61
+ hidden_channels,
62
+ hidden_channels,
63
+ filter_channels,
64
+ kernel_size,
65
+ p_dropout=p_dropout,
66
+ )
67
+ )
68
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
69
+
70
+ def forward(self, x, x_mask):
71
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
72
+ x = x * x_mask
73
+ zippep = zip(
74
+ self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
75
+ )
76
+ for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zippep:
77
+ y = attn_layers(x, x, attn_mask)
78
+ y = self.drop(y)
79
+ x = norm_layers_1(x + y)
80
+
81
+ y = ffn_layers(x, x_mask)
82
+ y = self.drop(y)
83
+ x = norm_layers_2(x + y)
84
+ x = x * x_mask
85
+ return x
86
+
87
+
88
+ class Decoder(nn.Module):
89
+ def __init__(
90
+ self,
91
+ hidden_channels,
92
+ filter_channels,
93
+ n_heads,
94
+ n_layers,
95
+ kernel_size=1,
96
+ p_dropout=0.0,
97
+ proximal_bias=False,
98
+ proximal_init=True,
99
+ **kwargs
100
+ ):
101
+ super(Decoder, self).__init__()
102
+ self.hidden_channels = hidden_channels
103
+ self.filter_channels = filter_channels
104
+ self.n_heads = n_heads
105
+ self.n_layers = n_layers
106
+ self.kernel_size = kernel_size
107
+ self.p_dropout = p_dropout
108
+ self.proximal_bias = proximal_bias
109
+ self.proximal_init = proximal_init
110
+
111
+ self.drop = nn.Dropout(p_dropout)
112
+ self.self_attn_layers = nn.ModuleList()
113
+ self.norm_layers_0 = nn.ModuleList()
114
+ self.encdec_attn_layers = nn.ModuleList()
115
+ self.norm_layers_1 = nn.ModuleList()
116
+ self.ffn_layers = nn.ModuleList()
117
+ self.norm_layers_2 = nn.ModuleList()
118
+ for i in range(self.n_layers):
119
+ self.self_attn_layers.append(
120
+ MultiHeadAttention(
121
+ hidden_channels,
122
+ hidden_channels,
123
+ n_heads,
124
+ p_dropout=p_dropout,
125
+ proximal_bias=proximal_bias,
126
+ proximal_init=proximal_init,
127
+ )
128
+ )
129
+ self.norm_layers_0.append(LayerNorm(hidden_channels))
130
+ self.encdec_attn_layers.append(
131
+ MultiHeadAttention(
132
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
133
+ )
134
+ )
135
+ self.norm_layers_1.append(LayerNorm(hidden_channels))
136
+ self.ffn_layers.append(
137
+ FFN(
138
+ hidden_channels,
139
+ hidden_channels,
140
+ filter_channels,
141
+ kernel_size,
142
+ p_dropout=p_dropout,
143
+ causal=True,
144
+ )
145
+ )
146
+ self.norm_layers_2.append(LayerNorm(hidden_channels))
147
+
148
+ def forward(self, x, x_mask, h, h_mask):
149
+ """
150
+ x: decoder input
151
+ h: encoder output
152
+ """
153
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
154
+ device=x.device, dtype=x.dtype
155
+ )
156
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
157
+ x = x * x_mask
158
+ for i in range(self.n_layers):
159
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
160
+ y = self.drop(y)
161
+ x = self.norm_layers_0[i](x + y)
162
+
163
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
164
+ y = self.drop(y)
165
+ x = self.norm_layers_1[i](x + y)
166
+
167
+ y = self.ffn_layers[i](x, x_mask)
168
+ y = self.drop(y)
169
+ x = self.norm_layers_2[i](x + y)
170
+ x = x * x_mask
171
+ return x
172
+
173
+
174
+ class MultiHeadAttention(nn.Module):
175
+ def __init__(
176
+ self,
177
+ channels,
178
+ out_channels,
179
+ n_heads,
180
+ p_dropout=0.0,
181
+ window_size=None,
182
+ heads_share=True,
183
+ block_length=None,
184
+ proximal_bias=False,
185
+ proximal_init=False,
186
+ ):
187
+ super(MultiHeadAttention, self).__init__()
188
+ assert channels % n_heads == 0
189
+
190
+ self.channels = channels
191
+ self.out_channels = out_channels
192
+ self.n_heads = n_heads
193
+ self.p_dropout = p_dropout
194
+ self.window_size = window_size
195
+ self.heads_share = heads_share
196
+ self.block_length = block_length
197
+ self.proximal_bias = proximal_bias
198
+ self.proximal_init = proximal_init
199
+ self.attn = None
200
+
201
+ self.k_channels = channels // n_heads
202
+ self.conv_q = nn.Conv1d(channels, channels, 1)
203
+ self.conv_k = nn.Conv1d(channels, channels, 1)
204
+ self.conv_v = nn.Conv1d(channels, channels, 1)
205
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
206
+ self.drop = nn.Dropout(p_dropout)
207
+
208
+ if window_size is not None:
209
+ n_heads_rel = 1 if heads_share else n_heads
210
+ rel_stddev = self.k_channels**-0.5
211
+ self.emb_rel_k = nn.Parameter(
212
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
213
+ * rel_stddev
214
+ )
215
+ self.emb_rel_v = nn.Parameter(
216
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
217
+ * rel_stddev
218
+ )
219
+
220
+ nn.init.xavier_uniform_(self.conv_q.weight)
221
+ nn.init.xavier_uniform_(self.conv_k.weight)
222
+ nn.init.xavier_uniform_(self.conv_v.weight)
223
+ if proximal_init:
224
+ with torch.no_grad():
225
+ self.conv_k.weight.copy_(self.conv_q.weight)
226
+ self.conv_k.bias.copy_(self.conv_q.bias)
227
+
228
+ def forward(
229
+ self, x: torch.Tensor, c: torch.Tensor, attn_mask: Optional[torch.Tensor] = None
230
+ ):
231
+ q = self.conv_q(x)
232
+ k = self.conv_k(c)
233
+ v = self.conv_v(c)
234
+
235
+ x, _ = self.attention(q, k, v, mask=attn_mask)
236
+
237
+ x = self.conv_o(x)
238
+ return x
239
+
240
+ def attention(
241
+ self,
242
+ query: torch.Tensor,
243
+ key: torch.Tensor,
244
+ value: torch.Tensor,
245
+ mask: Optional[torch.Tensor] = None,
246
+ ):
247
+ # reshape [b, d, t] -> [b, n_h, t, d_k]
248
+ b, d, t_s = key.size()
249
+ t_t = query.size(2)
250
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
251
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
252
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
253
+
254
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
255
+ if self.window_size is not None:
256
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
257
+ rel_logits = self._matmul_with_relative_keys(
258
+ query / math.sqrt(self.k_channels), key_relative_embeddings
259
+ )
260
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
261
+ scores = scores + scores_local
262
+ if self.proximal_bias:
263
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
264
+ scores = scores + self._attention_bias_proximal(t_s).to(
265
+ device=scores.device, dtype=scores.dtype
266
+ )
267
+ if mask is not None:
268
+ scores = scores.masked_fill(mask == 0, -1e4)
269
+ if self.block_length is not None:
270
+ assert (
271
+ t_s == t_t
272
+ ), "Local attention is only available for self-attention."
273
+ block_mask = (
274
+ torch.ones_like(scores)
275
+ .triu(-self.block_length)
276
+ .tril(self.block_length)
277
+ )
278
+ scores = scores.masked_fill(block_mask == 0, -1e4)
279
+ p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
280
+ p_attn = self.drop(p_attn)
281
+ output = torch.matmul(p_attn, value)
282
+ if self.window_size is not None:
283
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
284
+ value_relative_embeddings = self._get_relative_embeddings(
285
+ self.emb_rel_v, t_s
286
+ )
287
+ output = output + self._matmul_with_relative_values(
288
+ relative_weights, value_relative_embeddings
289
+ )
290
+ output = (
291
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
292
+ ) # [b, n_h, t_t, d_k] -> [b, d, t_t]
293
+ return output, p_attn
294
+
295
+ def _matmul_with_relative_values(self, x, y):
296
+ """
297
+ x: [b, h, l, m]
298
+ y: [h or 1, m, d]
299
+ ret: [b, h, l, d]
300
+ """
301
+ ret = torch.matmul(x, y.unsqueeze(0))
302
+ return ret
303
+
304
+ def _matmul_with_relative_keys(self, x, y):
305
+ """
306
+ x: [b, h, l, d]
307
+ y: [h or 1, m, d]
308
+ ret: [b, h, l, m]
309
+ """
310
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
311
+ return ret
312
+
313
+ def _get_relative_embeddings(self, relative_embeddings, length):
314
+ max_relative_position = 2 * self.window_size + 1
315
+ # Pad first before slice to avoid using cond ops.
316
+
317
+ pad_length = torch.clamp(length - (self.window_size + 1), min=0)
318
+ slice_start_position = torch.clamp((self.window_size + 1) - length, min=0)
319
+ slice_end_position = slice_start_position + 2 * length - 1
320
+ padded_relative_embeddings = F.pad(
321
+ relative_embeddings,
322
+ # commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
323
+ [0, 0, pad_length, pad_length, 0, 0],
324
+ )
325
+ used_relative_embeddings = padded_relative_embeddings[
326
+ :, slice_start_position:slice_end_position
327
+ ]
328
+ return used_relative_embeddings
329
+
330
+ def _relative_position_to_absolute_position(self, x):
331
+ """
332
+ x: [b, h, l, 2*l-1]
333
+ ret: [b, h, l, l]
334
+ """
335
+ batch, heads, length, _ = x.size()
336
+ # Concat columns of pad to shift from relative to absolute indexing.
337
+ x = F.pad(
338
+ x,
339
+ # commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]])
340
+ [0, 1, 0, 0, 0, 0, 0, 0],
341
+ )
342
+
343
+ # Concat extra elements so to add up to shape (len+1, 2*len-1).
344
+ x_flat = x.view([batch, heads, length * 2 * length])
345
+ x_flat = F.pad(
346
+ x_flat,
347
+ [0, length - 1, 0, 0, 0, 0],
348
+ )
349
+
350
+ # Reshape and slice out the padded elements.
351
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
352
+ :, :, :length, length - 1 :
353
+ ]
354
+ return x_final
355
+
356
+ def _absolute_position_to_relative_position(self, x):
357
+ """
358
+ x: [b, h, l, l]
359
+ ret: [b, h, l, 2*l-1]
360
+ """
361
+ batch, heads, length, _ = x.size()
362
+ # padd along column
363
+ x = F.pad(
364
+ x,
365
+ [0, length - 1, 0, 0, 0, 0, 0, 0],
366
+ )
367
+ x_flat = x.view([batch, heads, length*length + length * (length - 1)])
368
+ # add 0's in the beginning that will skew the elements after reshape
369
+ x_flat = F.pad(
370
+ x_flat,
371
+ [length, 0, 0, 0, 0, 0],
372
+ )
373
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
374
+ return x_final
375
+
376
+ def _attention_bias_proximal(self, length):
377
+ """Bias for self-attention to encourage attention to close positions.
378
+ Args:
379
+ length: an integer scalar.
380
+ Returns:
381
+ a Tensor with shape [1, 1, length, length]
382
+ """
383
+ r = torch.arange(length, dtype=torch.float32)
384
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
385
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
386
+
387
+
388
+ class FFN(nn.Module):
389
+ def __init__(
390
+ self,
391
+ in_channels,
392
+ out_channels,
393
+ filter_channels,
394
+ kernel_size,
395
+ p_dropout=0.0,
396
+ activation: str = None,
397
+ causal=False,
398
+ ):
399
+ super(FFN, self).__init__()
400
+ self.in_channels = in_channels
401
+ self.out_channels = out_channels
402
+ self.filter_channels = filter_channels
403
+ self.kernel_size = kernel_size
404
+ self.p_dropout = p_dropout
405
+ self.activation = activation
406
+ self.causal = causal
407
+ self.is_activation = True if activation == "gelu" else False
408
+ # if causal:
409
+ # self.padding = self._causal_padding
410
+ # else:
411
+ # self.padding = self._same_padding
412
+
413
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
414
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
415
+ self.drop = nn.Dropout(p_dropout)
416
+
417
+ def padding(self, x: torch.Tensor, x_mask: torch.Tensor) -> torch.Tensor:
418
+ if self.causal:
419
+ padding = self._causal_padding(x * x_mask)
420
+ else:
421
+ padding = self._same_padding(x * x_mask)
422
+ return padding
423
+
424
+ def forward(self, x: torch.Tensor, x_mask: torch.Tensor):
425
+ x = self.conv_1(self.padding(x, x_mask))
426
+ if self.is_activation:
427
+ x = x * torch.sigmoid(1.702 * x)
428
+ else:
429
+ x = torch.relu(x)
430
+ x = self.drop(x)
431
+
432
+ x = self.conv_2(self.padding(x, x_mask))
433
+ return x * x_mask
434
+
435
+ def _causal_padding(self, x):
436
+ if self.kernel_size == 1:
437
+ return x
438
+ pad_l = self.kernel_size - 1
439
+ pad_r = 0
440
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
441
+ x = F.pad(
442
+ x,
443
+ # commons.convert_pad_shape(padding)
444
+ [pad_l, pad_r, 0, 0, 0, 0],
445
+ )
446
+ return x
447
+
448
+ def _same_padding(self, x):
449
+ if self.kernel_size == 1:
450
+ return x
451
+ pad_l = (self.kernel_size - 1) // 2
452
+ pad_r = self.kernel_size // 2
453
+ # padding = [[0, 0], [0, 0], [pad_l, pad_r]]
454
+ x = F.pad(
455
+ x,
456
+ # commons.convert_pad_shape(padding)
457
+ [pad_l, pad_r, 0, 0, 0, 0],
458
+ )
459
+ return x