_Noxty commited on
Commit
ab4641a
·
verified ·
1 Parent(s): 2cba714

Update libs/infer_packs/modules.py

Browse files
Files changed (1) hide show
  1. libs/infer_packs/modules.py +615 -615
libs/infer_packs/modules.py CHANGED
@@ -1,615 +1,615 @@
1
- import copy
2
- import math
3
- from typing import Optional, Tuple
4
-
5
- import numpy as np
6
- import scipy
7
- import torch
8
- from torch import nn
9
- from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
10
- from torch.nn import functional as F
11
- from torch.nn.utils import remove_weight_norm, weight_norm
12
-
13
- from infer.lib.infer_pack import commons
14
- from infer.lib.infer_pack.commons import get_padding, init_weights
15
- from infer.lib.infer_pack.transforms import piecewise_rational_quadratic_transform
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
-
20
- class LayerNorm(nn.Module):
21
- def __init__(self, channels, eps=1e-5):
22
- super(LayerNorm, self).__init__()
23
- self.channels = channels
24
- self.eps = eps
25
-
26
- self.gamma = nn.Parameter(torch.ones(channels))
27
- self.beta = nn.Parameter(torch.zeros(channels))
28
-
29
- def forward(self, x):
30
- x = x.transpose(1, -1)
31
- x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
- return x.transpose(1, -1)
33
-
34
-
35
- class ConvReluNorm(nn.Module):
36
- def __init__(
37
- self,
38
- in_channels,
39
- hidden_channels,
40
- out_channels,
41
- kernel_size,
42
- n_layers,
43
- p_dropout,
44
- ):
45
- super(ConvReluNorm, self).__init__()
46
- self.in_channels = in_channels
47
- self.hidden_channels = hidden_channels
48
- self.out_channels = out_channels
49
- self.kernel_size = kernel_size
50
- self.n_layers = n_layers
51
- self.p_dropout = float(p_dropout)
52
- assert n_layers > 1, "Number of layers should be larger than 0."
53
-
54
- self.conv_layers = nn.ModuleList()
55
- self.norm_layers = nn.ModuleList()
56
- self.conv_layers.append(
57
- nn.Conv1d(
58
- in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
- )
60
- )
61
- self.norm_layers.append(LayerNorm(hidden_channels))
62
- self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(float(p_dropout)))
63
- for _ in range(n_layers - 1):
64
- self.conv_layers.append(
65
- nn.Conv1d(
66
- hidden_channels,
67
- hidden_channels,
68
- kernel_size,
69
- padding=kernel_size // 2,
70
- )
71
- )
72
- self.norm_layers.append(LayerNorm(hidden_channels))
73
- self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
- self.proj.weight.data.zero_()
75
- self.proj.bias.data.zero_()
76
-
77
- def forward(self, x, x_mask):
78
- x_org = x
79
- for i in range(self.n_layers):
80
- x = self.conv_layers[i](x * x_mask)
81
- x = self.norm_layers[i](x)
82
- x = self.relu_drop(x)
83
- x = x_org + self.proj(x)
84
- return x * x_mask
85
-
86
-
87
- class DDSConv(nn.Module):
88
- """
89
- Dialted and Depth-Separable Convolution
90
- """
91
-
92
- def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
- super(DDSConv, self).__init__()
94
- self.channels = channels
95
- self.kernel_size = kernel_size
96
- self.n_layers = n_layers
97
- self.p_dropout = float(p_dropout)
98
-
99
- self.drop = nn.Dropout(float(p_dropout))
100
- self.convs_sep = nn.ModuleList()
101
- self.convs_1x1 = nn.ModuleList()
102
- self.norms_1 = nn.ModuleList()
103
- self.norms_2 = nn.ModuleList()
104
- for i in range(n_layers):
105
- dilation = kernel_size**i
106
- padding = (kernel_size * dilation - dilation) // 2
107
- self.convs_sep.append(
108
- nn.Conv1d(
109
- channels,
110
- channels,
111
- kernel_size,
112
- groups=channels,
113
- dilation=dilation,
114
- padding=padding,
115
- )
116
- )
117
- self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
- self.norms_1.append(LayerNorm(channels))
119
- self.norms_2.append(LayerNorm(channels))
120
-
121
- def forward(self, x, x_mask, g: Optional[torch.Tensor] = None):
122
- if g is not None:
123
- x = x + g
124
- for i in range(self.n_layers):
125
- y = self.convs_sep[i](x * x_mask)
126
- y = self.norms_1[i](y)
127
- y = F.gelu(y)
128
- y = self.convs_1x1[i](y)
129
- y = self.norms_2[i](y)
130
- y = F.gelu(y)
131
- y = self.drop(y)
132
- x = x + y
133
- return x * x_mask
134
-
135
-
136
- class WN(torch.nn.Module):
137
- def __init__(
138
- self,
139
- hidden_channels,
140
- kernel_size,
141
- dilation_rate,
142
- n_layers,
143
- gin_channels=0,
144
- p_dropout=0,
145
- ):
146
- super(WN, self).__init__()
147
- assert kernel_size % 2 == 1
148
- self.hidden_channels = hidden_channels
149
- self.kernel_size = (kernel_size,)
150
- self.dilation_rate = dilation_rate
151
- self.n_layers = n_layers
152
- self.gin_channels = gin_channels
153
- self.p_dropout = float(p_dropout)
154
-
155
- self.in_layers = torch.nn.ModuleList()
156
- self.res_skip_layers = torch.nn.ModuleList()
157
- self.drop = nn.Dropout(float(p_dropout))
158
-
159
- if gin_channels != 0:
160
- cond_layer = torch.nn.Conv1d(
161
- gin_channels, 2 * hidden_channels * n_layers, 1
162
- )
163
- self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
-
165
- for i in range(n_layers):
166
- dilation = dilation_rate**i
167
- padding = int((kernel_size * dilation - dilation) / 2)
168
- in_layer = torch.nn.Conv1d(
169
- hidden_channels,
170
- 2 * hidden_channels,
171
- kernel_size,
172
- dilation=dilation,
173
- padding=padding,
174
- )
175
- in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
- self.in_layers.append(in_layer)
177
-
178
- # last one is not necessary
179
- if i < n_layers - 1:
180
- res_skip_channels = 2 * hidden_channels
181
- else:
182
- res_skip_channels = hidden_channels
183
-
184
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
- res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
- self.res_skip_layers.append(res_skip_layer)
187
-
188
- def forward(
189
- self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
190
- ):
191
- output = torch.zeros_like(x)
192
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
193
-
194
- if g is not None:
195
- g = self.cond_layer(g)
196
-
197
- for i, (in_layer, res_skip_layer) in enumerate(
198
- zip(self.in_layers, self.res_skip_layers)
199
- ):
200
- x_in = in_layer(x)
201
- if g is not None:
202
- cond_offset = i * 2 * self.hidden_channels
203
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
204
- else:
205
- g_l = torch.zeros_like(x_in)
206
-
207
- acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
208
- acts = self.drop(acts)
209
-
210
- res_skip_acts = res_skip_layer(acts)
211
- if i < self.n_layers - 1:
212
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
213
- x = (x + res_acts) * x_mask
214
- output = output + res_skip_acts[:, self.hidden_channels :, :]
215
- else:
216
- output = output + res_skip_acts
217
- return output * x_mask
218
-
219
- def remove_weight_norm(self):
220
- if self.gin_channels != 0:
221
- torch.nn.utils.remove_weight_norm(self.cond_layer)
222
- for l in self.in_layers:
223
- torch.nn.utils.remove_weight_norm(l)
224
- for l in self.res_skip_layers:
225
- torch.nn.utils.remove_weight_norm(l)
226
-
227
- def __prepare_scriptable__(self):
228
- if self.gin_channels != 0:
229
- for hook in self.cond_layer._forward_pre_hooks.values():
230
- if (
231
- hook.__module__ == "torch.nn.utils.weight_norm"
232
- and hook.__class__.__name__ == "WeightNorm"
233
- ):
234
- torch.nn.utils.remove_weight_norm(self.cond_layer)
235
- for l in self.in_layers:
236
- for hook in l._forward_pre_hooks.values():
237
- if (
238
- hook.__module__ == "torch.nn.utils.weight_norm"
239
- and hook.__class__.__name__ == "WeightNorm"
240
- ):
241
- torch.nn.utils.remove_weight_norm(l)
242
- for l in self.res_skip_layers:
243
- for hook in l._forward_pre_hooks.values():
244
- if (
245
- hook.__module__ == "torch.nn.utils.weight_norm"
246
- and hook.__class__.__name__ == "WeightNorm"
247
- ):
248
- torch.nn.utils.remove_weight_norm(l)
249
- return self
250
-
251
-
252
- class ResBlock1(torch.nn.Module):
253
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
254
- super(ResBlock1, self).__init__()
255
- self.convs1 = nn.ModuleList(
256
- [
257
- weight_norm(
258
- Conv1d(
259
- channels,
260
- channels,
261
- kernel_size,
262
- 1,
263
- dilation=dilation[0],
264
- padding=get_padding(kernel_size, dilation[0]),
265
- )
266
- ),
267
- weight_norm(
268
- Conv1d(
269
- channels,
270
- channels,
271
- kernel_size,
272
- 1,
273
- dilation=dilation[1],
274
- padding=get_padding(kernel_size, dilation[1]),
275
- )
276
- ),
277
- weight_norm(
278
- Conv1d(
279
- channels,
280
- channels,
281
- kernel_size,
282
- 1,
283
- dilation=dilation[2],
284
- padding=get_padding(kernel_size, dilation[2]),
285
- )
286
- ),
287
- ]
288
- )
289
- self.convs1.apply(init_weights)
290
-
291
- self.convs2 = nn.ModuleList(
292
- [
293
- weight_norm(
294
- Conv1d(
295
- channels,
296
- channels,
297
- kernel_size,
298
- 1,
299
- dilation=1,
300
- padding=get_padding(kernel_size, 1),
301
- )
302
- ),
303
- weight_norm(
304
- Conv1d(
305
- channels,
306
- channels,
307
- kernel_size,
308
- 1,
309
- dilation=1,
310
- padding=get_padding(kernel_size, 1),
311
- )
312
- ),
313
- weight_norm(
314
- Conv1d(
315
- channels,
316
- channels,
317
- kernel_size,
318
- 1,
319
- dilation=1,
320
- padding=get_padding(kernel_size, 1),
321
- )
322
- ),
323
- ]
324
- )
325
- self.convs2.apply(init_weights)
326
- self.lrelu_slope = LRELU_SLOPE
327
-
328
- def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None):
329
- for c1, c2 in zip(self.convs1, self.convs2):
330
- xt = F.leaky_relu(x, self.lrelu_slope)
331
- if x_mask is not None:
332
- xt = xt * x_mask
333
- xt = c1(xt)
334
- xt = F.leaky_relu(xt, self.lrelu_slope)
335
- if x_mask is not None:
336
- xt = xt * x_mask
337
- xt = c2(xt)
338
- x = xt + x
339
- if x_mask is not None:
340
- x = x * x_mask
341
- return x
342
-
343
- def remove_weight_norm(self):
344
- for l in self.convs1:
345
- remove_weight_norm(l)
346
- for l in self.convs2:
347
- remove_weight_norm(l)
348
-
349
- def __prepare_scriptable__(self):
350
- for l in self.convs1:
351
- for hook in l._forward_pre_hooks.values():
352
- if (
353
- hook.__module__ == "torch.nn.utils.weight_norm"
354
- and hook.__class__.__name__ == "WeightNorm"
355
- ):
356
- torch.nn.utils.remove_weight_norm(l)
357
- for l in self.convs2:
358
- for hook in l._forward_pre_hooks.values():
359
- if (
360
- hook.__module__ == "torch.nn.utils.weight_norm"
361
- and hook.__class__.__name__ == "WeightNorm"
362
- ):
363
- torch.nn.utils.remove_weight_norm(l)
364
- return self
365
-
366
-
367
- class ResBlock2(torch.nn.Module):
368
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
369
- super(ResBlock2, self).__init__()
370
- self.convs = nn.ModuleList(
371
- [
372
- weight_norm(
373
- Conv1d(
374
- channels,
375
- channels,
376
- kernel_size,
377
- 1,
378
- dilation=dilation[0],
379
- padding=get_padding(kernel_size, dilation[0]),
380
- )
381
- ),
382
- weight_norm(
383
- Conv1d(
384
- channels,
385
- channels,
386
- kernel_size,
387
- 1,
388
- dilation=dilation[1],
389
- padding=get_padding(kernel_size, dilation[1]),
390
- )
391
- ),
392
- ]
393
- )
394
- self.convs.apply(init_weights)
395
- self.lrelu_slope = LRELU_SLOPE
396
-
397
- def forward(self, x, x_mask: Optional[torch.Tensor] = None):
398
- for c in self.convs:
399
- xt = F.leaky_relu(x, self.lrelu_slope)
400
- if x_mask is not None:
401
- xt = xt * x_mask
402
- xt = c(xt)
403
- x = xt + x
404
- if x_mask is not None:
405
- x = x * x_mask
406
- return x
407
-
408
- def remove_weight_norm(self):
409
- for l in self.convs:
410
- remove_weight_norm(l)
411
-
412
- def __prepare_scriptable__(self):
413
- for l in self.convs:
414
- for hook in l._forward_pre_hooks.values():
415
- if (
416
- hook.__module__ == "torch.nn.utils.weight_norm"
417
- and hook.__class__.__name__ == "WeightNorm"
418
- ):
419
- torch.nn.utils.remove_weight_norm(l)
420
- return self
421
-
422
-
423
- class Log(nn.Module):
424
- def forward(
425
- self,
426
- x: torch.Tensor,
427
- x_mask: torch.Tensor,
428
- g: Optional[torch.Tensor] = None,
429
- reverse: bool = False,
430
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
431
- if not reverse:
432
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
433
- logdet = torch.sum(-y, [1, 2])
434
- return y, logdet
435
- else:
436
- x = torch.exp(x) * x_mask
437
- return x
438
-
439
-
440
- class Flip(nn.Module):
441
- # torch.jit.script() Compiled functions \
442
- # can't take variable number of arguments or \
443
- # use keyword-only arguments with defaults
444
- def forward(
445
- self,
446
- x: torch.Tensor,
447
- x_mask: torch.Tensor,
448
- g: Optional[torch.Tensor] = None,
449
- reverse: bool = False,
450
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
451
- x = torch.flip(x, [1])
452
- if not reverse:
453
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
454
- return x, logdet
455
- else:
456
- return x, torch.zeros([1], device=x.device)
457
-
458
-
459
- class ElementwiseAffine(nn.Module):
460
- def __init__(self, channels):
461
- super(ElementwiseAffine, self).__init__()
462
- self.channels = channels
463
- self.m = nn.Parameter(torch.zeros(channels, 1))
464
- self.logs = nn.Parameter(torch.zeros(channels, 1))
465
-
466
- def forward(self, x, x_mask, reverse=False, **kwargs):
467
- if not reverse:
468
- y = self.m + torch.exp(self.logs) * x
469
- y = y * x_mask
470
- logdet = torch.sum(self.logs * x_mask, [1, 2])
471
- return y, logdet
472
- else:
473
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
474
- return x
475
-
476
-
477
- class ResidualCouplingLayer(nn.Module):
478
- def __init__(
479
- self,
480
- channels,
481
- hidden_channels,
482
- kernel_size,
483
- dilation_rate,
484
- n_layers,
485
- p_dropout=0,
486
- gin_channels=0,
487
- mean_only=False,
488
- ):
489
- assert channels % 2 == 0, "channels should be divisible by 2"
490
- super(ResidualCouplingLayer, self).__init__()
491
- self.channels = channels
492
- self.hidden_channels = hidden_channels
493
- self.kernel_size = kernel_size
494
- self.dilation_rate = dilation_rate
495
- self.n_layers = n_layers
496
- self.half_channels = channels // 2
497
- self.mean_only = mean_only
498
-
499
- self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
500
- self.enc = WN(
501
- hidden_channels,
502
- kernel_size,
503
- dilation_rate,
504
- n_layers,
505
- p_dropout=float(p_dropout),
506
- gin_channels=gin_channels,
507
- )
508
- self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
509
- self.post.weight.data.zero_()
510
- self.post.bias.data.zero_()
511
-
512
- def forward(
513
- self,
514
- x: torch.Tensor,
515
- x_mask: torch.Tensor,
516
- g: Optional[torch.Tensor] = None,
517
- reverse: bool = False,
518
- ):
519
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
520
- h = self.pre(x0) * x_mask
521
- h = self.enc(h, x_mask, g=g)
522
- stats = self.post(h) * x_mask
523
- if not self.mean_only:
524
- m, logs = torch.split(stats, [self.half_channels] * 2, 1)
525
- else:
526
- m = stats
527
- logs = torch.zeros_like(m)
528
-
529
- if not reverse:
530
- x1 = m + x1 * torch.exp(logs) * x_mask
531
- x = torch.cat([x0, x1], 1)
532
- logdet = torch.sum(logs, [1, 2])
533
- return x, logdet
534
- else:
535
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
536
- x = torch.cat([x0, x1], 1)
537
- return x, torch.zeros([1])
538
-
539
- def remove_weight_norm(self):
540
- self.enc.remove_weight_norm()
541
-
542
- def __prepare_scriptable__(self):
543
- for hook in self.enc._forward_pre_hooks.values():
544
- if (
545
- hook.__module__ == "torch.nn.utils.weight_norm"
546
- and hook.__class__.__name__ == "WeightNorm"
547
- ):
548
- torch.nn.utils.remove_weight_norm(self.enc)
549
- return self
550
-
551
-
552
- class ConvFlow(nn.Module):
553
- def __init__(
554
- self,
555
- in_channels,
556
- filter_channels,
557
- kernel_size,
558
- n_layers,
559
- num_bins=10,
560
- tail_bound=5.0,
561
- ):
562
- super(ConvFlow, self).__init__()
563
- self.in_channels = in_channels
564
- self.filter_channels = filter_channels
565
- self.kernel_size = kernel_size
566
- self.n_layers = n_layers
567
- self.num_bins = num_bins
568
- self.tail_bound = tail_bound
569
- self.half_channels = in_channels // 2
570
-
571
- self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
572
- self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
573
- self.proj = nn.Conv1d(
574
- filter_channels, self.half_channels * (num_bins * 3 - 1), 1
575
- )
576
- self.proj.weight.data.zero_()
577
- self.proj.bias.data.zero_()
578
-
579
- def forward(
580
- self,
581
- x: torch.Tensor,
582
- x_mask: torch.Tensor,
583
- g: Optional[torch.Tensor] = None,
584
- reverse=False,
585
- ):
586
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
587
- h = self.pre(x0)
588
- h = self.convs(h, x_mask, g=g)
589
- h = self.proj(h) * x_mask
590
-
591
- b, c, t = x0.shape
592
- h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
593
-
594
- unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
595
- unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
596
- self.filter_channels
597
- )
598
- unnormalized_derivatives = h[..., 2 * self.num_bins :]
599
-
600
- x1, logabsdet = piecewise_rational_quadratic_transform(
601
- x1,
602
- unnormalized_widths,
603
- unnormalized_heights,
604
- unnormalized_derivatives,
605
- inverse=reverse,
606
- tails="linear",
607
- tail_bound=self.tail_bound,
608
- )
609
-
610
- x = torch.cat([x0, x1], 1) * x_mask
611
- logdet = torch.sum(logabsdet * x_mask, [1, 2])
612
- if not reverse:
613
- return x, logdet
614
- else:
615
- return x
 
1
+ import copy
2
+ import math
3
+ from typing import Optional, Tuple
4
+
5
+ import numpy as np
6
+ import scipy
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
10
+ from torch.nn import functional as F
11
+ from torch.nn.utils import remove_weight_norm, weight_norm
12
+
13
+ from libs.infer_pack import commons
14
+ from libs.infer_pack.commons import get_padding, init_weights
15
+ from libs.infer_pack.transforms import piecewise_rational_quadratic_transform
16
+
17
+ LRELU_SLOPE = 0.1
18
+
19
+
20
+ class LayerNorm(nn.Module):
21
+ def __init__(self, channels, eps=1e-5):
22
+ super(LayerNorm, self).__init__()
23
+ self.channels = channels
24
+ self.eps = eps
25
+
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class ConvReluNorm(nn.Module):
36
+ def __init__(
37
+ self,
38
+ in_channels,
39
+ hidden_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ n_layers,
43
+ p_dropout,
44
+ ):
45
+ super(ConvReluNorm, self).__init__()
46
+ self.in_channels = in_channels
47
+ self.hidden_channels = hidden_channels
48
+ self.out_channels = out_channels
49
+ self.kernel_size = kernel_size
50
+ self.n_layers = n_layers
51
+ self.p_dropout = float(p_dropout)
52
+ assert n_layers > 1, "Number of layers should be larger than 0."
53
+
54
+ self.conv_layers = nn.ModuleList()
55
+ self.norm_layers = nn.ModuleList()
56
+ self.conv_layers.append(
57
+ nn.Conv1d(
58
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
59
+ )
60
+ )
61
+ self.norm_layers.append(LayerNorm(hidden_channels))
62
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(float(p_dropout)))
63
+ for _ in range(n_layers - 1):
64
+ self.conv_layers.append(
65
+ nn.Conv1d(
66
+ hidden_channels,
67
+ hidden_channels,
68
+ kernel_size,
69
+ padding=kernel_size // 2,
70
+ )
71
+ )
72
+ self.norm_layers.append(LayerNorm(hidden_channels))
73
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
74
+ self.proj.weight.data.zero_()
75
+ self.proj.bias.data.zero_()
76
+
77
+ def forward(self, x, x_mask):
78
+ x_org = x
79
+ for i in range(self.n_layers):
80
+ x = self.conv_layers[i](x * x_mask)
81
+ x = self.norm_layers[i](x)
82
+ x = self.relu_drop(x)
83
+ x = x_org + self.proj(x)
84
+ return x * x_mask
85
+
86
+
87
+ class DDSConv(nn.Module):
88
+ """
89
+ Dialted and Depth-Separable Convolution
90
+ """
91
+
92
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
93
+ super(DDSConv, self).__init__()
94
+ self.channels = channels
95
+ self.kernel_size = kernel_size
96
+ self.n_layers = n_layers
97
+ self.p_dropout = float(p_dropout)
98
+
99
+ self.drop = nn.Dropout(float(p_dropout))
100
+ self.convs_sep = nn.ModuleList()
101
+ self.convs_1x1 = nn.ModuleList()
102
+ self.norms_1 = nn.ModuleList()
103
+ self.norms_2 = nn.ModuleList()
104
+ for i in range(n_layers):
105
+ dilation = kernel_size**i
106
+ padding = (kernel_size * dilation - dilation) // 2
107
+ self.convs_sep.append(
108
+ nn.Conv1d(
109
+ channels,
110
+ channels,
111
+ kernel_size,
112
+ groups=channels,
113
+ dilation=dilation,
114
+ padding=padding,
115
+ )
116
+ )
117
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
118
+ self.norms_1.append(LayerNorm(channels))
119
+ self.norms_2.append(LayerNorm(channels))
120
+
121
+ def forward(self, x, x_mask, g: Optional[torch.Tensor] = None):
122
+ if g is not None:
123
+ x = x + g
124
+ for i in range(self.n_layers):
125
+ y = self.convs_sep[i](x * x_mask)
126
+ y = self.norms_1[i](y)
127
+ y = F.gelu(y)
128
+ y = self.convs_1x1[i](y)
129
+ y = self.norms_2[i](y)
130
+ y = F.gelu(y)
131
+ y = self.drop(y)
132
+ x = x + y
133
+ return x * x_mask
134
+
135
+
136
+ class WN(torch.nn.Module):
137
+ def __init__(
138
+ self,
139
+ hidden_channels,
140
+ kernel_size,
141
+ dilation_rate,
142
+ n_layers,
143
+ gin_channels=0,
144
+ p_dropout=0,
145
+ ):
146
+ super(WN, self).__init__()
147
+ assert kernel_size % 2 == 1
148
+ self.hidden_channels = hidden_channels
149
+ self.kernel_size = (kernel_size,)
150
+ self.dilation_rate = dilation_rate
151
+ self.n_layers = n_layers
152
+ self.gin_channels = gin_channels
153
+ self.p_dropout = float(p_dropout)
154
+
155
+ self.in_layers = torch.nn.ModuleList()
156
+ self.res_skip_layers = torch.nn.ModuleList()
157
+ self.drop = nn.Dropout(float(p_dropout))
158
+
159
+ if gin_channels != 0:
160
+ cond_layer = torch.nn.Conv1d(
161
+ gin_channels, 2 * hidden_channels * n_layers, 1
162
+ )
163
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
164
+
165
+ for i in range(n_layers):
166
+ dilation = dilation_rate**i
167
+ padding = int((kernel_size * dilation - dilation) / 2)
168
+ in_layer = torch.nn.Conv1d(
169
+ hidden_channels,
170
+ 2 * hidden_channels,
171
+ kernel_size,
172
+ dilation=dilation,
173
+ padding=padding,
174
+ )
175
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
176
+ self.in_layers.append(in_layer)
177
+
178
+ # last one is not necessary
179
+ if i < n_layers - 1:
180
+ res_skip_channels = 2 * hidden_channels
181
+ else:
182
+ res_skip_channels = hidden_channels
183
+
184
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
185
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
186
+ self.res_skip_layers.append(res_skip_layer)
187
+
188
+ def forward(
189
+ self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None
190
+ ):
191
+ output = torch.zeros_like(x)
192
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
193
+
194
+ if g is not None:
195
+ g = self.cond_layer(g)
196
+
197
+ for i, (in_layer, res_skip_layer) in enumerate(
198
+ zip(self.in_layers, self.res_skip_layers)
199
+ ):
200
+ x_in = in_layer(x)
201
+ if g is not None:
202
+ cond_offset = i * 2 * self.hidden_channels
203
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
204
+ else:
205
+ g_l = torch.zeros_like(x_in)
206
+
207
+ acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
208
+ acts = self.drop(acts)
209
+
210
+ res_skip_acts = res_skip_layer(acts)
211
+ if i < self.n_layers - 1:
212
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
213
+ x = (x + res_acts) * x_mask
214
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
215
+ else:
216
+ output = output + res_skip_acts
217
+ return output * x_mask
218
+
219
+ def remove_weight_norm(self):
220
+ if self.gin_channels != 0:
221
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
222
+ for l in self.in_layers:
223
+ torch.nn.utils.remove_weight_norm(l)
224
+ for l in self.res_skip_layers:
225
+ torch.nn.utils.remove_weight_norm(l)
226
+
227
+ def __prepare_scriptable__(self):
228
+ if self.gin_channels != 0:
229
+ for hook in self.cond_layer._forward_pre_hooks.values():
230
+ if (
231
+ hook.__module__ == "torch.nn.utils.weight_norm"
232
+ and hook.__class__.__name__ == "WeightNorm"
233
+ ):
234
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
235
+ for l in self.in_layers:
236
+ for hook in l._forward_pre_hooks.values():
237
+ if (
238
+ hook.__module__ == "torch.nn.utils.weight_norm"
239
+ and hook.__class__.__name__ == "WeightNorm"
240
+ ):
241
+ torch.nn.utils.remove_weight_norm(l)
242
+ for l in self.res_skip_layers:
243
+ for hook in l._forward_pre_hooks.values():
244
+ if (
245
+ hook.__module__ == "torch.nn.utils.weight_norm"
246
+ and hook.__class__.__name__ == "WeightNorm"
247
+ ):
248
+ torch.nn.utils.remove_weight_norm(l)
249
+ return self
250
+
251
+
252
+ class ResBlock1(torch.nn.Module):
253
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
254
+ super(ResBlock1, self).__init__()
255
+ self.convs1 = nn.ModuleList(
256
+ [
257
+ weight_norm(
258
+ Conv1d(
259
+ channels,
260
+ channels,
261
+ kernel_size,
262
+ 1,
263
+ dilation=dilation[0],
264
+ padding=get_padding(kernel_size, dilation[0]),
265
+ )
266
+ ),
267
+ weight_norm(
268
+ Conv1d(
269
+ channels,
270
+ channels,
271
+ kernel_size,
272
+ 1,
273
+ dilation=dilation[1],
274
+ padding=get_padding(kernel_size, dilation[1]),
275
+ )
276
+ ),
277
+ weight_norm(
278
+ Conv1d(
279
+ channels,
280
+ channels,
281
+ kernel_size,
282
+ 1,
283
+ dilation=dilation[2],
284
+ padding=get_padding(kernel_size, dilation[2]),
285
+ )
286
+ ),
287
+ ]
288
+ )
289
+ self.convs1.apply(init_weights)
290
+
291
+ self.convs2 = nn.ModuleList(
292
+ [
293
+ weight_norm(
294
+ Conv1d(
295
+ channels,
296
+ channels,
297
+ kernel_size,
298
+ 1,
299
+ dilation=1,
300
+ padding=get_padding(kernel_size, 1),
301
+ )
302
+ ),
303
+ weight_norm(
304
+ Conv1d(
305
+ channels,
306
+ channels,
307
+ kernel_size,
308
+ 1,
309
+ dilation=1,
310
+ padding=get_padding(kernel_size, 1),
311
+ )
312
+ ),
313
+ weight_norm(
314
+ Conv1d(
315
+ channels,
316
+ channels,
317
+ kernel_size,
318
+ 1,
319
+ dilation=1,
320
+ padding=get_padding(kernel_size, 1),
321
+ )
322
+ ),
323
+ ]
324
+ )
325
+ self.convs2.apply(init_weights)
326
+ self.lrelu_slope = LRELU_SLOPE
327
+
328
+ def forward(self, x: torch.Tensor, x_mask: Optional[torch.Tensor] = None):
329
+ for c1, c2 in zip(self.convs1, self.convs2):
330
+ xt = F.leaky_relu(x, self.lrelu_slope)
331
+ if x_mask is not None:
332
+ xt = xt * x_mask
333
+ xt = c1(xt)
334
+ xt = F.leaky_relu(xt, self.lrelu_slope)
335
+ if x_mask is not None:
336
+ xt = xt * x_mask
337
+ xt = c2(xt)
338
+ x = xt + x
339
+ if x_mask is not None:
340
+ x = x * x_mask
341
+ return x
342
+
343
+ def remove_weight_norm(self):
344
+ for l in self.convs1:
345
+ remove_weight_norm(l)
346
+ for l in self.convs2:
347
+ remove_weight_norm(l)
348
+
349
+ def __prepare_scriptable__(self):
350
+ for l in self.convs1:
351
+ for hook in l._forward_pre_hooks.values():
352
+ if (
353
+ hook.__module__ == "torch.nn.utils.weight_norm"
354
+ and hook.__class__.__name__ == "WeightNorm"
355
+ ):
356
+ torch.nn.utils.remove_weight_norm(l)
357
+ for l in self.convs2:
358
+ for hook in l._forward_pre_hooks.values():
359
+ if (
360
+ hook.__module__ == "torch.nn.utils.weight_norm"
361
+ and hook.__class__.__name__ == "WeightNorm"
362
+ ):
363
+ torch.nn.utils.remove_weight_norm(l)
364
+ return self
365
+
366
+
367
+ class ResBlock2(torch.nn.Module):
368
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
369
+ super(ResBlock2, self).__init__()
370
+ self.convs = nn.ModuleList(
371
+ [
372
+ weight_norm(
373
+ Conv1d(
374
+ channels,
375
+ channels,
376
+ kernel_size,
377
+ 1,
378
+ dilation=dilation[0],
379
+ padding=get_padding(kernel_size, dilation[0]),
380
+ )
381
+ ),
382
+ weight_norm(
383
+ Conv1d(
384
+ channels,
385
+ channels,
386
+ kernel_size,
387
+ 1,
388
+ dilation=dilation[1],
389
+ padding=get_padding(kernel_size, dilation[1]),
390
+ )
391
+ ),
392
+ ]
393
+ )
394
+ self.convs.apply(init_weights)
395
+ self.lrelu_slope = LRELU_SLOPE
396
+
397
+ def forward(self, x, x_mask: Optional[torch.Tensor] = None):
398
+ for c in self.convs:
399
+ xt = F.leaky_relu(x, self.lrelu_slope)
400
+ if x_mask is not None:
401
+ xt = xt * x_mask
402
+ xt = c(xt)
403
+ x = xt + x
404
+ if x_mask is not None:
405
+ x = x * x_mask
406
+ return x
407
+
408
+ def remove_weight_norm(self):
409
+ for l in self.convs:
410
+ remove_weight_norm(l)
411
+
412
+ def __prepare_scriptable__(self):
413
+ for l in self.convs:
414
+ for hook in l._forward_pre_hooks.values():
415
+ if (
416
+ hook.__module__ == "torch.nn.utils.weight_norm"
417
+ and hook.__class__.__name__ == "WeightNorm"
418
+ ):
419
+ torch.nn.utils.remove_weight_norm(l)
420
+ return self
421
+
422
+
423
+ class Log(nn.Module):
424
+ def forward(
425
+ self,
426
+ x: torch.Tensor,
427
+ x_mask: torch.Tensor,
428
+ g: Optional[torch.Tensor] = None,
429
+ reverse: bool = False,
430
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
431
+ if not reverse:
432
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
433
+ logdet = torch.sum(-y, [1, 2])
434
+ return y, logdet
435
+ else:
436
+ x = torch.exp(x) * x_mask
437
+ return x
438
+
439
+
440
+ class Flip(nn.Module):
441
+ # torch.jit.script() Compiled functions \
442
+ # can't take variable number of arguments or \
443
+ # use keyword-only arguments with defaults
444
+ def forward(
445
+ self,
446
+ x: torch.Tensor,
447
+ x_mask: torch.Tensor,
448
+ g: Optional[torch.Tensor] = None,
449
+ reverse: bool = False,
450
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
451
+ x = torch.flip(x, [1])
452
+ if not reverse:
453
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
454
+ return x, logdet
455
+ else:
456
+ return x, torch.zeros([1], device=x.device)
457
+
458
+
459
+ class ElementwiseAffine(nn.Module):
460
+ def __init__(self, channels):
461
+ super(ElementwiseAffine, self).__init__()
462
+ self.channels = channels
463
+ self.m = nn.Parameter(torch.zeros(channels, 1))
464
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
465
+
466
+ def forward(self, x, x_mask, reverse=False, **kwargs):
467
+ if not reverse:
468
+ y = self.m + torch.exp(self.logs) * x
469
+ y = y * x_mask
470
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
471
+ return y, logdet
472
+ else:
473
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
474
+ return x
475
+
476
+
477
+ class ResidualCouplingLayer(nn.Module):
478
+ def __init__(
479
+ self,
480
+ channels,
481
+ hidden_channels,
482
+ kernel_size,
483
+ dilation_rate,
484
+ n_layers,
485
+ p_dropout=0,
486
+ gin_channels=0,
487
+ mean_only=False,
488
+ ):
489
+ assert channels % 2 == 0, "channels should be divisible by 2"
490
+ super(ResidualCouplingLayer, self).__init__()
491
+ self.channels = channels
492
+ self.hidden_channels = hidden_channels
493
+ self.kernel_size = kernel_size
494
+ self.dilation_rate = dilation_rate
495
+ self.n_layers = n_layers
496
+ self.half_channels = channels // 2
497
+ self.mean_only = mean_only
498
+
499
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
500
+ self.enc = WN(
501
+ hidden_channels,
502
+ kernel_size,
503
+ dilation_rate,
504
+ n_layers,
505
+ p_dropout=float(p_dropout),
506
+ gin_channels=gin_channels,
507
+ )
508
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
509
+ self.post.weight.data.zero_()
510
+ self.post.bias.data.zero_()
511
+
512
+ def forward(
513
+ self,
514
+ x: torch.Tensor,
515
+ x_mask: torch.Tensor,
516
+ g: Optional[torch.Tensor] = None,
517
+ reverse: bool = False,
518
+ ):
519
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
520
+ h = self.pre(x0) * x_mask
521
+ h = self.enc(h, x_mask, g=g)
522
+ stats = self.post(h) * x_mask
523
+ if not self.mean_only:
524
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
525
+ else:
526
+ m = stats
527
+ logs = torch.zeros_like(m)
528
+
529
+ if not reverse:
530
+ x1 = m + x1 * torch.exp(logs) * x_mask
531
+ x = torch.cat([x0, x1], 1)
532
+ logdet = torch.sum(logs, [1, 2])
533
+ return x, logdet
534
+ else:
535
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
536
+ x = torch.cat([x0, x1], 1)
537
+ return x, torch.zeros([1])
538
+
539
+ def remove_weight_norm(self):
540
+ self.enc.remove_weight_norm()
541
+
542
+ def __prepare_scriptable__(self):
543
+ for hook in self.enc._forward_pre_hooks.values():
544
+ if (
545
+ hook.__module__ == "torch.nn.utils.weight_norm"
546
+ and hook.__class__.__name__ == "WeightNorm"
547
+ ):
548
+ torch.nn.utils.remove_weight_norm(self.enc)
549
+ return self
550
+
551
+
552
+ class ConvFlow(nn.Module):
553
+ def __init__(
554
+ self,
555
+ in_channels,
556
+ filter_channels,
557
+ kernel_size,
558
+ n_layers,
559
+ num_bins=10,
560
+ tail_bound=5.0,
561
+ ):
562
+ super(ConvFlow, self).__init__()
563
+ self.in_channels = in_channels
564
+ self.filter_channels = filter_channels
565
+ self.kernel_size = kernel_size
566
+ self.n_layers = n_layers
567
+ self.num_bins = num_bins
568
+ self.tail_bound = tail_bound
569
+ self.half_channels = in_channels // 2
570
+
571
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
572
+ self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
573
+ self.proj = nn.Conv1d(
574
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
575
+ )
576
+ self.proj.weight.data.zero_()
577
+ self.proj.bias.data.zero_()
578
+
579
+ def forward(
580
+ self,
581
+ x: torch.Tensor,
582
+ x_mask: torch.Tensor,
583
+ g: Optional[torch.Tensor] = None,
584
+ reverse=False,
585
+ ):
586
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
587
+ h = self.pre(x0)
588
+ h = self.convs(h, x_mask, g=g)
589
+ h = self.proj(h) * x_mask
590
+
591
+ b, c, t = x0.shape
592
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
593
+
594
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
595
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
596
+ self.filter_channels
597
+ )
598
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
599
+
600
+ x1, logabsdet = piecewise_rational_quadratic_transform(
601
+ x1,
602
+ unnormalized_widths,
603
+ unnormalized_heights,
604
+ unnormalized_derivatives,
605
+ inverse=reverse,
606
+ tails="linear",
607
+ tail_bound=self.tail_bound,
608
+ )
609
+
610
+ x = torch.cat([x0, x1], 1) * x_mask
611
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
612
+ if not reverse:
613
+ return x, logdet
614
+ else:
615
+ return x