waytan22 commited on
Commit
e730386
·
verified ·
1 Parent(s): 18e91e0

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. ckpt/60000_alnew.pt +3 -0
  3. ckpt/encode-s12k.pt +3 -0
  4. ckpt/model_1rvq/model_2_fixed.safetensors +3 -0
  5. ckpt/model_septoken/model_2.safetensors +3 -0
  6. ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors +0 -0
  7. ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors.index.json +0 -0
  8. ckpt/models--lengyue233--content-vec-best/blobs/5186a71b15933aca2d9942db95e1aff02642d1f0 +71 -0
  9. ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e +3 -0
  10. ckpt/models--lengyue233--content-vec-best/refs/main +1 -0
  11. ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/config.json +71 -0
  12. ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin +3 -0
  13. ckpt/vae/autoencoder_music_1320k.ckpt +3 -0
  14. ckpt/vae/stable_audio_1920_vae.json +122 -0
  15. third_party/Qwen2-7B/LICENSE +202 -0
  16. third_party/Qwen2-7B/README.md +97 -0
  17. third_party/Qwen2-7B/config.json +27 -0
  18. third_party/Qwen2-7B/generation_config.json +7 -0
  19. third_party/Qwen2-7B/merges.txt +0 -0
  20. third_party/Qwen2-7B/tokenizer.json +0 -0
  21. third_party/Qwen2-7B/tokenizer_config.json +40 -0
  22. third_party/Qwen2-7B/vocab.json +0 -0
  23. third_party/demucs/__init__.py +0 -0
  24. third_party/demucs/ckpt/htdemucs.pth +3 -0
  25. third_party/demucs/ckpt/htdemucs.yaml +1 -0
  26. third_party/demucs/models/__init__.py +0 -0
  27. third_party/demucs/models/apply.py +315 -0
  28. third_party/demucs/models/audio.py +291 -0
  29. third_party/demucs/models/demucs.py +452 -0
  30. third_party/demucs/models/htdemucs.py +955 -0
  31. third_party/demucs/models/pretrained.py +34 -0
  32. third_party/demucs/models/spec.py +51 -0
  33. third_party/demucs/models/states.py +102 -0
  34. third_party/demucs/models/transformer.py +765 -0
  35. third_party/demucs/models/utils.py +125 -0
  36. third_party/demucs/run.py +109 -0
  37. third_party/hub/version.txt +1 -0
  38. third_party/stable_audio_tools/.gitignore +164 -0
  39. third_party/stable_audio_tools/LICENSE +21 -0
  40. third_party/stable_audio_tools/LICENSES/LICENSE_ADP.txt +21 -0
  41. third_party/stable_audio_tools/LICENSES/LICENSE_AURALOSS.txt +201 -0
  42. third_party/stable_audio_tools/LICENSES/LICENSE_DESCRIPT.txt +21 -0
  43. third_party/stable_audio_tools/LICENSES/LICENSE_META.txt +21 -0
  44. third_party/stable_audio_tools/LICENSES/LICENSE_NVIDIA.txt +21 -0
  45. third_party/stable_audio_tools/LICENSES/LICENSE_XTRANSFORMERS.txt +21 -0
  46. third_party/stable_audio_tools/README.md +157 -0
  47. third_party/stable_audio_tools/config/model_1920.json +122 -0
  48. third_party/stable_audio_tools/config/model_config.json +122 -0
  49. third_party/stable_audio_tools/defaults.ini +56 -0
  50. third_party/stable_audio_tools/docs/autoencoders.md +357 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e filter=lfs diff=lfs merge=lfs -text
