Spaces:
Running
Running
File size: 39,250 Bytes
864affd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 |
# *****************************************************************************
# MIT License
#
# Copyright (c) Facebook, Inc. and its affiliates.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# *****************************************************************************
import math
import typing as tp
from typing import Any, Dict, List, Optional
import torch
from torch import nn
from torch.nn import functional as F
class _ScaledEmbedding(torch.nn.Module):
r"""Make continuous embeddings and boost learning rate
Args:
num_embeddings (int): number of embeddings
embedding_dim (int): embedding dimensions
scale (float, optional): amount to scale learning rate (Default: 10.0)
smooth (bool, optional): choose to apply smoothing (Default: ``False``)
"""
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth: bool = False):
super().__init__()
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
if smooth:
weight = torch.cumsum(self.embedding.weight.data, dim=0)
# when summing gaussian, scale raises as sqrt(n), so we normalize by that.
weight = weight / torch.arange(1, num_embeddings + 1).sqrt()[:, None]
self.embedding.weight.data[:] = weight
self.embedding.weight.data /= scale
self.scale = scale
@property
def weight(self) -> torch.Tensor:
return self.embedding.weight * self.scale
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Forward pass for embedding with scale.
Args:
x (torch.Tensor): input tensor of shape `(num_embeddings)`
Returns:
(Tensor):
Embedding output of shape `(num_embeddings, embedding_dim)`
"""
out = self.embedding(x) * self.scale
return out
class _HEncLayer(torch.nn.Module):
r"""Encoder layer. This used both by the time and the frequency branch.
Args:
chin (int): number of input channels.
chout (int): number of output channels.
kernel_size (int, optional): Kernel size for encoder (Default: 8)
stride (int, optional): Stride for encoder layer (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 4)
empty (bool, optional): used to make a layer with just the first conv. this is used
before merging the time and freq. branches. (Default: ``False``)
freq (bool, optional): boolean for whether conv layer is for frequency domain (Default: ``True``)
norm_type (string, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
context (int, optional): context size for the 1x1 conv. (Default: 0)
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
pad (bool, optional): true to pad the input. Padding is done so that the output size is
always the input size / stride. (Default: ``True``)
"""
def __init__(
self,
chin: int,
chout: int,
kernel_size: int = 8,
stride: int = 4,
norm_groups: int = 4,
empty: bool = False,
freq: bool = True,
norm_type: str = "group_norm",
context: int = 0,
dconv_kw: Optional[Dict[str, Any]] = None,
pad: bool = True,
):
super().__init__()
if dconv_kw is None:
dconv_kw = {}
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
pad_val = kernel_size // 4 if pad else 0
klass = nn.Conv1d
self.freq = freq
self.kernel_size = kernel_size
self.stride = stride
self.empty = empty
self.pad = pad_val
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
pad_val = [pad_val, 0]
klass = nn.Conv2d
self.conv = klass(chin, chout, kernel_size, stride, pad_val)
self.norm1 = norm_fn(chout)
if self.empty:
self.rewrite = nn.Identity()
self.norm2 = nn.Identity()
self.dconv = nn.Identity()
else:
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
self.norm2 = norm_fn(2 * chout)
self.dconv = _DConv(chout, **dconv_kw)
def forward(self, x: torch.Tensor, inject: Optional[torch.Tensor] = None) -> torch.Tensor:
r"""Forward pass for encoding layer.
Size depends on whether frequency or time
Args:
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
`(B, C, T)` for time
inject (torch.Tensor, optional): on last layer, combine frequency and time branches through inject param,
same shape as x (default: ``None``)
Returns:
Tensor
output tensor after encoder layer of shape `(B, C, F / stride, T)` for frequency
and shape `(B, C, ceil(T / stride))` for time
"""
if not self.freq and x.dim() == 4:
B, C, Fr, T = x.shape
x = x.view(B, -1, T)
if not self.freq:
le = x.shape[-1]
if not le % self.stride == 0:
x = F.pad(x, (0, self.stride - (le % self.stride)))
y = self.conv(x)
if self.empty:
return y
if inject is not None:
if inject.shape[-1] != y.shape[-1]:
raise ValueError("Injection shapes do not align")
if inject.dim() == 3 and y.dim() == 4:
inject = inject[:, :, None]
y = y + inject
y = F.gelu(self.norm1(y))
if self.freq:
B, C, Fr, T = y.shape
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
y = self.dconv(y)
y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
else:
y = self.dconv(y)
z = self.norm2(self.rewrite(y))
z = F.glu(z, dim=1)
return z
class _HDecLayer(torch.nn.Module):
r"""Decoder layer. This used both by the time and the frequency branches.
Args:
chin (int): number of input channels.
chout (int): number of output channels.
last (bool, optional): whether current layer is final layer (Default: ``False``)
kernel_size (int, optional): Kernel size for encoder (Default: 8)
stride (int): Stride for encoder layer (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 1)
empty (bool, optional): used to make a layer with just the first conv. this is used
before merging the time and freq. branches. (Default: ``False``)
freq (bool, optional): boolean for whether conv layer is for frequency (Default: ``True``)
norm_type (str, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
context (int, optional): context size for the 1x1 conv. (Default: 1)
dconv_kw (Dict[str, Any] or None, optional): dictionary of kwargs for the DConv class. (Default: ``None``)
pad (bool, optional): true to pad the input. Padding is done so that the output size is
always the input size / stride. (Default: ``True``)
"""
def __init__(
self,
chin: int,
chout: int,
last: bool = False,
kernel_size: int = 8,
stride: int = 4,
norm_groups: int = 1,
empty: bool = False,
freq: bool = True,
norm_type: str = "group_norm",
context: int = 1,
dconv_kw: Optional[Dict[str, Any]] = None,
pad: bool = True,
):
super().__init__()
if dconv_kw is None:
dconv_kw = {}
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
if pad:
if (kernel_size - stride) % 2 != 0:
raise ValueError("Kernel size and stride do not align")
pad = (kernel_size - stride) // 2
else:
pad = 0
self.pad = pad
self.last = last
self.freq = freq
self.chin = chin
self.empty = empty
self.stride = stride
self.kernel_size = kernel_size
klass = nn.Conv1d
klass_tr = nn.ConvTranspose1d
if freq:
kernel_size = [kernel_size, 1]
stride = [stride, 1]
klass = nn.Conv2d
klass_tr = nn.ConvTranspose2d
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
self.norm2 = norm_fn(chout)
if self.empty:
self.rewrite = nn.Identity()
self.norm1 = nn.Identity()
else:
self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
self.norm1 = norm_fn(2 * chin)
def forward(self, x: torch.Tensor, skip: Optional[torch.Tensor], length):
r"""Forward pass for decoding layer.
Size depends on whether frequency or time
Args:
x (torch.Tensor): tensor input of shape `(B, C, F, T)` for frequency and shape
`(B, C, T)` for time
skip (torch.Tensor, optional): on first layer, separate frequency and time branches using param
(default: ``None``)
length (int): Size of tensor for output
Returns:
(Tensor, Tensor):
Tensor
output tensor after decoder layer of shape `(B, C, F * stride, T)` for frequency domain except last
frequency layer shape is `(B, C, kernel_size, T)`. Shape is `(B, C, stride * T)`
for time domain.
Tensor
contains the output just before final transposed convolution, which is used when the
freq. and time branch separate. Otherwise, does not matter. Shape is
`(B, C, F, T)` for frequency and `(B, C, T)` for time.
"""
if self.freq and x.dim() == 3:
B, C, T = x.shape
x = x.view(B, self.chin, -1, T)
if not self.empty:
x = x + skip
y = F.glu(self.norm1(self.rewrite(x)), dim=1)
else:
y = x
if skip is not None:
raise ValueError("Skip must be none when empty is true.")
z = self.norm2(self.conv_tr(y))
if self.freq:
if self.pad:
z = z[..., self.pad : -self.pad, :]
else:
z = z[..., self.pad : self.pad + length]
if z.shape[-1] != length:
raise ValueError("Last index of z must be equal to length")
if not self.last:
z = F.gelu(z)
return z, y
class HDemucs(torch.nn.Module):
r"""Hybrid Demucs model from
*Hybrid Spectrogram and Waveform Source Separation* :cite:`defossez2021hybrid`.
See Also:
* :class:`torchaudio.pipelines.SourceSeparationBundle`: Source separation pipeline with pre-trained models.
Args:
sources (List[str]): list of source names. List can contain the following source
options: [``"bass"``, ``"drums"``, ``"other"``, ``"mixture"``, ``"vocals"``].
audio_channels (int, optional): input/output audio channels. (Default: 2)
channels (int, optional): initial number of hidden channels. (Default: 48)
growth (int, optional): increase the number of hidden channels by this factor at each layer. (Default: 2)
nfft (int, optional): number of fft bins. Note that changing this requires careful computation of
various shape parameters and will not work out of the box for hybrid models. (Default: 4096)
depth (int, optional): number of layers in encoder and decoder (Default: 6)
freq_emb (float, optional): add frequency embedding after the first frequency layer if > 0,
the actual value controls the weight of the embedding. (Default: 0.2)
emb_scale (int, optional): equivalent to scaling the embedding learning rate (Default: 10)
emb_smooth (bool, optional): initialize the embedding with a smooth one (with respect to frequencies).
(Default: ``True``)
kernel_size (int, optional): kernel_size for encoder and decoder layers. (Default: 8)
time_stride (int, optional): stride for the final time layer, after the merge. (Default: 2)
stride (int, optional): stride for encoder and decoder layers. (Default: 4)
context (int, optional): context for 1x1 conv in the decoder. (Default: 4)
context_enc (int, optional): context for 1x1 conv in the encoder. (Default: 0)
norm_starts (int, optional): layer at which group norm starts being used.
decoder layers are numbered in reverse order. (Default: 4)
norm_groups (int, optional): number of groups for group norm. (Default: 4)
dconv_depth (int, optional): depth of residual DConv branch. (Default: 2)
dconv_comp (int, optional): compression of DConv branch. (Default: 4)
dconv_attn (int, optional): adds attention layers in DConv branch starting at this layer. (Default: 4)
dconv_lstm (int, optional): adds a LSTM layer in DConv branch starting at this layer. (Default: 4)
dconv_init (float, optional): initial scale for the DConv branch LayerScale. (Default: 1e-4)
"""
def __init__(
self,
sources: List[str],
audio_channels: int = 2,
channels: int = 48,
growth: int = 2,
nfft: int = 4096,
depth: int = 6,
freq_emb: float = 0.2,
emb_scale: int = 10,
emb_smooth: bool = True,
kernel_size: int = 8,
time_stride: int = 2,
stride: int = 4,
context: int = 1,
context_enc: int = 0,
norm_starts: int = 4,
norm_groups: int = 4,
dconv_depth: int = 2,
dconv_comp: int = 4,
dconv_attn: int = 4,
dconv_lstm: int = 4,
dconv_init: float = 1e-4,
):
super().__init__()
self.depth = depth
self.nfft = nfft
self.audio_channels = audio_channels
self.sources = sources
self.kernel_size = kernel_size
self.context = context
self.stride = stride
self.channels = channels
self.hop_length = self.nfft // 4
self.freq_emb = None
self.freq_encoder = nn.ModuleList()
self.freq_decoder = nn.ModuleList()
self.time_encoder = nn.ModuleList()
self.time_decoder = nn.ModuleList()
chin = audio_channels
chin_z = chin * 2 # number of channels for the freq branch
chout = channels
chout_z = channels
freqs = self.nfft // 2
for index in range(self.depth):
lstm = index >= dconv_lstm
attn = index >= dconv_attn
norm_type = "group_norm" if index >= norm_starts else "none"
freq = freqs > 1
stri = stride
ker = kernel_size
if not freq:
if freqs != 1:
raise ValueError("When freq is false, freqs must be 1.")
ker = time_stride * 2
stri = time_stride
pad = True
last_freq = False
if freq and freqs <= kernel_size:
ker = freqs
pad = False
last_freq = True
kw = {
"kernel_size": ker,
"stride": stri,
"freq": freq,
"pad": pad,
"norm_type": norm_type,
"norm_groups": norm_groups,
"dconv_kw": {
"lstm": lstm,
"attn": attn,
"depth": dconv_depth,
"compress": dconv_comp,
"init": dconv_init,
},
}
kwt = dict(kw)
kwt["freq"] = 0
kwt["kernel_size"] = kernel_size
kwt["stride"] = stride
kwt["pad"] = True
kw_dec = dict(kw)
if last_freq:
chout_z = max(chout, chout_z)
chout = chout_z
enc = _HEncLayer(chin_z, chout_z, context=context_enc, **kw)
if freq:
if last_freq is True and nfft == 2048:
kwt["stride"] = 2
kwt["kernel_size"] = 4
tenc = _HEncLayer(chin, chout, context=context_enc, empty=last_freq, **kwt)
self.time_encoder.append(tenc)
self.freq_encoder.append(enc)
if index == 0:
chin = self.audio_channels * len(self.sources)
chin_z = chin * 2
dec = _HDecLayer(chout_z, chin_z, last=index == 0, context=context, **kw_dec)
if freq:
tdec = _HDecLayer(chout, chin, empty=last_freq, last=index == 0, context=context, **kwt)
self.time_decoder.insert(0, tdec)
self.freq_decoder.insert(0, dec)
chin = chout
chin_z = chout_z
chout = int(growth * chout)
chout_z = int(growth * chout_z)
if freq:
if freqs <= kernel_size:
freqs = 1
else:
freqs //= stride
if index == 0 and freq_emb:
self.freq_emb = _ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
self.freq_emb_scale = freq_emb
_rescale_module(self)
def _spec(self, x):
hl = self.hop_length
nfft = self.nfft
x0 = x # noqa
# We re-pad the signal in order to keep the property
# that the size of the output is exactly the size of the input
# divided by the stride (here hop_length), when divisible.
# This is achieved by padding by 1/4th of the kernel size (here nfft).
# which is not supported by torch.stft.
# Having all convolution operations follow this convention allow to easily
# align the time and frequency branches later on.
if hl != nfft // 4:
raise ValueError("Hop length must be nfft // 4")
le = int(math.ceil(x.shape[-1] / hl))
pad = hl // 2 * 3
x = self._pad1d(x, pad, pad + le * hl - x.shape[-1], mode="reflect")
z = _spectro(x, nfft, hl)[..., :-1, :]
if z.shape[-1] != le + 4:
raise ValueError("Spectrogram's last dimension must be 4 + input size divided by stride")
z = z[..., 2 : 2 + le]
return z
def _ispec(self, z, length=None):
hl = self.hop_length
z = F.pad(z, [0, 0, 0, 1])
z = F.pad(z, [2, 2])
pad = hl // 2 * 3
le = hl * int(math.ceil(length / hl)) + 2 * pad
x = _ispectro(z, hl, length=le)
x = x[..., pad : pad + length]
return x
def _pad1d(self, x: torch.Tensor, padding_left: int, padding_right: int, mode: str = "zero", value: float = 0.0):
"""Wrapper around F.pad, in order for reflect padding when num_frames is shorter than max_pad.
Add extra zero padding around in order for padding to not break."""
length = x.shape[-1]
if mode == "reflect":
max_pad = max(padding_left, padding_right)
if length <= max_pad:
x = F.pad(x, (0, max_pad - length + 1))
return F.pad(x, (padding_left, padding_right), mode, value)
def _magnitude(self, z):
# move the complex dimension to the channel one.
B, C, Fr, T = z.shape
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
m = m.reshape(B, C * 2, Fr, T)
return m
def _mask(self, m):
# `m` is a full spectrogram and `z` is ignored.
B, S, C, Fr, T = m.shape
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
out = torch.view_as_complex(out.contiguous())
return out
def forward(self, input: torch.Tensor):
r"""HDemucs forward call
Args:
input (torch.Tensor): input mixed tensor of shape `(batch_size, channel, num_frames)`
Returns:
Tensor
output tensor split into sources of shape `(batch_size, num_sources, channel, num_frames)`
"""
if input.ndim != 3:
raise ValueError(f"Expected 3D tensor with dimensions (batch, channel, frames). Found: {input.shape}")
if input.shape[1] != self.audio_channels:
raise ValueError(
f"The channel dimension of input Tensor must match `audio_channels` of HDemucs model. "
f"Found:{input.shape[1]}."
)
x = input
length = x.shape[-1]
z = self._spec(input)
mag = self._magnitude(z)
x = mag
B, C, Fq, T = x.shape
# unlike previous Demucs, we always normalize because it is easier.
mean = x.mean(dim=(1, 2, 3), keepdim=True)
std = x.std(dim=(1, 2, 3), keepdim=True)
x = (x - mean) / (1e-5 + std)
# x will be the freq. branch input.
# Prepare the time branch input.
xt = input
meant = xt.mean(dim=(1, 2), keepdim=True)
stdt = xt.std(dim=(1, 2), keepdim=True)
xt = (xt - meant) / (1e-5 + stdt)
saved = [] # skip connections, freq.
saved_t = [] # skip connections, time.
lengths: List[int] = [] # saved lengths to properly remove padding, freq branch.
lengths_t: List[int] = [] # saved lengths for time branch.
for idx, encode in enumerate(self.freq_encoder):
lengths.append(x.shape[-1])
inject = None
if idx < len(self.time_encoder):
# we have not yet merged branches.
lengths_t.append(xt.shape[-1])
tenc = self.time_encoder[idx]
xt = tenc(xt)
if not tenc.empty:
# save for skip connection
saved_t.append(xt)
else:
# tenc contains just the first conv., so that now time and freq.
# branches have the same shape and can be merged.
inject = xt
x = encode(x, inject)
if idx == 0 and self.freq_emb is not None:
# add frequency embedding to allow for non equivariant convolutions
# over the frequency axis.
frs = torch.arange(x.shape[-2], device=x.device)
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
x = x + self.freq_emb_scale * emb
saved.append(x)
x = torch.zeros_like(x)
xt = torch.zeros_like(x)
# initialize everything to zero (signal will go through u-net skips).
for idx, decode in enumerate(self.freq_decoder):
skip = saved.pop(-1)
x, pre = decode(x, skip, lengths.pop(-1))
# `pre` contains the output just before final transposed convolution,
# which is used when the freq. and time branch separate.
offset = self.depth - len(self.time_decoder)
if idx >= offset:
tdec = self.time_decoder[idx - offset]
length_t = lengths_t.pop(-1)
if tdec.empty:
if pre.shape[2] != 1:
raise ValueError(f"If tdec empty is True, pre shape does not match {pre.shape}")
pre = pre[:, :, 0]
xt, _ = tdec(pre, None, length_t)
else:
skip = saved_t.pop(-1)
xt, _ = tdec(xt, skip, length_t)
if len(saved) != 0:
raise AssertionError("saved is not empty")
if len(lengths_t) != 0:
raise AssertionError("lengths_t is not empty")
if len(saved_t) != 0:
raise AssertionError("saved_t is not empty")
S = len(self.sources)
x = x.view(B, S, -1, Fq, T)
x = x * std[:, None] + mean[:, None]
zout = self._mask(x)
x = self._ispec(zout, length)
xt = xt.view(B, S, -1, length)
xt = xt * stdt[:, None] + meant[:, None]
x = xt + x
return x
class _DConv(torch.nn.Module):
r"""
New residual branches in each encoder layer.
This alternates dilated convolutions, potentially with LSTMs and attention.
Also before entering each residual branch, dimension is projected on a smaller subspace,
e.g. of dim `channels // compress`.
Args:
channels (int): input/output channels for residual branch.
compress (float, optional): amount of channel compression inside the branch. (default: 4)
depth (int, optional): number of layers in the residual branch. Each layer has its own
projection, and potentially LSTM and attention.(default: 2)
init (float, optional): initial scale for LayerNorm. (default: 1e-4)
norm_type (bool, optional): Norm type, either ``group_norm `` or ``none`` (Default: ``group_norm``)
attn (bool, optional): use LocalAttention. (Default: ``False``)
heads (int, optional): number of heads for the LocalAttention. (default: 4)
ndecay (int, optional): number of decay controls in the LocalAttention. (default: 4)
lstm (bool, optional): use LSTM. (Default: ``False``)
kernel_size (int, optional): kernel size for the (dilated) convolutions. (default: 3)
"""
def __init__(
self,
channels: int,
compress: float = 4,
depth: int = 2,
init: float = 1e-4,
norm_type: str = "group_norm",
attn: bool = False,
heads: int = 4,
ndecay: int = 4,
lstm: bool = False,
kernel_size: int = 3,
):
super().__init__()
if kernel_size % 2 == 0:
raise ValueError("Kernel size should not be divisible by 2")
self.channels = channels
self.compress = compress
self.depth = abs(depth)
dilate = depth > 0
norm_fn: tp.Callable[[int], nn.Module]
norm_fn = lambda d: nn.Identity() # noqa
if norm_type == "group_norm":
norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
hidden = int(channels / compress)
act = nn.GELU
self.layers = nn.ModuleList([])
for d in range(self.depth):
dilation = pow(2, d) if dilate else 1
padding = dilation * (kernel_size // 2)
mods = [
nn.Conv1d(channels, hidden, kernel_size, dilation=dilation, padding=padding),
norm_fn(hidden),
act(),
nn.Conv1d(hidden, 2 * channels, 1),
norm_fn(2 * channels),
nn.GLU(1),
_LayerScale(channels, init),
]
if attn:
mods.insert(3, _LocalState(hidden, heads=heads, ndecay=ndecay))
if lstm:
mods.insert(3, _BLSTM(hidden, layers=2, skip=True))
layer = nn.Sequential(*mods)
self.layers.append(layer)
def forward(self, x):
r"""DConv forward call
Args:
x (torch.Tensor): input tensor for convolution
Returns:
Tensor
Output after being run through layers.
"""
for layer in self.layers:
x = x + layer(x)
return x
class _BLSTM(torch.nn.Module):
r"""
BiLSTM with same hidden units as input dim.
If `max_steps` is not None, input will be splitting in overlapping
chunks and the LSTM applied separately on each chunk.
Args:
dim (int): dimensions at LSTM layer.
layers (int, optional): number of LSTM layers. (default: 1)
skip (bool, optional): (default: ``False``)
"""
def __init__(self, dim, layers: int = 1, skip: bool = False):
super().__init__()
self.max_steps = 200
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
self.linear = nn.Linear(2 * dim, dim)
self.skip = skip
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""BLSTM forward call
Args:
x (torch.Tensor): input tensor for BLSTM shape is `(batch_size, dim, time_steps)`
Returns:
Tensor
Output after being run through bidirectional LSTM. Shape is `(batch_size, dim, time_steps)`
"""
B, C, T = x.shape
y = x
framed = False
width = 0
stride = 0
nframes = 0
if self.max_steps is not None and T > self.max_steps:
width = self.max_steps
stride = width // 2
frames = _unfold(x, width, stride)
nframes = frames.shape[2]
framed = True
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
x = x.permute(2, 0, 1)
x = self.lstm(x)[0]
x = self.linear(x)
x = x.permute(1, 2, 0)
if framed:
out = []
frames = x.reshape(B, -1, C, width)
limit = stride // 2
for k in range(nframes):
if k == 0:
out.append(frames[:, k, :, :-limit])
elif k == nframes - 1:
out.append(frames[:, k, :, limit:])
else:
out.append(frames[:, k, :, limit:-limit])
out = torch.cat(out, -1)
out = out[..., :T]
x = out
if self.skip:
x = x + y
return x
class _LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
"""
def __init__(self, channels: int, heads: int = 4, ndecay: int = 4):
r"""
Args:
channels (int): Size of Conv1d layers.
heads (int, optional): (default: 4)
ndecay (int, optional): (default: 4)
"""
super(_LocalState, self).__init__()
if channels % heads != 0:
raise ValueError("Channels must be divisible by heads.")
self.heads = heads
self.ndecay = ndecay
self.content = nn.Conv1d(channels, channels, 1)
self.query = nn.Conv1d(channels, channels, 1)
self.key = nn.Conv1d(channels, channels, 1)
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
if ndecay:
# Initialize decay close to zero (there is a sigmoid), for maximum initial window.
self.query_decay.weight.data *= 0.01
if self.query_decay.bias is None:
raise ValueError("bias must not be None.")
self.query_decay.bias.data[:] = -2
self.proj = nn.Conv1d(channels + heads * 0, channels, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""LocalState forward call
Args:
x (torch.Tensor): input tensor for LocalState
Returns:
Tensor
Output after being run through LocalState layer.
"""
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= math.sqrt(keys.shape[2])
if self.ndecay:
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
decay_q = self.query_decay(x).view(B, heads, -1, T)
decay_q = torch.sigmoid(decay_q) / 2
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / math.sqrt(self.ndecay)
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
# Kill self reference.
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
weights = torch.softmax(dots, dim=2)
content = self.content(x).view(B, heads, -1, T)
result = torch.einsum("bhts,bhct->bhcs", weights, content)
result = result.reshape(B, -1, T)
return x + self.proj(result)
class _LayerScale(nn.Module):
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
This rescales diagonally residual outputs close to 0 initially, then learnt.
"""
def __init__(self, channels: int, init: float = 0):
r"""
Args:
channels (int): Size of rescaling
init (float, optional): Scale to default to (default: 0)
"""
super().__init__()
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
self.scale.data[:] = init
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""LayerScale forward call
Args:
x (torch.Tensor): input tensor for LayerScale
Returns:
Tensor
Output after rescaling tensor.
"""
return self.scale[:, None] * x
def _unfold(a: torch.Tensor, kernel_size: int, stride: int) -> torch.Tensor:
"""Given input of size [*OT, T], output Tensor of size [*OT, F, K]
with K the kernel size, by extracting frames with the given stride.
This will pad the input so that `F = ceil(T / K)`.
see https://github.com/pytorch/pytorch/issues/60466
"""
shape = list(a.shape[:-1])
length = int(a.shape[-1])
n_frames = math.ceil(length / stride)
tgt_length = (n_frames - 1) * stride + kernel_size
a = F.pad(input=a, pad=[0, tgt_length - length])
strides = [a.stride(dim) for dim in range(a.dim())]
if strides[-1] != 1:
raise ValueError("Data should be contiguous.")
strides = strides[:-1] + [stride, 1]
shape.append(n_frames)
shape.append(kernel_size)
return a.as_strided(shape, strides)
def _rescale_module(module):
r"""
Rescales initial weight scale for all models within the module.
"""
for sub in module.modules():
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
std = sub.weight.std().detach()
scale = (std / 0.1) ** 0.5
sub.weight.data /= scale
if sub.bias is not None:
sub.bias.data /= scale
def _spectro(x: torch.Tensor, n_fft: int = 512, hop_length: int = 0, pad: int = 0) -> torch.Tensor:
other = list(x.shape[:-1])
length = int(x.shape[-1])
x = x.reshape(-1, length)
z = torch.stft(
x,
n_fft * (1 + pad),
hop_length,
window=torch.hann_window(n_fft).to(x),
win_length=n_fft,
normalized=True,
center=True,
return_complex=True,
pad_mode="reflect",
)
_, freqs, frame = z.shape
other.extend([freqs, frame])
return z.view(other)
def _ispectro(z: torch.Tensor, hop_length: int = 0, length: int = 0, pad: int = 0) -> torch.Tensor:
other = list(z.shape[:-2])
freqs = int(z.shape[-2])
frames = int(z.shape[-1])
n_fft = 2 * freqs - 2
z = z.view(-1, freqs, frames)
win_length = n_fft // (1 + pad)
x = torch.istft(
z,
n_fft,
hop_length,
window=torch.hann_window(win_length).to(z.real),
win_length=win_length,
normalized=True,
length=length,
center=True,
)
_, length = x.shape
other.append(length)
return x.view(other)
def hdemucs_low(sources: List[str]) -> HDemucs:
"""Builds low nfft (1024) version of :class:`HDemucs`, suitable for sample rates around 8 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=1024, depth=5)
def hdemucs_medium(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (2048) version of :class:`HDemucs`, suitable for sample rates of 16-32 kHz.
.. note::
Medium HDemucs has not been tested against the original Hybrid Demucs as this nfft and depth configuration is
not compatible with the original implementation in https://github.com/facebookresearch/demucs
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=2048, depth=6)
def hdemucs_high(sources: List[str]) -> HDemucs:
r"""Builds medium nfft (4096) version of :class:`HDemucs`, suitable for sample rates of 44.1-48 kHz.
Args:
sources (List[str]): See :py:func:`HDemucs`.
Returns:
HDemucs:
HDemucs model.
"""
return HDemucs(sources=sources, nfft=4096, depth=6)
|