ckpt/60000_alnew.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8763fc75e5db768c334a9fbadd08e2004eccb6e15156c76b4c2a3984f8fbb884
3
+ size 11318365872
ckpt/encode-s12k.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e250df56b035f74c1f66f15133f4c78f664d70fa0b09aa9a752b7871bb58c02f
3
+ size 3957949089
ckpt/model_1rvq/model_2_fixed.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:339a16956b859a82defc02bfd32c3744d11ff942065f6ec9306dfd4400d62110
3
+ size 4704507596
ckpt/model_septoken/model_2.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:758aa342942a7b7c0ae179af1a952e0b944e39128ea816741499b3031113aaee
3
+ size 4808167708
ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors ADDED
File without changes
ckpt/models--lengyue233--content-vec-best/.no_exist/c0b9ba13db21beaa4053faae94c102ebe326fd68/model.safetensors.index.json ADDED
File without changes
ckpt/models--lengyue233--content-vec-best/blobs/5186a71b15933aca2d9942db95e1aff02642d1f0 ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "HubertModelWithFinalProj"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "classifier_proj_size": 256,
10
+ "conv_bias": false,
11
+ "conv_dim": [
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512
19
+ ],
20
+ "conv_kernel": [
21
+ 10,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 2,
27
+ 2
28
+ ],
29
+ "conv_stride": [
30
+ 5,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2
37
+ ],
38
+ "ctc_loss_reduction": "sum",
39
+ "ctc_zero_infinity": false,
40
+ "do_stable_layer_norm": false,
41
+ "eos_token_id": 2,
42
+ "feat_extract_activation": "gelu",
43
+ "feat_extract_norm": "group",
44
+ "feat_proj_dropout": 0.0,
45
+ "feat_proj_layer_norm": true,
46
+ "final_dropout": 0.1,
47
+ "hidden_act": "gelu",
48
+ "hidden_dropout": 0.1,
49
+ "hidden_size": 768,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 3072,
52
+ "layer_norm_eps": 1e-05,
53
+ "layerdrop": 0.1,
54
+ "mask_feature_length": 10,
55
+ "mask_feature_min_masks": 0,
56
+ "mask_feature_prob": 0.0,
57
+ "mask_time_length": 10,
58
+ "mask_time_min_masks": 2,
59
+ "mask_time_prob": 0.05,
60
+ "model_type": "hubert",
61
+ "num_attention_heads": 12,
62
+ "num_conv_pos_embedding_groups": 16,
63
+ "num_conv_pos_embeddings": 128,
64
+ "num_feat_extract_layers": 7,
65
+ "num_hidden_layers": 12,
66
+ "pad_token_id": 0,
67
+ "torch_dtype": "float32",
68
+ "transformers_version": "4.27.3",
69
+ "use_weighted_layer_sum": false,
70
+ "vocab_size": 32
71
+ }
ckpt/models--lengyue233--content-vec-best/blobs/d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
3
+ size 378342945
ckpt/models--lengyue233--content-vec-best/refs/main ADDED
@@ -0,0 +1 @@
 
 
1
+ c0b9ba13db21beaa4053faae94c102ebe326fd68
ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/config.json ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.1,
3
+ "apply_spec_augment": true,
4
+ "architectures": [
5
+ "HubertModelWithFinalProj"
6
+ ],
7
+ "attention_dropout": 0.1,
8
+ "bos_token_id": 1,
9
+ "classifier_proj_size": 256,
10
+ "conv_bias": false,
11
+ "conv_dim": [
12
+ 512,
13
+ 512,
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512
19
+ ],
20
+ "conv_kernel": [
21
+ 10,
22
+ 3,
23
+ 3,
24
+ 3,
25
+ 3,
26
+ 2,
27
+ 2
28
+ ],
29
+ "conv_stride": [
30
+ 5,
31
+ 2,
32
+ 2,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2
37
+ ],
38
+ "ctc_loss_reduction": "sum",
39
+ "ctc_zero_infinity": false,
40
+ "do_stable_layer_norm": false,
41
+ "eos_token_id": 2,
42
+ "feat_extract_activation": "gelu",
43
+ "feat_extract_norm": "group",
44
+ "feat_proj_dropout": 0.0,
45
+ "feat_proj_layer_norm": true,
46
+ "final_dropout": 0.1,
47
+ "hidden_act": "gelu",
48
+ "hidden_dropout": 0.1,
49
+ "hidden_size": 768,
50
+ "initializer_range": 0.02,
51
+ "intermediate_size": 3072,
52
+ "layer_norm_eps": 1e-05,
53
+ "layerdrop": 0.1,
54
+ "mask_feature_length": 10,
55
+ "mask_feature_min_masks": 0,
56
+ "mask_feature_prob": 0.0,
57
+ "mask_time_length": 10,
58
+ "mask_time_min_masks": 2,
59
+ "mask_time_prob": 0.05,
60
+ "model_type": "hubert",
61
+ "num_attention_heads": 12,
62
+ "num_conv_pos_embedding_groups": 16,
63
+ "num_conv_pos_embeddings": 128,
64
+ "num_feat_extract_layers": 7,
65
+ "num_hidden_layers": 12,
66
+ "pad_token_id": 0,
67
+ "torch_dtype": "float32",
68
+ "transformers_version": "4.27.3",
69
+ "use_weighted_layer_sum": false,
70
+ "vocab_size": 32
71
+ }
ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d8dd400e054ddf4e6be75dab5a2549db748cc99e756a097c496c099f65a4854e
3
+ size 378342945
ckpt/vae/autoencoder_music_1320k.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10ccb6c83613781ad32e998a90597ba7eb9292911a224598da1fd53728eb4cd3
3
+ size 674920616
ckpt/vae/stable_audio_1920_vae.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 403200,
4
+ "sample_rate": 48000,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 6, 10],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 6, 10],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 1920,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 2000
120
+ }
121
+ }
122
+ }
third_party/Qwen2-7B/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2024 Alibaba Cloud
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
third_party/Qwen2-7B/README.md ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ pipeline_tag: text-generation
5
+ tags:
6
+ - pretrained
7
+ license: apache-2.0
8
+ ---
9
+
10
+ # Qwen2-7B
11
+
12
+ ## Introduction
13
+
14
+ Qwen2 is the new series of Qwen large language models. For Qwen2, we release a number of base language models and instruction-tuned language models ranging from 0.5 to 72 billion parameters, including a Mixture-of-Experts model. This repo contains the 7B Qwen2 base language model.
15
+
16
+ Compared with the state-of-the-art opensource language models, including the previous released Qwen1.5, Qwen2 has generally surpassed most opensource models and demonstrated competitiveness against proprietary models across a series of benchmarks targeting for language understanding, language generation, multilingual capability, coding, mathematics, reasoning, etc.
17
+
18
+ For more details, please refer to our [blog](https://qwenlm.github.io/blog/qwen2/), [GitHub](https://github.com/QwenLM/Qwen2), and [Documentation](https://qwen.readthedocs.io/en/latest/).
19
+ <br>
20
+
21
+
22
+ ## Model Details
23
+ Qwen2 is a language model series including decoder language models of different model sizes. For each size, we release the base language model and the aligned chat model. It is based on the Transformer architecture with SwiGLU activation, attention QKV bias, group query attention, etc. Additionally, we have an improved tokenizer adaptive to multiple natural languages and codes.
24
+
25
+ ## Requirements
26
+ The code of Qwen2 has been in the latest Hugging face transformers and we advise you to install `transformers>=4.37.0`, or you might encounter the following error:
27
+ ```
28
+ KeyError: 'qwen2'
29
+ ```
30
+
31
+
32
+ ## Usage
33
+
34
+ We do not advise you to use base language models for text generation. Instead, you can apply post-training, e.g., SFT, RLHF, continued pretraining, etc., on this model.
35
+
36
+
37
+ ### Performance
38
+
39
+ The evaluation of base models mainly focuses on the model performance of natural language understanding, general question answering, coding, mathematics, scientific knowledge, reasoning, multilingual capability, etc.
40
+
41
+ The datasets for evaluation include:
42
+
43
+ **English Tasks**: MMLU (5-shot), MMLU-Pro (5-shot), GPQA (5shot), Theorem QA (5-shot), BBH (3-shot), HellaSwag (10-shot), Winogrande (5-shot), TruthfulQA (0-shot), ARC-C (25-shot)
44
+
45
+ **Coding Tasks**: EvalPlus (0-shot) (HumanEval, MBPP, HumanEval+, MBPP+), MultiPL-E (0-shot) (Python, C++, JAVA, PHP, TypeScript, C#, Bash, JavaScript)
46
+
47
+ **Math Tasks**: GSM8K (4-shot), MATH (4-shot)
48
+
49
+ **Chinese Tasks**: C-Eval(5-shot), CMMLU (5-shot)
50
+
51
+ **Multilingual Tasks**: Multi-Exam (M3Exam 5-shot, IndoMMLU 3-shot, ruMMLU 5-shot, mMMLU 5-shot), Multi-Understanding (BELEBELE 5-shot, XCOPA 5-shot, XWinograd 5-shot, XStoryCloze 0-shot, PAWS-X 5-shot), Multi-Mathematics (MGSM 8-shot), Multi-Translation (Flores-101 5-shot)
52
+
53
+
54
+
55
+ #### Qwen2-7B performance
56
+ | Datasets | Mistral-7B | Gemma-7B | Llama-3-8B | Qwen1.5-7B | Qwen2-7B |
57
+ | :--------| :---------: | :------------: | :------------: | :------------: | :------------: |
58
+ |# Params | 7.2B | 8.5B | 8.0B | 7.7B | 7.6B |
59
+ |# Non-emb Params | 7.0B | 7.8B | 7.0B | 6.5B | 6.5B |
60
+ | ***English*** | | | | | |
61
+ |MMLU | 64.2 | 64.6 | 66.6 | 61.0 | **70.3** |
62
+ |MMLU-Pro | 30.9 | 33.7 | 35.4 | 29.9 | **40.0** |
63
+ |GPQA | 24.7 | 25.7 | 25.8 | 26.7 | **31.8** |
64
+ |Theorem QA | 19.2 | 21.5 | 22.1 | 14.2 | **31.1** |
65
+ |BBH | 56.1 | 55.1 | 57.7 | 40.2 | **62.6** |
66
+ |HellaSwag | **83.2** | 82.2 | 82.1 | 78.5 | 80.7 |
67
+ |Winogrande | 78.4 | **79.0** | 77.4 | 71.3 | 77.0 |
68
+ |ARC-C | 60.0 | **61.1** | 59.3 | 54.2 | 60.6 |
69
+ |TruthfulQA | 42.2 | 44.8 | 44.0 | 51.1 | **54.2** |
70
+ | ***Coding*** | | | | | |
71
+ |HumanEval | 29.3 | 37.2 | 33.5 | 36.0 | **51.2** |
72
+ |MBPP | 51.1 | 50.6 | 53.9 | 51.6 | **65.9** |
73
+ |EvalPlus | 36.4 | 39.6 | 40.3 | 40.0 | **54.2** |
74
+ |MultiPL-E | 29.4 | 29.7 | 22.6 | 28.1 | **46.3** |
75
+ | ***Mathematics*** | | | | | |
76
+ |GSM8K | 52.2 | 46.4 | 56.0 | 62.5 | **79.9** |
77
+ |MATH | 13.1 | 24.3 | 20.5 | 20.3 | **44.2** |
78
+ | ***Chinese*** | | | | | |
79
+ |C-Eval | 47.4 | 43.6 | 49.5 | 74.1 | **83.2** |
80
+ |CMMLU | - | - | 50.8 | 73.1 | **83.9** |
81
+ | ***Multilingual*** | | | | | |
82
+ |Multi-Exam | 47.1 | 42.7 | 52.3 | 47.7 | **59.2** |
83
+ |Multi-Understanding | 63.3 | 58.3 | 68.6 | 67.6 | **72.0** |
84
+ |Multi-Mathematics | 26.3 | 39.1 | 36.3 | 37.3 | **57.5** |
85
+ |Multi-Translation | 23.3 | 31.2 | **31.9** | 28.4 | 31.5 |
86
+
87
+
88
+ ## Citation
89
+
90
+ If you find our work helpful, feel free to give us a cite.
91
+
92
+ ```
93
+ @article{qwen2,
94
+ title={Qwen2 Technical Report},
95
+ year={2024}
96
+ }
97
+ ```
third_party/Qwen2-7B/config.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Qwen2ForCausalLM"
4
+ ],
5
+ "attention_dropout": 0.0,
6
+ "bos_token_id": 151643,
7
+ "eos_token_id": 151643,
8
+ "hidden_act": "silu",
9
+ "hidden_size": 3584,
10
+ "initializer_range": 0.02,
11
+ "intermediate_size": 18944,
12
+ "max_position_embeddings": 131072,
13
+ "max_window_layers": 28,
14
+ "model_type": "qwen2",
15
+ "num_attention_heads": 28,
16
+ "num_hidden_layers": 28,
17
+ "num_key_value_heads": 4,
18
+ "rms_norm_eps": 1e-06,
19
+ "rope_theta": 1000000.0,
20
+ "sliding_window": 131072,
21
+ "tie_word_embeddings": false,
22
+ "torch_dtype": "bfloat16",
23
+ "transformers_version": "4.37.2",
24
+ "use_cache": true,
25
+ "use_sliding_window": false,
26
+ "vocab_size": 152064
27
+ }
third_party/Qwen2-7B/generation_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token_id": 151643,
3
+ "do_sample": false,
4
+ "eos_token_id": 151643,
5
+ "max_new_tokens": 2048,
6
+ "transformers_version": "4.37.0"
7
+ }
third_party/Qwen2-7B/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
third_party/Qwen2-7B/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
third_party/Qwen2-7B/tokenizer_config.json ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "added_tokens_decoder": {
4
+ "151643": {
5
+ "content": "<|endoftext|>",
6
+ "lstrip": false,
7
+ "normalized": false,
8
+ "rstrip": false,
9
+ "single_word": false,
10
+ "special": true
11
+ },
12
+ "151644": {
13
+ "content": "<|im_start|>",
14
+ "lstrip": false,
15
+ "normalized": false,
16
+ "rstrip": false,
17
+ "single_word": false,
18
+ "special": true
19
+ },
20
+ "151645": {
21
+ "content": "<|im_end|>",
22
+ "lstrip": false,
23
+ "normalized": false,
24
+ "rstrip": false,
25
+ "single_word": false,
26
+ "special": true
27
+ }
28
+ },
29
+ "additional_special_tokens": ["<|im_start|>", "<|im_end|>"],
30
+ "bos_token": null,
31
+ "chat_template": "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
32
+ "clean_up_tokenization_spaces": false,
33
+ "eos_token": "<|endoftext|>",
34
+ "errors": "replace",
35
+ "model_max_length": 32768,
36
+ "pad_token": "<|endoftext|>",
37
+ "split_special_tokens": false,
38
+ "tokenizer_class": "Qwen2Tokenizer",
39
+ "unk_token": null
40
+ }
third_party/Qwen2-7B/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
third_party/demucs/__init__.py ADDED
File without changes
third_party/demucs/ckpt/htdemucs.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4378974c3df2fbcf872d2aeb32218e4de376799494579655775a375d09931c2
3
+ size 168138881
third_party/demucs/ckpt/htdemucs.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ models: ['htdemucs']
third_party/demucs/models/__init__.py ADDED
File without changes
third_party/demucs/models/apply.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : apply.py
5
+ @Time : 2023/8/8 下午4:22
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Apply
10
+ """
11
+
12
+ from concurrent.futures import ThreadPoolExecutor
13
+ import torch
14
+ import os
15
+ import random
16
+ import typing as tp
17
+
18
+ import torch as th
19
+ from torch import nn
20
+ from torch.nn import functional as F
21
+ import tqdm
22
+
23
+ from .htdemucs import HTDemucs
24
+ from .audio import load_track, save_audio
25
+ from .utils import center_trim, DummyPoolExecutor
26
+
27
+ Model = tp.Union[HTDemucs]
28
+
29
+
30
+ class BagOfModels(nn.Module):
31
+ def __init__(self, models: tp.List[Model],
32
+ weights: tp.Optional[tp.List[tp.List[float]]] = None,
33
+ segment: tp.Optional[float] = None):
34
+ """
35
+ Represents a bag of models with specific weights.
36
+ You should call `apply_model` rather than calling directly the forward here for
37
+ optimal performance.
38
+
39
+ Args:
40
+ models (list[nn.Module]): list of Demucs/HDemucs models.
41
+ weights (list[list[float]]): list of weights. If None, assumed to
42
+ be all ones, otherwise it should be a list of N list (N number of models),
43
+ each containing S floats (S number of sources).
44
+ segment (None or float): overrides the `segment` attribute of each model
45
+ (this is performed inplace, be careful is you reuse the models passed).
46
+ """
47
+ super().__init__()
48
+ assert len(models) > 0
49
+ first = models[0]
50
+ for other in models:
51
+ assert other.sources == first.sources
52
+ assert other.samplerate == first.samplerate
53
+ assert other.audio_channels == first.audio_channels
54
+ if segment is not None:
55
+ other.segment = segment
56
+
57
+ self.audio_channels = first.audio_channels
58
+ self.samplerate = first.samplerate
59
+ self.sources = first.sources
60
+ self.models = nn.ModuleList(models)
61
+
62
+ if weights is None:
63
+ weights = [[1. for _ in first.sources] for _ in models]
64
+ else:
65
+ assert len(weights) == len(models)
66
+ for weight in weights:
67
+ assert len(weight) == len(first.sources)
68
+ self.weights = weights
69
+
70
+ @property
71
+ def max_allowed_segment(self) -> float:
72
+ max_allowed_segment = float('inf')
73
+ for model in self.models:
74
+ if isinstance(model, HTDemucs):
75
+ max_allowed_segment = min(max_allowed_segment, float(model.segment))
76
+ return max_allowed_segment
77
+
78
+ def forward(self, x):
79
+ raise NotImplementedError("Call `apply_model` on this.")
80
+
81
+ def separate(self, source_file, output_dir, stem=None, device=None):
82
+ wav, _ = load_track(source_file, self.audio_channels, self.samplerate)
83
+ ref = wav.mean(0)
84
+ wav -= ref.mean()
85
+ wav /= ref.std()
86
+ sources = apply_model(self, wav[None], device=device, shifts=1, split=True, overlap=0.25,
87
+ progress=True, num_workers=0, segment=None)[0]
88
+ sources *= ref.std()
89
+ sources += ref.mean()
90
+
91
+ output_paths = []
92
+ name, ext = os.path.splitext(os.path.split(source_file)[-1])
93
+ if ext != ".flac":
94
+ ext = ".flac"
95
+ kwargs = {
96
+ 'samplerate': self.samplerate,
97
+ 'bitrate': 320,
98
+ 'clip': "rescale",
99
+ 'as_float': False,
100
+ 'bits_per_sample': 16,
101
+ }
102
+ if stem is None:
103
+ for source, stem in zip(sources, self.sources):
104
+ output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
105
+ save_audio(source, output_stem_path, **kwargs)
106
+ output_paths.append(output_stem_path)
107
+ else:
108
+ sources = list(sources)
109
+ output_stem_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
110
+ save_audio(sources.pop(self.sources.index(stem)), output_stem_path, **kwargs)
111
+ other_stem = torch.zeros_like(sources[0])
112
+ for i in sources:
113
+ other_stem += i
114
+ output_no_stem_path = os.path.join(output_dir, f"{name}_no_{stem}{ext}")
115
+ save_audio(other_stem, output_no_stem_path, **kwargs)
116
+ output_paths = [output_stem_path, output_no_stem_path]
117
+
118
+ return output_paths
119
+
120
+
121
+ class TensorChunk:
122
+ def __init__(self, tensor, offset=0, length=None):
123
+ total_length = tensor.shape[-1]
124
+ assert offset >= 0
125
+ assert offset < total_length
126
+
127
+ if length is None:
128
+ length = total_length - offset
129
+ else:
130
+ length = min(total_length - offset, length)
131
+
132
+ if isinstance(tensor, TensorChunk):
133
+ self.tensor = tensor.tensor
134
+ self.offset = offset + tensor.offset
135
+ else:
136
+ self.tensor = tensor
137
+ self.offset = offset
138
+ self.length = length
139
+ self.device = tensor.device
140
+
141
+ @property
142
+ def shape(self):
143
+ shape = list(self.tensor.shape)
144
+ shape[-1] = self.length
145
+ return shape
146
+
147
+ def padded(self, target_length):
148
+ delta = target_length - self.length
149
+ total_length = self.tensor.shape[-1]
150
+ assert delta >= 0
151
+
152
+ start = self.offset - delta // 2
153
+ end = start + target_length
154
+
155
+ correct_start = max(0, start)
156
+ correct_end = min(total_length, end)
157
+
158
+ pad_left = correct_start - start
159
+ pad_right = end - correct_end
160
+
161
+ out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
162
+ assert out.shape[-1] == target_length
163
+ return out
164
+
165
+
166
+ def tensor_chunk(tensor_or_chunk):
167
+ if isinstance(tensor_or_chunk, TensorChunk):
168
+ return tensor_or_chunk
169
+ else:
170
+ assert isinstance(tensor_or_chunk, th.Tensor)
171
+ return TensorChunk(tensor_or_chunk)
172
+
173
+
174
+ def apply_model(model: tp.Union[BagOfModels, Model],
175
+ mix: tp.Union[th.Tensor, TensorChunk],
176
+ shifts: int = 1, split: bool = True,
177
+ overlap: float = 0.25, transition_power: float = 1.,
178
+ progress: bool = False, device=None,
179
+ num_workers: int = 0, segment: tp.Optional[float] = None,
180
+ pool=None) -> th.Tensor:
181
+ """
182
+ Apply model to a given mixture.
183
+
184
+ Args:
185
+ shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec
186
+ and apply the oppositve shift to the output. This is repeated `shifts` time and
187
+ all predictions are averaged. This effectively makes the model time equivariant
188
+ and improves SDR by up to 0.2 points.
189
+ split (bool): if True, the input will be broken down in 8 seconds extracts
190
+ and predictions will be performed individually on each and concatenated.
191
+ Useful for model with large memory footprint like Tasnet.
192
+ progress (bool): if True, show a progress bar (requires split=True)
193
+ device (torch.device, str, or None): if provided, device on which to
194
+ execute the computation, otherwise `mix.device` is assumed.
195
+ When `device` is different from `mix.device`, only local computations will
196
+ be on `device`, while the entire tracks will be stored on `mix.device`.
197
+ num_workers (int): if non zero, device is 'cpu', how many threads to
198
+ use in parallel.
199
+ segment (float or None): override the model segment parameter.
200
+ """
201
+ if device is None:
202
+ device = mix.device
203
+ else:
204
+ device = th.device(device)
205
+ if pool is None:
206
+ if num_workers > 0 and device.type == 'cpu':
207
+ pool = ThreadPoolExecutor(num_workers)
208
+ else:
209
+ pool = DummyPoolExecutor()
210
+ kwargs: tp.Dict[str, tp.Any] = {
211
+ 'shifts': shifts,
212
+ 'split': split,
213
+ 'overlap': overlap,
214
+ 'transition_power': transition_power,
215
+ 'progress': progress,
216
+ 'device': device,
217
+ 'pool': pool,
218
+ 'segment': segment,
219
+ }
220
+ out: tp.Union[float, th.Tensor]
221
+ if isinstance(model, BagOfModels):
222
+ # Special treatment for bag of model.
223
+ # We explicitely apply multiple times `apply_model` so that the random shifts
224
+ # are different for each model.
225
+ estimates: tp.Union[float, th.Tensor] = 0.
226
+ totals = [0.] * len(model.sources)
227
+ for sub_model, model_weights in zip(model.models, model.weights):
228
+ original_model_device = next(iter(sub_model.parameters())).device
229
+ sub_model.to(device)
230
+
231
+ out = apply_model(sub_model, mix, **kwargs)
232
+ sub_model.to(original_model_device)
233
+ for k, inst_weight in enumerate(model_weights):
234
+ out[:, k, :, :] *= inst_weight
235
+ totals[k] += inst_weight
236
+ estimates += out
237
+ del out
238
+
239
+ assert isinstance(estimates, th.Tensor)
240
+ for k in range(estimates.shape[1]):
241
+ estimates[:, k, :, :] /= totals[k]
242
+ return estimates
243
+
244
+ model.to(device)
245
+ model.eval()
246
+ assert transition_power >= 1, "transition_power < 1 leads to weird behavior."
247
+ batch, channels, length = mix.shape
248
+ if shifts:
249
+ kwargs['shifts'] = 0
250
+ max_shift = int(0.5 * model.samplerate)
251
+ mix = tensor_chunk(mix)
252
+ assert isinstance(mix, TensorChunk)
253
+ padded_mix = mix.padded(length + 2 * max_shift)
254
+ out = 0.
255
+ for _ in range(shifts):
256
+ offset = random.randint(0, max_shift)
257
+ shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
258
+ shifted_out = apply_model(model, shifted, **kwargs)
259
+ out += shifted_out[..., max_shift - offset:]
260
+ out /= shifts
261
+ assert isinstance(out, th.Tensor)
262
+ return out
263
+ elif split:
264
+ kwargs['split'] = False
265
+ out = th.zeros(batch, len(model.sources), channels, length, device=mix.device)
266
+ sum_weight = th.zeros(length, device=mix.device)
267
+ if segment is None:
268
+ segment = model.segment
269
+ assert segment is not None and segment > 0.
270
+ segment_length: int = int(model.samplerate * segment)
271
+ stride = int((1 - overlap) * segment_length)
272
+ offsets = range(0, length, stride)
273
+ scale = float(format(stride / model.samplerate, ".2f"))
274
+ # We start from a triangle shaped weight, with maximal weight in the middle
275
+ # of the segment. Then we normalize and take to the power `transition_power`.
276
+ # Large values of transition power will lead to sharper transitions.
277
+ weight = th.cat([th.arange(1, segment_length // 2 + 1, device=device),
278
+ th.arange(segment_length - segment_length // 2, 0, -1, device=device)])
279
+ assert len(weight) == segment_length
280
+ # If the overlap < 50%, this will translate to linear transition when
281
+ # transition_power is 1.
282
+ weight = (weight / weight.max())**transition_power
283
+ futures = []
284
+ for offset in offsets:
285
+ chunk = TensorChunk(mix, offset, segment_length)
286
+ future = pool.submit(apply_model, model, chunk, **kwargs)
287
+ futures.append((future, offset))
288
+ offset += segment_length
289
+ if progress:
290
+ futures = tqdm.tqdm(futures, unit_scale=scale, ncols=120, unit='seconds')
291
+ for future, offset in futures:
292
+ chunk_out = future.result()
293
+ chunk_length = chunk_out.shape[-1]
294
+ out[..., offset:offset + segment_length] += (
295
+ weight[:chunk_length] * chunk_out).to(mix.device)
296
+ sum_weight[offset:offset + segment_length] += weight[:chunk_length].to(mix.device)
297
+ assert sum_weight.min() > 0
298
+ out /= sum_weight
299
+ assert isinstance(out, th.Tensor)
300
+ return out
301
+ else:
302
+ valid_length: int
303
+ if isinstance(model, HTDemucs) and segment is not None:
304
+ valid_length = int(segment * model.samplerate)
305
+ elif hasattr(model, 'valid_length'):
306
+ valid_length = model.valid_length(length) # type: ignore
307
+ else:
308
+ valid_length = length
309
+ mix = tensor_chunk(mix)
310
+ assert isinstance(mix, TensorChunk)
311
+ padded_mix = mix.padded(valid_length).to(device)
312
+ with th.no_grad():
313
+ out = model(padded_mix)
314
+ assert isinstance(out, th.Tensor)
315
+ return center_trim(out, length)
third_party/demucs/models/audio.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : audio.py
5
+ @Time : 2023/8/8 下午7:18
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Audio
10
+ """
11
+ import json
12
+ import subprocess as sp
13
+ import typing as tp
14
+ from pathlib import Path
15
+
16
+ import lameenc
17
+ import julius
18
+ import torch
19
+ import numpy as np
20
+ import torchaudio as ta
21
+
22
+ from .utils import temp_filenames
23
+
24
+
25
+ def _read_info(path):
26
+ stdout_data = sp.check_output([
27
+ 'ffprobe', "-loglevel", "panic",
28
+ str(path), '-print_format', 'json', '-show_format', '-show_streams'
29
+ ])
30
+ return json.loads(stdout_data.decode('utf-8'))
31
+
32
+
33
+ class AudioFile:
34
+ """
35
+ Allows to read audio from any format supported by ffmpeg, as well as resampling or
36
+ converting to mono on the fly. See :method:`read` for more details.
37
+ """
38
+ def __init__(self, path: Path):
39
+ self.path = Path(path)
40
+ self._info = None
41
+
42
+ def __repr__(self):
43
+ features = [("path", self.path)]
44
+ features.append(("samplerate", self.samplerate()))
45
+ features.append(("channels", self.channels()))
46
+ features.append(("streams", len(self)))
47
+ features_str = ", ".join(f"{name}={value}" for name, value in features)
48
+ return f"AudioFile({features_str})"
49
+
50
+ @property
51
+ def info(self):
52
+ if self._info is None:
53
+ self._info = _read_info(self.path)
54
+ return self._info
55
+
56
+ @property
57
+ def duration(self):
58
+ return float(self.info['format']['duration'])
59
+
60
+ @property
61
+ def _audio_streams(self):
62
+ return [
63
+ index for index, stream in enumerate(self.info["streams"])
64
+ if stream["codec_type"] == "audio"
65
+ ]
66
+
67
+ def __len__(self):
68
+ return len(self._audio_streams)
69
+
70
+ def channels(self, stream=0):
71
+ return int(self.info['streams'][self._audio_streams[stream]]['channels'])
72
+
73
+ def samplerate(self, stream=0):
74
+ return int(self.info['streams'][self._audio_streams[stream]]['sample_rate'])
75
+
76
+ def read(self,
77
+ seek_time=None,
78
+ duration=None,
79
+ streams=slice(None),
80
+ samplerate=None,
81
+ channels=None):
82
+ """
83
+ Slightly more efficient implementation than stempeg,
84
+ in particular, this will extract all stems at once
85
+ rather than having to loop over one file multiple times
86
+ for each stream.
87
+
88
+ Args:
89
+ seek_time (float): seek time in seconds or None if no seeking is needed.
90
+ duration (float): duration in seconds to extract or None to extract until the end.
91
+ streams (slice, int or list): streams to extract, can be a single int, a list or
92
+ a slice. If it is a slice or list, the output will be of size [S, C, T]
93
+ with S the number of streams, C the number of channels and T the number of samples.
94
+ If it is an int, the output will be [C, T].
95
+ samplerate (int): if provided, will resample on the fly. If None, no resampling will
96
+ be done. Original sampling rate can be obtained with :method:`samplerate`.
97
+ channels (int): if 1, will convert to mono. We do not rely on ffmpeg for that
98
+ as ffmpeg automatically scale by +3dB to conserve volume when playing on speakers.
99
+ See https://sound.stackexchange.com/a/42710.
100
+ Our definition of mono is simply the average of the two channels. Any other
101
+ value will be ignored.
102
+ """
103
+ streams = np.array(range(len(self)))[streams]
104
+ single = not isinstance(streams, np.ndarray)
105
+ if single:
106
+ streams = [streams]
107
+
108
+ if duration is None:
109
+ target_size = None
110
+ query_duration = None
111
+ else:
112
+ target_size = int((samplerate or self.samplerate()) * duration)
113
+ query_duration = float((target_size + 1) / (samplerate or self.samplerate()))
114
+
115
+ with temp_filenames(len(streams)) as filenames:
116
+ command = ['ffmpeg', '-y']
117
+ command += ['-loglevel', 'panic']
118
+ if seek_time:
119
+ command += ['-ss', str(seek_time)]
120
+ command += ['-i', str(self.path)]
121
+ for stream, filename in zip(streams, filenames):
122
+ command += ['-map', f'0:{self._audio_streams[stream]}']
123
+ if query_duration is not None:
124
+ command += ['-t', str(query_duration)]
125
+ command += ['-threads', '1']
126
+ command += ['-f', 'f32le']
127
+ if samplerate is not None:
128
+ command += ['-ar', str(samplerate)]
129
+ command += [filename]
130
+
131
+ sp.run(command, check=True)
132
+ wavs = []
133
+ for filename in filenames:
134
+ wav = np.fromfile(filename, dtype=np.float32)
135
+ wav = torch.from_numpy(wav)
136
+ wav = wav.view(-1, self.channels()).t()
137
+ if channels is not None:
138
+ wav = convert_audio_channels(wav, channels)
139
+ if target_size is not None:
140
+ wav = wav[..., :target_size]
141
+ wavs.append(wav)
142
+ wav = torch.stack(wavs, dim=0)
143
+ if single:
144
+ wav = wav[0]
145
+ return wav
146
+
147
+
148
+ def convert_audio_channels(wav, channels=2):
149
+ """Convert audio to the given number of channels."""
150
+ *shape, src_channels, length = wav.shape
151
+ if src_channels == channels:
152
+ pass
153
+ elif channels == 1:
154
+ # Case 1:
155
+ # The caller asked 1-channel audio, but the stream have multiple
156
+ # channels, downmix all channels.
157
+ wav = wav.mean(dim=-2, keepdim=True)
158
+ elif src_channels == 1:
159
+ # Case 2:
160
+ # The caller asked for multiple channels, but the input file have
161
+ # one single channel, replicate the audio over all channels.
162
+ wav = wav.expand(*shape, channels, length)
163
+ elif src_channels >= channels:
164
+ # Case 3:
165
+ # The caller asked for multiple channels, and the input file have
166
+ # more channels than requested. In that case return the first channels.
167
+ wav = wav[..., :channels, :]
168
+ else:
169
+ # Case 4: What is a reasonable choice here?
170
+ raise ValueError('The audio file has less channels than requested but is not mono.')
171
+ return wav
172
+
173
+
174
+ def convert_audio(wav, from_samplerate, to_samplerate, channels):
175
+ """Convert audio from a given samplerate to a target one and target number of channels."""
176
+ wav = convert_audio_channels(wav, channels)
177
+ return julius.resample_frac(wav, from_samplerate, to_samplerate)
178
+
179
+
180
+ def i16_pcm(wav):
181
+ """Convert audio to 16 bits integer PCM format."""
182
+ if wav.dtype.is_floating_point:
183
+ return (wav.clamp_(-1, 1) * (2**15 - 1)).short()
184
+ else:
185
+ return wav
186
+
187
+
188
+ def f32_pcm(wav):
189
+ """Convert audio to float 32 bits PCM format."""
190
+ if wav.dtype.is_floating_point:
191
+ return wav
192
+ else:
193
+ return wav.float() / (2**15 - 1)
194
+
195
+
196
+ def as_dtype_pcm(wav):
197
+ """Convert audio to either f32 pcm or i16 pcm depending on the given dtype."""
198
+ if wav.dtype.is_floating_point:
199
+ return f32_pcm(wav)
200
+ else:
201
+ return i16_pcm(wav)
202
+
203
+
204
+ def encode_mp3(wav, path, samplerate=44100, bitrate=320, verbose=False):
205
+ """Save given audio as mp3. This should work on all OSes."""
206
+ c, _ = wav.shape
207
+ wav = i16_pcm(wav)
208
+ encoder = lameenc.Encoder()
209
+ encoder.set_bit_rate(bitrate)
210
+ encoder.set_in_sample_rate(samplerate)
211
+ encoder.set_channels(c)
212
+ encoder.set_quality(2) # 2-highest, 7-fastest
213
+ if not verbose:
214
+ encoder.silence()
215
+ wav = wav.data.cpu()
216
+ wav = wav.transpose(0, 1).numpy()
217
+ mp3_data = encoder.encode(wav.tobytes())
218
+ mp3_data += encoder.flush()
219
+ with open(path, "wb") as f:
220
+ f.write(mp3_data)
221
+
222
+
223
+ def prevent_clip(wav, mode='rescale'):
224
+ """
225
+ different strategies for avoiding raw clipping.
226
+ """
227
+ if mode is None or mode == 'none':
228
+ return wav
229
+ assert wav.dtype.is_floating_point, "too late for clipping"
230
+ if mode == 'rescale':
231
+ wav = wav / max(1.01 * wav.abs().max(), 1)
232
+ elif mode == 'clamp':
233
+ wav = wav.clamp(-0.99, 0.99)
234
+ elif mode == 'tanh':
235
+ wav = torch.tanh(wav)
236
+ else:
237
+ raise ValueError(f"Invalid mode {mode}")
238
+ return wav
239
+
240
+
241
+ def save_audio(wav: torch.Tensor,
242
+ path: tp.Union[str, Path],
243
+ samplerate: int,
244
+ bitrate: int = 320,
245
+ clip: tp.Union[str] = 'rescale',
246
+ bits_per_sample: tp.Union[int] = 16,
247
+ as_float: bool = False):
248
+ """Save audio file, automatically preventing clipping if necessary
249
+ based on the given `clip` strategy. If the path ends in `.mp3`, this
250
+ will save as mp3 with the given `bitrate`.
251
+ """
252
+ wav = prevent_clip(wav, mode=clip)
253
+ path = Path(path)
254
+ suffix = path.suffix.lower()
255
+ if suffix == ".mp3":
256
+ encode_mp3(wav, path, samplerate, bitrate, verbose=True)
257
+ elif suffix == ".wav":
258
+ if as_float:
259
+ bits_per_sample = 32
260
+ encoding = 'PCM_F'
261
+ else:
262
+ encoding = 'PCM_S'
263
+ ta.save(str(path), wav, sample_rate=samplerate,
264
+ encoding=encoding, bits_per_sample=bits_per_sample)
265
+ elif suffix == ".flac":
266
+ ta.save(str(path), wav, sample_rate=samplerate, bits_per_sample=bits_per_sample)
267
+ else:
268
+ raise ValueError(f"Invalid suffix for path: {suffix}")
269
+
270
+
271
+ def load_track(track, audio_channels, samplerate):
272
+ errors = {}
273
+ wav = None
274
+
275
+ try:
276
+ wav = AudioFile(track).read(
277
+ streams=0,
278
+ samplerate=samplerate,
279
+ channels=audio_channels)
280
+ except sp.CalledProcessError:
281
+ errors['ffmpeg'] = 'FFmpeg could not read the file.'
282
+
283
+ if wav is None:
284
+ try:
285
+ wav, sr = ta.load(str(track))
286
+ except RuntimeError as err:
287
+ errors['torchaudio'] = err.args[0]
288
+ else:
289
+ wav = convert_audio(wav, sr, samplerate, audio_channels)
290
+
291
+ return wav, errors
third_party/demucs/models/demucs.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : demucs.py
5
+ @Time : 2023/8/8 下午4:36
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Demucs
10
+ """
11
+
12
+ import math
13
+ import typing as tp
14
+
15
+ import julius
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+
20
+ from .states import capture_init
21
+ from .utils import center_trim, unfold
22
+ from .transformer import LayerScale
23
+
24
+
25
+ class BLSTM(nn.Module):
26
+ """
27
+ BiLSTM with same hidden units as input dim.
28
+ If `max_steps` is not None, input will be splitting in overlapping
29
+ chunks and the LSTM applied separately on each chunk.
30
+ """
31
+ def __init__(self, dim, layers=1, max_steps=None, skip=False):
32
+ super().__init__()
33
+ assert max_steps is None or max_steps % 4 == 0
34
+ self.max_steps = max_steps
35
+ self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
36
+ self.linear = nn.Linear(2 * dim, dim)
37
+ self.skip = skip
38
+
39
+ def forward(self, x):
40
+ b, c, t = x.shape
41
+ y = x
42
+ framed = False
43
+ if self.max_steps is not None and t > self.max_steps:
44
+ width = self.max_steps
45
+ stride = width // 2
46
+ frames = unfold(x, width, stride)
47
+ nframes = frames.shape[2]
48
+ framed = True
49
+ x = frames.permute(0, 2, 1, 3).reshape(-1, c, width)
50
+
51
+ x = x.permute(2, 0, 1)
52
+
53
+ x = self.lstm(x)[0]
54
+ x = self.linear(x)
55
+ x = x.permute(1, 2, 0)
56
+ if framed:
57
+ out = []
58
+ frames = x.reshape(b, -1, c, width)
59
+ limit = stride // 2
60
+ for k in range(nframes):
61
+ if k == 0:
62
+ out.append(frames[:, k, :, :-limit])
63
+ elif k == nframes - 1:
64
+ out.append(frames[:, k, :, limit:])
65
+ else:
66
+ out.append(frames[:, k, :, limit:-limit])
67
+ out = torch.cat(out, -1)
68
+ out = out[..., :t]
69
+ x = out
70
+ if self.skip:
71
+ x = x + y
72
+ return x
73
+
74
+
75
+ def rescale_conv(conv, reference):
76
+ """Rescale initial weight scale. It is unclear why it helps but it certainly does.
77
+ """
78
+ std = conv.weight.std().detach()
79
+ scale = (std / reference)**0.5
80
+ conv.weight.data /= scale
81
+ if conv.bias is not None:
82
+ conv.bias.data /= scale
83
+
84
+
85
+ def rescale_module(module, reference):
86
+ for sub in module.modules():
87
+ if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)):
88
+ rescale_conv(sub, reference)
89
+
90
+
91
+ class DConv(nn.Module):
92
+ """
93
+ New residual branches in each encoder layer.
94
+ This alternates dilated convolutions, potentially with LSTMs and attention.
95
+ Also before entering each residual branch, dimension is projected on a smaller subspace,
96
+ e.g. of dim `channels // compress`.
97
+ """
98
+ def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4,
99
+ norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True,
100
+ kernel=3, dilate=True):
101
+ """
102
+ Args:
103
+ channels: input/output channels for residual branch.
104
+ compress: amount of channel compression inside the branch.
105
+ depth: number of layers in the residual branch. Each layer has its own
106
+ projection, and potentially LSTM and attention.
107
+ init: initial scale for LayerNorm.
108
+ norm: use GroupNorm.
109
+ attn: use LocalAttention.
110
+ heads: number of heads for the LocalAttention.
111
+ ndecay: number of decay controls in the LocalAttention.
112
+ lstm: use LSTM.
113
+ gelu: Use GELU activation.
114
+ kernel: kernel size for the (dilated) convolutions.
115
+ dilate: if true, use dilation, increasing with the depth.
116
+ """
117
+
118
+ super().__init__()
119
+ assert kernel % 2 == 1
120
+ self.channels = channels
121
+ self.compress = compress
122
+ self.depth = abs(depth)
123
+ dilate = depth > 0
124
+
125
+ norm_fn: tp.Callable[[int], nn.Module]
126
+ norm_fn = lambda d: nn.Identity() # noqa
127
+ if norm:
128
+ norm_fn = lambda d: nn.GroupNorm(1, d) # noqa
129
+
130
+ hidden = int(channels / compress)
131
+
132
+ act: tp.Type[nn.Module]
133
+ if gelu:
134
+ act = nn.GELU
135
+ else:
136
+ act = nn.ReLU
137
+
138
+ self.layers = nn.ModuleList([])
139
+ for d in range(self.depth):
140
+ dilation = 2 ** d if dilate else 1
141
+ padding = dilation * (kernel // 2)
142
+ mods = [
143
+ nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
144
+ norm_fn(hidden), act(),
145
+ nn.Conv1d(hidden, 2 * channels, 1),
146
+ norm_fn(2 * channels), nn.GLU(1),
147
+ LayerScale(channels, init),
148
+ ]
149
+ if attn:
150
+ mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
151
+ if lstm:
152
+ mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
153
+ layer = nn.Sequential(*mods)
154
+ self.layers.append(layer)
155
+
156
+ def forward(self, x):
157
+ for layer in self.layers:
158
+ x = x + layer(x)
159
+ return x
160
+
161
+
162
+ class LocalState(nn.Module):
163
+ """Local state allows to have attention based only on data (no positional embedding),
164
+ but while setting a constraint on the time window (e.g. decaying penalty term).
165
+
166
+ Also a failed experiments with trying to provide some frequency based attention.
167
+ """
168
+ def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
169
+ super().__init__()
170
+ assert channels % heads == 0, (channels, heads)
171
+ self.heads = heads
172
+ self.nfreqs = nfreqs
173
+ self.ndecay = ndecay
174
+ self.content = nn.Conv1d(channels, channels, 1)
175
+ self.query = nn.Conv1d(channels, channels, 1)
176
+ self.key = nn.Conv1d(channels, channels, 1)
177
+ if nfreqs:
178
+ self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
179
+ if ndecay:
180
+ self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
181
+ # Initialize decay close to zero (there is a sigmoid), for maximum initial window.
182
+ self.query_decay.weight.data *= 0.01
183
+ assert self.query_decay.bias is not None # stupid type checker
184
+ self.query_decay.bias.data[:] = -2
185
+ self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
186
+
187
+ def forward(self, x):
188
+ b, _, t = x.shape
189
+ heads = self.heads
190
+ indexes = torch.arange(t, device=x.device, dtype=x.dtype)
191
+ # left index are keys, right index are queries
192
+ delta = indexes[:, None] - indexes[None, :]
193
+
194
+ queries = self.query(x).view(b, heads, -1, t)
195
+ keys = self.key(x).view(b, heads, -1, t)
196
+ # t are keys, s are queries
197
+ dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
198
+ dots /= keys.shape[2]**0.5
199
+ if self.nfreqs:
200
+ periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
201
+ freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
202
+ freq_q = self.query_freqs(x).view(b, heads, -1, t) / self.nfreqs ** 0.5
203
+ dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
204
+ if self.ndecay:
205
+ decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
206
+ decay_q = self.query_decay(x).view(b, heads, -1, t)
207
+ decay_q = torch.sigmoid(decay_q) / 2
208
+ decay_kernel = - decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
209
+ dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
210
+
211
+ # Kill self reference.
212
+ dots.masked_fill_(torch.eye(t, device=dots.device, dtype=torch.bool), -100)
213
+ weights = torch.softmax(dots, dim=2)
214
+
215
+ content = self.content(x).view(b, heads, -1, t)
216
+ result = torch.einsum("bhts,bhct->bhcs", weights, content)
217
+ if self.nfreqs:
218
+ time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
219
+ result = torch.cat([result, time_sig], 2)
220
+ result = result.reshape(b, -1, t)
221
+ return x + self.proj(result)
222
+
223
+
224
+ class Demucs(nn.Module):
225
+ @capture_init
226
+ def __init__(self,
227
+ sources,
228
+ # Channels
229
+ audio_channels=2,
230
+ channels=64,
231
+ growth=2.,
232
+ # Main structure
233
+ depth=6,
234
+ rewrite=True,
235
+ lstm_layers=0,
236
+ # Convolutions
237
+ kernel_size=8,
238
+ stride=4,
239
+ context=1,
240
+ # Activations
241
+ gelu=True,
242
+ glu=True,
243
+ # Normalization
244
+ norm_starts=4,
245
+ norm_groups=4,
246
+ # DConv residual branch
247
+ dconv_mode=1,
248
+ dconv_depth=2,
249
+ dconv_comp=4,
250
+ dconv_attn=4,
251
+ dconv_lstm=4,
252
+ dconv_init=1e-4,
253
+ # Pre/post processing
254
+ normalize=True,
255
+ resample=True,
256
+ # Weight init
257
+ rescale=0.1,
258
+ # Metadata
259
+ samplerate=44100,
260
+ segment=4 * 10):
261
+ """
262
+ Args:
263
+ sources (list[str]): list of source names
264
+ audio_channels (int): stereo or mono
265
+ channels (int): first convolution channels
266
+ depth (int): number of encoder/decoder layers
267
+ growth (float): multiply (resp divide) number of channels by that
268
+ for each layer of the encoder (resp decoder)
269
+ depth (int): number of layers in the encoder and in the decoder.
270
+ rewrite (bool): add 1x1 convolution to each layer.
271
+ lstm_layers (int): number of lstm layers, 0 = no lstm. Deactivated
272
+ by default, as this is now replaced by the smaller and faster small LSTMs
273
+ in the DConv branches.
274
+ kernel_size (int): kernel size for convolutions
275
+ stride (int): stride for convolutions
276
+ context (int): kernel size of the convolution in the
277
+ decoder before the transposed convolution. If > 1,
278
+ will provide some context from neighboring time steps.
279
+ gelu: use GELU activation function.
280
+ glu (bool): use glu instead of ReLU for the 1x1 rewrite conv.
281
+ norm_starts: layer at which group norm starts being used.
282
+ decoder layers are numbered in reverse order.
283
+ norm_groups: number of groups for group norm.
284
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
285
+ dconv_depth: depth of residual DConv branch.
286
+ dconv_comp: compression of DConv branch.
287
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
288
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
289
+ dconv_init: initial scale for the DConv branch LayerScale.
290
+ normalize (bool): normalizes the input audio on the fly, and scales back
291
+ the output by the same amount.
292
+ resample (bool): upsample x2 the input and downsample /2 the output.
293
+ rescale (float): rescale initial weights of convolutions
294
+ to get their standard deviation closer to `rescale`.
295
+ samplerate (int): stored as meta information for easing
296
+ future evaluations of the model.
297
+ segment (float): duration of the chunks of audio to ideally evaluate the model on.
298
+ This is used by `demucs.apply.apply_model`.
299
+ """
300
+
301
+ super().__init__()
302
+ self.audio_channels = audio_channels
303
+ self.sources = sources
304
+ self.kernel_size = kernel_size
305
+ self.context = context
306
+ self.stride = stride
307
+ self.depth = depth
308
+ self.resample = resample
309
+ self.channels = channels
310
+ self.normalize = normalize
311
+ self.samplerate = samplerate
312
+ self.segment = segment
313
+ self.encoder = nn.ModuleList()
314
+ self.decoder = nn.ModuleList()
315
+ self.skip_scales = nn.ModuleList()
316
+
317
+ if glu:
318
+ activation = nn.GLU(dim=1)
319
+ ch_scale = 2
320
+ else:
321
+ activation = nn.ReLU()
322
+ ch_scale = 1
323
+ if gelu:
324
+ act2 = nn.GELU
325
+ else:
326
+ act2 = nn.ReLU
327
+
328
+ in_channels = audio_channels
329
+ padding = 0
330
+ for index in range(depth):
331
+ norm_fn = lambda d: nn.Identity() # noqa
332
+ if index >= norm_starts:
333
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
334
+
335
+ encode = []
336
+ encode += [
337
+ nn.Conv1d(in_channels, channels, kernel_size, stride),
338
+ norm_fn(channels),
339
+ act2(),
340
+ ]
341
+ attn = index >= dconv_attn
342
+ lstm = index >= dconv_lstm
343
+ if dconv_mode & 1:
344
+ encode += [DConv(channels, depth=dconv_depth, init=dconv_init,
345
+ compress=dconv_comp, attn=attn, lstm=lstm)]
346
+ if rewrite:
347
+ encode += [
348
+ nn.Conv1d(channels, ch_scale * channels, 1),
349
+ norm_fn(ch_scale * channels), activation]
350
+ self.encoder.append(nn.Sequential(*encode))
351
+
352
+ decode = []
353
+ if index > 0:
354
+ out_channels = in_channels
355
+ else:
356
+ out_channels = len(self.sources) * audio_channels
357
+ if rewrite:
358
+ decode += [
359
+ nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context),
360
+ norm_fn(ch_scale * channels), activation]
361
+ if dconv_mode & 2:
362
+ decode += [DConv(channels, depth=dconv_depth, init=dconv_init,
363
+ compress=dconv_comp, attn=attn, lstm=lstm)]
364
+ decode += [nn.ConvTranspose1d(channels, out_channels,
365
+ kernel_size, stride, padding=padding)]
366
+ if index > 0:
367
+ decode += [norm_fn(out_channels), act2()]
368
+ self.decoder.insert(0, nn.Sequential(*decode))
369
+ in_channels = channels
370
+ channels = int(growth * channels)
371
+
372
+ channels = in_channels
373
+ if lstm_layers:
374
+ self.lstm = BLSTM(channels, lstm_layers)
375
+ else:
376
+ self.lstm = None
377
+
378
+ if rescale:
379
+ rescale_module(self, reference=rescale)
380
+
381
+ def valid_length(self, length):
382
+ """
383
+ Return the nearest valid length to use with the model so that
384
+ there is no time steps left over in a convolution, e.g. for all
385
+ layers, size of the input - kernel_size % stride = 0.
386
+
387
+ Note that input are automatically padded if necessary to ensure that the output
388
+ has the same length as the input.
389
+ """
390
+ if self.resample:
391
+ length *= 2
392
+
393
+ for _ in range(self.depth):
394
+ length = math.ceil((length - self.kernel_size) / self.stride) + 1
395
+ length = max(1, length)
396
+
397
+ for _ in range(self.depth):
398
+ length = (length - 1) * self.stride + self.kernel_size
399
+
400
+ if self.resample:
401
+ length = math.ceil(length / 2)
402
+ return int(length)
403
+
404
+ def forward(self, mix):
405
+ x = mix
406
+ length = x.shape[-1]
407
+
408
+ if self.normalize:
409
+ mono = mix.mean(dim=1, keepdim=True)
410
+ mean = mono.mean(dim=-1, keepdim=True)
411
+ std = mono.std(dim=-1, keepdim=True)
412
+ x = (x - mean) / (1e-5 + std)
413
+ else:
414
+ mean = 0
415
+ std = 1
416
+
417
+ delta = self.valid_length(length) - length
418
+ x = F.pad(x, (delta // 2, delta - delta // 2))
419
+
420
+ if self.resample:
421
+ x = julius.resample_frac(x, 1, 2)
422
+
423
+ saved = []
424
+ for encode in self.encoder:
425
+ x = encode(x)
426
+ saved.append(x)
427
+
428
+ if self.lstm:
429
+ x = self.lstm(x)
430
+
431
+ for decode in self.decoder:
432
+ skip = saved.pop(-1)
433
+ skip = center_trim(skip, x)
434
+ x = decode(x + skip)
435
+
436
+ if self.resample:
437
+ x = julius.resample_frac(x, 2, 1)
438
+ x = x * std + mean
439
+ x = center_trim(x, length)
440
+ x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
441
+ return x
442
+
443
+ def load_state_dict(self, state, strict=True):
444
+ # fix a mismatch with previous generation Demucs models.
445
+ for idx in range(self.depth):
446
+ for a in ['encoder', 'decoder']:
447
+ for b in ['bias', 'weight']:
448
+ new = f'{a}.{idx}.3.{b}'
449
+ old = f'{a}.{idx}.2.{b}'
450
+ if old in state and new not in state:
451
+ state[new] = state.pop(old)
452
+ super().load_state_dict(state, strict=strict)
third_party/demucs/models/htdemucs.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : htdemucs.py
5
+ @Time : 2023/8/8 下午4:27
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : The spectrogram and Hybrid version of Demucs
10
+ """
11
+
12
+ import math
13
+ import typing as tp
14
+ from copy import deepcopy
15
+ from fractions import Fraction
16
+
17
+ import torch
18
+ from torch import nn
19
+ from torch.nn import functional as F
20
+ from einops import rearrange
21
+ from openunmix.filtering import wiener
22
+
23
+ from .transformer import CrossTransformerEncoder
24
+ from .demucs import DConv, rescale_module
25
+ from .states import capture_init
26
+ from .spec import spectro, ispectro
27
+
28
+
29
+ def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = 'constant', value: float = 0.):
30
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
31
+ If this is the case, we insert extra 0 padding to the right before the reflection happen."""
32
+ x0 = x
33
+ length = x.shape[-1]
34
+ padding_left, padding_right = paddings
35
+ if mode == 'reflect':
36
+ max_pad = max(padding_left, padding_right)
37
+ if length <= max_pad:
38
+ extra_pad = max_pad - length + 1
39
+ extra_pad_right = min(padding_right, extra_pad)
40
+ extra_pad_left = extra_pad - extra_pad_right
41
+ paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
42
+ x = F.pad(x, (extra_pad_left, extra_pad_right))
43
+ out = F.pad(x, paddings, mode, value)
44
+ assert out.shape[-1] == length + padding_left + padding_right
45
+ assert (out[..., padding_left: padding_left + length] == x0).all()
46
+ return out
47
+
48
+
49
+ class ScaledEmbedding(nn.Module):
50
+ """
51
+ Boost learning rate for embeddings (with `scale`).
52
+ Also, can make embeddings continuous with `smooth`.
53
+ """
54
+ def __init__(self, num_embeddings: int, embedding_dim: int,
55
+ scale: float = 10., smooth=False):
56
+ super().__init__()
57
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
58
+ if smooth:
59
+ weight = torch.cumsum(self.embedding.weight.data, dim=0)
60
+ # when summing gaussian, overscale raises as sqrt(n), so we nornalize by that.
61
+ weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
62
+ self.embedding.weight.data[:] = weight
63
+ self.embedding.weight.data /= scale
64
+ self.scale = scale
65
+
66
+ @property
67
+ def weight(self):
68
+ return self.embedding.weight * self.scale
69
+
70
+ def forward(self, x):
71
+ out = self.embedding(x) * self.scale
72
+ return out
73
+
74
+
75
+ class HEncLayer(nn.Module):
76
+ def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False,
77
+ freq=True, dconv=True, norm=True, context=0, dconv_kw=None, pad=True,
78
+ rewrite=True):
79
+ """Encoder layer. This used both by the time and the frequency branch.
80
+ """
81
+ super().__init__()
82
+ norm_fn = lambda d: nn.Identity() # noqa
83
+ if norm:
84
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
85
+ if pad:
86
+ pad = kernel_size // 4
87
+ else:
88
+ pad = 0
89
+ klass = nn.Conv1d
90
+ self.freq = freq
91
+ self.kernel_size = kernel_size
92
+ self.stride = stride
93
+ self.empty = empty
94
+ self.norm = norm
95
+ self.pad = pad
96
+ if freq:
97
+ kernel_size = [kernel_size, 1]
98
+ stride = [stride, 1]
99
+ pad = [pad, 0]
100
+ klass = nn.Conv2d
101
+ self.conv = klass(chin, chout, kernel_size, stride, pad)
102
+ if self.empty:
103
+ return
104
+ self.norm1 = norm_fn(chout)
105
+ self.rewrite = None
106
+ if rewrite:
107
+ self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
108
+ self.norm2 = norm_fn(2 * chout)
109
+
110
+ self.dconv = None
111
+ if dconv:
112
+ self.dconv = DConv(chout, **dconv_kw)
113
+
114
+ def forward(self, x, inject=None):
115
+ """
116
+ `inject` is used to inject the result from the time branch into the frequency branch,
117
+ when both have the same stride.
118
+ """
119
+ if not self.freq and x.dim() == 4:
120
+ b, c, fr, t = x.shape
121
+ x = x.view(b, -1, t)
122
+
123
+ if not self.freq:
124
+ le = x.shape[-1]
125
+ if not le % self.stride == 0:
126
+ x = F.pad(x, (0, self.stride - (le % self.stride)))
127
+ y = self.conv(x)
128
+ if self.empty:
129
+ return y
130
+ if inject is not None:
131
+ assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
132
+ if inject.dim() == 3 and y.dim() == 4:
133
+ inject = inject[:, :, None]
134
+ y = y + inject
135
+ y = F.gelu(self.norm1(y))
136
+ if self.dconv:
137
+ if self.freq:
138
+ b, c, fr, t = y.shape
139
+ y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
140
+ y = self.dconv(y)
141
+ if self.freq:
142
+ y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
143
+ if self.rewrite:
144
+ z = self.norm2(self.rewrite(y))
145
+ z = F.glu(z, dim=1)
146
+ else:
147
+ z = y
148
+ return z
149
+
150
+
151
+ class MultiWrap(nn.Module):
152
+ """
153
+ Takes one layer and replicate it N times. each replica will act
154
+ on a frequency band. All is done so that if the N replica have the same weights,
155
+ then this is exactly equivalent to applying the original module on all frequencies.
156
+ """
157
+ def __init__(self, layer, split_ratios):
158
+ super().__init__()
159
+ self.split_ratios = split_ratios
160
+ self.layers = nn.ModuleList()
161
+ self.conv = isinstance(layer, HEncLayer)
162
+ assert not layer.norm
163
+ assert layer.freq
164
+ assert layer.pad
165
+ if not self.conv:
166
+ assert not layer.context_freq
167
+ for _ in range(len(split_ratios) + 1):
168
+ lay = deepcopy(layer)
169
+ if self.conv:
170
+ lay.conv.padding = (0, 0)
171
+ else:
172
+ lay.pad = False
173
+ for m in lay.modules():
174
+ if hasattr(m, 'reset_parameters'):
175
+ m.reset_parameters()
176
+ self.layers.append(lay)
177
+
178
+ def forward(self, x, skip=None, length=None):
179
+ _, _, fr, _ = x.shape
180
+
181
+ ratios = list(self.split_ratios) + [1]
182
+ start = 0
183
+ outs = []
184
+ for ratio, layer in zip(ratios, self.layers):
185
+ if self.conv:
186
+ pad = layer.kernel_size // 4
187
+ if ratio == 1:
188
+ limit = fr
189
+ frames = -1
190
+ else:
191
+ limit = int(round(fr * ratio))
192
+ le = limit - start
193
+ if start == 0:
194
+ le += pad
195
+ frames = round((le - layer.kernel_size) / layer.stride + 1)
196
+ limit = start + (frames - 1) * layer.stride + layer.kernel_size
197
+ if start == 0:
198
+ limit -= pad
199
+ assert limit - start > 0, (limit, start)
200
+ assert limit <= fr, (limit, fr)
201
+ y = x[:, :, start:limit, :]
202
+ if start == 0:
203
+ y = F.pad(y, (0, 0, pad, 0))
204
+ if ratio == 1:
205
+ y = F.pad(y, (0, 0, 0, pad))
206
+ outs.append(layer(y))
207
+ start = limit - layer.kernel_size + layer.stride
208
+ else:
209
+ if ratio == 1:
210
+ limit = fr
211
+ else:
212
+ limit = int(round(fr * ratio))
213
+ last = layer.last
214
+ layer.last = True
215
+
216
+ y = x[:, :, start:limit]
217
+ s = skip[:, :, start:limit]
218
+ out, _ = layer(y, s, None)
219
+ if outs:
220
+ outs[-1][:, :, -layer.stride:] += (
221
+ out[:, :, :layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1))
222
+ out = out[:, :, layer.stride:]
223
+ if ratio == 1:
224
+ out = out[:, :, :-layer.stride // 2, :]
225
+ if start == 0:
226
+ out = out[:, :, layer.stride // 2:, :]
227
+ outs.append(out)
228
+ layer.last = last
229
+ start = limit
230
+ out = torch.cat(outs, dim=2)
231
+ if not self.conv and not last:
232
+ out = F.gelu(out)
233
+ if self.conv:
234
+ return out
235
+ else:
236
+ return out, None
237
+
238
+
239
+ class HDecLayer(nn.Module):
240
+ def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False,
241
+ freq=True, dconv=True, norm=True, context=1, dconv_kw=None, pad=True,
242
+ context_freq=True, rewrite=True):
243
+ """
244
+ Same as HEncLayer but for decoder. See `HEncLayer` for documentation.
245
+ """
246
+ super().__init__()
247
+ norm_fn = lambda d: nn.Identity() # noqa
248
+ if norm:
249
+ norm_fn = lambda d: nn.GroupNorm(norm_groups, d) # noqa
250
+ if pad:
251
+ pad = kernel_size // 4
252
+ else:
253
+ pad = 0
254
+ self.pad = pad
255
+ self.last = last
256
+ self.freq = freq
257
+ self.chin = chin
258
+ self.empty = empty
259
+ self.stride = stride
260
+ self.kernel_size = kernel_size
261
+ self.norm = norm
262
+ self.context_freq = context_freq
263
+ klass = nn.Conv1d
264
+ klass_tr = nn.ConvTranspose1d
265
+ if freq:
266
+ kernel_size = [kernel_size, 1]
267
+ stride = [stride, 1]
268
+ klass = nn.Conv2d
269
+ klass_tr = nn.ConvTranspose2d
270
+ self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
271
+ self.norm2 = norm_fn(chout)
272
+ if self.empty:
273
+ return
274
+ self.rewrite = None
275
+ if rewrite:
276
+ if context_freq:
277
+ self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
278
+ else:
279
+ self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1,
280
+ [0, context])
281
+ self.norm1 = norm_fn(2 * chin)
282
+
283
+ self.dconv = None
284
+ if dconv:
285
+ self.dconv = DConv(chin, **dconv_kw)
286
+
287
+ def forward(self, x, skip, length):
288
+ if self.freq and x.dim() == 3:
289
+ b, c, t = x.shape
290
+ x = x.view(b, self.chin, -1, t)
291
+
292
+ if not self.empty:
293
+ x = x + skip
294
+
295
+ if self.rewrite:
296
+ y = F.glu(self.norm1(self.rewrite(x)), dim=1)
297
+ else:
298
+ y = x
299
+ if self.dconv:
300
+ if self.freq:
301
+ b, c, fr, t = y.shape
302
+ y = y.permute(0, 2, 1, 3).reshape(-1, c, t)
303
+ y = self.dconv(y)
304
+ if self.freq:
305
+ y = y.view(b, fr, c, t).permute(0, 2, 1, 3)
306
+ else:
307
+ y = x
308
+ assert skip is None
309
+ z = self.norm2(self.conv_tr(y))
310
+ if self.freq:
311
+ if self.pad:
312
+ z = z[..., self.pad:-self.pad, :]
313
+ else:
314
+ z = z[..., self.pad:self.pad + length]
315
+ assert z.shape[-1] == length, (z.shape[-1], length)
316
+ if not self.last:
317
+ z = F.gelu(z)
318
+ return z, y
319
+
320
+
321
+ class HTDemucs(nn.Module):
322
+ """
323
+ Spectrogram and hybrid Demucs model.
324
+ The spectrogram model has the same structure as Demucs, except the first few layers are over the
325
+ frequency axis, until there is only 1 frequency, and then it moves to time convolutions.
326
+ Frequency layers can still access information across time steps thanks to the DConv residual.
327
+
328
+ Hybrid model have a parallel time branch. At some layer, the time branch has the same stride
329
+ as the frequency branch and then the two are combined. The opposite happens in the decoder.
330
+
331
+ Models can either use naive iSTFT from masking, Wiener filtering ([Ulhih et al. 2017]),
332
+ or complex as channels (CaC) [Choi et al. 2020]. Wiener filtering is based on
333
+ Open Unmix implementation [Stoter et al. 2019].
334
+
335
+ The loss is always on the temporal domain, by backpropagating through the above
336
+ output methods and iSTFT. This allows to define hybrid models nicely. However, this breaks
337
+ a bit Wiener filtering, as doing more iteration at test time will change the spectrogram
338
+ contribution, without changing the one from the waveform, which will lead to worse performance.
339
+ I tried using the residual option in OpenUnmix Wiener implementation, but it didn't improve.
340
+ CaC on the other hand provides similar performance for hybrid, and works naturally with
341
+ hybrid models.
342
+
343
+ This model also uses frequency embeddings are used to improve efficiency on convolutions
344
+ over the freq. axis, following [Isik et al. 2020] (https://arxiv.org/pdf/2008.04470.pdf).
345
+
346
+ Unlike classic Demucs, there is no resampling here, and normalization is always applied.
347
+ """
348
+
349
+ @capture_init
350
+ def __init__(
351
+ self,
352
+ sources,
353
+ # Channels
354
+ audio_channels=2,
355
+ channels=48,
356
+ channels_time=None,
357
+ growth=2,
358
+ # STFT
359
+ nfft=4096,
360
+ wiener_iters=0,
361
+ end_iters=0,
362
+ wiener_residual=False,
363
+ cac=True,
364
+ # Main structure
365
+ depth=4,
366
+ rewrite=True,
367
+ # Frequency branch
368
+ multi_freqs=None,
369
+ multi_freqs_depth=3,
370
+ freq_emb=0.2,
371
+ emb_scale=10,
372
+ emb_smooth=True,
373
+ # Convolutions
374
+ kernel_size=8,
375
+ time_stride=2,
376
+ stride=4,
377
+ context=1,
378
+ context_enc=0,
379
+ # Normalization
380
+ norm_starts=4,
381
+ norm_groups=4,
382
+ # DConv residual branch
383
+ dconv_mode=1,
384
+ dconv_depth=2,
385
+ dconv_comp=8,
386
+ dconv_init=1e-3,
387
+ # Before the Transformer
388
+ bottom_channels=0,
389
+ # Transformer
390
+ t_layers=5,
391
+ t_emb="sin",
392
+ t_hidden_scale=4.0,
393
+ t_heads=8,
394
+ t_dropout=0.0,
395
+ t_max_positions=10000,
396
+ t_norm_in=True,
397
+ t_norm_in_group=False,
398
+ t_group_norm=False,
399
+ t_norm_first=True,
400
+ t_norm_out=True,
401
+ t_max_period=10000.0,
402
+ t_weight_decay=0.0,
403
+ t_lr=None,
404
+ t_layer_scale=True,
405
+ t_gelu=True,
406
+ t_weight_pos_embed=1.0,
407
+ t_sin_random_shift=0,
408
+ t_cape_mean_normalize=True,
409
+ t_cape_augment=True,
410
+ t_cape_glob_loc_scale=None,
411
+ t_sparse_self_attn=False,
412
+ t_sparse_cross_attn=False,
413
+ t_mask_type="diag",
414
+ t_mask_random_seed=42,
415
+ t_sparse_attn_window=500,
416
+ t_global_window=100,
417
+ t_sparsity=0.95,
418
+ t_auto_sparsity=False,
419
+ # ------ Particuliar parameters
420
+ t_cross_first=False,
421
+ # Weight init
422
+ rescale=0.1,
423
+ # Metadata
424
+ samplerate=44100,
425
+ segment=10,
426
+ use_train_segment=True,
427
+ ):
428
+ """
429
+ Args:
430
+ sources (list[str]): list of source names.
431
+ audio_channels (int): input/output audio channels.
432
+ channels (int): initial number of hidden channels.
433
+ channels_time: if not None, use a different `channels` value for the time branch.
434
+ growth: increase the number of hidden channels by this factor at each layer.
435
+ nfft: number of fft bins. Note that changing this require careful computation of
436
+ various shape parameters and will not work out of the box for hybrid models.
437
+ wiener_iters: when using Wiener filtering, number of iterations at test time.
438
+ end_iters: same but at train time. For a hybrid model, must be equal to `wiener_iters`.
439
+ wiener_residual: add residual source before wiener filtering.
440
+ cac: uses complex as channels, i.e. complex numbers are 2 channels each
441
+ in input and output. no further processing is done before ISTFT.
442
+ depth (int): number of layers in the encoder and in the decoder.
443
+ rewrite (bool): add 1x1 convolution to each layer.
444
+ multi_freqs: list of frequency ratios for splitting frequency bands with `MultiWrap`.
445
+ multi_freqs_depth: how many layers to wrap with `MultiWrap`. Only the outermost
446
+ layers will be wrapped.
447
+ freq_emb: add frequency embedding after the first frequency layer if > 0,
448
+ the actual value controls the weight of the embedding.
449
+ emb_scale: equivalent to scaling the embedding learning rate
450
+ emb_smooth: initialize the embedding with a smooth one (with respect to frequencies).
451
+ kernel_size: kernel_size for encoder and decoder layers.
452
+ stride: stride for encoder and decoder layers.
453
+ time_stride: stride for the final time layer, after the merge.
454
+ context: context for 1x1 conv in the decoder.
455
+ context_enc: context for 1x1 conv in the encoder.
456
+ norm_starts: layer at which group norm starts being used.
457
+ decoder layers are numbered in reverse order.
458
+ norm_groups: number of groups for group norm.
459
+ dconv_mode: if 1: dconv in encoder only, 2: decoder only, 3: both.
460
+ dconv_depth: depth of residual DConv branch.
461
+ dconv_comp: compression of DConv branch.
462
+ dconv_attn: adds attention layers in DConv branch starting at this layer.
463
+ dconv_lstm: adds a LSTM layer in DConv branch starting at this layer.
464
+ dconv_init: initial scale for the DConv branch LayerScale.
465
+ bottom_channels: if >0 it adds a linear layer (1x1 Conv) before and after the
466
+ transformer in order to change the number of channels
467
+ t_layers: number of layers in each branch (waveform and spec) of the transformer
468
+ t_emb: "sin", "cape" or "scaled"
469
+ t_hidden_scale: the hidden scale of the Feedforward parts of the transformer
470
+ for instance if C = 384 (the number of channels in the transformer) and
471
+ t_hidden_scale = 4.0 then the intermediate layer of the FFN has dimension
472
+ 384 * 4 = 1536
473
+ t_heads: number of heads for the transformer
474
+ t_dropout: dropout in the transformer
475
+ t_max_positions: max_positions for the "scaled" positional embedding, only
476
+ useful if t_emb="scaled"
477
+ t_norm_in: (bool) norm before addinf positional embedding and getting into the
478
+ transformer layers
479
+ t_norm_in_group: (bool) if True while t_norm_in=True, the norm is on all the
480
+ timesteps (GroupNorm with group=1)
481
+ t_group_norm: (bool) if True, the norms of the Encoder Layers are on all the
482
+ timesteps (GroupNorm with group=1)
483
+ t_norm_first: (bool) if True the norm is before the attention and before the FFN
484
+ t_norm_out: (bool) if True, there is a GroupNorm (group=1) at the end of each layer
485
+ t_max_period: (float) denominator in the sinusoidal embedding expression
486
+ t_weight_decay: (float) weight decay for the transformer
487
+ t_lr: (float) specific learning rate for the transformer
488
+ t_layer_scale: (bool) Layer Scale for the transformer
489
+ t_gelu: (bool) activations of the transformer are GeLU if True, ReLU else
490
+ t_weight_pos_embed: (float) weighting of the positional embedding
491
+ t_cape_mean_normalize: (bool) if t_emb="cape", normalisation of positional embeddings
492
+ see: https://arxiv.org/abs/2106.03143
493
+ t_cape_augment: (bool) if t_emb="cape", must be True during training and False
494
+ during the inference, see: https://arxiv.org/abs/2106.03143
495
+ t_cape_glob_loc_scale: (list of 3 floats) if t_emb="cape", CAPE parameters
496
+ see: https://arxiv.org/abs/2106.03143
497
+ t_sparse_self_attn: (bool) if True, the self attentions are sparse
498
+ t_sparse_cross_attn: (bool) if True, the cross-attentions are sparse (don't use it
499
+ unless you designed really specific masks)
500
+ t_mask_type: (str) can be "diag", "jmask", "random", "global" or any combination
501
+ with '_' between: i.e. "diag_jmask_random" (note that this is permutation
502
+ invariant i.e. "diag_jmask_random" is equivalent to "jmask_random_diag")
503
+ t_mask_random_seed: (int) if "random" is in t_mask_type, controls the seed
504
+ that generated the random part of the mask
505
+ t_sparse_attn_window: (int) if "diag" is in t_mask_type, for a query (i), and
506
+ a key (j), the mask is True id |i-j|<=t_sparse_attn_window
507
+ t_global_window: (int) if "global" is in t_mask_type, mask[:t_global_window, :]
508
+ and mask[:, :t_global_window] will be True
509
+ t_sparsity: (float) if "random" is in t_mask_type, t_sparsity is the sparsity
510
+ level of the random part of the mask.
511
+ t_cross_first: (bool) if True cross attention is the first layer of the
512
+ transformer (False seems to be better)
513
+ rescale: weight rescaling trick
514
+ use_train_segment: (bool) if True, the actual size that is used during the
515
+ training is used during inference.
516
+ """
517
+ super().__init__()
518
+ self.cac = cac
519
+ self.wiener_residual = wiener_residual
520
+ self.audio_channels = audio_channels
521
+ self.sources = sources
522
+ self.kernel_size = kernel_size
523
+ self.context = context
524
+ self.stride = stride
525
+ self.depth = depth
526
+ self.bottom_channels = bottom_channels
527
+ self.channels = channels
528
+ self.samplerate = samplerate
529
+ self.segment = segment
530
+ self.use_train_segment = use_train_segment
531
+ self.nfft = nfft
532
+ self.hop_length = nfft // 4
533
+ self.wiener_iters = wiener_iters
534
+ self.end_iters = end_iters
535
+ self.freq_emb = None
536
+ assert wiener_iters == end_iters
537
+
538
+ self.encoder = nn.ModuleList()
539
+ self.decoder = nn.ModuleList()
540
+
541
+ self.tencoder = nn.ModuleList()
542
+ self.tdecoder = nn.ModuleList()
543
+
544
+ chin = audio_channels
545
+ chin_z = chin # number of channels for the freq branch
546
+ if self.cac:
547
+ chin_z *= 2
548
+ chout = channels_time or channels
549
+ chout_z = channels
550
+ freqs = nfft // 2
551
+
552
+ for index in range(depth):
553
+ norm = index >= norm_starts
554
+ freq = freqs > 1
555
+ stri = stride
556
+ ker = kernel_size
557
+ if not freq:
558
+ assert freqs == 1
559
+ ker = time_stride * 2
560
+ stri = time_stride
561
+
562
+ pad = True
563
+ last_freq = False
564
+ if freq and freqs <= kernel_size:
565
+ ker = freqs
566
+ pad = False
567
+ last_freq = True
568
+
569
+ kw = {
570
+ "kernel_size": ker,
571
+ "stride": stri,
572
+ "freq": freq,
573
+ "pad": pad,
574
+ "norm": norm,
575
+ "rewrite": rewrite,
576
+ "norm_groups": norm_groups,
577
+ "dconv_kw": {
578
+ "depth": dconv_depth,
579
+ "compress": dconv_comp,
580
+ "init": dconv_init,
581
+ "gelu": True,
582
+ },
583
+ }
584
+ kwt = dict(kw)
585
+ kwt["freq"] = 0
586
+ kwt["kernel_size"] = kernel_size
587
+ kwt["stride"] = stride
588
+ kwt["pad"] = True
589
+ kw_dec = dict(kw)
590
+ multi = False
591
+ if multi_freqs and index < multi_freqs_depth:
592
+ multi = True
593
+ kw_dec["context_freq"] = False
594
+
595
+ if last_freq:
596
+ chout_z = max(chout, chout_z)
597
+ chout = chout_z
598
+
599
+ enc = HEncLayer(
600
+ chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw
601
+ )
602
+ if freq:
603
+ tenc = HEncLayer(
604
+ chin,
605
+ chout,
606
+ dconv=dconv_mode & 1,
607
+ context=context_enc,
608
+ empty=last_freq,
609
+ **kwt
610
+ )
611
+ self.tencoder.append(tenc)
612
+
613
+ if multi:
614
+ enc = MultiWrap(enc, multi_freqs)
615
+ self.encoder.append(enc)
616
+ if index == 0:
617
+ chin = self.audio_channels * len(self.sources)
618
+ chin_z = chin
619
+ if self.cac:
620
+ chin_z *= 2
621
+ dec = HDecLayer(
622
+ chout_z,
623
+ chin_z,
624
+ dconv=dconv_mode & 2,
625
+ last=index == 0,
626
+ context=context,
627
+ **kw_dec
628
+ )
629
+ if multi:
630
+ dec = MultiWrap(dec, multi_freqs)
631
+ if freq:
632
+ tdec = HDecLayer(
633
+ chout,
634
+ chin,
635
+ dconv=dconv_mode & 2,
636
+ empty=last_freq,
637
+ last=index == 0,
638
+ context=context,
639
+ **kwt
640
+ )
641
+ self.tdecoder.insert(0, tdec)
642
+ self.decoder.insert(0, dec)
643
+
644
+ chin = chout
645
+ chin_z = chout_z
646
+ chout = int(growth * chout)
647
+ chout_z = int(growth * chout_z)
648
+ if freq:
649
+ if freqs <= kernel_size:
650
+ freqs = 1
651
+ else:
652
+ freqs //= stride
653
+ if index == 0 and freq_emb:
654
+ self.freq_emb = ScaledEmbedding(
655
+ freqs, chin_z, smooth=emb_smooth, scale=emb_scale
656
+ )
657
+ self.freq_emb_scale = freq_emb
658
+
659
+ if rescale:
660
+ rescale_module(self, reference=rescale)
661
+
662
+ transformer_channels = channels * growth ** (depth - 1)
663
+ if bottom_channels:
664
+ self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
665
+ self.channel_downsampler = nn.Conv1d(
666
+ bottom_channels, transformer_channels, 1
667
+ )
668
+ self.channel_upsampler_t = nn.Conv1d(
669
+ transformer_channels, bottom_channels, 1
670
+ )
671
+ self.channel_downsampler_t = nn.Conv1d(
672
+ bottom_channels, transformer_channels, 1
673
+ )
674
+
675
+ transformer_channels = bottom_channels
676
+
677
+ if t_layers > 0:
678
+ if t_cape_glob_loc_scale is None:
679
+ t_cape_glob_loc_scale = [5000.0, 1.0, 1.4]
680
+ self.crosstransformer = CrossTransformerEncoder(
681
+ dim=transformer_channels,
682
+ emb=t_emb,
683
+ hidden_scale=t_hidden_scale,
684
+ num_heads=t_heads,
685
+ num_layers=t_layers,
686
+ cross_first=t_cross_first,
687
+ dropout=t_dropout,
688
+ max_positions=t_max_positions,
689
+ norm_in=t_norm_in,
690
+ norm_in_group=t_norm_in_group,
691
+ group_norm=t_group_norm,
692
+ norm_first=t_norm_first,
693
+ norm_out=t_norm_out,
694
+ max_period=t_max_period,
695
+ weight_decay=t_weight_decay,
696
+ lr=t_lr,
697
+ layer_scale=t_layer_scale,
698
+ gelu=t_gelu,
699
+ sin_random_shift=t_sin_random_shift,
700
+ weight_pos_embed=t_weight_pos_embed,
701
+ cape_mean_normalize=t_cape_mean_normalize,
702
+ cape_augment=t_cape_augment,
703
+ cape_glob_loc_scale=t_cape_glob_loc_scale,
704
+ sparse_self_attn=t_sparse_self_attn,
705
+ sparse_cross_attn=t_sparse_cross_attn,
706
+ mask_type=t_mask_type,
707
+ mask_random_seed=t_mask_random_seed,
708
+ sparse_attn_window=t_sparse_attn_window,
709
+ global_window=t_global_window,
710
+ sparsity=t_sparsity,
711
+ auto_sparsity=t_auto_sparsity,
712
+ )
713
+ else:
714
+ self.crosstransformer = None
715
+
716
+ def _spec(self, x):
717
+ hl = self.hop_length
718
+ nfft = self.nfft
719
+
720
+ # We re-pad the signal in order to keep the property
721
+ # that the size of the output is exactly the size of the input
722
+ # divided by the stride (here hop_length), when divisible.
723
+ # This is achieved by padding by 1/4th of the kernel size (here nfft).
724
+ # which is not supported by torch.stft.
725
+ # Having all convolution operations follow this convention allow to easily
726
+ # align the time and frequency branches later on.
727
+ assert hl == nfft // 4
728
+ le = int(math.ceil(x.shape[-1] / hl))
729
+ pad = hl // 2 * 3
730
+ x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
731
+
732
+ z = spectro(x, nfft, hl)[..., :-1, :]
733
+ assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
734
+ z = z[..., 2: 2 + le]
735
+ return z
736
+
737
+ def _ispec(self, z, length=None, scale=0):
738
+ hl = self.hop_length // (4**scale)
739
+ z = F.pad(z, (0, 0, 0, 1))
740
+ z = F.pad(z, (2, 2))
741
+ pad = hl // 2 * 3
742
+ le = hl * int(math.ceil(length / hl)) + 2 * pad
743
+ x = ispectro(z, hl, length=le)
744
+ x = x[..., pad: pad + length]
745
+ return x
746
+
747
+ def _magnitude(self, z):
748
+ # return the magnitude of the spectrogram, except when cac is True,
749
+ # in which case we just move the complex dimension to the channel one.
750
+ if self.cac:
751
+ b, c, fr, t = z.shape
752
+ m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
753
+ m = m.reshape(b, c * 2, fr, t)
754
+ else:
755
+ m = z.abs()
756
+ return m
757
+
758
+ def _mask(self, z, m):
759
+ # Apply masking given the mixture spectrogram `z` and the estimated mask `m`.
760
+ # If `cac` is True, `m` is actually a full spectrogram and `z` is ignored.
761
+ niters = self.wiener_iters
762
+ if self.cac:
763
+ b, s, _, fr, t = m.shape
764
+ out = m.view(b, s, -1, 2, fr, t).permute(0, 1, 2, 4, 5, 3)
765
+ out = torch.view_as_complex(out.contiguous())
766
+ return out
767
+ if self.training:
768
+ niters = self.end_iters
769
+ if niters < 0:
770
+ z = z[:, None]
771
+ return z / (1e-8 + z.abs()) * m
772
+ else:
773
+ return self._wiener(m, z, niters)
774
+
775
+ def _wiener(self, mag_out, mix_stft, niters):
776
+ # apply wiener filtering from OpenUnmix.
777
+ init = mix_stft.dtype
778
+ wiener_win_len = 300
779
+ residual = self.wiener_residual
780
+
781
+ b, s, c, fq, t = mag_out.shape
782
+ mag_out = mag_out.permute(0, 4, 3, 2, 1)
783
+ mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
784
+
785
+ outs = []
786
+ for sample in range(b):
787
+ pos = 0
788
+ out = []
789
+ for pos in range(0, t, wiener_win_len):
790
+ frame = slice(pos, pos + wiener_win_len)
791
+ z_out = wiener(
792
+ mag_out[sample, frame],
793
+ mix_stft[sample, frame],
794
+ niters,
795
+ residual=residual,
796
+ )
797
+ out.append(z_out.transpose(-1, -2))
798
+ outs.append(torch.cat(out, dim=0))
799
+ out = torch.view_as_complex(torch.stack(outs, 0))
800
+ out = out.permute(0, 4, 3, 2, 1).contiguous()
801
+ if residual:
802
+ out = out[:, :-1]
803
+ assert list(out.shape) == [b, s, c, fq, t]
804
+ return out.to(init)
805
+
806
+ def valid_length(self, length: int):
807
+ """
808
+ Return a length that is appropriate for evaluation.
809
+ In our case, always return the training length, unless
810
+ it is smaller than the given length, in which case this
811
+ raises an error.
812
+ """
813
+ if not self.use_train_segment:
814
+ return length
815
+ training_length = int(self.segment * self.samplerate)
816
+ if training_length < length:
817
+ raise ValueError(
818
+ f"Given length {length} is longer than "
819
+ f"training length {training_length}")
820
+ return training_length
821
+
822
+ def forward(self, mix):
823
+ length = mix.shape[-1]
824
+ length_pre_pad = None
825
+ if self.use_train_segment:
826
+ if self.training:
827
+ self.segment = Fraction(mix.shape[-1], self.samplerate)
828
+ else:
829
+ training_length = int(self.segment * self.samplerate)
830
+ if mix.shape[-1] < training_length:
831
+ length_pre_pad = mix.shape[-1]
832
+ mix = F.pad(mix, (0, training_length - length_pre_pad))
833
+ z = self._spec(mix)
834
+ mag = self._magnitude(z).to(mix.device)
835
+ x = mag
836
+
837
+ b, _, fq, t = x.shape
838
+
839
+ # unlike previous Demucs, we always normalize because it is easier.
840
+ mean = x.mean(dim=(1, 2, 3), keepdim=True)
841
+ std = x.std(dim=(1, 2, 3), keepdim=True)
842
+ x = (x - mean) / (1e-5 + std)
843
+ # x will be the freq. branch input.
844
+
845
+ # Prepare the time branch input.
846
+ xt = mix
847
+ meant = xt.mean(dim=(1, 2), keepdim=True)
848
+ stdt = xt.std(dim=(1, 2), keepdim=True)
849
+ xt = (xt - meant) / (1e-5 + stdt)
850
+
851
+ # okay, this is a giant mess I know...
852
+ saved = [] # skip connections, freq.
853
+ saved_t = [] # skip connections, time.
854
+ lengths = [] # saved lengths to properly remove padding, freq branch.
855
+ lengths_t = [] # saved lengths for time branch.
856
+ for idx, encode in enumerate(self.encoder):
857
+ lengths.append(x.shape[-1])
858
+ inject = None
859
+ if idx < len(self.tencoder):
860
+ # we have not yet merged branches.
861
+ lengths_t.append(xt.shape[-1])
862
+ tenc = self.tencoder[idx]
863
+ xt = tenc(xt)
864
+ if not tenc.empty:
865
+ # save for skip connection
866
+ saved_t.append(xt)
867
+ else:
868
+ # tenc contains just the first conv., so that now time and freq.
869
+ # branches have the same shape and can be merged.
870
+ inject = xt
871
+ x = encode(x, inject)
872
+ if idx == 0 and self.freq_emb is not None:
873
+ # add frequency embedding to allow for non equivariant convolutions
874
+ # over the frequency axis.
875
+ frs = torch.arange(x.shape[-2], device=x.device)
876
+ emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
877
+ x = x + self.freq_emb_scale * emb
878
+
879
+ saved.append(x)
880
+ if self.crosstransformer:
881
+ if self.bottom_channels:
882
+ _, _, f, _ = x.shape
883
+ x = rearrange(x, "b c f t-> b c (f t)")
884
+ x = self.channel_upsampler(x)
885
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
886
+ xt = self.channel_upsampler_t(xt)
887
+
888
+ x, xt = self.crosstransformer(x, xt)
889
+
890
+ if self.bottom_channels:
891
+ x = rearrange(x, "b c f t-> b c (f t)")
892
+ x = self.channel_downsampler(x)
893
+ x = rearrange(x, "b c (f t)-> b c f t", f=f)
894
+ xt = self.channel_downsampler_t(xt)
895
+
896
+ for idx, decode in enumerate(self.decoder):
897
+ skip = saved.pop(-1)
898
+ x, pre = decode(x, skip, lengths.pop(-1))
899
+ # `pre` contains the output just before final transposed convolution,
900
+ # which is used when the freq. and time branch separate.
901
+
902
+ offset = self.depth - len(self.tdecoder)
903
+ if idx >= offset:
904
+ tdec = self.tdecoder[idx - offset]
905
+ length_t = lengths_t.pop(-1)
906
+ if tdec.empty:
907
+ assert pre.shape[2] == 1, pre.shape
908
+ pre = pre[:, :, 0]
909
+ xt, _ = tdec(pre, None, length_t)
910
+ else:
911
+ skip = saved_t.pop(-1)
912
+ xt, _ = tdec(xt, skip, length_t)
913
+
914
+ # Let's make sure we used all stored skip connections.
915
+ assert len(saved) == 0
916
+ assert len(lengths_t) == 0
917
+ assert len(saved_t) == 0
918
+
919
+ s = len(self.sources)
920
+ x = x.view(b, s, -1, fq, t)
921
+ x = x * std[:, None] + mean[:, None]
922
+
923
+ # to cpu as mps doesnt support complex numbers
924
+ # demucs issue #435 ##432
925
+ # NOTE: in this case z already is on cpu
926
+ # TODO: remove this when mps supports complex numbers
927
+ x_is_mps = x.device.type == "mps"
928
+ if x_is_mps:
929
+ x = x.cpu()
930
+
931
+ zout = self._mask(z, x)
932
+ if self.use_train_segment:
933
+ if self.training:
934
+ x = self._ispec(zout, length)
935
+ else:
936
+ x = self._ispec(zout, training_length)
937
+ else:
938
+ x = self._ispec(zout, length)
939
+
940
+ # back to mps device
941
+ if x_is_mps:
942
+ x = x.to("mps")
943
+
944
+ if self.use_train_segment:
945
+ if self.training:
946
+ xt = xt.view(b, s, -1, length)
947
+ else:
948
+ xt = xt.view(b, s, -1, training_length)
949
+ else:
950
+ xt = xt.view(b, s, -1, length)
951
+ xt = xt * stdt[:, None] + meant[:, None]
952
+ x = xt + x
953
+ if length_pre_pad:
954
+ x = x[..., :length_pre_pad]
955
+ return x
third_party/demucs/models/pretrained.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : pretrained.py
5
+ @Time : 2023/8/8 下午7:22
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Loading pretrained models.
10
+ """
11
+ from pathlib import Path
12
+
13
+ import yaml
14
+
15
+ from .apply import BagOfModels
16
+ from .htdemucs import HTDemucs
17
+ from .states import load_state_dict
18
+
19
+
20
+ def add_model_flags(parser):
21
+ group = parser.add_mutually_exclusive_group(required=False)
22
+ group.add_argument("-s", "--sig", help="Locally trained XP signature.")
23
+ group.add_argument("-n", "--name", default=None,
24
+ help="Pretrained model name or signature. Default is htdemucs.")
25
+ parser.add_argument("--repo", type=Path,
26
+ help="Folder containing all pre-trained models for use with -n.")
27
+
28
+
29
+ def get_model_from_yaml(yaml_file, model_file):
30
+ bag = yaml.safe_load(open(yaml_file))
31
+ model = load_state_dict(HTDemucs, model_file)
32
+ weights = bag.get('weights')
33
+ segment = bag.get('segment')
34
+ return BagOfModels([model], weights, segment)
third_party/demucs/models/spec.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : spec.py
5
+ @Time : 2023/8/8 下午5:10
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Spec
10
+ """
11
+
12
+ import torch as th
13
+
14
+
15
+ def spectro(x, n_fft=512, hop_length=None, pad=0):
16
+ *other, length = x.shape
17
+ x = x.reshape(-1, length)
18
+ is_mps = x.device.type == 'mps'
19
+ if is_mps:
20
+ x = x.cpu()
21
+ z = th.stft(x,
22
+ n_fft * (1 + pad),
23
+ hop_length or n_fft // 4,
24
+ window=th.hann_window(n_fft).to(x),
25
+ win_length=n_fft,
26
+ normalized=True,
27
+ center=True,
28
+ return_complex=True,
29
+ pad_mode='reflect')
30
+ _, freqs, frame = z.shape
31
+ return z.view(*other, freqs, frame)
32
+
33
+
34
+ def ispectro(z, hop_length=None, length=None, pad=0):
35
+ *other, freqs, frames = z.shape
36
+ n_fft = 2 * freqs - 2
37
+ z = z.view(-1, freqs, frames)
38
+ win_length = n_fft // (1 + pad)
39
+ is_mps = z.device.type == 'mps'
40
+ if is_mps:
41
+ z = z.cpu()
42
+ x = th.istft(z,
43
+ n_fft,
44
+ hop_length,
45
+ window=th.hann_window(win_length).to(z.real),
46
+ win_length=win_length,
47
+ normalized=True,
48
+ length=length,
49
+ center=True)
50
+ _, length = x.shape
51
+ return x.view(*other, length)
third_party/demucs/models/states.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : states.py
5
+ @Time : 2023/8/8 下午7:01
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Utilities to save and load models.
10
+ """
11
+ import functools
12
+ import inspect
13
+ import warnings
14
+ from pathlib import Path
15
+ from fractions import Fraction
16
+
17
+ import torch
18
+
19
+
20
+ def load_state_dict(net, pth_path):
21
+ kwargs = {'sources': ['drums', 'bass', 'other', 'vocal'], 'audio_channels': 2, 'samplerate': 44100,
22
+ 'segment': Fraction(39, 5), 'channels': 48, 'channels_time': None, 'growth': 2, 'nfft': 4096,
23
+ 'wiener_iters': 0, 'end_iters': 0, 'wiener_residual': False, 'cac': True, 'depth': 4, 'rewrite': True,
24
+ 'multi_freqs': [], 'multi_freqs_depth': 3, 'freq_emb': 0.2, 'emb_scale': 10, 'emb_smooth': True,
25
+ 'kernel_size': 8, 'stride': 4, 'time_stride': 2, 'context': 1, 'context_enc': 0, 'norm_starts': 4,
26
+ 'norm_groups': 4, 'dconv_mode': 3, 'dconv_depth': 2, 'dconv_comp': 8, 'dconv_init': 0.001,
27
+ 'bottom_channels': 512, 't_layers': 5, 't_hidden_scale': 4.0, 't_heads': 8, 't_dropout': 0.02,
28
+ 't_layer_scale': True, 't_gelu': True, 't_emb': 'sin', 't_max_positions': 10000, 't_max_period': 10000.0,
29
+ 't_weight_pos_embed': 1.0, 't_cape_mean_normalize': True, 't_cape_augment': True,
30
+ 't_cape_glob_loc_scale': [5000.0, 1.0, 1.4], 't_sin_random_shift': 0, 't_norm_in': True,
31
+ 't_norm_in_group': False, 't_group_norm': False, 't_norm_first': True, 't_norm_out': True,
32
+ 't_weight_decay': 0.0, 't_lr': None, 't_sparse_self_attn': False, 't_sparse_cross_attn': False,
33
+ 't_mask_type': 'diag', 't_mask_random_seed': 42, 't_sparse_attn_window': 400, 't_global_window': 100,
34
+ 't_sparsity': 0.95, 't_auto_sparsity': False, 't_cross_first': False, 'rescale': 0.1}
35
+ model = net(**kwargs)
36
+ state_dict = torch.load(pth_path)
37
+ model.load_state_dict(state_dict)
38
+ return model
39
+
40
+
41
+ def load_model(path_or_package, strict=False):
42
+ """Load a model from the given serialized model, either given as a dict (already loaded)
43
+ or a path to a file on disk."""
44
+ if isinstance(path_or_package, dict):
45
+ package = path_or_package
46
+ elif isinstance(path_or_package, (str, Path)):
47
+ with warnings.catch_warnings():
48
+ warnings.simplefilter("ignore")
49
+ path = path_or_package
50
+ package = torch.load(path, 'cpu')
51
+ else:
52
+ raise ValueError(f"Invalid type for {path_or_package}.")
53
+
54
+ klass = package["klass"]
55
+ args = package["args"]
56
+ kwargs = package["kwargs"]
57
+
58
+ if strict:
59
+ model = klass(*args, **kwargs)
60
+ else:
61
+ sig = inspect.signature(klass)
62
+ for key in list(kwargs):
63
+ if key not in sig.parameters:
64
+ warnings.warn("Dropping inexistant parameter " + key)
65
+ del kwargs[key]
66
+ model = klass(*args, **kwargs)
67
+
68
+ state = package["state"]
69
+
70
+ set_state(model, state)
71
+ return model
72
+
73
+
74
+ def get_state(model, quantizer, half=False):
75
+ """Get the state from a model, potentially with quantization applied.
76
+ If `half` is True, model are stored as half precision, which shouldn't impact performance
77
+ but half the state size."""
78
+ if quantizer is None:
79
+ dtype = torch.half if half else None
80
+ state = {k: p.data.to(device='cpu', dtype=dtype) for k, p in model.state_dict().items()}
81
+ else:
82
+ state = quantizer.get_quantized_state()
83
+ state['__quantized'] = True
84
+ return state
85
+
86
+
87
+ def set_state(model, state, quantizer=None):
88
+ """Set the state on a given model."""
89
+ if state.get('__quantized'):
90
+ quantizer.restore_quantized_state(model, state['quantized'])
91
+ else:
92
+ model.load_state_dict(state)
93
+ return state
94
+
95
+
96
+ def capture_init(init):
97
+ @functools.wraps(init)
98
+ def __init__(self, *args, **kwargs):
99
+ self._init_args_kwargs = (args, kwargs)
100
+ init(self, *args, **kwargs)
101
+
102
+ return __init__
third_party/demucs/models/transformer.py ADDED
@@ -0,0 +1,765 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : transformer.py
5
+ @Time : 2023/8/8 下午5:05
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : Transformer
10
+ """
11
+ import math
12
+ import random
13
+ import typing as tp
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ import numpy as np
19
+ from einops import rearrange
20
+ from torch.nn import TransformerEncoderLayer, MultiheadAttention, Linear, LayerNorm
21
+
22
+
23
+ def create_sin_embedding(
24
+ length: int, dim: int, shift: int = 0, device="cpu", max_period=10000
25
+ ):
26
+ # We aim for TBC format
27
+ assert dim % 2 == 0
28
+ pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
29
+ half_dim = dim // 2
30
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
31
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
32
+ return torch.cat(
33
+ [
34
+ torch.cos(phase),
35
+ torch.sin(phase),
36
+ ],
37
+ dim=-1,
38
+ )
39
+
40
+
41
+ def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
42
+ """
43
+ :param d_model: dimension of the model
44
+ :param height: height of the positions
45
+ :param width: width of the positions
46
+ :return: d_model*height*width position matrix
47
+ """
48
+ if d_model % 4 != 0:
49
+ raise ValueError(
50
+ "Cannot use sin/cos positional encoding with "
51
+ "odd dimension (got dim={:d})".format(d_model)
52
+ )
53
+ pe = torch.zeros(d_model, height, width)
54
+ # Each dimension use half of d_model
55
+ d_model = int(d_model / 2)
56
+ div_term = torch.exp(
57
+ torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model)
58
+ )
59
+ pos_w = torch.arange(0.0, width).unsqueeze(1)
60
+ pos_h = torch.arange(0.0, height).unsqueeze(1)
61
+ pe[0:d_model:2, :, :] = (
62
+ torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
63
+ )
64
+ pe[1:d_model:2, :, :] = (
65
+ torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
66
+ )
67
+ pe[d_model::2, :, :] = (
68
+ torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
69
+ )
70
+ pe[d_model + 1:: 2, :, :] = (
71
+ torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
72
+ )
73
+
74
+ return pe[None, :].to(device)
75
+
76
+
77
+ def create_sin_embedding_cape(
78
+ length: int,
79
+ dim: int,
80
+ batch_size: int,
81
+ mean_normalize: bool,
82
+ augment: bool, # True during training
83
+ max_global_shift: float = 0.0, # delta max
84
+ max_local_shift: float = 0.0, # epsilon max
85
+ max_scale: float = 1.0,
86
+ device: str = "cpu",
87
+ max_period: float = 10000.0,
88
+ ):
89
+ # We aim for TBC format
90
+ assert dim % 2 == 0
91
+ pos = 1.0 * torch.arange(length).view(-1, 1, 1) # (length, 1, 1)
92
+ pos = pos.repeat(1, batch_size, 1) # (length, batch_size, 1)
93
+ if mean_normalize:
94
+ pos -= torch.nanmean(pos, dim=0, keepdim=True)
95
+
96
+ if augment:
97
+ delta = np.random.uniform(
98
+ -max_global_shift, +max_global_shift, size=[1, batch_size, 1]
99
+ )
100
+ delta_local = np.random.uniform(
101
+ -max_local_shift, +max_local_shift, size=[length, batch_size, 1]
102
+ )
103
+ log_lambdas = np.random.uniform(
104
+ -np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1]
105
+ )
106
+ pos = (pos + delta + delta_local) * np.exp(log_lambdas)
107
+
108
+ pos = pos.to(device)
109
+
110
+ half_dim = dim // 2
111
+ adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
112
+ phase = pos / (max_period ** (adim / (half_dim - 1)))
113
+ return torch.cat(
114
+ [
115
+ torch.cos(phase),
116
+ torch.sin(phase),
117
+ ],
118
+ dim=-1,
119
+ ).float()
120
+
121
+
122
+ def get_causal_mask(length):
123
+ pos = torch.arange(length)
124
+ return pos > pos[:, None]
125
+
126
+
127
+ def get_elementary_mask(
128
+ t1,
129
+ t2,
130
+ mask_type,
131
+ sparse_attn_window,
132
+ global_window,
133
+ mask_random_seed,
134
+ sparsity,
135
+ device,
136
+ ):
137
+ """
138
+ When the input of the Decoder has length T1 and the output T2
139
+ The mask matrix has shape (T2, T1)
140
+ """
141
+ assert mask_type in ["diag", "jmask", "random", "global"]
142
+
143
+ if mask_type == "global":
144
+ mask = torch.zeros(t2, t1, dtype=torch.bool)
145
+ mask[:, :global_window] = True
146
+ line_window = int(global_window * t2 / t1)
147
+ mask[:line_window, :] = True
148
+
149
+ if mask_type == "diag":
150
+
151
+ mask = torch.zeros(t2, t1, dtype=torch.bool)
152
+ rows = torch.arange(t2)[:, None]
153
+ cols = (
154
+ (t1 / t2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1))
155
+ .long()
156
+ .clamp(0, t1 - 1)
157
+ )
158
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
159
+
160
+ elif mask_type == "jmask":
161
+ mask = torch.zeros(t2 + 2, t1 + 2, dtype=torch.bool)
162
+ rows = torch.arange(t2 + 2)[:, None]
163
+ t = torch.arange(0, int((2 * t1) ** 0.5 + 1))
164
+ t = (t * (t + 1) / 2).int()
165
+ t = torch.cat([-t.flip(0)[:-1], t])
166
+ cols = (t1 / t2 * rows + t).long().clamp(0, t1 + 1)
167
+ mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols))
168
+ mask = mask[1:-1, 1:-1]
169
+
170
+ elif mask_type == "random":
171
+ gene = torch.Generator(device=device)
172
+ gene.manual_seed(mask_random_seed)
173
+ mask = (
174
+ torch.rand(t1 * t2, generator=gene, device=device).reshape(t2, t1)
175
+ > sparsity
176
+ )
177
+
178
+ mask = mask.to(device)
179
+ return mask
180
+
181
+
182
+ def get_mask(
183
+ t1,
184
+ t2,
185
+ mask_type,
186
+ sparse_attn_window,
187
+ global_window,
188
+ mask_random_seed,
189
+ sparsity,
190
+ device,
191
+ ):
192
+ """
193
+ Return a SparseCSRTensor mask that is a combination of elementary masks
194
+ mask_type can be a combination of multiple masks: for instance "diag_jmask_random"
195
+ """
196
+ from xformers.sparse import SparseCSRTensor
197
+ # create a list
198
+ mask_types = mask_type.split("_")
199
+
200
+ all_masks = [
201
+ get_elementary_mask(
202
+ t1,
203
+ t2,
204
+ mask,
205
+ sparse_attn_window,
206
+ global_window,
207
+ mask_random_seed,
208
+ sparsity,
209
+ device,
210
+ )
211
+ for mask in mask_types
212
+ ]
213
+
214
+ final_mask = torch.stack(all_masks).sum(axis=0) > 0
215
+
216
+ return SparseCSRTensor.from_dense(final_mask[None])
217
+
218
+
219
+ class ScaledEmbedding(nn.Module):
220
+ def __init__(
221
+ self,
222
+ num_embeddings: int,
223
+ embedding_dim: int,
224
+ scale: float = 1.0,
225
+ boost: float = 3.0,
226
+ ):
227
+ super().__init__()
228
+ self.embedding = nn.Embedding(num_embeddings, embedding_dim)
229
+ self.embedding.weight.data *= scale / boost
230
+ self.boost = boost
231
+
232
+ @property
233
+ def weight(self):
234
+ return self.embedding.weight * self.boost
235
+
236
+ def forward(self, x):
237
+ return self.embedding(x) * self.boost
238
+
239
+
240
+ class LayerScale(nn.Module):
241
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
242
+ This rescales diagonaly residual outputs close to 0 initially, then learnt.
243
+ """
244
+
245
+ def __init__(self, channels: int, init: float = 0, channel_last=False):
246
+ """
247
+ channel_last = False corresponds to (B, C, T) tensors
248
+ channel_last = True corresponds to (T, B, C) tensors
249
+ """
250
+ super().__init__()
251
+ self.channel_last = channel_last
252
+ self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
253
+ self.scale.data[:] = init
254
+
255
+ def forward(self, x):
256
+ if self.channel_last:
257
+ return self.scale * x
258
+ else:
259
+ return self.scale[:, None] * x
260
+
261
+
262
+ class MyGroupNorm(nn.GroupNorm):
263
+ def __init__(self, *args, **kwargs):
264
+ super().__init__(*args, **kwargs)
265
+
266
+ def forward(self, x):
267
+ """
268
+ x: (B, T, C)
269
+ if num_groups=1: Normalisation on all T and C together for each B
270
+ """
271
+ x = x.transpose(1, 2)
272
+ return super().forward(x).transpose(1, 2)
273
+
274
+
275
+ class MyTransformerEncoderLayer(TransformerEncoderLayer):
276
+ def __init__(
277
+ self,
278
+ d_model,
279
+ nhead,
280
+ dim_feedforward=2048,
281
+ dropout=0.1,
282
+ activation=F.relu,
283
+ group_norm=0,
284
+ norm_first=False,
285
+ norm_out=False,
286
+ layer_norm_eps=1e-5,
287
+ layer_scale=False,
288
+ init_values=1e-4,
289
+ device=None,
290
+ dtype=None,
291
+ sparse=False,
292
+ mask_type="diag",
293
+ mask_random_seed=42,
294
+ sparse_attn_window=500,
295
+ global_window=50,
296
+ auto_sparsity=False,
297
+ sparsity=0.95,
298
+ batch_first=False,
299
+ ):
300
+ factory_kwargs = {"device": device, "dtype": dtype}
301
+ super().__init__(
302
+ d_model=d_model,
303
+ nhead=nhead,
304
+ dim_feedforward=dim_feedforward,
305
+ dropout=dropout,
306
+ activation=activation,
307
+ layer_norm_eps=layer_norm_eps,
308
+ batch_first=batch_first,
309
+ norm_first=norm_first,
310
+ device=device,
311
+ dtype=dtype,
312
+ )
313
+ self.sparse = sparse
314
+ self.auto_sparsity = auto_sparsity
315
+ if sparse:
316
+ if not auto_sparsity:
317
+ self.mask_type = mask_type
318
+ self.sparse_attn_window = sparse_attn_window
319
+ self.global_window = global_window
320
+ self.sparsity = sparsity
321
+ if group_norm:
322
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
323
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
324
+
325
+ self.norm_out = None
326
+ if self.norm_first & norm_out:
327
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
328
+ self.gamma_1 = (
329
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
330
+ )
331
+ self.gamma_2 = (
332
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
333
+ )
334
+
335
+ if sparse:
336
+ self.self_attn = MultiheadAttention(
337
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
338
+ auto_sparsity=sparsity if auto_sparsity else 0,
339
+ )
340
+ self.__setattr__("src_mask", torch.zeros(1, 1))
341
+ self.mask_random_seed = mask_random_seed
342
+
343
+ def forward(self, src, src_mask=None, src_key_padding_mask=None):
344
+ """
345
+ if batch_first = False, src shape is (t, b, c)
346
+ the case where batch_first=True is not covered
347
+ """
348
+ device = src.device
349
+ x = src
350
+ t, _, _ = x.shape
351
+ if self.sparse and not self.auto_sparsity:
352
+ assert src_mask is None
353
+ src_mask = self.src_mask
354
+ if src_mask.shape[-1] != t:
355
+ src_mask = get_mask(
356
+ t,
357
+ t,
358
+ self.mask_type,
359
+ self.sparse_attn_window,
360
+ self.global_window,
361
+ self.mask_random_seed,
362
+ self.sparsity,
363
+ device,
364
+ )
365
+ self.__setattr__("src_mask", src_mask)
366
+
367
+ if self.norm_first:
368
+ x = x + self.gamma_1(
369
+ self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
370
+ )
371
+ x = x + self.gamma_2(self._ff_block(self.norm2(x)))
372
+
373
+ if self.norm_out:
374
+ x = self.norm_out(x)
375
+ else:
376
+ x = self.norm1(
377
+ x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask))
378
+ )
379
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
380
+
381
+ return x
382
+
383
+
384
+ class CrossTransformerEncoderLayer(nn.Module):
385
+ def __init__(
386
+ self,
387
+ d_model: int,
388
+ nhead: int,
389
+ dim_feedforward: int = 2048,
390
+ dropout: float = 0.1,
391
+ activation=F.relu,
392
+ layer_norm_eps: float = 1e-5,
393
+ layer_scale: bool = False,
394
+ init_values: float = 1e-4,
395
+ norm_first: bool = False,
396
+ group_norm: bool = False,
397
+ norm_out: bool = False,
398
+ sparse=False,
399
+ mask_type="diag",
400
+ mask_random_seed=42,
401
+ sparse_attn_window=500,
402
+ global_window=50,
403
+ sparsity=0.95,
404
+ auto_sparsity=None,
405
+ device=None,
406
+ dtype=None,
407
+ batch_first=False,
408
+ ):
409
+ factory_kwargs = {"device": device, "dtype": dtype}
410
+ super().__init__()
411
+
412
+ self.sparse = sparse
413
+ self.auto_sparsity = auto_sparsity
414
+ if sparse:
415
+ if not auto_sparsity:
416
+ self.mask_type = mask_type
417
+ self.sparse_attn_window = sparse_attn_window
418
+ self.global_window = global_window
419
+ self.sparsity = sparsity
420
+
421
+ self.cross_attn: nn.Module
422
+ self.cross_attn = MultiheadAttention(
423
+ d_model, nhead, dropout=dropout, batch_first=batch_first)
424
+ # Implementation of Feedforward model
425
+ self.linear1 = Linear(d_model, dim_feedforward, **factory_kwargs)
426
+ self.dropout = nn.Dropout(dropout)
427
+ self.linear2 = Linear(dim_feedforward, d_model, **factory_kwargs)
428
+
429
+ self.norm_first = norm_first
430
+ self.norm1: nn.Module
431
+ self.norm2: nn.Module
432
+ self.norm3: nn.Module
433
+ if group_norm:
434
+ self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
435
+ self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
436
+ self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
437
+ else:
438
+ self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
439
+ self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
440
+ self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
441
+
442
+ self.norm_out = None
443
+ if self.norm_first & norm_out:
444
+ self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
445
+
446
+ self.gamma_1 = (
447
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
448
+ )
449
+ self.gamma_2 = (
450
+ LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
451
+ )
452
+
453
+ self.dropout1 = nn.Dropout(dropout)
454
+ self.dropout2 = nn.Dropout(dropout)
455
+
456
+ # Legacy string support for activation function.
457
+ if isinstance(activation, str):
458
+ self.activation = self._get_activation_fn(activation)
459
+ else:
460
+ self.activation = activation
461
+
462
+ if sparse:
463
+ self.cross_attn = MultiheadAttention(
464
+ d_model, nhead, dropout=dropout, batch_first=batch_first,
465
+ auto_sparsity=sparsity if auto_sparsity else 0)
466
+ if not auto_sparsity:
467
+ self.__setattr__("mask", torch.zeros(1, 1))
468
+ self.mask_random_seed = mask_random_seed
469
+
470
+ def forward(self, q, k, mask=None):
471
+ """
472
+ Args:
473
+ q: tensor of shape (T, B, C)
474
+ k: tensor of shape (S, B, C)
475
+ mask: tensor of shape (T, S)
476
+
477
+ """
478
+ device = q.device
479
+ t, _, _ = q.shape
480
+ s, _, _ = k.shape
481
+ if self.sparse and not self.auto_sparsity:
482
+ assert mask is None
483
+ mask = self.mask
484
+ if mask.shape[-1] != s or mask.shape[-2] != t:
485
+ mask = get_mask(
486
+ s,
487
+ t,
488
+ self.mask_type,
489
+ self.sparse_attn_window,
490
+ self.global_window,
491
+ self.mask_random_seed,
492
+ self.sparsity,
493
+ device,
494
+ )
495
+ self.__setattr__("mask", mask)
496
+
497
+ if self.norm_first:
498
+ x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
499
+ x = x + self.gamma_2(self._ff_block(self.norm3(x)))
500
+ if self.norm_out:
501
+ x = self.norm_out(x)
502
+ else:
503
+ x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
504
+ x = self.norm2(x + self.gamma_2(self._ff_block(x)))
505
+
506
+ return x
507
+
508
+ # self-attention block
509
+ def _ca_block(self, q, k, attn_mask=None):
510
+ x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
511
+ return self.dropout1(x)
512
+
513
+ # feed forward block
514
+ def _ff_block(self, x):
515
+ x = self.linear2(self.dropout(self.activation(self.linear1(x))))
516
+ return self.dropout2(x)
517
+
518
+ def _get_activation_fn(self, activation):
519
+ if activation == "relu":
520
+ return F.relu
521
+ elif activation == "gelu":
522
+ return F.gelu
523
+
524
+ raise RuntimeError("activation should be relu/gelu, not {}".format(activation))
525
+
526
+
527
+ # ----------------- MULTI-BLOCKS MODELS: -----------------------
528
+
529
+
530
+ class CrossTransformerEncoder(nn.Module):
531
+ def __init__(
532
+ self,
533
+ dim: int,
534
+ emb: str = "sin",
535
+ hidden_scale: float = 4.0,
536
+ num_heads: int = 8,
537
+ num_layers: int = 6,
538
+ cross_first: bool = False,
539
+ dropout: float = 0.0,
540
+ max_positions: int = 1000,
541
+ norm_in: bool = True,
542
+ norm_in_group: bool = False,
543
+ group_norm: int = False,
544
+ norm_first: bool = False,
545
+ norm_out: bool = False,
546
+ max_period: float = 10000.0,
547
+ weight_decay: float = 0.0,
548
+ lr: tp.Optional[float] = None,
549
+ layer_scale: bool = False,
550
+ gelu: bool = True,
551
+ sin_random_shift: int = 0,
552
+ weight_pos_embed: float = 1.0,
553
+ cape_mean_normalize: bool = True,
554
+ cape_augment: bool = True,
555
+ cape_glob_loc_scale: list = None,
556
+ sparse_self_attn: bool = False,
557
+ sparse_cross_attn: bool = False,
558
+ mask_type: str = "diag",
559
+ mask_random_seed: int = 42,
560
+ sparse_attn_window: int = 500,
561
+ global_window: int = 50,
562
+ auto_sparsity: bool = False,
563
+ sparsity: float = 0.95,
564
+ ):
565
+ super().__init__()
566
+ """
567
+ """
568
+ assert dim % num_heads == 0
569
+
570
+ hidden_dim = int(dim * hidden_scale)
571
+
572
+ self.num_layers = num_layers
573
+ # classic parity = 1 means that if idx%2 == 1 there is a
574
+ # classical encoder else there is a cross encoder
575
+ self.classic_parity = 1 if cross_first else 0
576
+ self.emb = emb
577
+ self.max_period = max_period
578
+ self.weight_decay = weight_decay
579
+ self.weight_pos_embed = weight_pos_embed
580
+ self.sin_random_shift = sin_random_shift
581
+ if emb == "cape":
582
+ self.cape_mean_normalize = cape_mean_normalize
583
+ self.cape_augment = cape_augment
584
+ if cape_glob_loc_scale is None:
585
+ self.cape_glob_loc_scale = [5000.0, 1.0, 1.4]
586
+ else:
587
+ self.cape_glob_loc_scale = cape_glob_loc_scale
588
+ if emb == "scaled":
589
+ self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
590
+
591
+ self.lr = lr
592
+
593
+ activation: tp.Any = F.gelu if gelu else F.relu
594
+
595
+ self.norm_in: nn.Module
596
+ self.norm_in_t: nn.Module
597
+ if norm_in:
598
+ self.norm_in = LayerNorm(dim)
599
+ self.norm_in_t = LayerNorm(dim)
600
+ elif norm_in_group:
601
+ self.norm_in = MyGroupNorm(int(norm_in_group), dim)
602
+ self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
603
+ else:
604
+ self.norm_in = nn.Identity()
605
+ self.norm_in_t = nn.Identity()
606
+
607
+ # spectrogram layers
608
+ self.layers = nn.ModuleList()
609
+ # temporal layers
610
+ self.layers_t = nn.ModuleList()
611
+
612
+ kwargs_common = {
613
+ "d_model": dim,
614
+ "nhead": num_heads,
615
+ "dim_feedforward": hidden_dim,
616
+ "dropout": dropout,
617
+ "activation": activation,
618
+ "group_norm": group_norm,
619
+ "norm_first": norm_first,
620
+ "norm_out": norm_out,
621
+ "layer_scale": layer_scale,
622
+ "mask_type": mask_type,
623
+ "mask_random_seed": mask_random_seed,
624
+ "sparse_attn_window": sparse_attn_window,
625
+ "global_window": global_window,
626
+ "sparsity": sparsity,
627
+ "auto_sparsity": auto_sparsity,
628
+ "batch_first": True,
629
+ }
630
+
631
+ kwargs_classic_encoder = dict(kwargs_common)
632
+ kwargs_classic_encoder.update({
633
+ "sparse": sparse_self_attn,
634
+ })
635
+ kwargs_cross_encoder = dict(kwargs_common)
636
+ kwargs_cross_encoder.update({
637
+ "sparse": sparse_cross_attn,
638
+ })
639
+
640
+ for idx in range(num_layers):
641
+ if idx % 2 == self.classic_parity:
642
+
643
+ self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
644
+ self.layers_t.append(
645
+ MyTransformerEncoderLayer(**kwargs_classic_encoder)
646
+ )
647
+
648
+ else:
649
+ self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
650
+
651
+ self.layers_t.append(
652
+ CrossTransformerEncoderLayer(**kwargs_cross_encoder)
653
+ )
654
+
655
+ def forward(self, x, xt):
656
+ _, c, fr, t1 = x.shape
657
+ pos_emb_2d = create_2d_sin_embedding(
658
+ c, fr, t1, x.device, self.max_period
659
+ ) # (1, C, Fr, T1)
660
+ pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
661
+ x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
662
+ x = self.norm_in(x)
663
+ x = x + self.weight_pos_embed * pos_emb_2d
664
+
665
+ b, c, t2 = xt.shape
666
+ xt = rearrange(xt, "b c t2 -> b t2 c") # now T2, B, C
667
+ pos_emb = self._get_pos_embedding(t2, b, c, x.device)
668
+ pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
669
+ xt = self.norm_in_t(xt)
670
+ xt = xt + self.weight_pos_embed * pos_emb
671
+
672
+ for idx in range(self.num_layers):
673
+ if idx % 2 == self.classic_parity:
674
+ x = self.layers[idx](x)
675
+ xt = self.layers_t[idx](xt)
676
+ else:
677
+ old_x = x
678
+ x = self.layers[idx](x, xt)
679
+ xt = self.layers_t[idx](xt, old_x)
680
+
681
+ x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=t1)
682
+ xt = rearrange(xt, "b t2 c -> b c t2")
683
+ return x, xt
684
+
685
+ def _get_pos_embedding(self, t, b, c, device):
686
+ if self.emb == "sin":
687
+ shift = random.randrange(self.sin_random_shift + 1)
688
+ pos_emb = create_sin_embedding(
689
+ t, c, shift=shift, device=device, max_period=self.max_period
690
+ )
691
+ elif self.emb == "cape":
692
+ if self.training:
693
+ pos_emb = create_sin_embedding_cape(
694
+ t,
695
+ c,
696
+ b,
697
+ device=device,
698
+ max_period=self.max_period,
699
+ mean_normalize=self.cape_mean_normalize,
700
+ augment=self.cape_augment,
701
+ max_global_shift=self.cape_glob_loc_scale[0],
702
+ max_local_shift=self.cape_glob_loc_scale[1],
703
+ max_scale=self.cape_glob_loc_scale[2],
704
+ )
705
+ else:
706
+ pos_emb = create_sin_embedding_cape(
707
+ t,
708
+ c,
709
+ b,
710
+ device=device,
711
+ max_period=self.max_period,
712
+ mean_normalize=self.cape_mean_normalize,
713
+ augment=False,
714
+ )
715
+
716
+ elif self.emb == "scaled":
717
+ pos = torch.arange(t, device=device)
718
+ pos_emb = self.position_embeddings(pos)[:, None]
719
+
720
+ return pos_emb
721
+
722
+ def make_optim_group(self):
723
+ group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
724
+ if self.lr is not None:
725
+ group["lr"] = self.lr
726
+ return group
727
+
728
+
729
+ def scaled_query_key_softmax(q, k, att_mask):
730
+ from xformers.ops import masked_matmul
731
+ q = q / (k.size(-1)) ** 0.5
732
+ att = masked_matmul(q, k.transpose(-2, -1), att_mask)
733
+ att = torch.nn.functional.softmax(att, -1)
734
+ return att
735
+
736
+
737
+ def scaled_dot_product_attention(q, k, v, att_mask, dropout):
738
+ att = scaled_query_key_softmax(q, k, att_mask=att_mask)
739
+ att = dropout(att)
740
+ y = att @ v
741
+ return y
742
+
743
+
744
+ def _compute_buckets(x, r):
745
+ qq = torch.einsum('btf,bfhi->bhti', x, r)
746
+ qq = torch.cat([qq, -qq], dim=-1)
747
+ buckets = qq.argmax(dim=-1)
748
+
749
+ return buckets.permute(0, 2, 1).byte().contiguous()
750
+
751
+
752
+ def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None):
753
+ # assert False, "The code for the custom sparse kernel is not ready for release yet."
754
+ from xformers.ops import find_locations, sparse_memory_efficient_attention
755
+ n_hashes = 32
756
+ proj_size = 4
757
+ query, key, value = [x.contiguous() for x in [query, key, value]]
758
+ with torch.no_grad():
759
+ r = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device)
760
+ bucket_query = _compute_buckets(query, r)
761
+ bucket_key = _compute_buckets(key, r)
762
+ row_offsets, column_indices = find_locations(
763
+ bucket_query, bucket_key, sparsity, infer_sparsity)
764
+ return sparse_memory_efficient_attention(
765
+ query, key, value, row_offsets, column_indices, attn_bias)
third_party/demucs/models/utils.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : utils.py
5
+ @Time : 2023/8/8 下午4:26
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2023, Tencent
9
+ @Desc : utils
10
+ """
11
+ from contextlib import contextmanager
12
+ import math
13
+ import os
14
+ import tempfile
15
+ import typing as tp
16
+ import json
17
+ import subprocess
18
+
19
+ import torch
20
+ from torch.nn import functional as F
21
+ from torch.utils.data import Subset
22
+
23
+
24
+ def unfold(a, kernel_size, stride):
25
+ """Given input of size [*OT, T], output Tensor of size [*OT, F, K]
26
+ with K the kernel size, by extracting frames with the given stride.
27
+
28
+ This will pad the input so that `F = ceil(T / K)`.
29
+
30
+ see https://github.com/pytorch/pytorch/issues/60466
31
+ """
32
+ *shape, length = a.shape
33
+ n_frames = math.ceil(length / stride)
34
+ tgt_length = (n_frames - 1) * stride + kernel_size
35
+ a = F.pad(a, (0, tgt_length - length))
36
+ strides = list(a.stride())
37
+ assert strides[-1] == 1, 'data should be contiguous'
38
+ strides = strides[:-1] + [stride, 1]
39
+ return a.as_strided([*shape, n_frames, kernel_size], strides)
40
+
41
+
42
+ def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
43
+ """
44
+ Center trim `tensor` with respect to `reference`, along the last dimension.
45
+ `reference` can also be a number, representing the length to trim to.
46
+ If the size difference != 0 mod 2, the extra sample is removed on the right side.
47
+ """
48
+ ref_size: int
49
+ if isinstance(reference, torch.Tensor):
50
+ ref_size = reference.size(-1)
51
+ else:
52
+ ref_size = reference
53
+ delta = tensor.size(-1) - ref_size
54
+ if delta < 0:
55
+ raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.")
56
+ if delta:
57
+ tensor = tensor[..., delta // 2:-(delta - delta // 2)]
58
+ return tensor
59
+
60
+
61
+ def pull_metric(history: tp.List[dict], name: str):
62
+ out = []
63
+ for metrics in history:
64
+ metric = metrics
65
+ for part in name.split("."):
66
+ metric = metric[part]
67
+ out.append(metric)
68
+ return out
69
+
70
+
71
+ def sizeof_fmt(num: float, suffix: str = 'B'):
72
+ """
73
+ Given `num` bytes, return human readable size.
74
+ Taken from https://stackoverflow.com/a/1094933
75
+ """
76
+ for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']:
77
+ if abs(num) < 1024.0:
78
+ return "%3.1f%s%s" % (num, unit, suffix)
79
+ num /= 1024.0
80
+ return "%.1f%s%s" % (num, 'Yi', suffix)
81
+
82
+
83
+ @contextmanager
84
+ def temp_filenames(count: int, delete=True):
85
+ names = []
86
+ try:
87
+ for _ in range(count):
88
+ names.append(tempfile.NamedTemporaryFile(delete=False).name)
89
+ yield names
90
+ finally:
91
+ if delete:
92
+ for name in names:
93
+ os.unlink(name)
94
+
95
+
96
+ def random_subset(dataset, max_samples: int, seed: int = 42):
97
+ if max_samples >= len(dataset):
98
+ return dataset
99
+
100
+ generator = torch.Generator().manual_seed(seed)
101
+ perm = torch.randperm(len(dataset), generator=generator)
102
+ return Subset(dataset, perm[:max_samples].tolist())
103
+
104
+
105
+ class DummyPoolExecutor:
106
+ class DummyResult:
107
+ def __init__(self, func, *args, **kwargs):
108
+ self.func = func
109
+ self.args = args
110
+ self.kwargs = kwargs
111
+
112
+ def result(self):
113
+ return self.func(*self.args, **self.kwargs)
114
+
115
+ def __init__(self, workers=0):
116
+ pass
117
+
118
+ def submit(self, func, *args, **kwargs):
119
+ return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
120
+
121
+ def __enter__(self):
122
+ return self
123
+
124
+ def __exit__(self, exc_type, exc_value, exc_tb):
125
+ return
third_party/demucs/run.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ @File : layers.py
5
+ @Time : 2024/4/22 下午2:40
6
+ @Author : waytan
7
+ @Contact : [email protected]
8
+ @License : (C)Copyright 2024, Tencent
9
+ """
10
+ import os
11
+ import json
12
+ import time
13
+ import logging
14
+ import argparse
15
+ from datetime import datetime
16
+
17
+
18
+ import torch
19
+
20
+ from models.apply import BagOfModels
21
+ from models.pretrained import get_model_from_yaml
22
+
23
+
24
+ class Separator:
25
+ def __init__(self, dm_model_path, dm_config_path, gpu_id=0) -> None:
26
+ if torch.cuda.is_available() and gpu_id < torch.cuda.device_count():
27
+ self.device = torch.device(f"cuda:{gpu_id}")
28
+ else:
29
+ self.device = torch.device("cpu")
30
+ self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path)
31
+
32
+ def init_demucs_model(self, model_path, config_path) -> BagOfModels:
33
+ model = get_model_from_yaml(config_path, model_path)
34
+ model.to(self.device)
35
+ model.eval()
36
+ return model
37
+
38
+ def run(self, audio_path, output_dir, ext=".flac"):
39
+ name, _ = os.path.splitext(os.path.split(audio_path)[-1])
40
+ output_paths = []
41
+ for stem in self.demucs_model.sources:
42
+ output_path = os.path.join(output_dir, f"{name}_{stem}{ext}")
43
+ if os.path.exists(output_path):
44
+ output_paths.append(output_path)
45
+ if len(output_paths) == 4:
46
+ drums_path, bass_path, other_path, vocal_path = output_paths
47
+ else:
48
+ drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device)
49
+ data_dict = {
50
+ "vocal_path": vocal_path,
51
+ "bgm_path": [drums_path, bass_path, other_path]
52
+ }
53
+ return data_dict
54
+
55
+
56
+ def json_io(input_json, output_json, model_dir, dst_dir, gpu_id=0):
57
+ current_datetime = datetime.now()
58
+ current_datetime_str = current_datetime.strftime('%Y-%m-%d-%H:%M')
59
+ logging.basicConfig(filename=os.path.join(dst_dir, f'logger-separate-{os.path.split(input_json)[1]}-{current_datetime_str}.log'), level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')
60
+
61
+ sp = Separator(os.path.join(model_dir, "htdemucs.pth"), os.path.join(model_dir, "htdemucs.yaml"), gpu_id=gpu_id)
62
+ with open(input_json, "r") as fp:
63
+ lines = fp.readlines()
64
+ t1 = time.time()
65
+ success_num = 0
66
+ fail_num = 0
67
+ total_num = len(lines)
68
+ sep_items = []
69
+ for line in lines:
70
+ item = json.loads(line)
71
+ flac_file = item["path"]
72
+ try:
73
+ fix_data = sp.run(flac_file, dst_dir)
74
+ except Exception as e:
75
+ fail_num += 1
76
+ logging.error(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process fail for {str(e)}")
77
+ continue
78
+
79
+ item["vocal_path"] = fix_data["vocal_path"]
80
+ item["bgm_path"] = fix_data["bgm_path"]
81
+ sep_items.append(item)
82
+ success_num += 1
83
+ logging.debug(f"process-{success_num + fail_num}/{total_num}|success-{success_num}|fail-{fail_num}|{item['idx']} process success")
84
+
85
+ with open(output_json, "w", encoding='utf-8') as fw:
86
+ for item in sep_items:
87
+ fw.write(json.dumps(item, ensure_ascii=False) + "\n")
88
+
89
+ t2 = time.time()
90
+ logging.debug(f"total cost {round(t2-t1, 3)}s")
91
+
92
+
93
+ if __name__ == "__main__":
94
+ parser = argparse.ArgumentParser(description='')
95
+ parser.add_argument("-m", dest="model_dir")
96
+ parser.add_argument("-d", dest="dst_dir")
97
+ parser.add_argument("-j", dest="input_json")
98
+ parser.add_argument("-o", dest="output_json")
99
+ parser.add_argument("-gid", dest="gpu_id", default=0, type=int)
100
+ args = parser.parse_args()
101
+
102
+ if not args.dst_dir:
103
+ dst_dir = os.path.join(os.getcwd(), "separate_result")
104
+ os.makedirs(dst_dir, exist_ok=True)
105
+ else:
106
+ dst_dir = os.path.join(args.dst_dir, "separate_result")
107
+ os.makedirs(dst_dir, exist_ok=True)
108
+
109
+ json_io(args.input_json, args.output_json, args.model_dir, dst_dir, gpu_id=args.gpu_id)
third_party/hub/version.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 1
third_party/stable_audio_tools/.gitignore ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
161
+
162
+ *.ckpt
163
+ *.wav
164
+ wandb/*
third_party/stable_audio_tools/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 Stability AI
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/LICENSES/LICENSE_ADP.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 archinet.ai
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/LICENSES/LICENSE_AURALOSS.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
third_party/stable_audio_tools/LICENSES/LICENSE_DESCRIPT.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023-present, Descript
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/LICENSES/LICENSE_META.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) Meta Platforms, Inc. and affiliates.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/LICENSES/LICENSE_NVIDIA.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/LICENSES/LICENSE_XTRANSFORMERS.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Phil Wang
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
third_party/stable_audio_tools/README.md ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stable-audio-tools
2
+ Training and inference code for audio generation models
3
+
4
+ # Install
5
+
6
+ The library can be installed from PyPI with:
7
+ ```bash
8
+ $ pip install stable-audio-tools
9
+ ```
10
+
11
+ To run the training scripts or inference code, you'll want to clone this repository, navigate to the root, and run:
12
+ ```bash
13
+ $ pip install .
14
+ ```
15
+
16
+ # Requirements
17
+ Requires PyTorch 2.0 or later for Flash Attention support
18
+
19
+ Development for the repo is done in Python 3.8.10
20
+
21
+ # Interface
22
+
23
+ A basic Gradio interface is provided to test out trained models.
24
+
25
+ For example, to create an interface for the [`stable-audio-open-1.0`](https://huggingface.co/stabilityai/stable-audio-open-1.0) model, once you've accepted the terms for the model on Hugging Face, you can run:
26
+ ```bash
27
+ $ python3 ./run_gradio.py --pretrained-name stabilityai/stable-audio-open-1.0
28
+ ```
29
+
30
+ The `run_gradio.py` script accepts the following command line arguments:
31
+
32
+ - `--pretrained-name`
33
+ - Hugging Face repository name for a Stable Audio Tools model
34
+ - Will prioritize `model.safetensors` over `model.ckpt` in the repo
35
+ - Optional, used in place of `model-config` and `ckpt-path` when using pre-trained model checkpoints on Hugging Face
36
+ - `--model-config`
37
+ - Path to the model config file for a local model
38
+ - `--ckpt-path`
39
+ - Path to unwrapped model checkpoint file for a local model
40
+ - `--pretransform-ckpt-path`
41
+ - Path to an unwrapped pretransform checkpoint, replaces the pretransform in the model, useful for testing out fine-tuned decoders
42
+ - Optional
43
+ - `--share`
44
+ - If true, a publicly shareable link will be created for the Gradio demo
45
+ - Optional
46
+ - `--username` and `--password`
47
+ - Used together to set a login for the Gradio demo
48
+ - Optional
49
+ - `--model-half`
50
+ - If true, the model weights to half-precision
51
+ - Optional
52
+
53
+ # Training
54
+
55
+ ## Prerequisites
56
+ Before starting your training run, you'll need a model config file, as well as a dataset config file. For more information about those, refer to the Configurations section below
57
+
58
+ The training code also requires a Weights & Biases account to log the training outputs and demos. Create an account and log in with:
59
+ ```bash
60
+ $ wandb login
61
+ ```
62
+
63
+ ## Start training
64
+ To start a training run, run the `train.py` script in the repo root with:
65
+ ```bash
66
+ $ python3 ./train.py --dataset-config /path/to/dataset/config --model-config /path/to/model/config --name harmonai_train
67
+ ```
68
+
69
+ The `--name` parameter will set the project name for your Weights and Biases run.
70
+
71
+ ## Training wrappers and model unwrapping
72
+ `stable-audio-tools` uses PyTorch Lightning to facilitate multi-GPU and multi-node training.
73
+
74
+ When a model is being trained, it is wrapped in a "training wrapper", which is a `pl.LightningModule` that contains all of the relevant objects needed only for training. That includes things like discriminators for autoencoders, EMA copies of models, and all of the optimizer states.
75
+
76
+ The checkpoint files created during training include this training wrapper, which greatly increases the size of the checkpoint file.
77
+
78
+ `unwrap_model.py` in the repo root will take in a wrapped model checkpoint and save a new checkpoint file including only the model itself.
79
+
80
+ That can be run with from the repo root with:
81
+ ```bash
82
+ $ python3 ./unwrap_model.py --model-config /path/to/model/config --ckpt-path /path/to/wrapped/ckpt --name model_unwrap
83
+ ```
84
+
85
+ Unwrapped model checkpoints are required for:
86
+ - Inference scripts
87
+ - Using a model as a pretransform for another model (e.g. using an autoencoder model for latent diffusion)
88
+ - Fine-tuning a pre-trained model with a modified configuration (i.e. partial initialization)
89
+
90
+ ## Fine-tuning
91
+ Fine-tuning a model involves continuning a training run from a pre-trained checkpoint.
92
+
93
+ To continue a training run from a wrapped model checkpoint, you can pass in the checkpoint path to `train.py` with the `--ckpt-path` flag.
94
+
95
+ To start a fresh training run using a pre-trained unwrapped model, you can pass in the unwrapped checkpoint to `train.py` with the `--pretrained-ckpt-path` flag.
96
+
97
+ ## Additional training flags
98
+
99
+ Additional optional flags for `train.py` include:
100
+ - `--config-file`
101
+ - The path to the defaults.ini file in the repo root, required if running `train.py` from a directory other than the repo root
102
+ - `--pretransform-ckpt-path`
103
+ - Used in various model types such as latent diffusion models to load a pre-trained autoencoder. Requires an unwrapped model checkpoint.
104
+ - `--save-dir`
105
+ - The directory in which to save the model checkpoints
106
+ - `--checkpoint-every`
107
+ - The number of steps between saved checkpoints.
108
+ - *Default*: 10000
109
+ - `--batch-size`
110
+ - Number of samples per-GPU during training. Should be set as large as your GPU VRAM will allow.
111
+ - *Default*: 8
112
+ - `--num-gpus`
113
+ - Number of GPUs per-node to use for training
114
+ - *Default*: 1
115
+ - `--num-nodes`
116
+ - Number of GPU nodes being used for training
117
+ - *Default*: 1
118
+ - `--accum-batches`
119
+ - Enables and sets the number of batches for gradient batch accumulation. Useful for increasing effective batch size when training on smaller GPUs.
120
+ - `--strategy`
121
+ - Multi-GPU strategy for distributed training. Setting to `deepspeed` will enable DeepSpeed ZeRO Stage 2.
122
+ - *Default*: `ddp` if `--num_gpus` > 1, else None
123
+ - `--precision`
124
+ - floating-point precision to use during training
125
+ - *Default*: 16
126
+ - `--num-workers`
127
+ - Number of CPU workers used by the data loader
128
+ - `--seed`
129
+ - RNG seed for PyTorch, helps with deterministic training
130
+
131
+ # Configurations
132
+ Training and inference code for `stable-audio-tools` is based around JSON configuration files that define model hyperparameters, training settings, and information about your training dataset.
133
+
134
+ ## Model config
135
+ The model config file defines all of the information needed to load a model for training or inference. It also contains the training configuration needed to fine-tune a model or train from scratch.
136
+
137
+ The following properties are defined in the top level of the model configuration:
138
+
139
+ - `model_type`
140
+ - The type of model being defined, currently limited to one of `"autoencoder", "diffusion_uncond", "diffusion_cond", "diffusion_cond_inpaint", "diffusion_autoencoder", "lm"`.
141
+ - `sample_size`
142
+ - The length of the audio provided to the model during training, in samples. For diffusion models, this is also the raw audio sample length used for inference.
143
+ - `sample_rate`
144
+ - The sample rate of the audio provided to the model during training, and generated during inference, in Hz.
145
+ - `audio_channels`
146
+ - The number of channels of audio provided to the model during training, and generated during inference. Defaults to 2. Set to 1 for mono.
147
+ - `model`
148
+ - The specific configuration for the model being defined, varies based on `model_type`
149
+ - `training`
150
+ - The training configuration for the model, varies based on `model_type`. Provides parameters for training as well as demos.
151
+
152
+ ## Dataset config
153
+ `stable-audio-tools` currently supports two kinds of data sources: local directories of audio files, and WebDataset datasets stored in Amazon S3. More information can be found in [the dataset config documentation](docs/datasets.md)
154
+
155
+ # Todo
156
+ - [ ] Add troubleshooting section
157
+ - [ ] Add contribution guidelines
third_party/stable_audio_tools/config/model_1920.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 403200,
4
+ "sample_rate": 48000,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 6, 10],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 6, 10],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 1920,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 2000
120
+ }
121
+ }
122
+ }
third_party/stable_audio_tools/config/model_config.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 409600,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 8, 8],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 8, 8],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 2048,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 2000
120
+ }
121
+ }
122
+ }
third_party/stable_audio_tools/defaults.ini ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ [DEFAULTS]
3
+
4
+ #name of the run
5
+ name = stable_audio_tools
6
+
7
+ # the batch size
8
+ batch_size = 8
9
+
10
+ # number of GPUs to use for training
11
+ num_gpus = 1
12
+
13
+ # number of nodes to use for training
14
+ num_nodes = 1
15
+
16
+ # Multi-GPU strategy for PyTorch Lightning
17
+ strategy = ""
18
+
19
+ # Precision to use for training
20
+ precision = "16-mixed"
21
+
22
+ # number of CPU workers for the DataLoader
23
+ num_workers = 8
24
+
25
+ # the random seed
26
+ seed = 42
27
+
28
+ # Batches for gradient accumulation
29
+ accum_batches = 1
30
+
31
+ # Number of steps between checkpoints
32
+ checkpoint_every = 10000
33
+
34
+ # trainer checkpoint file to restart training from
35
+ ckpt_path = ''
36
+
37
+ # model checkpoint file to start a new training run from
38
+ pretrained_ckpt_path = ''
39
+
40
+ # Checkpoint path for the pretransform model if needed
41
+ pretransform_ckpt_path = ''
42
+
43
+ # configuration model specifying model hyperparameters
44
+ model_config = ''
45
+
46
+ # configuration for datasets
47
+ dataset_config = ''
48
+
49
+ # directory to save the checkpoints in
50
+ save_dir = ''
51
+
52
+ # gradient_clip_val passed into PyTorch Lightning Trainer
53
+ gradient_clip_val = 0.0
54
+
55
+ # remove the weight norm from the pretransform model
56
+ remove_pretransform_weight_norm = ''
third_party/stable_audio_tools/docs/autoencoders.md ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Autoencoders
2
+ At a high level, autoencoders are models constructed of two parts: an *encoder*, and a *decoder*.
3
+
4
+ The *encoder* takes in an sequence (such as mono or stereo audio) and outputs a compressed representation of that sequence as a d-channel "latent sequence", usually heavily downsampled by a constant factor.
5
+
6
+ The *decoder* takes in a d-channel latent sequence and upsamples it back to the original input sequence length, reversing the compression of the encoder.
7
+
8
+ Autoencoders are trained with a combination of reconstruction and adversarial losses in order to create a compact and invertible representation of raw audio data that allows downstream models to work in a data-compressed "latent space", with various desirable and controllable properties such as reduced sequence length, noise resistance, and discretization.
9
+
10
+ The autoencoder architectures defined in `stable-audio-tools` are largely fully-convolutional, which allows autoencoders trained on small lengths to be applied to arbitrary-length sequences. For example, an autoencoder trained on 1-second samples could be used to encode 45-second inputs to a latent diffusion model.
11
+
12
+ # Model configs
13
+ The model config file for an autoencoder should set the `model_type` to `autoencoder`, and the `model` object should have the following properties:
14
+
15
+ - `encoder`
16
+ - Configuration for the autoencoder's encoder half
17
+ - `decoder`
18
+ - Configuration for the autoencoder's decoder half
19
+ - `latent_dim`
20
+ - Latent dimension of the autoencoder, used by inference scripts and downstream models
21
+ - `downsampling_ratio`
22
+ - Downsampling ratio between the input sequence and the latent sequence, used by inference scripts and downstream models
23
+ - `io_channels`
24
+ - Number of input and output channels for the autoencoder when they're the same, used by inference scripts and downstream models
25
+ - `bottleneck`
26
+ - Configuration for the autoencoder's bottleneck
27
+ - Optional
28
+ - `pretransform`
29
+ - A pretransform definition for the autoencoder, such as wavelet decomposition or another autoencoder
30
+ - See [pretransforms.md](pretransforms.md) for more information
31
+ - Optional
32
+ - `in_channels`
33
+ - Specifies the number of input channels for the autoencoder, when it's different from `io_channels`, such as in a mono-to-stereo model
34
+ - Optional
35
+ - `out_channels`
36
+ - Specifies the number of output channels for the autoencoder, when it's different from `io_channels`
37
+ - Optional
38
+
39
+ # Training configs
40
+ The `training` config in the autoencoder model config file should have the following properties:
41
+ - `learning_rate`
42
+ - The learning rate to use during training
43
+ - `use_ema`
44
+ - If true, a copy of the model weights is maintained during training and updated as an exponential moving average of the trained model's weights.
45
+ - Optional. Default: `false`
46
+ - `warmup_steps`
47
+ - The number of training steps before turning on adversarial losses
48
+ - Optional. Default: `0`
49
+ - `encoder_freeze_on_warmup`
50
+ - If true, freezes the encoder after the warmup steps have completed, so adversarial training only affects the decoder.
51
+ - Optional. Default: `false`
52
+ - `loss_configs`
53
+ - Configurations for the loss function calculation
54
+ - Optional
55
+ - `optimizer_configs`
56
+ - Configuration for optimizers and schedulers
57
+ - Optional
58
+
59
+ ## Loss configs
60
+ There are few different types of losses that are used for autoencoder training, including spectral losses, time-domain losses, adversarial losses, and bottleneck-specific losses.
61
+
62
+ Hyperparameters fo these losses as well as loss weighting factors can be configured in the `loss_configs` property in the `training` config.
63
+
64
+ ### Spectral losses
65
+ Multi-resolution STFT losses are the main reconstruction loss used for our audio autoencoders. We use the [auraloss](https://github.com/csteinmetz1/auraloss/tree/main/auraloss) library for our spectral loss functions.
66
+
67
+ For mono autoencoders (`io_channels` == 1), we use the [MultiResolutionSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L329) module.
68
+
69
+ For stereo autoencoders (`io_channels` == 2), we use the [SumAndDifferenceSTFTLoss](https://github.com/csteinmetz1/auraloss/blob/1576b0cd6e927abc002b23cf3bfc455b660f663c/auraloss/freq.py#L533) module.
70
+
71
+ #### Example config
72
+ ```json
73
+ "spectral": {
74
+ "type": "mrstft",
75
+ "config": {
76
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
77
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
78
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
79
+ "perceptual_weighting": true
80
+ },
81
+ "weights": {
82
+ "mrstft": 1.0
83
+ }
84
+ }
85
+ ```
86
+
87
+ ### Time-domain loss
88
+ We compute the L1 distance between the original audio and the decoded audio to provide a time-domain loss.
89
+
90
+ #### Example config
91
+ ```json
92
+ "time": {
93
+ "type": "l1",
94
+ "weights": {
95
+ "l1": 0.1
96
+ }
97
+ }
98
+ ```
99
+
100
+ ### Adversarial losses
101
+ Adversarial losses bring in an ensemble of discriminator models to discriminate between real and fake audio, providing a signal to the autoencoder on perceptual discrepancies to fix.
102
+
103
+ We largely rely on the [multi-scale STFT discriminator](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/msstftd.py#L99) from the EnCodec repo
104
+
105
+ #### Example config
106
+ ```json
107
+ "discriminator": {
108
+ "type": "encodec",
109
+ "config": {
110
+ "filters": 32,
111
+ "n_ffts": [2048, 1024, 512, 256, 128],
112
+ "hop_lengths": [512, 256, 128, 64, 32],
113
+ "win_lengths": [2048, 1024, 512, 256, 128]
114
+ },
115
+ "weights": {
116
+ "adversarial": 0.1,
117
+ "feature_matching": 5.0
118
+ }
119
+ }
120
+ ```
121
+
122
+ ## Demo config
123
+ The only property to set for autoencoder training demos is the `demo_every` property, determining the number of steps between demos.
124
+
125
+ ### Example config
126
+ ```json
127
+ "demo": {
128
+ "demo_every": 2000
129
+ }
130
+ ```
131
+
132
+ # Encoder and decoder types
133
+ Encoders and decoders are defined separately in the model configuration, so encoders and decoders from different model architectures and libraries can be used interchangeably.
134
+
135
+ ## Oobleck
136
+ Oobleck is Harmonai's in-house autoencoder architecture, implementing features from a variety of other autoencoder architectures.
137
+
138
+ ### Example config
139
+ ```json
140
+ "encoder": {
141
+ "type": "oobleck",
142
+ "config": {
143
+ "in_channels": 2,
144
+ "channels": 128,
145
+ "c_mults": [1, 2, 4, 8],
146
+ "strides": [2, 4, 8, 8],
147
+ "latent_dim": 128,
148
+ "use_snake": true
149
+ }
150
+ },
151
+ "decoder": {
152
+ "type": "oobleck",
153
+ "config": {
154
+ "out_channels": 2,
155
+ "channels": 128,
156
+ "c_mults": [1, 2, 4, 8],
157
+ "strides": [2, 4, 8, 8],
158
+ "latent_dim": 64,
159
+ "use_snake": true,
160
+ "use_nearest_upsample": false
161
+ }
162
+ }
163
+ ```
164
+
165
+ ## DAC
166
+ This is the Encoder and Decoder definitions from the `descript-audio-codec` repo. It's a simple fully-convolutional autoencoder with channels doubling every level. The encoder and decoder configs are passed directly into the constructors for the DAC [Encoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L64) and [Decoder](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/model/dac.py#L115).
167
+
168
+ **Note: This does not include the DAC quantizer, and does not load pre-trained DAC models, this is just the encoder and decoder definitions.**
169
+
170
+ ### Example config
171
+ ```json
172
+ "encoder": {
173
+ "type": "dac",
174
+ "config": {
175
+ "in_channels": 2,
176
+ "latent_dim": 32,
177
+ "d_model": 128,
178
+ "strides": [2, 4, 4, 4]
179
+ }
180
+ },
181
+ "decoder": {
182
+ "type": "dac",
183
+ "config": {
184
+ "out_channels": 2,
185
+ "latent_dim": 32,
186
+ "channels": 1536,
187
+ "rates": [4, 4, 4, 2]
188
+ }
189
+ }
190
+ ```
191
+
192
+ ## SEANet
193
+ This is the SEANetEncoder and SEANetDecoder definitions from Meta's EnCodec repo. This is the same encoder and decoder architecture used in the EnCodec models used in MusicGen, without the quantizer.
194
+
195
+ The encoder and decoder configs are passed directly into the [SEANetEncoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L66C12-L66C12) and [SEANetDecoder](https://github.com/facebookresearch/encodec/blob/0e2d0aed29362c8e8f52494baf3e6f99056b214f/encodec/modules/seanet.py#L147) classes directly, though we reverse the input order of the strides (ratios) in the encoder to make it consistent with the order in the decoder.
196
+
197
+ ### Example config
198
+ ```json
199
+ "encoder": {
200
+ "type": "seanet",
201
+ "config": {
202
+ "channels": 2,
203
+ "dimension": 128,
204
+ "n_filters": 64,
205
+ "ratios": [4, 4, 8, 8],
206
+ "n_residual_layers": 1,
207
+ "dilation_base": 2,
208
+ "lstm": 2,
209
+ "norm": "weight_norm"
210
+ }
211
+ },
212
+ "decoder": {
213
+ "type": "seanet",
214
+ "config": {
215
+ "channels": 2,
216
+ "dimension": 64,
217
+ "n_filters": 64,
218
+ "ratios": [4, 4, 8, 8],
219
+ "n_residual_layers": 1,
220
+ "dilation_base": 2,
221
+ "lstm": 2,
222
+ "norm": "weight_norm"
223
+ }
224
+ },
225
+ ```
226
+
227
+ # Bottlenecks
228
+ In our terminology, the "bottleneck" of an autoencoder is a module placed between the encoder and decoder to enforce particular constraints on the latent space the encoder creates.
229
+
230
+ Bottlenecks have a similar interface to the autoencoder with `encode()` and `decode()` functions defined. Some bottlenecks return extra information in addition to the output latent series, such as quantized token indices, or additional losses to be considered during training.
231
+
232
+ To define a bottleneck for the autoencoder, you can provide the `bottleneck` object in the autoencoder's model configuration, with the following
233
+
234
+ ## VAE
235
+
236
+ The Variational Autoencoder (VAE) bottleneck splits the encoder's output in half along the channel dimension, treats the two halves as the "mean" and "scale" parameters for VAE sampling, and performs the latent sampling. At a basic level, the "scale" values determine the amount of noise to add to the "mean" latents, which creates a noise-resistant latent space where more of the latent space decodes to perceptually "valid" audio. This is particularly helpful for diffusion models where the outpus of the diffusion sampling process leave a bit of Gaussian error noise.
237
+
238
+ **Note: For the VAE bottleneck to work, the output dimension of the encoder must be twice the size of the input dimension for the decoder.**
239
+
240
+ ### Example config
241
+ ```json
242
+ "bottleneck": {
243
+ "type": "vae"
244
+ }
245
+ ```
246
+
247
+ ### Extra info
248
+ The VAE bottleneck also returns a `kl` value in the encoder info. This is the [KL divergence](https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence) between encoded/sampled latent space and a Gaussian distribution. By including this value as a loss value to optimize, we push our latent distribution closer to a normal distribution, potentially trading off reconstruction quality.
249
+
250
+ ### Example loss config
251
+ ```json
252
+ "bottleneck": {
253
+ "type": "kl",
254
+ "weights": {
255
+ "kl": 1e-4
256
+ }
257
+ }
258
+ ```
259
+
260
+ ## Tanh
261
+ This bottleneck applies the tanh function to the latent series, "soft-clipping" the latent values to be between -1 and 1. This is a quick and dirty way to enforce a limit on the variance of the latent space, but training these models can be unstable as it's seemingly easy for the latent space to saturate the values to -1 or 1 and never recover.
262
+
263
+ ### Example config
264
+ ```json
265
+ "bottleneck": {
266
+ "type": "tanh"
267
+ }
268
+ ```
269
+
270
+ ## Wasserstein
271
+ The Wasserstein bottleneck implements the WAE-MMD regularization method from the [Wasserstein Auto-Encoders](https://arxiv.org/abs/1711.01558) paper, calculating the Maximum Mean Discrepancy (MMD) between the latent space and a Gaussian distribution. Including this value as a loss value to optimize leads to a more Gaussian latent space, but does not require stochastic sampling as with a VAE, so the encoder is deterministic.
272
+
273
+ The Wasserstein bottleneck also exposes the `noise_augment_dim` property, which concatenates `noise_augment_dim` channels of Gaussian noise to the latent series before passing into the decoder. This adds some stochasticity to the latents which can be helpful for adversarial training, while keeping the encoder outputs deterministic.
274
+
275
+ **Note: The MMD calculation is very VRAM-intensive for longer sequence lengths, so training a Wasserstein autoencoder is best done on autoencoders with a decent downsampling factor, or on short sequence lengths. For inference, the MMD calculation is disabled.**
276
+
277
+ ### Example config
278
+ ```json
279
+ "bottleneck": {
280
+ "type": "wasserstein"
281
+ }
282
+ ```
283
+
284
+ ### Extra info
285
+ This bottleneck adds the `mmd` value to the encoder info, representing the Maximum Mean Discrepancy.
286
+
287
+ ### Example loss config
288
+ ```json
289
+ "bottleneck": {
290
+ "type": "mmd",
291
+ "weights": {
292
+ "mmd": 100
293
+ }
294
+ }
295
+ ```
296
+
297
+ ## L2 normalization (Spherical autoencoder)
298
+ The L2 normalization bottleneck normalizes the latents across the channel-dimension, projecting the latents to a d-dimensional hypersphere. This acts as a form of latent space normalization.
299
+
300
+
301
+ ### Example config
302
+ ```json
303
+ "bottleneck": {
304
+ "type": "l2_norm"
305
+ }
306
+ ```
307
+
308
+
309
+ ## RVQ
310
+ Residual vector quantization (RVQ) is currently the leading method for learning discrete neural audio codecs (tokenizers for audio). In vector quantization, each item in the latent sequence is individually "snapped" to the nearest vector in a discrete "codebook" of learned vectors. The index of the vector in the codebook can then be used as a token index for things like autoregressive transformers. Residual vector quantization improves the precision of normal vector quantization by adding additional codebooks. For a deeper dive into RVQ, check out [this blog post by Dr. Scott Hawley](https://drscotthawley.github.io/blog/posts/2023-06-12-RVQ.html).
311
+
312
+ This RVQ bottleneck uses [lucidrains' implementation](https://github.com/lucidrains/vector-quantize-pytorch/tree/master) from the `vector-quantize-pytorch` repo, which provides a lot of different quantizer options. The bottleneck config is passed through to the `ResidualVQ` [constructor](https://github.com/lucidrains/vector-quantize-pytorch/blob/0c6cea24ce68510b607f2c9997e766d9d55c085b/vector_quantize_pytorch/residual_vq.py#L26).
313
+
314
+ **Note: This RVQ implementation uses manual replacement of codebook vectors to reduce codebook collapse. This does not work with multi-GPU training as the random replacement is not synchronized across devices.**
315
+
316
+ ### Example config
317
+ ```json
318
+ "bottleneck": {
319
+ "type": "rvq",
320
+ "config": {
321
+ "num_quantizers": 4,
322
+ "codebook_size": 2048,
323
+ "dim": 1024,
324
+ "decay": 0.99,
325
+ }
326
+ }
327
+ ```
328
+
329
+ ## DAC RVQ
330
+ This is the residual vector quantization implementation from the `descript-audio-codec` repo. It differs from the above implementation in that it does not use manual replacements to improve codebook usage, but instead uses learnable linear layers to project the latents down to a lower-dimensional space before performing the individual quantization operations. This means it's compatible with distributed training.
331
+
332
+ The bottleneck config is passed directly into the `ResidualVectorQuantize` [constructor](https://github.com/descriptinc/descript-audio-codec/blob/c7cfc5d2647e26471dc394f95846a0830e7bec34/dac/nn/quantize.py#L97).
333
+
334
+ The `quantize_on_decode` property is also exposed, which moves the quantization process to the decoder. This should not be used during training, but is helpful when training latent diffusion models that use the quantization process as a way to remove error after the diffusion sampling process.
335
+
336
+ ### Example config
337
+ ```json
338
+ "bottleneck": {
339
+ "type": "dac_rvq",
340
+ "config": {
341
+ "input_dim": 64,
342
+ "n_codebooks": 9,
343
+ "codebook_dim": 32,
344
+ "codebook_size": 1024,
345
+ "quantizer_dropout": 0.5
346
+ }
347
+ }
348
+ ```
349
+
350
+ ### Extra info
351
+ The DAC RVQ bottleneck also adds the following properties to the `info` object:
352
+ - `pre_quantizer`
353
+ - The pre-quantization latent series, useful in combination with `quantize_on_decode` for training latent diffusion models.
354
+ - `vq/commitment_loss`
355
+ - Commitment loss for the quantizer
356
+ - `vq/codebook_loss`
357
+ - Codebook loss for the quantizer