LPDoctor commited on
Commit
b5eac81
·
1 Parent(s): ecbda4a

Add ThinkSound module files to repository

Browse files
Files changed (44) hide show
  1. ThinkSound +0 -1
  2. ThinkSound/.DS_Store +0 -0
  3. ThinkSound/__init__.py +2 -0
  4. ThinkSound/__pycache__/__init__.cpython-313.pyc +0 -0
  5. ThinkSound/configs/model_configs/stable_audio_2_0_vae.json +122 -0
  6. ThinkSound/configs/model_configs/thinksound.json +147 -0
  7. ThinkSound/configs/multimodal_dataset_demo.json +53 -0
  8. ThinkSound/data/__init__.py +0 -0
  9. ThinkSound/data/datamodule.py +194 -0
  10. ThinkSound/data/dataset.py +1266 -0
  11. ThinkSound/data/utils.py +378 -0
  12. ThinkSound/inference/__init__.py +0 -0
  13. ThinkSound/inference/generation.py +274 -0
  14. ThinkSound/inference/sampling.py +232 -0
  15. ThinkSound/inference/utils.py +35 -0
  16. ThinkSound/models/__init__.py +1 -0
  17. ThinkSound/models/__pycache__/__init__.cpython-313.pyc +0 -0
  18. ThinkSound/models/__pycache__/factory.cpython-313.pyc +0 -0
  19. ThinkSound/models/__pycache__/pretrained.cpython-313.pyc +0 -0
  20. ThinkSound/models/__pycache__/utils.cpython-313.pyc +0 -0
  21. ThinkSound/models/autoencoders.py +800 -0
  22. ThinkSound/models/blocks.py +430 -0
  23. ThinkSound/models/bottleneck.py +355 -0
  24. ThinkSound/models/codebook_patterns.py +545 -0
  25. ThinkSound/models/conditioners.py +1005 -0
  26. ThinkSound/models/diffusion.py +920 -0
  27. ThinkSound/models/dit.py +439 -0
  28. ThinkSound/models/embeddings.py +85 -0
  29. ThinkSound/models/factory.py +156 -0
  30. ThinkSound/models/local_attention.py +278 -0
  31. ThinkSound/models/mmdit.py +578 -0
  32. ThinkSound/models/pretrained.py +25 -0
  33. ThinkSound/models/pretransforms.py +258 -0
  34. ThinkSound/models/transformer.py +821 -0
  35. ThinkSound/models/transformer_layers.py +271 -0
  36. ThinkSound/models/utils.py +164 -0
  37. ThinkSound/training/__init__.py +1 -0
  38. ThinkSound/training/autoencoders.py +504 -0
  39. ThinkSound/training/diffusion.py +1076 -0
  40. ThinkSound/training/factory.py +262 -0
  41. ThinkSound/training/losses/__init__.py +1 -0
  42. ThinkSound/training/losses/auraloss.py +691 -0
  43. ThinkSound/training/losses/losses.py +100 -0
  44. ThinkSound/training/utils.py +200 -0
ThinkSound DELETED
@@ -1 +0,0 @@
1
- Subproject commit 600962ed922a87bf416a6c152a64a35756c9c97e
 
 
ThinkSound/.DS_Store ADDED
Binary file (6.15 kB). View file
 
ThinkSound/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .models.factory import create_model_from_config, create_model_from_config_path
2
+ from .models.pretrained import get_pretrained_model
ThinkSound/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (326 Bytes). View file
 
ThinkSound/configs/model_configs/stable_audio_2_0_vae.json ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "autoencoder",
3
+ "sample_size": 65536,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "encoder": {
8
+ "type": "oobleck",
9
+ "config": {
10
+ "in_channels": 2,
11
+ "channels": 128,
12
+ "c_mults": [1, 2, 4, 8, 16],
13
+ "strides": [2, 4, 4, 8, 8],
14
+ "latent_dim": 128,
15
+ "use_snake": true
16
+ }
17
+ },
18
+ "decoder": {
19
+ "type": "oobleck",
20
+ "config": {
21
+ "out_channels": 2,
22
+ "channels": 128,
23
+ "c_mults": [1, 2, 4, 8, 16],
24
+ "strides": [2, 4, 4, 8, 8],
25
+ "latent_dim": 64,
26
+ "use_snake": true,
27
+ "final_tanh": false
28
+ }
29
+ },
30
+ "bottleneck": {
31
+ "type": "vae"
32
+ },
33
+ "latent_dim": 64,
34
+ "downsampling_ratio": 2048,
35
+ "io_channels": 2
36
+ },
37
+ "training": {
38
+ "learning_rate": 1.5e-4,
39
+ "warmup_steps": 0,
40
+ "use_ema": true,
41
+ "optimizer_configs": {
42
+ "autoencoder": {
43
+ "optimizer": {
44
+ "type": "AdamW",
45
+ "config": {
46
+ "betas": [0.8, 0.99],
47
+ "lr": 1.5e-4,
48
+ "weight_decay": 1e-3
49
+ }
50
+ },
51
+ "scheduler": {
52
+ "type": "InverseLR",
53
+ "config": {
54
+ "inv_gamma": 200000,
55
+ "power": 0.5,
56
+ "warmup": 0.999
57
+ }
58
+ }
59
+ },
60
+ "discriminator": {
61
+ "optimizer": {
62
+ "type": "AdamW",
63
+ "config": {
64
+ "betas": [0.8, 0.99],
65
+ "lr": 3e-4,
66
+ "weight_decay": 1e-3
67
+ }
68
+ },
69
+ "scheduler": {
70
+ "type": "InverseLR",
71
+ "config": {
72
+ "inv_gamma": 200000,
73
+ "power": 0.5,
74
+ "warmup": 0.999
75
+ }
76
+ }
77
+ }
78
+ },
79
+ "loss_configs": {
80
+ "discriminator": {
81
+ "type": "encodec",
82
+ "config": {
83
+ "filters": 64,
84
+ "n_ffts": [2048, 1024, 512, 256, 128],
85
+ "hop_lengths": [512, 256, 128, 64, 32],
86
+ "win_lengths": [2048, 1024, 512, 256, 128]
87
+ },
88
+ "weights": {
89
+ "adversarial": 0.1,
90
+ "feature_matching": 5.0
91
+ }
92
+ },
93
+ "spectral": {
94
+ "type": "mrstft",
95
+ "config": {
96
+ "fft_sizes": [2048, 1024, 512, 256, 128, 64, 32],
97
+ "hop_sizes": [512, 256, 128, 64, 32, 16, 8],
98
+ "win_lengths": [2048, 1024, 512, 256, 128, 64, 32],
99
+ "perceptual_weighting": true
100
+ },
101
+ "weights": {
102
+ "mrstft": 1.0
103
+ }
104
+ },
105
+ "time": {
106
+ "type": "l1",
107
+ "weights": {
108
+ "l1": 0.0
109
+ }
110
+ },
111
+ "bottleneck": {
112
+ "type": "kl",
113
+ "weights": {
114
+ "kl": 1e-4
115
+ }
116
+ }
117
+ },
118
+ "demo": {
119
+ "demo_every": 10000
120
+ }
121
+ }
122
+ }
ThinkSound/configs/model_configs/thinksound.json ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "mm_diffusion_cond",
3
+ "sample_size": 397312,
4
+ "sample_rate": 44100,
5
+ "audio_channels": 2,
6
+ "model": {
7
+ "pretransform": {
8
+ "type": "autoencoder",
9
+ "iterate_batch": true,
10
+ "config": {
11
+ "encoder": {
12
+ "type": "oobleck",
13
+ "config": {
14
+ "in_channels": 2,
15
+ "channels": 128,
16
+ "c_mults": [1, 2, 4, 8, 16],
17
+ "strides": [2, 4, 4, 8, 8],
18
+ "latent_dim": 128,
19
+ "use_snake": true
20
+ }
21
+ },
22
+ "decoder": {
23
+ "type": "oobleck",
24
+ "config": {
25
+ "out_channels": 2,
26
+ "channels": 128,
27
+ "c_mults": [1, 2, 4, 8, 16],
28
+ "strides": [2, 4, 4, 8, 8],
29
+ "latent_dim": 64,
30
+ "use_snake": true,
31
+ "final_tanh": false
32
+ }
33
+ },
34
+ "bottleneck": {
35
+ "type": "vae"
36
+ },
37
+ "latent_dim": 64,
38
+ "downsampling_ratio": 2048,
39
+ "io_channels": 2
40
+ }
41
+ },
42
+ "conditioning": {
43
+ "configs": [
44
+ {
45
+ "id": "metaclip_features",
46
+ "type": "mm_unchang",
47
+ "config": {
48
+ "dim": 1024,
49
+ "output_dim": 1024
50
+ }
51
+ },
52
+ {
53
+ "id": "metaclip_text_features",
54
+ "type": "mm_unchang",
55
+ "config": {
56
+ "dim": 1024,
57
+ "output_dim": 1024
58
+ }
59
+ },
60
+ {
61
+ "id": "sync_features",
62
+ "type": "mm_unchang",
63
+ "config": {
64
+ "dim": 768,
65
+ "output_dim": 768
66
+ }
67
+ },
68
+ {
69
+ "id": "t5_features",
70
+ "type": "mm_unchang",
71
+ "config": {
72
+ "dim": 2048,
73
+ "output_dim": 2048
74
+ }
75
+ }
76
+ ],
77
+ "cond_dim": 768
78
+ },
79
+ "diffusion": {
80
+ "mm_cond_ids": ["metaclip_features", "sync_features", "metaclip_text_features","t5_features"],
81
+ "type": "mmdit",
82
+ "diffusion_objective": "rectified_flow",
83
+ "config": {
84
+ "latent_dim":64,
85
+ "clip_dim":1024,
86
+ "sync_dim":768,
87
+ "text_dim":2048,
88
+ "hidden_dim":1024,
89
+ "depth":21,
90
+ "fused_depth":14,
91
+ "num_heads":16,
92
+ "latent_seq_len":194,
93
+ "clip_seq_len":72,
94
+ "sync_seq_len":216,
95
+ "v2": true,
96
+ "kernel_size": 3
97
+ }
98
+ },
99
+ "io_channels": 64
100
+ },
101
+ "training": {
102
+ "use_ema": true,
103
+ "log_loss_info": false,
104
+ "cfg_dropout_prob": 0.2,
105
+ "pre_encoded": true,
106
+ "timestep_sampler": "logit_normal",
107
+ "optimizer_configs": {
108
+ "diffusion": {
109
+ "optimizer": {
110
+ "type": "AdamW",
111
+ "config": {
112
+ "lr": 5e-5,
113
+ "betas": [0.9, 0.95],
114
+ "weight_decay": 1e-4,
115
+ "eps": 1e-6
116
+ }
117
+ },
118
+ "scheduler": {
119
+ "type": "InverseLR",
120
+ "config": {
121
+ "inv_gamma": 1000000,
122
+ "power": 0.5,
123
+ "warmup": 0.99
124
+ }
125
+ }
126
+ }
127
+ },
128
+ "demo": {
129
+ "demo_every": 5000,
130
+ "demo_steps": 24,
131
+ "num_demos": 10,
132
+ "demo_cond": [
133
+ "dataset/vggsound/video_latents_t5_clip_npz/test/0Cu33yBwAPg_000060.npz",
134
+ "dataset/vggsound/video_latents_t5_clip_npz/test/bmKtI808DsU_000009.npz",
135
+ "dataset/vggsound/video_latents_t5_clip_npz/test/VC0c22cJTbM_000424.npz",
136
+ "dataset/vggsound/video_latents_t5_clip_npz/test/F3gsbUTdc2U_000090.npz",
137
+ "dataset/vggsound/video_latents_t5_clip_npz/test/WatvT8A8iug_000100.npz",
138
+ "dataset/vggsound/video_latents_t5_clip_npz/test/0nvBTp-q7tU_000112.npz",
139
+ "dataset/vggsound/video_latents_t5_clip_npz/test/3-PFuDkTM48_000080.npz",
140
+ "dataset/vggsound/video_latents_t5_clip_npz/test/luSAuu-BoPs_000232.npz",
141
+ "dataset/vggsound/video_latents_t5_clip_npz/test/__8UJxW0aOQ_000002.npz",
142
+ "dataset/vggsound/video_latents_t5_clip_npz/test/_0m_YMpQayA_000168.npz"
143
+ ],
144
+ "demo_cfg_scales": [5]
145
+ }
146
+ }
147
+ }
ThinkSound/configs/multimodal_dataset_demo.json ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "dataset_type": "multimodal_dir",
3
+ "video_datasets": [
4
+ {
5
+ "id": "vggsound",
6
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/train",
7
+ "split_path": "dataset/vggsound/split_txt/train_cot.txt"
8
+ }
9
+ ],
10
+ "audio_datasets": [
11
+ {
12
+ "id": "audiostock",
13
+ "path": "dataset/Laion-Audio-630k/audiostock_latents_npz",
14
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_audiostock_1.txt"
15
+ },
16
+ {
17
+ "id": "freesound_no_overlap",
18
+ "path": "dataset/Laion-Audio-630k/freesound_no_overlap_latents_npz",
19
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_freesound.txt"
20
+ },
21
+ {
22
+ "id": "audioset_sl",
23
+ "path": "dataset/wavcaps/audioset_sl_latents_npz",
24
+ "split_path": "dataset/wavcaps/split_txt/cot_audio_sl_1.txt"
25
+ },
26
+ {
27
+ "id": "audiocaps",
28
+ "path": "dataset/1_audiocaps/audiocaps_latents_npz",
29
+ "split_path": "dataset/1_audiocaps/split_txt/train_cot.txt"
30
+ },
31
+ {
32
+ "id": "bbc",
33
+ "path": "dataset/Laion-Audio-630k/bbc_latents_npz",
34
+ "split_path": "dataset/Laion-Audio-630k/split_txt/cot_bbc_1.txt"
35
+ }
36
+ ],
37
+ "val_datasets": [
38
+ {
39
+ "id": "vggsound",
40
+ "path": "dataset/vggsound/video_latents_t5_clip_npz/test",
41
+ "split_path": "dataset/vggsound/split_txt/test_cot.txt"
42
+ }
43
+ ],
44
+ "test_datasets": [
45
+ {
46
+ "id": "vggsound",
47
+ "path": "cot_coarse",
48
+ "split_path": "cot_vgg_demo_caption.txt"
49
+ }
50
+ ],
51
+ "random_crop": true,
52
+ "input_type": "prompt"
53
+ }
ThinkSound/data/__init__.py ADDED
File without changes
ThinkSound/data/datamodule.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lightning as L
2
+ from .dataset import LatentDataset, SampleDataset, VideoDataset, AudioDataset, MultiModalDataset, LocalDatasetConfig, collation_fn
3
+ import importlib
4
+ from torch.utils.data import DataLoader
5
+
6
+
7
+ def get_configs(audio_configs):
8
+ configs = []
9
+ for config in audio_configs:
10
+ data_dir_path = config.get("path", None)
11
+ audio_dir_path = config.get("audio_dir", None)
12
+ split_path = config.get("split_path", None)
13
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
14
+
15
+ custom_metadata_fn = None
16
+ custom_metadata_module_path = config.get("custom_metadata_module", None)
17
+
18
+ if custom_metadata_module_path:
19
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
20
+ metadata_module = importlib.util.module_from_spec(spec)
21
+ spec.loader.exec_module(metadata_module)
22
+ custom_metadata_fn = metadata_module.get_custom_metadata
23
+
24
+ configs.append(
25
+ LocalDatasetConfig(
26
+ id=config["id"],
27
+ path=data_dir_path,
28
+ split_path=split_path,
29
+ custom_metadata_fn=custom_metadata_fn,
30
+ audio_dir=audio_dir_path
31
+ )
32
+ )
33
+ return configs
34
+
35
+ class DataModule(L.LightningDataModule):
36
+ def __init__(self, dataset_config, batch_size, test_batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4,repeat_num=5,latent_length=194):
37
+ super().__init__()
38
+ dataset_type = dataset_config.get("dataset_type", None)
39
+ self.batch_size = batch_size
40
+ self.num_workers = num_workers
41
+ self.test_batch_size = test_batch_size
42
+ self.repeat_num = repeat_num
43
+ self.latent_length = latent_length
44
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
45
+
46
+ if audio_channels == 1:
47
+ force_channels = "mono"
48
+ elif audio_channels == 2:
49
+ force_channels = "stereo"
50
+ else:
51
+ force_channels = "foa"
52
+ val_dir_configs = dataset_config.get("val_datasets", None)
53
+ test_dir_configs = dataset_config.get("test_datasets", None)
54
+ configs = []
55
+ val_configs = []
56
+ test_configs = []
57
+ if dataset_type == "audio_dir":
58
+ audio_dir_configs = dataset_config.get("datasets", None)
59
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
60
+ configs = get_configs(audio_dir_configs)
61
+ val_configs = get_configs(val_dir_configs)
62
+ test_configs = get_configs(test_dir_configs)
63
+ elif dataset_type == "latent_dir" or dataset_type == "video_dataset":
64
+ audio_dir_configs = dataset_config.get("datasets", None)
65
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
66
+ for i, dataset in enumerate((audio_dir_configs, val_dir_configs, test_dir_configs)):
67
+ for config in dataset:
68
+ data_dir_path = config.get("path", None)
69
+ audio_dir_path = config.get("audio_dir", None)
70
+ split_path = config.get("split_path", None)
71
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
72
+
73
+ content = LocalDatasetConfig(
74
+ id=config["id"],
75
+ path=data_dir_path,
76
+ split_path=split_path,
77
+ audio_dir=audio_dir_path,
78
+ extra_cot=config.get("extra_cot", None)
79
+ )
80
+ if i == 0:
81
+ configs.append(content)
82
+ elif i == 1:
83
+ val_configs.append(content)
84
+ else:
85
+ test_configs.append(content)
86
+ elif dataset_type == "multimodal_dir":
87
+ self.audio_configs = []
88
+ self.video_configs = []
89
+ audio_dir_configs = dataset_config.get("audio_datasets", None)
90
+ video_dir_configs = dataset_config.get("video_datasets", None)
91
+ assert audio_dir_configs is not None and video_dir_configs is not None, "Directory configuration must be specified in video_datasets and audio_datasets"
92
+ for i, dataset in enumerate((audio_dir_configs, video_dir_configs, val_dir_configs, test_dir_configs)):
93
+ for config in dataset:
94
+ data_dir_path = config.get("path", None)
95
+ audio_dir_path = config.get("audio_dir", None)
96
+ split_path = config.get("split_path", None)
97
+ assert data_dir_path is not None, "Path must be set for local audio directory configuration"
98
+ print(f'extra cot: {config.get("extra_cot", None)}')
99
+ content = LocalDatasetConfig(
100
+ id=config["id"],
101
+ path=data_dir_path,
102
+ split_path=split_path,
103
+ audio_dir=audio_dir_path,
104
+ extra_cot=config.get("extra_cot", None)
105
+ )
106
+ if i == 0:
107
+ self.audio_configs.append(content)
108
+ elif i == 1:
109
+ self.video_configs.append(content)
110
+ elif i == 2:
111
+ val_configs.append(content)
112
+ else:
113
+ test_configs.append(content)
114
+ self.dataset_type = dataset_type
115
+ self.configs = configs
116
+ self.val_configs = val_configs
117
+ self.test_configs = test_configs
118
+ self.sample_rate = sample_rate
119
+ self.sample_size = sample_size
120
+ self.random_crop = dataset_config.get("random_crop", True)
121
+ self.input_type = dataset_config.get("input_type", "video")
122
+ self.fps = dataset_config.get("fps", 4)
123
+ self.force_channels = force_channels
124
+
125
+
126
+ def setup(self, stage: str):
127
+ if self.dataset_type == 'audio_dir':
128
+ dataset_class = SampleDataset
129
+ elif self.dataset_type == 'latent_dir':
130
+ dataset_class = LatentDataset
131
+ elif self.dataset_type == 'video_dataset':
132
+ dataset_class = VideoDataset
133
+ elif self.dataset_type == 'multimodal_dir':
134
+ dataset_class = VideoDataset
135
+
136
+ def create_dataset(configs, random_crop):
137
+ return dataset_class(
138
+ configs,
139
+ sample_rate=self.sample_rate,
140
+ sample_size=self.sample_size,
141
+ random_crop=random_crop,
142
+ input_type=self.input_type,
143
+ fps=self.input_type,
144
+ force_channels=self.force_channels,
145
+ latent_length=self.latent_length
146
+ )
147
+
148
+ if stage == 'fit':
149
+ if self.dataset_type != 'multimodal_dir':
150
+ self.train_set = create_dataset(self.configs, random_crop=self.random_crop)
151
+ else:
152
+ self.video_set = VideoDataset(
153
+ self.video_configs,
154
+ sample_rate=self.sample_rate,
155
+ sample_size=self.sample_size,
156
+ random_crop=self.random_crop,
157
+ input_type=self.input_type,
158
+ fps=self.input_type,
159
+ force_channels=self.force_channels
160
+ )
161
+ self.audio_set = AudioDataset(
162
+ self.audio_configs,
163
+ sample_rate=self.sample_rate,
164
+ sample_size=self.sample_size,
165
+ random_crop=self.random_crop,
166
+ input_type=self.input_type,
167
+ fps=self.input_type,
168
+ force_channels=self.force_channels
169
+ )
170
+ self.train_set = MultiModalDataset([self.video_set]*self.repeat_num, [self.audio_set])
171
+ self.val_set = create_dataset(self.val_configs, random_crop=False)
172
+ elif stage == 'validate':
173
+ self.val_set = create_dataset(self.val_configs, random_crop=False)
174
+ elif stage == 'predict':
175
+ self.test_set = create_dataset(self.test_configs, random_crop=False)
176
+
177
+ def train_dataloader(self):
178
+ return DataLoader(self.train_set, self.batch_size, shuffle=True,
179
+ num_workers=self.num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
180
+
181
+ def val_dataloader(self):
182
+ return DataLoader(self.val_set, self.batch_size, shuffle=False,
183
+ num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
184
+
185
+ def predict_dataloader(self):
186
+ return DataLoader(self.test_set, batch_size=self.test_batch_size, shuffle=False,
187
+ num_workers=self.num_workers, persistent_workers=False, pin_memory=False, drop_last=False, collate_fn=collation_fn)
188
+
189
+ # def predict_dataloader(self):
190
+ # return DataLoader(self.mnist_predict, batch_size=self.batch_size)
191
+
192
+ # def teardown(self, stage: str):
193
+ # # Used to clean-up when the run is finished
194
+ # ...
ThinkSound/data/dataset.py ADDED
@@ -0,0 +1,1266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib
2
+ import numpy as np
3
+ import io
4
+ import os
5
+ import posixpath
6
+ import random
7
+ import re
8
+ import subprocess
9
+ import time
10
+ import torch
11
+ import torchaudio
12
+ import webdataset as wds
13
+ import pandas as pd
14
+ from aeiou.core import is_silence
15
+ from os import path
16
+ from pathlib import Path
17
+ from pedalboard.io import AudioFile
18
+ from torchaudio import transforms as T
19
+ from typing import Optional, Callable, List
20
+ import bisect
21
+
22
+ from .utils import FOA, Stereo, Mono, PhaseFlipper, PadCrop_Normalized_T, PadCrop_Video_Normalized_T, PadCrop_Video_Hiera_Normalized_T, PadCrop_Video_Image_Normalized_T, PadCrop_DualVideo_Normalized_T
23
+
24
+ AUDIO_KEYS = ("flac", "wav", "mp3", "m4a", "ogg", "opus")
25
+
26
+ # fast_scandir implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
27
+
28
+ def fast_scandir(
29
+ dir:str, # top-level directory at which to begin scanning
30
+ ext:list, # list of allowed file extensions,
31
+ #max_size = 1 * 1000 * 1000 * 1000 # Only files < 1 GB
32
+ ):
33
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
34
+ subfolders, files = [], []
35
+ ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed
36
+ try: # hope to avoid 'permission denied' by this try
37
+ for f in os.scandir(dir):
38
+ try: # 'hope to avoid too many levels of symbolic links' error
39
+ if f.is_dir():
40
+ subfolders.append(f.path)
41
+ elif f.is_file():
42
+ file_ext = os.path.splitext(f.name)[1].lower()
43
+ is_hidden = os.path.basename(f.path).startswith(".")
44
+
45
+ if file_ext in ext and not is_hidden:
46
+ files.append(f.path)
47
+ except:
48
+ pass
49
+ except:
50
+ pass
51
+
52
+ for dir in list(subfolders):
53
+ sf, f = fast_scandir(dir, ext)
54
+ subfolders.extend(sf)
55
+ files.extend(f)
56
+ return subfolders, files
57
+
58
+ def keyword_scandir(
59
+ dir: str, # top-level directory at which to begin scanning
60
+ ext: list, # list of allowed file extensions
61
+ keywords: list, # list of keywords to search for in the file name
62
+ ):
63
+ "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243"
64
+ subfolders, files = [], []
65
+ # make keywords case insensitive
66
+ keywords = [keyword.lower() for keyword in keywords]
67
+ # add starting period to extensions if needed
68
+ ext = ['.'+x if x[0] != '.' else x for x in ext]
69
+ banned_words = ["paxheader", "__macosx"]
70
+ try: # hope to avoid 'permission denied' by this try
71
+ for f in os.scandir(dir):
72
+ try: # 'hope to avoid too many levels of symbolic links' error
73
+ if f.is_dir():
74
+ subfolders.append(f.path)
75
+ elif f.is_file():
76
+ is_hidden = f.name.split("/")[-1][0] == '.'
77
+ has_ext = os.path.splitext(f.name)[1].lower() in ext
78
+ name_lower = f.name.lower()
79
+ has_keyword = any(
80
+ [keyword in name_lower for keyword in keywords])
81
+ has_banned = any(
82
+ [banned_word in name_lower for banned_word in banned_words])
83
+ if has_ext and has_keyword and not has_banned and not is_hidden and not os.path.basename(f.path).startswith("._"):
84
+ files.append(f.path)
85
+ except:
86
+ pass
87
+ except:
88
+ pass
89
+
90
+ for dir in list(subfolders):
91
+ sf, f = keyword_scandir(dir, ext, keywords)
92
+ subfolders.extend(sf)
93
+ files.extend(f)
94
+ return subfolders, files
95
+
96
+ def get_audio_filenames(
97
+ paths: list, # directories in which to search
98
+ keywords=None,
99
+ exts=['.wav', '.mp3', '.flac', '.ogg', '.aif', '.opus']
100
+ ):
101
+ "recursively get a list of audio filenames"
102
+ filenames = []
103
+ if type(paths) is str:
104
+ paths = [paths]
105
+ for path in paths: # get a list of relevant filenames
106
+ if keywords is not None:
107
+ subfolders, files = keyword_scandir(path, exts, keywords)
108
+ else:
109
+ subfolders, files = fast_scandir(path, exts)
110
+ filenames.extend(files)
111
+ return filenames
112
+
113
+ class LocalDatasetConfig:
114
+ def __init__(
115
+ self,
116
+ id: str,
117
+ path: str,
118
+ split_path: str,
119
+ audio_dir: str = None,
120
+ extra_cot: str = None,
121
+ custom_metadata_fn: Optional[Callable[[str], str]] = None
122
+ ):
123
+ self.id = id
124
+ self.path = path
125
+ self.split_path = split_path
126
+ self.audio_dir = audio_dir
127
+ self.custom_metadata_fn = custom_metadata_fn
128
+ self.extra_cot = extra_cot
129
+ class SampleDataset(torch.utils.data.Dataset):
130
+ def __init__(
131
+ self,
132
+ configs,
133
+ sample_size=65536,
134
+ sample_rate=48000,
135
+ keywords=None,
136
+ random_crop=True,
137
+ input_type="prompt",
138
+ fps=4,
139
+ force_channels="stereo"
140
+ ):
141
+ super().__init__()
142
+ self.filenames = []
143
+
144
+ self.augs = torch.nn.Sequential(
145
+ PhaseFlipper(),
146
+ )
147
+
148
+ self.root_paths = []
149
+ if input_type == 'video':
150
+ self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
151
+ elif input_type == 'video_hiera':
152
+ self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
153
+ elif input_type == 'video_image':
154
+ self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
155
+ elif input_type == 'dual_video':
156
+ self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
157
+ else:
158
+ self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
159
+
160
+ self.force_channels = force_channels
161
+ print('######################')
162
+ print(f'input channels is: {force_channels}')
163
+ print('######################')
164
+ self.encoding = torch.nn.Sequential(
165
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
166
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
167
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
168
+ )
169
+ self.input_type = input_type
170
+ self.sr = sample_rate
171
+ self.custom_metadata_fns = {}
172
+
173
+ for config in configs:
174
+ self.root_paths.append(config.path)
175
+ def add_prefix(s):
176
+ return str(os.path.join(config.path,f'{s.strip()}'))
177
+ with open(config.split_path,'r') as f:
178
+ item_names = f.readlines()
179
+ filenames = list(map(add_prefix, item_names))
180
+ self.filenames.extend(filenames)
181
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
182
+ if config.custom_metadata_fn is not None:
183
+ self.custom_metadata_fns[config.path] = config.custom_metadata_fn
184
+
185
+ print(f'Found {len(self.filenames)} files')
186
+
187
+ def load_file(self, filename):
188
+ ext = filename.split(".")[-1]
189
+ if ext == "mp3":
190
+ with AudioFile(filename) as f:
191
+ audio = f.read(f.frames)
192
+ audio = torch.from_numpy(audio)
193
+ in_sr = f.samplerate
194
+ else:
195
+ audio, in_sr = torchaudio.load(filename, format=ext)
196
+
197
+ if in_sr != self.sr:
198
+ try:
199
+ resample_tf = T.Resample(in_sr, self.sr)
200
+ audio = resample_tf(audio)
201
+ except:
202
+ print(f'{filename} resample errors')
203
+
204
+ assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
205
+ return audio
206
+
207
+ def __len__(self):
208
+ return len(self.filenames)
209
+
210
+ def __getitem__(self, idx):
211
+ audio_filename = self.filenames[idx]
212
+ assert os.path.exists(audio_filename), f'{audio_filename}: file not exists'
213
+ try:
214
+ start_time = time.time()
215
+ audio = self.load_file(audio_filename)
216
+ info = {}
217
+ info["path"] = audio_filename
218
+
219
+ for root_path in self.root_paths:
220
+ if root_path in audio_filename:
221
+ info["relpath"] = path.relpath(audio_filename, root_path)
222
+
223
+
224
+ for custom_md_path in self.custom_metadata_fns.keys():
225
+ if custom_md_path in audio_filename:
226
+ custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
227
+ custom_metadata = custom_metadata_fn(info, audio)
228
+ info.update(custom_metadata)
229
+
230
+ if "__reject__" in info and info["__reject__"]:
231
+ return self[random.randrange(len(self))]
232
+ if self.input_type == 'video':
233
+ audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'])
234
+ info['video'] = video
235
+ elif self.input_type == 'dual_video':
236
+ audio, video_360, video_fov, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['video'], info['video_fov'])
237
+ info['video_360'] = video_360
238
+ info['video_fov'] = video_fov
239
+ else:
240
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio)
241
+ assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
242
+ # Run augmentations on this sample (including random crop)
243
+ if self.augs is not None:
244
+ audio = self.augs(audio)
245
+
246
+ audio = audio.clamp(-1, 1)
247
+
248
+ # Encode the file to assist in prediction
249
+ if self.encoding is not None:
250
+ audio = self.encoding(audio)
251
+
252
+
253
+
254
+ info["timestamps"] = (t_start, t_end)
255
+ info["seconds_start"] = seconds_start
256
+ info["seconds_total"] = seconds_total
257
+ info["padding_mask"] = padding_mask
258
+
259
+ end_time = time.time()
260
+ info["load_time"] = end_time - start_time
261
+
262
+
263
+ return (audio, info)
264
+ except Exception as e:
265
+ print(f'Couldn\'t load file {audio_filename}: {e}')
266
+ return self[random.randrange(len(self))]
267
+
268
+ class LatentDataset(torch.utils.data.Dataset):
269
+ def __init__(
270
+ self,
271
+ configs,
272
+ sample_size=65536,
273
+ sample_rate=48000,
274
+ keywords=None,
275
+ random_crop=True,
276
+ input_type="prompt",
277
+ fps=4,
278
+ force_channels="stereo"
279
+ ):
280
+ super().__init__()
281
+ self.filenames = []
282
+
283
+ self.augs = torch.nn.Sequential(
284
+ PhaseFlipper(),
285
+ )
286
+
287
+ self.root_paths = []
288
+
289
+ self.force_channels = force_channels
290
+ print('######################')
291
+ print(f'input channels is: {force_channels}')
292
+ print('######################')
293
+ self.encoding = torch.nn.Sequential(
294
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
295
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
296
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
297
+ )
298
+ self.input_type = input_type
299
+ self.sr = sample_rate
300
+ for config in configs:
301
+ self.root_paths.append(config.path)
302
+ def add_prefix(s):
303
+ return str(os.path.join(config.path,f'{s.strip()}'))
304
+ with open(config.split_path,'r') as f:
305
+ item_names = f.readlines()
306
+ filenames = list(map(add_prefix, item_names))
307
+ self.filenames.extend(filenames)
308
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
309
+
310
+
311
+ print(f'Found {len(self.filenames)} files')
312
+
313
+ def load_file(self, filename, info):
314
+ # try:
315
+ npz_file = filename.replace('.pth','.npz')
316
+ if os.path.exists(filename) and '.npz' not in filename:
317
+ data = torch.load(filename, weights_only=False)
318
+ elif os.path.exists(npz_file):
319
+ # print(filename)
320
+ npz_data = np.load(npz_file,allow_pickle=True)
321
+ data = {key: npz_data[key] for key in npz_data.files}
322
+ # print("data.keys()",data.keys())
323
+ for key in data.keys():
324
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
325
+ data[key] = torch.from_numpy(data[key])
326
+ else:
327
+ raise ValueError(f'error load file: {filename}')
328
+ info.update(data)
329
+ audio = data['latent']
330
+ # except:
331
+ # print(f'error load file: {filename}')
332
+ return audio, info['metaclip_features']
333
+
334
+ def __len__(self):
335
+ return len(self.filenames)
336
+
337
+ def __getitem__(self, idx):
338
+ audio_filename = self.filenames[idx]
339
+ assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists'
340
+ # try:
341
+ start_time = time.time()
342
+ info = {}
343
+ audio, video = self.load_file(audio_filename, info)
344
+ info["path"] = audio_filename
345
+
346
+ info['id'] = Path(audio_filename).stem
347
+ for root_path in self.root_paths:
348
+ if root_path in audio_filename:
349
+ info["relpath"] = path.relpath(audio_filename, root_path)
350
+
351
+ return (audio, info)
352
+
353
+ class AudioDataset(torch.utils.data.Dataset):
354
+ def __init__(
355
+ self,
356
+ configs,
357
+ sample_size=65536,
358
+ sample_rate=48000,
359
+ keywords=None,
360
+ random_crop=True,
361
+ input_type="prompt",
362
+ fps=4,
363
+ force_channels="stereo"
364
+ ):
365
+ super().__init__()
366
+ self.filenames = []
367
+
368
+ self.augs = torch.nn.Sequential(
369
+ PhaseFlipper(),
370
+ )
371
+
372
+ self.root_paths = []
373
+
374
+ self.force_channels = force_channels
375
+ print('######################')
376
+ print(f'input channels is: {force_channels}')
377
+ print('######################')
378
+ self.encoding = torch.nn.Sequential(
379
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
380
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
381
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
382
+ )
383
+ self.fake_clip_features = torch.zeros(72, 1024)
384
+ self.fake_sync_features = torch.zeros(216, 768)
385
+ self.video_exist = torch.tensor(0, dtype=torch.bool)
386
+ self.input_type = input_type
387
+ self.sr = sample_rate
388
+ for config in configs:
389
+ self.root_paths.append(config.path)
390
+ def add_prefix(s):
391
+ return str(os.path.join(config.path,f'{s.strip()}'))
392
+ with open(config.split_path,'r') as f:
393
+ item_names = f.readlines()
394
+ filenames = list(map(add_prefix, item_names))
395
+ self.filenames.extend(filenames)
396
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
397
+
398
+
399
+ print(f'Found {len(self.filenames)} files')
400
+
401
+ def load_file(self, filename, info):
402
+ # try:
403
+ npz_file = filename.replace('.pth','.npz')
404
+ if os.path.exists(filename) and '.npz' not in filename:
405
+ data = torch.load(filename, weights_only=False)
406
+ elif os.path.exists(npz_file):
407
+ # print(filename)
408
+ npz_data = np.load(npz_file,allow_pickle=True)
409
+ data = {key: npz_data[key] for key in npz_data.files}
410
+ # print("data.keys()",data.keys())
411
+ for key in data.keys():
412
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
413
+ data[key] = torch.from_numpy(data[key])
414
+ else:
415
+ raise ValueError(f'error load file: {filename}')
416
+ info.update(data)
417
+ audio = data['latent']
418
+ info['metaclip_features'] = self.fake_clip_features
419
+ info['sync_features'] = self.fake_sync_features
420
+ info['video_exist'] = self.video_exist
421
+ # except:
422
+ # print(f'error load file: {filename}')
423
+ return audio, info['metaclip_features']
424
+
425
+ def __len__(self):
426
+ return len(self.filenames)
427
+
428
+ def __getitem__(self, idx):
429
+ audio_filename = self.filenames[idx]
430
+ assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists'
431
+ # try:
432
+ start_time = time.time()
433
+ info = {}
434
+ audio, video = self.load_file(audio_filename, info)
435
+ info["path"] = audio_filename
436
+
437
+ info['id'] = Path(audio_filename).stem
438
+ for root_path in self.root_paths:
439
+ if root_path in audio_filename:
440
+ info["relpath"] = path.relpath(audio_filename, root_path)
441
+
442
+ return (audio, info)
443
+
444
+ class VideoDataset(torch.utils.data.Dataset):
445
+ def __init__(
446
+ self,
447
+ configs,
448
+ sample_size=65536,
449
+ sample_rate=48000,
450
+ keywords=None,
451
+ random_crop=True,
452
+ input_type="prompt",
453
+ fps=4,
454
+ force_channels="stereo",
455
+ latent_length=194, # default latent length for video dataset
456
+ ):
457
+ self.latent_length = latent_length
458
+ super().__init__()
459
+ self.filenames = []
460
+ print(f'configs: {configs[0]}')
461
+ if configs[0].extra_cot is not None:
462
+ self.extra_cot = configs[0].extra_cot
463
+ print(f'load extra cot from {self.extra_cot}')
464
+ else:
465
+ self.extra_cot = None
466
+ self.augs = torch.nn.Sequential(
467
+ PhaseFlipper(),
468
+ )
469
+
470
+ self.root_paths = []
471
+
472
+ self.force_channels = force_channels
473
+ print('######################')
474
+ print(f'input channels is: {force_channels}')
475
+ print('######################')
476
+ self.encoding = torch.nn.Sequential(
477
+ FOA() if self.force_channels == "foa" else torch.nn.Identity(),
478
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
479
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
480
+ )
481
+ self.input_type = input_type
482
+ self.sr = sample_rate
483
+ self.video_exist = torch.tensor(1, dtype=torch.bool)
484
+ for config in configs:
485
+ self.root_paths.append(config.path)
486
+ def add_prefix(s):
487
+ return str(os.path.join(config.path,f'{s.strip()}'))
488
+ with open(config.split_path,'r') as f:
489
+ item_names = f.readlines()
490
+ filenames = list(map(add_prefix, item_names))
491
+ self.filenames.extend(filenames)
492
+ # self.filenames.extend(get_audio_filenames(config.path, keywords))
493
+
494
+
495
+ print(f'Found {len(self.filenames)} files')
496
+
497
+ def load_file(self, filename, info):
498
+ # try:
499
+ npz_file = filename.replace('.pth','.npz')
500
+ if os.path.exists(filename) and '.npz' not in filename:
501
+ data = torch.load(filename, weights_only=False)
502
+ elif os.path.exists(npz_file):
503
+ # print(filename)
504
+ npz_data = np.load(npz_file,allow_pickle=True)
505
+ data = {key: npz_data[key] for key in npz_data.files}
506
+ # print("data.keys()",data.keys())
507
+ for key in data.keys():
508
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
509
+ data[key] = torch.from_numpy(data[key])
510
+ if self.extra_cot is not None:
511
+ extra_pth = filename.replace('.npz','.pth')
512
+ extra_pth = os.path.join(self.extra_cot, os.path.basename(extra_pth))
513
+ if os.path.exists(extra_pth):
514
+ extra_data = torch.load(extra_pth, weights_only=False)
515
+ for key in extra_data.keys():
516
+ if isinstance(extra_data[key], torch.Tensor):
517
+ # print(f'load extra cot {key}')
518
+ data[key] = extra_data[key]
519
+ else:
520
+ raise ValueError(f'error load file: {filename}')
521
+ info.update(data)
522
+ if 'latent' in data.keys():
523
+ audio = data['latent']
524
+ else:
525
+ audio = torch.zeros(64,self.latent_length)
526
+ info['video_exist'] = self.video_exist
527
+ # except:
528
+ # print(f'error load file: {filename}')
529
+ return audio, info['metaclip_features']
530
+
531
+ def __len__(self):
532
+ return len(self.filenames)
533
+
534
+ def __getitem__(self, idx):
535
+ audio_filename = self.filenames[idx]
536
+ assert os.path.exists(audio_filename) or audio_filename.replace('.pth','.npz'), f'{audio_filename}: file not exists'
537
+ # try:
538
+ start_time = time.time()
539
+ info = {}
540
+ audio, video = self.load_file(audio_filename, info)
541
+ info["path"] = audio_filename
542
+
543
+ info['id'] = Path(audio_filename).stem
544
+ for root_path in self.root_paths:
545
+ if root_path in audio_filename:
546
+ info["relpath"] = path.relpath(audio_filename, root_path)
547
+
548
+ return (audio, info)
549
+
550
+ # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
551
+ class MultiModalDataset(torch.utils.data.Dataset):
552
+ datasets: list[torch.utils.data.Dataset]
553
+ cumulative_sizes: list[int]
554
+
555
+ @staticmethod
556
+ def cumsum(sequence):
557
+ r, s = [], 0
558
+ for e in sequence:
559
+ l = len(e)
560
+ r.append(l + s)
561
+ s += l
562
+ return r
563
+
564
+ def __init__(self, video_datasets: list[torch.utils.data.Dataset], audio_datasets: list[torch.utils.data.Dataset]):
565
+ super().__init__()
566
+ self.video_datasets = list(video_datasets)
567
+ self.audio_datasets = list(audio_datasets)
568
+ self.datasets = self.video_datasets + self.audio_datasets
569
+
570
+ self.cumulative_sizes = self.cumsum(self.datasets)
571
+ print(f'Found {self.cumulative_sizes[-1]} files')
572
+
573
+ def __len__(self):
574
+ return self.cumulative_sizes[-1]
575
+
576
+ def __getitem__(self, idx):
577
+ if idx < 0:
578
+ if -idx > len(self):
579
+ raise ValueError("absolute value of index should not exceed dataset length")
580
+ idx = len(self) + idx
581
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
582
+ if dataset_idx == 0:
583
+ sample_idx = idx
584
+ else:
585
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
586
+ return self.datasets[dataset_idx][sample_idx]
587
+
588
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
589
+ return self.video_datasets[0].compute_latent_stats()
590
+
591
+
592
+ # class MultiModalDataset(torch.utils.data.Dataset):
593
+ # def __init__(
594
+ # self,
595
+ # configs,
596
+ # sample_size=65536,
597
+ # sample_rate=48000,
598
+ # keywords=None,
599
+ # random_crop=True,
600
+ # input_type="prompt",
601
+ # fps=4,
602
+ # force_channels="stereo"
603
+ # ):
604
+ # super().__init__()
605
+ # self.filenames = []
606
+ # self.captions = []
607
+ # self.caption_t5s = []
608
+ # self.ids = []
609
+ # self.augs = torch.nn.Sequential(
610
+ # PhaseFlipper(),
611
+ # )
612
+
613
+ # self.root_paths = []
614
+ # if input_type == 'video':
615
+ # self.pad_crop = PadCrop_Video_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
616
+ # elif input_type == 'video_hiera':
617
+ # self.pad_crop = PadCrop_Video_Hiera_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
618
+ # elif input_type == 'video_image':
619
+ # self.pad_crop = PadCrop_Video_Image_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
620
+ # elif input_type == 'dual_video':
621
+ # self.pad_crop = PadCrop_DualVideo_Normalized_T(sample_size, sample_rate, fps, randomize=random_crop)
622
+ # else:
623
+ # self.pad_crop = PadCrop_Normalized_T(sample_size, sample_rate, randomize=random_crop)
624
+
625
+ # self.force_channels = force_channels
626
+ # print('######################')
627
+ # print(f'input channels is: {force_channels}')
628
+ # print('######################')
629
+ # self.encoding = torch.nn.Sequential(
630
+ # FOA() if self.force_channels == "foa" else torch.nn.Identity(),
631
+ # Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
632
+ # Mono() if self.force_channels == "mono" else torch.nn.Identity(),
633
+ # )
634
+ # self.input_type = input_type
635
+ # self.sr = sample_rate
636
+ # self.custom_metadata_fns = {}
637
+
638
+ # for config in configs:
639
+ # print(config.split_path)
640
+ # self.root_paths.append(config.path)
641
+ # def add_prefix(s):
642
+ # return str(os.path.join(config.path,f'{s.strip()}'))
643
+ # with open(config.split_path,'r') as f:
644
+ # item_names = f.readlines()
645
+ # csv_path = config.split_path.replace('.txt','.csv')
646
+ # df = pd.read_csv(csv_path)
647
+ # # 检查是否存在 'caption_t5' 列,如果不存在则创建并复制 'caption' 的值
648
+ # if 'caption_t5' not in df.columns:
649
+ # df['caption_t5'] = df['caption']
650
+
651
+ # captions = df['caption'].tolist()
652
+ # caption_t5s = df['caption_t5'].tolist()
653
+ # filenames = list(map(add_prefix, item_names))
654
+ # assert len(captions) == len(caption_t5s) and len(captions) == len(filenames), f'{config.path} has wrong filename and caption'
655
+ # if config.id == 'vggsound':
656
+ # self.filenames.extend(filenames*5)
657
+ # self.captions.extend(captions*5)
658
+ # self.caption_t5s.extend(caption_t5s*5)
659
+ # self.ids.extend(df['id'].tolist()*5)
660
+ # else:
661
+ # self.filenames.extend(filenames)
662
+ # self.captions.extend(captions)
663
+ # self.caption_t5s.extend(caption_t5s)
664
+ # self.ids.extend(df['id'].tolist())
665
+ # # self.filenames.extend(get_audio_filenames(config.path, keywords))
666
+ # if config.custom_metadata_fn is not None:
667
+ # self.custom_metadata_fns[config.path] = config.custom_metadata_fn
668
+
669
+ # assert len(self.ids) == len(self.captions) and len(self.caption_t5s) == len(self.filenames), 'length need to be same'
670
+ # print(f'Found {len(self.filenames)} files')
671
+
672
+
673
+ # def load_file(self, filename):
674
+ # ext = filename.split(".")[-1]
675
+ # if ext == "mp3":
676
+ # with AudioFile(filename) as f:
677
+ # audio = f.read(f.frames)
678
+ # audio = torch.from_numpy(audio)
679
+ # in_sr = f.samplerate
680
+ # else:
681
+ # audio, in_sr = torchaudio.load(filename, format=ext)
682
+
683
+ # if in_sr != self.sr:
684
+ # try:
685
+ # resample_tf = T.Resample(in_sr, self.sr)
686
+ # audio = resample_tf(audio)
687
+ # except:
688
+ # print(f'{filename} resample errors')
689
+
690
+ # assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
691
+ # return audio
692
+
693
+ # def __len__(self):
694
+ # return len(self.filenames)
695
+
696
+ # def __getitem__(self, idx):
697
+ # audio_filename = self.filenames[idx]
698
+ # id = self.ids[idx]
699
+ # assert str(id) == str(Path(audio_filename).stem), f'audio_file: {audio_filename} needs to be same as {id} '
700
+ # assert os.path.exists(audio_filename), f'{audio_filename}: file not exists'
701
+ # try:
702
+ # start_time = time.time()
703
+ # audio = self.load_file(audio_filename)
704
+ # caption = self.captions[idx]
705
+ # caption_t5 = self.caption_t5s[idx]
706
+ # if pd.isna(caption_t5) or caption_t5 == '':
707
+ # caption_t5 = caption
708
+ # info = {}
709
+ # info["path"] = audio_filename
710
+ # info['caption'] = caption
711
+ # info['caption_t5'] = caption_t5
712
+
713
+ # for root_path in self.root_paths:
714
+ # if root_path in audio_filename:
715
+ # info["relpath"] = path.relpath(audio_filename, root_path)
716
+
717
+
718
+ # for custom_md_path in self.custom_metadata_fns.keys():
719
+ # if custom_md_path in audio_filename:
720
+ # custom_metadata_fn = self.custom_metadata_fns[custom_md_path]
721
+ # custom_metadata = custom_metadata_fn(info, audio)
722
+ # info.update(custom_metadata)
723
+
724
+ # if "__reject__" in info and info["__reject__"]:
725
+ # return self[random.randrange(len(self))]
726
+ # # if self.input_type == 'video':
727
+ # # audio, video, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio, info['clip_features'])
728
+ # # info['clip_features'] = video
729
+ # # else:
730
+ # if info['flag']:
731
+ # audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=False)
732
+ # else:
733
+ # audio, t_start, t_end, seconds_start, seconds_total, padding_mask = self.pad_crop(audio,randomize=True)
734
+ # assert not (torch.isnan(audio).any() or torch.isinf(audio).any()), f'file-{filename} contains nan or inf number, check it!'
735
+ # # Run augmentations on this sample (including random crop)
736
+ # if self.augs is not None:
737
+ # audio = self.augs(audio)
738
+
739
+ # audio = audio.clamp(-1, 1)
740
+
741
+ # # Encode the file to assist in prediction
742
+ # if self.encoding is not None:
743
+ # audio = self.encoding(audio)
744
+
745
+
746
+
747
+ # info["timestamps"] = (t_start, t_end)
748
+ # info["seconds_start"] = seconds_start
749
+ # info["seconds_total"] = seconds_total
750
+ # info["padding_mask"] = padding_mask
751
+
752
+ # end_time = time.time()
753
+ # info["load_time"] = end_time - start_time
754
+
755
+
756
+ # return (audio, info)
757
+ # except Exception as e:
758
+ # print(f'Couldn\'t load file {audio_filename}: {e}')
759
+ # return self[random.randrange(len(self))]
760
+
761
+ def group_by_keys(data, keys=wds.tariterators.base_plus_ext, lcase=True, suffixes=None, handler=None):
762
+ """Return function over iterator that groups key, value pairs into samples.
763
+ :param keys: function that splits the key into key and extension (base_plus_ext)
764
+ :param lcase: convert suffixes to lower case (Default value = True)
765
+ """
766
+ current_sample = None
767
+ for filesample in data:
768
+ assert isinstance(filesample, dict)
769
+ fname, value = filesample["fname"], filesample["data"]
770
+ prefix, suffix = keys(fname)
771
+ if wds.tariterators.trace:
772
+ print(
773
+ prefix,
774
+ suffix,
775
+ current_sample.keys() if isinstance(current_sample, dict) else None,
776
+ )
777
+ if prefix is None:
778
+ continue
779
+ if lcase:
780
+ suffix = suffix.lower()
781
+ if current_sample is None or prefix != current_sample["__key__"]:
782
+ if wds.tariterators.valid_sample(current_sample):
783
+ yield current_sample
784
+ current_sample = dict(__key__=prefix, __url__=filesample["__url__"])
785
+ if suffix in current_sample:
786
+ print(f"{fname}: duplicate file name in tar file {suffix} {current_sample.keys()}")
787
+ if suffixes is None or suffix in suffixes:
788
+ current_sample[suffix] = value
789
+ if wds.tariterators.valid_sample(current_sample):
790
+ yield current_sample
791
+
792
+ wds.tariterators.group_by_keys = group_by_keys
793
+
794
+ # S3 code and WDS preprocessing code based on implementation by Scott Hawley originally in https://github.com/zqevans/audio-diffusion/blob/main/dataset/dataset.py
795
+
796
+ def get_s3_contents(dataset_path, s3_url_prefix=None, filter='', recursive=True, debug=False, profile=None):
797
+ """
798
+ Returns a list of full S3 paths to files in a given S3 bucket and directory path.
799
+ """
800
+ # Ensure dataset_path ends with a trailing slash
801
+ if dataset_path != '' and not dataset_path.endswith('/'):
802
+ dataset_path += '/'
803
+ # Use posixpath to construct the S3 URL path
804
+ bucket_path = posixpath.join(s3_url_prefix or '', dataset_path)
805
+ # Construct the `aws s3 ls` command
806
+ cmd = ['aws', 's3', 'ls', bucket_path]
807
+
808
+ if profile is not None:
809
+ cmd.extend(['--profile', profile])
810
+
811
+ if recursive:
812
+ # Add the --recursive flag if requested
813
+ cmd.append('--recursive')
814
+
815
+ # Run the `aws s3 ls` command and capture the output
816
+ run_ls = subprocess.run(cmd, capture_output=True, check=True)
817
+ # Split the output into lines and strip whitespace from each line
818
+ contents = run_ls.stdout.decode('utf-8').split('\n')
819
+ contents = [x.strip() for x in contents if x]
820
+ # Remove the timestamp from lines that begin with a timestamp
821
+ contents = [re.sub(r'^\S+\s+\S+\s+\d+\s+', '', x)
822
+ if re.match(r'^\S+\s+\S+\s+\d+\s+', x) else x for x in contents]
823
+ # Construct a full S3 path for each file in the contents list
824
+ contents = [posixpath.join(s3_url_prefix or '', x)
825
+ for x in contents if not x.endswith('/')]
826
+ # Apply the filter, if specified
827
+ if filter:
828
+ contents = [x for x in contents if filter in x]
829
+ # Remove redundant directory names in the S3 URL
830
+ if recursive:
831
+ # Get the main directory name from the S3 URL
832
+ main_dir = "/".join(bucket_path.split('/')[3:])
833
+ # Remove the redundant directory names from each file path
834
+ contents = [x.replace(f'{main_dir}', '').replace(
835
+ '//', '/') for x in contents]
836
+ # Print debugging information, if requested
837
+ if debug:
838
+ print("contents = \n", contents)
839
+ # Return the list of S3 paths to files
840
+ return contents
841
+
842
+
843
+ def get_all_s3_urls(
844
+ names=[], # list of all valid [LAION AudioDataset] dataset names
845
+ # list of subsets you want from those datasets, e.g. ['train','valid']
846
+ subsets=[''],
847
+ s3_url_prefix=None, # prefix for those dataset names
848
+ recursive=True, # recursively list all tar files in all subdirs
849
+ filter_str='tar', # only grab files with this substring
850
+ # print debugging info -- note: info displayed likely to change at dev's whims
851
+ debug=False,
852
+ profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'}
853
+ ):
854
+ "get urls of shards (tar files) for multiple datasets in one s3 bucket"
855
+ urls = []
856
+ for name in names:
857
+ # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list
858
+ if s3_url_prefix is None:
859
+ contents_str = name
860
+ else:
861
+ # Construct the S3 path using the s3_url_prefix and the current name value
862
+ contents_str = posixpath.join(s3_url_prefix, name)
863
+ if debug:
864
+ print(f"get_all_s3_urls: {contents_str}:")
865
+ for subset in subsets:
866
+ subset_str = posixpath.join(contents_str, subset)
867
+ if debug:
868
+ print(f"subset_str = {subset_str}")
869
+ # Get the list of tar files in the current subset directory
870
+ profile = profiles.get(name, None)
871
+ tar_list = get_s3_contents(
872
+ subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile)
873
+ for tar in tar_list:
874
+ # Escape spaces and parentheses in the tar filename for use in the shell command
875
+ tar = tar.replace(" ", "\ ").replace(
876
+ "(", "\(").replace(")", "\)")
877
+ # Construct the S3 path to the current tar file
878
+ s3_path = posixpath.join(name, subset, tar) + " -"
879
+ # Construct the AWS CLI command to download the current tar file
880
+ if s3_url_prefix is None:
881
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}"
882
+ else:
883
+ request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}"
884
+ if profiles.get(name):
885
+ request_str += f" --profile {profiles.get(name)}"
886
+ if debug:
887
+ print("request_str = ", request_str)
888
+ # Add the constructed URL to the list of URLs
889
+ urls.append(request_str)
890
+ return urls
891
+
892
+
893
+ def log_and_continue(exn):
894
+ """Call in an exception handler to ignore any exception, isssue a warning, and continue."""
895
+ print(f"Handling webdataset error ({repr(exn)}). Ignoring.")
896
+ return True
897
+
898
+
899
+ def is_valid_sample(sample):
900
+ has_json = "json" in sample
901
+ has_audio = "audio" in sample
902
+ is_silent = is_silence(sample["audio"])
903
+ is_rejected = "__reject__" in sample["json"] and sample["json"]["__reject__"]
904
+
905
+ return has_json and has_audio and not is_silent and not is_rejected
906
+
907
+ class S3DatasetConfig:
908
+ def __init__(
909
+ self,
910
+ id: str,
911
+ s3_path: str,
912
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
913
+ profile: Optional[str] = None,
914
+ ):
915
+ self.id = id
916
+ self.path = s3_path
917
+ self.custom_metadata_fn = custom_metadata_fn
918
+ self.profile = profile
919
+ self.urls = []
920
+
921
+ def load_data_urls(self):
922
+ self.urls = get_all_s3_urls(
923
+ names=[self.path],
924
+ s3_url_prefix=None,
925
+ recursive=True,
926
+ profiles={self.path: self.profile} if self.profile else {},
927
+ )
928
+
929
+ return self.urls
930
+
931
+ class LocalWebDatasetConfig:
932
+ def __init__(
933
+ self,
934
+ id: str,
935
+ path: str,
936
+ custom_metadata_fn: Optional[Callable[[str], str]] = None,
937
+ profile: Optional[str] = None,
938
+ ):
939
+ self.id = id
940
+ self.path = path
941
+ self.custom_metadata_fn = custom_metadata_fn
942
+ self.urls = []
943
+
944
+ def load_data_urls(self):
945
+
946
+ self.urls = fast_scandir(self.path, ["tar"])[1]
947
+
948
+ return self.urls
949
+
950
+ def audio_decoder(key, value):
951
+ # Get file extension from key
952
+ ext = key.split(".")[-1]
953
+
954
+ if ext in AUDIO_KEYS:
955
+ return torchaudio.load(io.BytesIO(value))
956
+ else:
957
+ return None
958
+
959
+ def collation_fn(samples):
960
+ batched = list(zip(*samples))
961
+ result = []
962
+ for b in batched:
963
+ if isinstance(b[0], (int, float)):
964
+ b = np.array(b)
965
+ elif isinstance(b[0], torch.Tensor):
966
+ b = torch.stack(b)
967
+ elif isinstance(b[0], np.ndarray):
968
+ b = np.array(b)
969
+ else:
970
+ b = b
971
+ result.append(b)
972
+ return result
973
+
974
+ class WebDatasetDataLoader():
975
+ def __init__(
976
+ self,
977
+ datasets: List[S3DatasetConfig],
978
+ batch_size,
979
+ sample_size,
980
+ sample_rate=48000,
981
+ num_workers=8,
982
+ epoch_steps=1000,
983
+ random_crop=True,
984
+ force_channels="stereo",
985
+ augment_phase=True,
986
+ **data_loader_kwargs
987
+ ):
988
+
989
+ self.datasets = datasets
990
+
991
+ self.sample_size = sample_size
992
+ self.sample_rate = sample_rate
993
+ self.random_crop = random_crop
994
+ self.force_channels = force_channels
995
+ self.augment_phase = augment_phase
996
+
997
+ urls = [dataset.load_data_urls() for dataset in datasets]
998
+
999
+ # Flatten the list of lists of URLs
1000
+ urls = [url for dataset_urls in urls for url in dataset_urls]
1001
+
1002
+ # Shuffle the urls
1003
+ random.shuffle(urls)
1004
+
1005
+ self.dataset = wds.DataPipeline(
1006
+ wds.ResampledShards(urls),
1007
+ wds.tarfile_to_samples(handler=log_and_continue),
1008
+ wds.decode(audio_decoder, handler=log_and_continue),
1009
+ wds.map(self.wds_preprocess, handler=log_and_continue),
1010
+ wds.select(is_valid_sample),
1011
+ wds.to_tuple("audio", "json", handler=log_and_continue),
1012
+ #wds.shuffle(bufsize=1000, initial=5000),
1013
+ wds.batched(batch_size, partial=False, collation_fn=collation_fn),
1014
+ ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps)
1015
+
1016
+ self.data_loader = wds.WebLoader(self.dataset, num_workers=num_workers, **data_loader_kwargs)
1017
+
1018
+ def wds_preprocess(self, sample):
1019
+
1020
+ found_key, rewrite_key = '', ''
1021
+ for k, v in sample.items(): # print the all entries in dict
1022
+ for akey in AUDIO_KEYS:
1023
+ if k.endswith(akey):
1024
+ # to rename long/weird key with its simpler counterpart
1025
+ found_key, rewrite_key = k, akey
1026
+ break
1027
+ if '' != found_key:
1028
+ break
1029
+ if '' == found_key: # got no audio!
1030
+ return None # try returning None to tell WebDataset to skip this one
1031
+
1032
+ audio, in_sr = sample[found_key]
1033
+ if in_sr != self.sample_rate:
1034
+ resample_tf = T.Resample(in_sr, self.sample_rate)
1035
+ audio = resample_tf(audio)
1036
+
1037
+ if self.sample_size is not None:
1038
+ # Pad/crop and get the relative timestamp
1039
+ pad_crop = PadCrop_Normalized_T(
1040
+ self.sample_size, randomize=self.random_crop, sample_rate=self.sample_rate)
1041
+ audio, t_start, t_end, seconds_start, seconds_total, padding_mask = pad_crop(
1042
+ audio)
1043
+ sample["json"]["seconds_start"] = seconds_start
1044
+ sample["json"]["seconds_total"] = seconds_total
1045
+ sample["json"]["padding_mask"] = padding_mask
1046
+ else:
1047
+ t_start, t_end = 0, 1
1048
+
1049
+ # Check if audio is length zero, initialize to a single zero if so
1050
+ if audio.shape[-1] == 0:
1051
+ audio = torch.zeros(1, 1)
1052
+
1053
+ # Make the audio stereo and augment by randomly inverting phase
1054
+ augs = torch.nn.Sequential(
1055
+ Stereo() if self.force_channels == "stereo" else torch.nn.Identity(),
1056
+ Mono() if self.force_channels == "mono" else torch.nn.Identity(),
1057
+ PhaseFlipper() if self.augment_phase else torch.nn.Identity()
1058
+ )
1059
+
1060
+ audio = augs(audio)
1061
+
1062
+ sample["json"]["timestamps"] = (t_start, t_end)
1063
+
1064
+ if "text" in sample["json"]:
1065
+ sample["json"]["prompt"] = sample["json"]["text"]
1066
+
1067
+ # Check for custom metadata functions
1068
+ for dataset in self.datasets:
1069
+ if dataset.custom_metadata_fn is None:
1070
+ continue
1071
+
1072
+ if dataset.path in sample["__url__"]:
1073
+ custom_metadata = dataset.custom_metadata_fn(sample["json"], audio)
1074
+ sample["json"].update(custom_metadata)
1075
+
1076
+ if found_key != rewrite_key: # rename long/weird key with its simpler counterpart
1077
+ del sample[found_key]
1078
+
1079
+ sample["audio"] = audio
1080
+
1081
+ # Add audio to the metadata as well for conditioning
1082
+ sample["json"]["audio"] = audio
1083
+
1084
+ return sample
1085
+
1086
+ def create_dataloader_from_config(dataset_config, batch_size, sample_size, sample_rate, audio_channels=2, num_workers=4):
1087
+
1088
+ dataset_type = dataset_config.get("dataset_type", None)
1089
+
1090
+ assert dataset_type is not None, "Dataset type must be specified in dataset config"
1091
+
1092
+ if audio_channels == 1:
1093
+ force_channels = "mono"
1094
+ elif audio_channels == 2:
1095
+ force_channels = "stereo"
1096
+ else:
1097
+ force_channels = "foa"
1098
+
1099
+ if dataset_type == "audio_dir":
1100
+
1101
+ audio_dir_configs = dataset_config.get("datasets", None)
1102
+
1103
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
1104
+
1105
+ configs = []
1106
+
1107
+ for audio_dir_config in audio_dir_configs:
1108
+ audio_dir_path = audio_dir_config.get("path", None)
1109
+ split_path = audio_dir_config.get("split_path", None)
1110
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
1111
+ custom_metadata_fn = None
1112
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
1113
+
1114
+ if custom_metadata_module_path is not None:
1115
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
1116
+ metadata_module = importlib.util.module_from_spec(spec)
1117
+ spec.loader.exec_module(metadata_module)
1118
+
1119
+ custom_metadata_fn = metadata_module.get_custom_metadata
1120
+
1121
+ configs.append(
1122
+ LocalDatasetConfig(
1123
+ id=audio_dir_config["id"],
1124
+ path=audio_dir_path,
1125
+ split_path=split_path,
1126
+ custom_metadata_fn=custom_metadata_fn
1127
+ )
1128
+ )
1129
+
1130
+ train_set = SampleDataset(
1131
+ configs,
1132
+ sample_rate=sample_rate,
1133
+ sample_size=sample_size,
1134
+ random_crop=dataset_config.get("random_crop", True),
1135
+ input_type=dataset_config.get("input_type", "video"),
1136
+ fps=dataset_config.get("fps", 4),
1137
+ force_channels=force_channels
1138
+ )
1139
+
1140
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
1141
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
1142
+
1143
+ elif dataset_type in ["s3", "wds"]: # Support "s3" type for backwards compatibility
1144
+
1145
+ wds_configs = []
1146
+
1147
+ for wds_config in dataset_config["datasets"]:
1148
+
1149
+ custom_metadata_fn = None
1150
+ custom_metadata_module_path = wds_config.get("custom_metadata_module", None)
1151
+
1152
+ if custom_metadata_module_path is not None:
1153
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
1154
+ metadata_module = importlib.util.module_from_spec(spec)
1155
+ spec.loader.exec_module(metadata_module)
1156
+
1157
+ custom_metadata_fn = metadata_module.get_custom_metadata
1158
+
1159
+ if "s3_path" in wds_config:
1160
+
1161
+ wds_configs.append(
1162
+ S3DatasetConfig(
1163
+ id=wds_config["id"],
1164
+ s3_path=wds_config["s3_path"],
1165
+ custom_metadata_fn=custom_metadata_fn,
1166
+ profile=wds_config.get("profile", None),
1167
+ )
1168
+ )
1169
+
1170
+ elif "path" in wds_config:
1171
+
1172
+ wds_configs.append(
1173
+ LocalWebDatasetConfig(
1174
+ id=wds_config["id"],
1175
+ path=wds_config["path"],
1176
+ custom_metadata_fn=custom_metadata_fn
1177
+ )
1178
+ )
1179
+
1180
+ return WebDatasetDataLoader(
1181
+ wds_configs,
1182
+ sample_rate=sample_rate,
1183
+ sample_size=sample_size,
1184
+ batch_size=batch_size,
1185
+ random_crop=dataset_config.get("random_crop", True),
1186
+ num_workers=num_workers,
1187
+ persistent_workers=True,
1188
+ force_channels=force_channels,
1189
+ epoch_steps=dataset_config.get("epoch_steps", 2000)
1190
+ ).data_loader
1191
+
1192
+ elif dataset_type == "latent_dir":
1193
+
1194
+ audio_dir_configs = dataset_config.get("datasets", None)
1195
+
1196
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
1197
+
1198
+ configs = []
1199
+
1200
+ for audio_dir_config in audio_dir_configs:
1201
+ audio_dir_path = audio_dir_config.get("path", None)
1202
+ split_path = audio_dir_config.get("split_path", None)
1203
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
1204
+
1205
+ configs.append(
1206
+ LocalDatasetConfig(
1207
+ id=audio_dir_config["id"],
1208
+ path=audio_dir_path,
1209
+ split_path=split_path,
1210
+ )
1211
+ )
1212
+
1213
+ train_set = LatentDataset(
1214
+ configs,
1215
+ sample_rate=sample_rate,
1216
+ sample_size=sample_size,
1217
+ random_crop=dataset_config.get("random_crop", True),
1218
+ input_type=dataset_config.get("input_type", "video"),
1219
+ fps=dataset_config.get("fps", 4),
1220
+ force_channels=force_channels
1221
+ )
1222
+
1223
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
1224
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
1225
+ elif dataset_type == 'multimodal_dir':
1226
+ audio_dir_configs = dataset_config.get("datasets", None)
1227
+
1228
+ assert audio_dir_configs is not None, "Directory configuration must be specified in datasets[\"dataset\"]"
1229
+
1230
+ configs = []
1231
+
1232
+ for audio_dir_config in audio_dir_configs:
1233
+ audio_dir_path = audio_dir_config.get("path", None)
1234
+ split_path = audio_dir_config.get("split_path", None)
1235
+ assert audio_dir_path is not None, "Path must be set for local audio directory configuration"
1236
+ custom_metadata_fn = None
1237
+ custom_metadata_module_path = audio_dir_config.get("custom_metadata_module", None)
1238
+
1239
+ if custom_metadata_module_path is not None:
1240
+ spec = importlib.util.spec_from_file_location("metadata_module", custom_metadata_module_path)
1241
+ metadata_module = importlib.util.module_from_spec(spec)
1242
+ spec.loader.exec_module(metadata_module)
1243
+
1244
+ custom_metadata_fn = metadata_module.get_custom_metadata
1245
+
1246
+ configs.append(
1247
+ LocalDatasetConfig(
1248
+ id=audio_dir_config["id"],
1249
+ path=audio_dir_path,
1250
+ split_path=split_path,
1251
+ custom_metadata_fn=custom_metadata_fn
1252
+ )
1253
+ )
1254
+
1255
+ train_set = MultiModalDataset(
1256
+ configs,
1257
+ sample_rate=sample_rate,
1258
+ sample_size=sample_size,
1259
+ random_crop=dataset_config.get("random_crop", True),
1260
+ input_type=dataset_config.get("input_type", "video"),
1261
+ fps=dataset_config.get("fps", 4),
1262
+ force_channels=force_channels
1263
+ )
1264
+
1265
+ return torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
1266
+ num_workers=num_workers, persistent_workers=True, pin_memory=True, drop_last=True, collate_fn=collation_fn)
ThinkSound/data/utils.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import random
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from typing import Tuple
7
+ import numpy as np
8
+
9
+ class PadCrop(nn.Module):
10
+ def __init__(self, n_samples, randomize=True):
11
+ super().__init__()
12
+ self.n_samples = n_samples
13
+ self.randomize = randomize
14
+
15
+ def __call__(self, signal):
16
+ n, s = signal.shape
17
+ start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item()
18
+ end = start + self.n_samples
19
+ output = signal.new_zeros([n, self.n_samples])
20
+ output[:, :min(s, self.n_samples)] = signal[:, start:end]
21
+ return output
22
+
23
+ class PadCrop_Normalized_T(nn.Module):
24
+
25
+ def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True):
26
+
27
+ super().__init__()
28
+
29
+ self.n_samples = n_samples
30
+ self.sample_rate = sample_rate
31
+ self.randomize = randomize
32
+
33
+ def __call__(self, source: torch.Tensor, randomize=True) -> Tuple[torch.Tensor, float, float, int, int]:
34
+
35
+ n_channels, n_samples = source.shape
36
+
37
+ # If the audio is shorter than the desired length, pad it
38
+ upper_bound = max(0, n_samples - self.n_samples)
39
+
40
+ # If randomize is False, always start at the beginning of the audio
41
+ offset = 0
42
+ if(randomize and n_samples > self.n_samples):
43
+ offset = random.randint(0, upper_bound)
44
+
45
+ # Calculate the start and end times of the chunk
46
+ t_start = offset / (upper_bound + self.n_samples)
47
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
48
+
49
+ # Create the chunk
50
+ chunk = source.new_zeros([n_channels, self.n_samples])
51
+
52
+ # Copy the audio into the chunk
53
+ chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples]
54
+
55
+ # Calculate the start and end times of the chunk in seconds
56
+ seconds_start = math.floor(offset / self.sample_rate)
57
+ seconds_total = math.ceil(n_samples / self.sample_rate)
58
+
59
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
60
+ padding_mask = torch.zeros([self.n_samples])
61
+ padding_mask[:min(n_samples, self.n_samples)] = 1
62
+
63
+
64
+ return (
65
+ chunk,
66
+ t_start,
67
+ t_end,
68
+ seconds_start,
69
+ seconds_total,
70
+ padding_mask
71
+ )
72
+
73
+ class PadCrop_Video_Normalized_T(nn.Module):
74
+
75
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
76
+
77
+ super().__init__()
78
+
79
+ self.n_samples = n_samples
80
+ self.sample_rate = sample_rate
81
+ self.randomize = randomize
82
+ self.fps = fps
83
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
84
+
85
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
86
+ n_channels, n_samples = audio.shape
87
+ # print(video.shape)
88
+ n_frames, dim = video.shape
89
+ if not torch.is_tensor(video):
90
+ video = torch.from_numpy(video)
91
+ # If the audio is shorter than the desired length, pad it
92
+ audio_upper_bound = max(0, n_samples - self.n_samples)
93
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
94
+ upper_bound = min(audio_upper_bound,video_upper_bound)
95
+
96
+ # If randomize is False, always start at the beginning of the audio
97
+ offset = 0
98
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
99
+ offset = random.randint(0, upper_bound)
100
+
101
+ # Calculate the start and end times of the chunk
102
+ t_start = offset / (upper_bound + self.n_samples)
103
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
104
+ frame_offset = int(self.fps * offset / self.sample_rate)
105
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
106
+ # Create the chunk
107
+ chunk = audio.new_zeros([n_channels, self.n_samples])
108
+ video_chunk = video.new_zeros([self.n_frames, video.shape[1]])
109
+ # Copy the audio into the chunk
110
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
111
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames,:]
112
+ # Calculate the start and end times of the chunk in seconds
113
+ seconds_start = math.floor(offset / self.sample_rate)
114
+ seconds_total = math.ceil(n_samples / self.sample_rate)
115
+
116
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
117
+ padding_mask = torch.zeros([self.n_samples])
118
+ padding_mask[:min(n_samples, self.n_samples)] = 1
119
+
120
+
121
+ return (
122
+ chunk,
123
+ video_chunk,
124
+ t_start,
125
+ t_end,
126
+ seconds_start,
127
+ seconds_total,
128
+ padding_mask
129
+ )
130
+
131
+ class PadCrop_Video_Image_Normalized_T(nn.Module):
132
+
133
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
134
+
135
+ super().__init__()
136
+
137
+ self.n_samples = n_samples
138
+ self.sample_rate = sample_rate
139
+ self.randomize = randomize
140
+ self.fps = fps
141
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
142
+
143
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
144
+ n_channels, n_samples = audio.shape
145
+ # import ipdb
146
+ # ipdb.set_trace()
147
+ n_frames, channel, width, height= video.shape
148
+ video = torch.from_numpy(video)
149
+ # If the audio is shorter than the desired length, pad it
150
+ audio_upper_bound = max(0, n_samples - self.n_samples)
151
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
152
+ upper_bound = min(audio_upper_bound,video_upper_bound)
153
+
154
+ # If randomize is False, always start at the beginning of the audio
155
+ offset = 0
156
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
157
+ offset = random.randint(0, upper_bound)
158
+
159
+ # Calculate the start and end times of the chunk
160
+ t_start = offset / (upper_bound + self.n_samples)
161
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
162
+ frame_offset = int(self.fps * offset / self.sample_rate)
163
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
164
+ # Create the chunk
165
+ chunk = audio.new_zeros([n_channels, self.n_samples])
166
+ video_chunk = video.new_zeros([self.n_frames, channel, width, height])
167
+ # Copy the audio into the chunk
168
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
169
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames]
170
+ # Calculate the start and end times of the chunk in seconds
171
+ seconds_start = math.floor(offset / self.sample_rate)
172
+ seconds_total = math.ceil(n_samples / self.sample_rate)
173
+
174
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
175
+ padding_mask = torch.zeros([self.n_samples])
176
+ padding_mask[:min(n_samples, self.n_samples)] = 1
177
+
178
+
179
+ return (
180
+ chunk,
181
+ video_chunk,
182
+ t_start,
183
+ t_end,
184
+ seconds_start,
185
+ seconds_total,
186
+ padding_mask
187
+ )
188
+
189
+ class PadCrop_Video_Hiera_Normalized_T(nn.Module):
190
+
191
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
192
+
193
+ super().__init__()
194
+
195
+ self.n_samples = n_samples
196
+ self.sample_rate = sample_rate
197
+ self.randomize = randomize
198
+ self.fps = fps
199
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
200
+
201
+ def __call__(self, audio: torch.Tensor, video: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
202
+
203
+ n_channels, n_samples = audio.shape
204
+ n_frames, heigh, width, channel = video.shape
205
+ video = torch.from_numpy(video)
206
+ # If the audio is shorter than the desired length, pad it
207
+ audio_upper_bound = max(0, n_samples - self.n_samples)
208
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
209
+ upper_bound = min(audio_upper_bound,video_upper_bound)
210
+
211
+ # If randomize is False, always start at the beginning of the audio
212
+ offset = 0
213
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
214
+ offset = random.randint(0, upper_bound)
215
+
216
+ # Calculate the start and end times of the chunk
217
+ t_start = offset / (upper_bound + self.n_samples)
218
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
219
+ frame_offset = int(self.fps * offset / self.sample_rate)
220
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
221
+ # Create the chunk
222
+ chunk = audio.new_zeros([n_channels, self.n_samples])
223
+ video_chunk = video.new_zeros([self.n_frames, heigh, width, channel])
224
+ # Copy the audio into the chunk
225
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
226
+ video_chunk[:min(n_frames, self.n_frames)] = video[frame_offset:frame_offset + self.n_frames]
227
+ # video_chunk = video_chunk[None].permute(0, 4, 1, 2, 3).contiguous()
228
+ # print(video_chunk.shape)
229
+ # video_chunk = F.interpolate(
230
+ # video_chunk[0],
231
+ # size=(224, 224, 3), # 输出的空间尺寸
232
+ # scale_factor=(target_frames / video_tensor.shape[1], 1, 1), # 时间轴的缩放因子
233
+ # mode='trilinear', # 使用三线性插值
234
+ # align_corners=False
235
+ # )
236
+
237
+ # video_chunk = F.interpolate(video_chunk, size=(64, 224, 224), mode="trilinear")[0]
238
+ # video_chunk = video_chunk.view(3,4,16,224,224).transpose(0,1)
239
+ # Calculate the start and end times of the chunk in seconds
240
+ seconds_start = math.floor(offset / self.sample_rate)
241
+ seconds_total = math.ceil(n_samples / self.sample_rate)
242
+
243
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
244
+ padding_mask = torch.zeros([self.n_samples])
245
+ padding_mask[:min(n_samples, self.n_samples)] = 1
246
+
247
+
248
+ return (
249
+ chunk,
250
+ video_chunk,
251
+ t_start,
252
+ t_end,
253
+ seconds_start,
254
+ seconds_total,
255
+ padding_mask
256
+ )
257
+
258
+ class PadCrop_DualVideo_Normalized_T(nn.Module):
259
+
260
+ def __init__(self, n_samples: int, sample_rate: int, fps: int, randomize: bool = True):
261
+
262
+ super().__init__()
263
+
264
+ self.n_samples = n_samples
265
+ self.sample_rate = sample_rate
266
+ self.randomize = randomize
267
+ self.fps = fps
268
+ self.n_frames = int(self.fps * self.n_samples / self.sample_rate)
269
+
270
+ def __call__(self, audio: torch.Tensor, video_360: torch.Tensor, video_fov: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]:
271
+ n_channels, n_samples = audio.shape
272
+ # print(video.shape)
273
+ n_frames, dim = video_360.shape
274
+ video_360 = torch.from_numpy(video_360)
275
+ video_fov = torch.from_numpy(video_fov)
276
+ # If the audio is shorter than the desired length, pad it
277
+ audio_upper_bound = max(0, n_samples - self.n_samples)
278
+ video_upper_bound = int(max(0, n_frames - self.n_frames) * self.sample_rate / self.fps)
279
+ upper_bound = min(audio_upper_bound,video_upper_bound)
280
+
281
+ # If randomize is False, always start at the beginning of the audio
282
+ offset = 0
283
+ if(self.randomize and n_samples > self.n_samples and n_frames > self.n_frames):
284
+ offset = random.randint(0, upper_bound)
285
+
286
+ # Calculate the start and end times of the chunk
287
+ t_start = offset / (upper_bound + self.n_samples)
288
+ t_end = (offset + self.n_samples) / (upper_bound + self.n_samples)
289
+ frame_offset = int(self.fps * offset / self.sample_rate)
290
+ # frame_end = frame_offset + int(self.fps * self.n_samples / self.sample_rate)
291
+ # Create the chunk
292
+ chunk = audio.new_zeros([n_channels, self.n_samples])
293
+ video_360_chunk = video_360.new_zeros([self.n_frames, video_360.shape[1]])
294
+ video_fov_chunk = video_fov.new_zeros([self.n_frames, video_fov.shape[1]])
295
+ # Copy the audio into the chunk
296
+ chunk[:, :min(n_samples, self.n_samples)] = audio[:, offset:offset + self.n_samples]
297
+ video_360_chunk[:min(n_frames, self.n_frames)] = video_360[frame_offset:frame_offset + self.n_frames,:]
298
+ video_fov_chunk[:min(n_frames, self.n_frames)] = video_fov[frame_offset:frame_offset + self.n_frames,:]
299
+ # Calculate the start and end times of the chunk in seconds
300
+ seconds_start = math.floor(offset / self.sample_rate)
301
+ seconds_total = math.ceil(n_samples / self.sample_rate)
302
+
303
+ # Create a mask the same length as the chunk with 1s where the audio is and 0s where it isn't
304
+ padding_mask = torch.zeros([self.n_samples])
305
+ padding_mask[:min(n_samples, self.n_samples)] = 1
306
+
307
+
308
+ return (
309
+ chunk,
310
+ video_360_chunk,
311
+ video_fov_chunk,
312
+ t_start,
313
+ t_end,
314
+ seconds_start,
315
+ seconds_total,
316
+ padding_mask
317
+ )
318
+
319
+ class PhaseFlipper(nn.Module):
320
+ "Randomly invert the phase of a signal"
321
+ def __init__(self, p=0.5):
322
+ super().__init__()
323
+ self.p = p
324
+ def __call__(self, signal):
325
+ return -signal if (random.random() < self.p) else signal
326
+
327
+ class Mono(nn.Module):
328
+ def __call__(self, signal):
329
+ return torch.mean(signal, dim=0, keepdims=True) if len(signal.shape) > 1 else signal
330
+
331
+ class Stereo(nn.Module):
332
+ def __call__(self, signal):
333
+ signal_shape = signal.shape
334
+ # Check if it's mono
335
+ if len(signal_shape) == 1: # s -> 2, s
336
+ signal = signal.unsqueeze(0).repeat(2, 1)
337
+ elif len(signal_shape) == 2:
338
+ if signal_shape[0] == 1: #1, s -> 2, s
339
+ signal = signal.repeat(2, 1)
340
+ elif signal_shape[0] > 2: #?, s -> 2,s
341
+ signal = signal[:2, :]
342
+
343
+ return signal
344
+
345
+ class FOA(nn.Module):
346
+ def __call__(self, signal):
347
+ signal_shape = signal.shape
348
+ # Check if it's mono
349
+ if len(signal_shape) == 1: # s -> (4, s)
350
+ foa = torch.zeros(4, signal_shape[0], device=signal.device) # 与输入信号一致的设备类型
351
+ foa[0, :] = signal # W通道: 全方位声源
352
+ foa[1, :] = 0 # X通道
353
+ foa[2, :] = 0 # Y通道
354
+ foa[3, :] = 0 # Z通道
355
+ elif len(signal_shape) == 2:
356
+ foa = torch.zeros(4, signal_shape[1], device=signal.device) # 与输入信号一致的设备类型
357
+ if signal_shape[0] == 1: # (1, s) -> (4, s)
358
+ foa[0, :] = signal[0] # W通道: 全方位声源
359
+ foa[1, :] = 0 # X通道
360
+ foa[2, :] = 0 # Y通道
361
+ foa[3, :] = 0 # Z通道
362
+ elif signal_shape[0] == 2: # (2, s) -> (4, s)
363
+ left = signal[0]
364
+ right = signal[1]
365
+ # 将立体声信号映射到FOA信号通道
366
+ foa[0, :] = (left + right) / np.sqrt(2) # W通道: 全方位声源
367
+ foa[1, :] = (left - right) / np.sqrt(2) # X通道: 前后方向
368
+ foa[2, :] = 0 # Y通道: 左右方向,简单实现先置零
369
+ foa[3, :] = 0 # Z通道: 垂直方向,这里置零
370
+ else:
371
+ foa = signal
372
+
373
+ else:
374
+ raise ValueError(f"Unsupported signal shape: {signal_shape}")
375
+
376
+ assert foa.shape[0] == 4, f'inputs not FOA format'
377
+
378
+ return foa
ThinkSound/inference/__init__.py ADDED
File without changes
ThinkSound/inference/generation.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import typing as tp
4
+ import math
5
+ from torchaudio import transforms as T
6
+
7
+ from .utils import prepare_audio
8
+ from .sampling import sample, sample_k, sample_rf
9
+ from ..data.utils import PadCrop
10
+
11
+ def generate_diffusion_uncond(
12
+ model,
13
+ steps: int = 250,
14
+ batch_size: int = 1,
15
+ sample_size: int = 2097152,
16
+ seed: int = -1,
17
+ device: str = "cuda",
18
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
19
+ init_noise_level: float = 1.0,
20
+ return_latents = False,
21
+ **sampler_kwargs
22
+ ) -> torch.Tensor:
23
+
24
+ # The length of the output in audio samples
25
+ audio_sample_size = sample_size
26
+
27
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
28
+ if model.pretransform is not None:
29
+ sample_size = sample_size // model.pretransform.downsampling_ratio
30
+
31
+ # Seed
32
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
33
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
34
+ print(seed)
35
+ torch.manual_seed(seed)
36
+ # Define the initial noise immediately after setting the seed
37
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
38
+
39
+ if init_audio is not None:
40
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
41
+ in_sr, init_audio = init_audio
42
+
43
+ io_channels = model.io_channels
44
+
45
+ # For latent models, set the io_channels to the autoencoder's io_channels
46
+ if model.pretransform is not None:
47
+ io_channels = model.pretransform.io_channels
48
+
49
+ # Prepare the initial audio for use by the model
50
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
51
+
52
+ # For latent models, encode the initial audio into latents
53
+ if model.pretransform is not None:
54
+ init_audio = model.pretransform.encode(init_audio)
55
+
56
+ init_audio = init_audio.repeat(batch_size, 1, 1)
57
+ else:
58
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
59
+ init_audio = None
60
+ init_noise_level = None
61
+
62
+ # Inpainting mask
63
+
64
+ if init_audio is not None:
65
+ # variations
66
+ sampler_kwargs["sigma_max"] = init_noise_level
67
+ mask = None
68
+ else:
69
+ mask = None
70
+
71
+ # Now the generative AI part:
72
+
73
+ diff_objective = model.diffusion_objective
74
+
75
+ if diff_objective == "v":
76
+ # k-diffusion denoising process go!
77
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, device=device)
78
+ elif diff_objective == "rectified_flow":
79
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, device=device)
80
+
81
+ # Denoising process done.
82
+ # If this is latent diffusion, decode latents back into audio
83
+ if model.pretransform is not None and not return_latents:
84
+ sampled = model.pretransform.decode(sampled)
85
+
86
+ # Return audio
87
+ return sampled
88
+
89
+
90
+ def generate_diffusion_cond(
91
+ model,
92
+ steps: int = 250,
93
+ cfg_scale=6,
94
+ conditioning: dict = None,
95
+ conditioning_tensors: tp.Optional[dict] = None,
96
+ negative_conditioning: dict = None,
97
+ negative_conditioning_tensors: tp.Optional[dict] = None,
98
+ batch_size: int = 1,
99
+ sample_size: int = 2097152,
100
+ sample_rate: int = 48000,
101
+ seed: int = -1,
102
+ device: str = "cuda",
103
+ init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
104
+ init_noise_level: float = 1.0,
105
+ mask_args: dict = None,
106
+ return_latents = False,
107
+ **sampler_kwargs
108
+ ) -> torch.Tensor:
109
+ """
110
+ Generate audio from a prompt using a diffusion model.
111
+
112
+ Args:
113
+ model: The diffusion model to use for generation.
114
+ steps: The number of diffusion steps to use.
115
+ cfg_scale: Classifier-free guidance scale
116
+ conditioning: A dictionary of conditioning parameters to use for generation.
117
+ conditioning_tensors: A dictionary of precomputed conditioning tensors to use for generation.
118
+ batch_size: The batch size to use for generation.
119
+ sample_size: The length of the audio to generate, in samples.
120
+ sample_rate: The sample rate of the audio to generate (Deprecated, now pulled from the model directly)
121
+ seed: The random seed to use for generation, or -1 to use a random seed.
122
+ device: The device to use for generation.
123
+ init_audio: A tuple of (sample_rate, audio) to use as the initial audio for generation.
124
+ init_noise_level: The noise level to use when generating from an initial audio sample.
125
+ return_latents: Whether to return the latents used for generation instead of the decoded audio.
126
+ **sampler_kwargs: Additional keyword arguments to pass to the sampler.
127
+ """
128
+
129
+ # The length of the output in audio samples
130
+ audio_sample_size = sample_size
131
+
132
+ # If this is latent diffusion, change sample_size instead to the downsampled latent size
133
+ if model.pretransform is not None:
134
+ sample_size = sample_size // model.pretransform.downsampling_ratio
135
+
136
+ # Seed
137
+ # The user can explicitly set the seed to deterministically generate the same output. Otherwise, use a random seed.
138
+ seed = seed if seed != -1 else np.random.randint(0, 2**32 - 1, dtype=np.uint32)
139
+ print(seed)
140
+ torch.manual_seed(seed)
141
+ # Define the initial noise immediately after setting the seed
142
+ noise = torch.randn([batch_size, model.io_channels, sample_size], device=device)
143
+
144
+ torch.backends.cuda.matmul.allow_tf32 = False
145
+ torch.backends.cudnn.allow_tf32 = False
146
+ torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = False
147
+ torch.backends.cudnn.benchmark = False
148
+ import ipdb
149
+ # ipdb.set_trace()
150
+ # Conditioning
151
+ assert conditioning is not None or conditioning_tensors is not None, "Must provide either conditioning or conditioning_tensors"
152
+ if conditioning_tensors is None:
153
+ conditioning_tensors = model.conditioner(conditioning, device)
154
+ conditioning_inputs = model.get_conditioning_inputs(conditioning_tensors)
155
+
156
+ if negative_conditioning is not None or negative_conditioning_tensors is not None:
157
+
158
+ if negative_conditioning_tensors is None:
159
+ negative_conditioning_tensors = model.conditioner(negative_conditioning, device)
160
+
161
+ negative_conditioning_tensors = model.get_conditioning_inputs(negative_conditioning_tensors, negative=True)
162
+ else:
163
+ negative_conditioning_tensors = {}
164
+
165
+ if init_audio is not None:
166
+ # The user supplied some initial audio (for inpainting or variation). Let us prepare the input audio.
167
+ in_sr, init_audio = init_audio
168
+
169
+ io_channels = model.io_channels
170
+
171
+ # For latent models, set the io_channels to the autoencoder's io_channels
172
+ if model.pretransform is not None:
173
+ io_channels = model.pretransform.io_channels
174
+
175
+ # Prepare the initial audio for use by the model
176
+ init_audio = prepare_audio(init_audio, in_sr=in_sr, target_sr=model.sample_rate, target_length=audio_sample_size, target_channels=io_channels, device=device)
177
+
178
+ # For latent models, encode the initial audio into latents
179
+ if model.pretransform is not None:
180
+ init_audio = model.pretransform.encode(init_audio)
181
+
182
+ init_audio = init_audio.repeat(batch_size, 1, 1)
183
+ else:
184
+ # The user did not supply any initial audio for inpainting or variation. Generate new output from scratch.
185
+ init_audio = None
186
+ init_noise_level = None
187
+ mask_args = None
188
+
189
+ # Inpainting mask
190
+ if init_audio is not None and mask_args is not None:
191
+ # Cut and paste init_audio according to cropfrom, pastefrom, pasteto
192
+ # This is helpful for forward and reverse outpainting
193
+ cropfrom = math.floor(mask_args["cropfrom"]/100.0 * sample_size)
194
+ pastefrom = math.floor(mask_args["pastefrom"]/100.0 * sample_size)
195
+ pasteto = math.ceil(mask_args["pasteto"]/100.0 * sample_size)
196
+ assert pastefrom < pasteto, "Paste From should be less than Paste To"
197
+ croplen = pasteto - pastefrom
198
+ if cropfrom + croplen > sample_size:
199
+ croplen = sample_size - cropfrom
200
+ cropto = cropfrom + croplen
201
+ pasteto = pastefrom + croplen
202
+ cutpaste = init_audio.new_zeros(init_audio.shape)
203
+ cutpaste[:, :, pastefrom:pasteto] = init_audio[:,:,cropfrom:cropto]
204
+ #print(cropfrom, cropto, pastefrom, pasteto)
205
+ init_audio = cutpaste
206
+ # Build a soft mask (list of floats 0 to 1, the size of the latent) from the given args
207
+ mask = build_mask(sample_size, mask_args)
208
+ mask = mask.to(device)
209
+ elif init_audio is not None and mask_args is None:
210
+ # variations
211
+ sampler_kwargs["sigma_max"] = init_noise_level
212
+ mask = None
213
+ else:
214
+ mask = None
215
+
216
+ model_dtype = next(model.model.parameters()).dtype
217
+ noise = noise.type(model_dtype)
218
+ conditioning_inputs = {k: v.type(model_dtype) if v is not None else v for k, v in conditioning_inputs.items()}
219
+ # Now the generative AI part:
220
+ # k-diffusion denoising process go!
221
+ diff_objective = model.diffusion_objective
222
+ if diff_objective == "v":
223
+ # k-diffusion denoising process go!
224
+ # sampled = sample(model.model, noise, steps, 0, **conditioning_inputs)
225
+ sampled = sample_k(model.model, noise, init_audio, mask, steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
226
+ elif diff_objective == "rectified_flow":
227
+
228
+ if "sigma_min" in sampler_kwargs:
229
+ del sampler_kwargs["sigma_min"]
230
+
231
+ if "sampler_type" in sampler_kwargs:
232
+ del sampler_kwargs["sampler_type"]
233
+
234
+ sampled = sample_rf(model.model, noise, init_data=init_audio, steps=steps, **sampler_kwargs, **conditioning_inputs, **negative_conditioning_tensors, cfg_scale=cfg_scale, batch_cfg=True, rescale_cfg=True, device=device)
235
+
236
+ # v-diffusion:
237
+ #sampled = sample(model.model, noise, steps, 0, **conditioning_tensors, embedding_scale=cfg_scale)
238
+ del noise
239
+ del conditioning_tensors
240
+ del conditioning_inputs
241
+ torch.cuda.empty_cache()
242
+ # Denoising process done.
243
+ # If this is latent diffusion, decode latents back into audio
244
+ if model.pretransform is not None and not return_latents:
245
+ #cast sampled latents to pretransform dtype
246
+ sampled = sampled.to(next(model.pretransform.parameters()).dtype)
247
+ sampled = model.pretransform.decode(sampled)
248
+
249
+ # Return audio
250
+ return sampled
251
+
252
+ # builds a softmask given the parameters
253
+ # returns array of values 0 to 1, size sample_size, where 0 means noise / fresh generation, 1 means keep the input audio,
254
+ # and anything between is a mixture of old/new
255
+ # ideally 0.5 is half/half mixture but i haven't figured this out yet
256
+ def build_mask(sample_size, mask_args):
257
+ maskstart = math.floor(mask_args["maskstart"]/100.0 * sample_size)
258
+ maskend = math.ceil(mask_args["maskend"]/100.0 * sample_size)
259
+ softnessL = round(mask_args["softnessL"]/100.0 * sample_size)
260
+ softnessR = round(mask_args["softnessR"]/100.0 * sample_size)
261
+ marination = mask_args["marination"]
262
+ # use hann windows for softening the transition (i don't know if this is correct)
263
+ hannL = torch.hann_window(softnessL*2, periodic=False)[:softnessL]
264
+ hannR = torch.hann_window(softnessR*2, periodic=False)[softnessR:]
265
+ # build the mask.
266
+ mask = torch.zeros((sample_size))
267
+ mask[maskstart:maskend] = 1
268
+ mask[maskstart:maskstart+softnessL] = hannL
269
+ mask[maskend-softnessR:maskend] = hannR
270
+ # marination finishes the inpainting early in the denoising schedule, and lets audio get changed in the final rounds
271
+ if marination > 0:
272
+ mask = mask * (1-marination)
273
+ #print(mask)
274
+ return mask
ThinkSound/inference/sampling.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from tqdm import trange, tqdm
4
+
5
+ import k_diffusion as K
6
+
7
+ # Define the noise schedule and sampling loop
8
+ def get_alphas_sigmas(t):
9
+ """Returns the scaling factors for the clean image (alpha) and for the
10
+ noise (sigma), given a timestep."""
11
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
12
+
13
+ def alpha_sigma_to_t(alpha, sigma):
14
+ """Returns a timestep, given the scaling factors for the clean image and for
15
+ the noise."""
16
+ return torch.atan2(sigma, alpha) / math.pi * 2
17
+
18
+ def t_to_alpha_sigma(t):
19
+ """Returns the scaling factors for the clean image and for the noise, given
20
+ a timestep."""
21
+ return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
22
+
23
+
24
+ @torch.no_grad()
25
+ def sample_discrete_euler(model, x, steps, sigma_max=1, **extra_args):
26
+ """Draws samples from a model given starting noise. Euler method"""
27
+
28
+ # Make tensor of ones to broadcast the single t values
29
+ ts = x.new_ones([x.shape[0]])
30
+
31
+ # Create the noise schedule
32
+ t = torch.linspace(sigma_max, 0, steps + 1)
33
+
34
+ #alphas, sigmas = 1-t, t
35
+
36
+ for t_curr, t_prev in tqdm(zip(t[:-1], t[1:])):
37
+ # Broadcast the current timestep to the correct shape
38
+ t_curr_tensor = t_curr * torch.ones(
39
+ (x.shape[0],), dtype=x.dtype, device=x.device
40
+ )
41
+ dt = t_prev - t_curr # we solve backwards in our formulation
42
+ x = x + dt * model(x, t_curr_tensor, **extra_args) #.denoise(x, denoiser, t_curr_tensor, cond, uc)
43
+
44
+ # If we are on the last timestep, output the denoised image
45
+ return x
46
+
47
+ @torch.no_grad()
48
+ def sample(model, x, steps, eta, **extra_args):
49
+ """Draws samples from a model given starting noise. v-diffusion"""
50
+ ts = x.new_ones([x.shape[0]])
51
+
52
+ # Create the noise schedule
53
+ t = torch.linspace(1, 0, steps + 1)[:-1]
54
+
55
+ alphas, sigmas = get_alphas_sigmas(t)
56
+
57
+ # The sampling loop
58
+ for i in trange(steps):
59
+
60
+ # Get the model output (v, the predicted velocity)
61
+ with torch.cuda.amp.autocast():
62
+ v = model(x, ts * t[i], **extra_args).float()
63
+
64
+ # Predict the noise and the denoised image
65
+ pred = x * alphas[i] - v * sigmas[i]
66
+ eps = x * sigmas[i] + v * alphas[i]
67
+
68
+ # If we are not on the last timestep, compute the noisy image for the
69
+ # next timestep.
70
+ if i < steps - 1:
71
+ # If eta > 0, adjust the scaling factor for the predicted noise
72
+ # downward according to the amount of additional noise to add
73
+ ddim_sigma = eta * (sigmas[i + 1]**2 / sigmas[i]**2).sqrt() * \
74
+ (1 - alphas[i]**2 / alphas[i + 1]**2).sqrt()
75
+ adjusted_sigma = (sigmas[i + 1]**2 - ddim_sigma**2).sqrt()
76
+
77
+ # Recombine the predicted noise and predicted denoised image in the
78
+ # correct proportions for the next step
79
+ x = pred * alphas[i + 1] + eps * adjusted_sigma
80
+
81
+ # Add the correct amount of fresh noise
82
+ if eta:
83
+ x += torch.randn_like(x) * ddim_sigma
84
+
85
+ # If we are on the last timestep, output the denoised image
86
+ return pred
87
+
88
+ # Soft mask inpainting is just shrinking hard (binary) mask inpainting
89
+ # Given a float-valued soft mask (values between 0 and 1), get the binary mask for this particular step
90
+ def get_bmask(i, steps, mask):
91
+ strength = (i+1)/(steps)
92
+ # convert to binary mask
93
+ bmask = torch.where(mask<=strength,1,0)
94
+ return bmask
95
+
96
+ def make_cond_model_fn(model, cond_fn):
97
+ def cond_model_fn(x, sigma, **kwargs):
98
+ with torch.enable_grad():
99
+ x = x.detach().requires_grad_()
100
+ denoised = model(x, sigma, **kwargs)
101
+ cond_grad = cond_fn(x, sigma, denoised=denoised, **kwargs).detach()
102
+ cond_denoised = denoised.detach() + cond_grad * K.utils.append_dims(sigma**2, x.ndim)
103
+ return cond_denoised
104
+ return cond_model_fn
105
+
106
+ # Uses k-diffusion from https://github.com/crowsonkb/k-diffusion
107
+ # init_data is init_audio as latents (if this is latent diffusion)
108
+ # For sampling, set both init_data and mask to None
109
+ # For variations, set init_data
110
+ # For inpainting, set both init_data & mask
111
+ def sample_k(
112
+ model_fn,
113
+ noise,
114
+ init_data=None,
115
+ mask=None,
116
+ steps=100,
117
+ sampler_type="dpmpp-2m-sde",
118
+ sigma_min=0.5,
119
+ sigma_max=50,
120
+ rho=1.0, device="cuda",
121
+ callback=None,
122
+ cond_fn=None,
123
+ **extra_args
124
+ ):
125
+
126
+ denoiser = K.external.VDenoiser(model_fn)
127
+
128
+ if cond_fn is not None:
129
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
130
+
131
+ # Make the list of sigmas. Sigma values are scalars related to the amount of noise each denoising step has
132
+ sigmas = K.sampling.get_sigmas_polyexponential(steps, sigma_min, sigma_max, rho, device=device)
133
+ # Scale the initial noise by sigma
134
+ noise = noise * sigmas[0]
135
+
136
+ wrapped_callback = callback
137
+
138
+ if mask is None and init_data is not None:
139
+ # VARIATION (no inpainting)
140
+ # set the initial latent to the init_data, and noise it with initial sigma
141
+ x = init_data + noise
142
+ elif mask is not None and init_data is not None:
143
+ # INPAINTING
144
+ bmask = get_bmask(0, steps, mask)
145
+ # initial noising
146
+ input_noised = init_data + noise
147
+ # set the initial latent to a mix of init_data and noise, based on step 0's binary mask
148
+ x = input_noised * bmask + noise * (1-bmask)
149
+ # define the inpainting callback function (Note: side effects, it mutates x)
150
+ # See https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py#L596C13-L596C105
151
+ # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
152
+ # This is called immediately after `denoised = model(x, sigmas[i] * s_in, **extra_args)`
153
+ def inpainting_callback(args):
154
+ i = args["i"]
155
+ x = args["x"]
156
+ sigma = args["sigma"]
157
+ #denoised = args["denoised"]
158
+ # noise the init_data input with this step's appropriate amount of noise
159
+ input_noised = init_data + torch.randn_like(init_data) * sigma
160
+ # shrinking hard mask
161
+ bmask = get_bmask(i, steps, mask)
162
+ # mix input_noise with x, using binary mask
163
+ new_x = input_noised * bmask + x * (1-bmask)
164
+ # mutate x
165
+ x[:,:,:] = new_x[:,:,:]
166
+ # wrap together the inpainting callback and the user-submitted callback.
167
+ if callback is None:
168
+ wrapped_callback = inpainting_callback
169
+ else:
170
+ wrapped_callback = lambda args: (inpainting_callback(args), callback(args))
171
+ else:
172
+ # SAMPLING
173
+ # set the initial latent to noise
174
+ x = noise
175
+
176
+
177
+ with torch.cuda.amp.autocast():
178
+ if sampler_type == "k-heun":
179
+ return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
180
+ elif sampler_type == "k-lms":
181
+ return K.sampling.sample_lms(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
182
+ elif sampler_type == "k-dpmpp-2s-ancestral":
183
+ return K.sampling.sample_dpmpp_2s_ancestral(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
184
+ elif sampler_type == "k-dpm-2":
185
+ return K.sampling.sample_dpm_2(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
186
+ elif sampler_type == "k-dpm-fast":
187
+ return K.sampling.sample_dpm_fast(denoiser, x, sigma_min, sigma_max, steps, disable=False, callback=wrapped_callback, extra_args=extra_args)
188
+ elif sampler_type == "k-dpm-adaptive":
189
+ return K.sampling.sample_dpm_adaptive(denoiser, x, sigma_min, sigma_max, rtol=0.01, atol=0.01, disable=False, callback=wrapped_callback, extra_args=extra_args)
190
+ elif sampler_type == "dpmpp-2m-sde":
191
+ return K.sampling.sample_dpmpp_2m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
192
+ elif sampler_type == "dpmpp-3m-sde":
193
+ return K.sampling.sample_dpmpp_3m_sde(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
194
+
195
+ # Uses discrete Euler sampling for rectified flow models
196
+ # init_data is init_audio as latents (if this is latent diffusion)
197
+ # For sampling, set both init_data and mask to None
198
+ # For variations, set init_data
199
+ # For inpainting, set both init_data & mask
200
+ def sample_rf(
201
+ model_fn,
202
+ noise,
203
+ init_data=None,
204
+ steps=100,
205
+ sigma_max=1,
206
+ device="cuda",
207
+ callback=None,
208
+ cond_fn=None,
209
+ **extra_args
210
+ ):
211
+
212
+ if sigma_max > 1:
213
+ sigma_max = 1
214
+
215
+ if cond_fn is not None:
216
+ denoiser = make_cond_model_fn(denoiser, cond_fn)
217
+
218
+ wrapped_callback = callback
219
+
220
+ if init_data is not None:
221
+ # VARIATION (no inpainting)
222
+ # Interpolate the init data and the noise for init audio
223
+ x = init_data * (1 - sigma_max) + noise * sigma_max
224
+ else:
225
+ # SAMPLING
226
+ # set the initial latent to noise
227
+ x = noise
228
+
229
+ with torch.cuda.amp.autocast():
230
+ # TODO: Add callback support
231
+ #return sample_discrete_euler(model_fn, x, steps, sigma_max, callback=wrapped_callback, **extra_args)
232
+ return sample_discrete_euler(model_fn, x, steps, sigma_max, **extra_args)
ThinkSound/inference/utils.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..data.utils import PadCrop
2
+
3
+ from torchaudio import transforms as T
4
+
5
+ def set_audio_channels(audio, target_channels):
6
+ if target_channels == 1:
7
+ # Convert to mono
8
+ audio = audio.mean(1, keepdim=True)
9
+ elif target_channels == 2:
10
+ # Convert to stereo
11
+ if audio.shape[1] == 1:
12
+ audio = audio.repeat(1, 2, 1)
13
+ elif audio.shape[1] > 2:
14
+ audio = audio[:, :2, :]
15
+ return audio
16
+
17
+ def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device):
18
+
19
+ audio = audio.to(device)
20
+
21
+ if in_sr != target_sr:
22
+ resample_tf = T.Resample(in_sr, target_sr).to(device)
23
+ audio = resample_tf(audio)
24
+
25
+ audio = PadCrop(target_length, randomize=False)(audio)
26
+
27
+ # Add batch dimension
28
+ if audio.dim() == 1:
29
+ audio = audio.unsqueeze(0).unsqueeze(0)
30
+ elif audio.dim() == 2:
31
+ audio = audio.unsqueeze(0)
32
+
33
+ audio = set_audio_channels(audio, target_channels)
34
+
35
+ return audio
ThinkSound/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_model_from_config, create_model_from_config_path
ThinkSound/models/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (266 Bytes). View file
 
ThinkSound/models/__pycache__/factory.cpython-313.pyc ADDED
Binary file (5.77 kB). View file
 
ThinkSound/models/__pycache__/pretrained.cpython-313.pyc ADDED
Binary file (1.2 kB). View file
 
ThinkSound/models/__pycache__/utils.cpython-313.pyc ADDED
Binary file (8.39 kB). View file
 
ThinkSound/models/autoencoders.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+ from torchaudio import transforms as T
8
+ from alias_free_torch import Activation1d
9
+ from dac.nn.layers import WNConv1d, WNConvTranspose1d
10
+ from typing import Literal, Dict, Any
11
+
12
+ from ..inference.sampling import sample
13
+ from ..inference.utils import prepare_audio
14
+ from .blocks import SnakeBeta
15
+ from .bottleneck import Bottleneck, DiscreteBottleneck
16
+ from .diffusion import ConditionedDiffusionModel, DAU1DCondWrapper, UNet1DCondWrapper, DiTWrapper
17
+ from .factory import create_pretransform_from_config, create_bottleneck_from_config
18
+ from .pretransforms import Pretransform
19
+
20
+ def checkpoint(function, *args, **kwargs):
21
+ kwargs.setdefault("use_reentrant", False)
22
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
23
+
24
+ def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module:
25
+ if activation == "elu":
26
+ act = nn.ELU()
27
+ elif activation == "snake":
28
+ act = SnakeBeta(channels)
29
+ elif activation == "none":
30
+ act = nn.Identity()
31
+ else:
32
+ raise ValueError(f"Unknown activation {activation}")
33
+
34
+ if antialias:
35
+ act = Activation1d(act)
36
+
37
+ return act
38
+
39
+ class ResidualUnit(nn.Module):
40
+ def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False):
41
+ super().__init__()
42
+
43
+ self.dilation = dilation
44
+
45
+ padding = (dilation * (7-1)) // 2
46
+
47
+ self.layers = nn.Sequential(
48
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
49
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
50
+ kernel_size=7, dilation=dilation, padding=padding),
51
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels),
52
+ WNConv1d(in_channels=out_channels, out_channels=out_channels,
53
+ kernel_size=1)
54
+ )
55
+
56
+ def forward(self, x):
57
+ res = x
58
+
59
+ #x = checkpoint(self.layers, x)
60
+ x = self.layers(x)
61
+
62
+ return x + res
63
+
64
+ class EncoderBlock(nn.Module):
65
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False):
66
+ super().__init__()
67
+
68
+ self.layers = nn.Sequential(
69
+ ResidualUnit(in_channels=in_channels,
70
+ out_channels=in_channels, dilation=1, use_snake=use_snake),
71
+ ResidualUnit(in_channels=in_channels,
72
+ out_channels=in_channels, dilation=3, use_snake=use_snake),
73
+ ResidualUnit(in_channels=in_channels,
74
+ out_channels=in_channels, dilation=9, use_snake=use_snake),
75
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
76
+ WNConv1d(in_channels=in_channels, out_channels=out_channels,
77
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)),
78
+ )
79
+
80
+ def forward(self, x):
81
+ return self.layers(x)
82
+
83
+ class DecoderBlock(nn.Module):
84
+ def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False):
85
+ super().__init__()
86
+
87
+ if use_nearest_upsample:
88
+ upsample_layer = nn.Sequential(
89
+ nn.Upsample(scale_factor=stride, mode="nearest"),
90
+ WNConv1d(in_channels=in_channels,
91
+ out_channels=out_channels,
92
+ kernel_size=2*stride,
93
+ stride=1,
94
+ bias=False,
95
+ padding='same')
96
+ )
97
+ else:
98
+ upsample_layer = WNConvTranspose1d(in_channels=in_channels,
99
+ out_channels=out_channels,
100
+ kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2))
101
+
102
+ self.layers = nn.Sequential(
103
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels),
104
+ upsample_layer,
105
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
106
+ dilation=1, use_snake=use_snake),
107
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
108
+ dilation=3, use_snake=use_snake),
109
+ ResidualUnit(in_channels=out_channels, out_channels=out_channels,
110
+ dilation=9, use_snake=use_snake),
111
+ )
112
+
113
+ def forward(self, x):
114
+ return self.layers(x)
115
+
116
+ class OobleckEncoder(nn.Module):
117
+ def __init__(self,
118
+ in_channels=2,
119
+ channels=128,
120
+ latent_dim=32,
121
+ c_mults = [1, 2, 4, 8],
122
+ strides = [2, 4, 8, 8],
123
+ use_snake=False,
124
+ antialias_activation=False
125
+ ):
126
+ super().__init__()
127
+
128
+ c_mults = [1] + c_mults
129
+
130
+ self.depth = len(c_mults)
131
+
132
+ layers = [
133
+ WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3)
134
+ ]
135
+
136
+ for i in range(self.depth-1):
137
+ layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)]
138
+
139
+ layers += [
140
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels),
141
+ WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1)
142
+ ]
143
+
144
+ self.layers = nn.Sequential(*layers)
145
+
146
+ def forward(self, x):
147
+ return self.layers(x)
148
+
149
+
150
+ class OobleckDecoder(nn.Module):
151
+ def __init__(self,
152
+ out_channels=2,
153
+ channels=128,
154
+ latent_dim=32,
155
+ c_mults = [1, 2, 4, 8],
156
+ strides = [2, 4, 8, 8],
157
+ use_snake=False,
158
+ antialias_activation=False,
159
+ use_nearest_upsample=False,
160
+ final_tanh=True):
161
+ super().__init__()
162
+
163
+ c_mults = [1] + c_mults
164
+
165
+ self.depth = len(c_mults)
166
+
167
+ layers = [
168
+ WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3),
169
+ ]
170
+
171
+ for i in range(self.depth-1, 0, -1):
172
+ layers += [DecoderBlock(
173
+ in_channels=c_mults[i]*channels,
174
+ out_channels=c_mults[i-1]*channels,
175
+ stride=strides[i-1],
176
+ use_snake=use_snake,
177
+ antialias_activation=antialias_activation,
178
+ use_nearest_upsample=use_nearest_upsample
179
+ )
180
+ ]
181
+
182
+ layers += [
183
+ get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels),
184
+ WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False),
185
+ nn.Tanh() if final_tanh else nn.Identity()
186
+ ]
187
+
188
+ self.layers = nn.Sequential(*layers)
189
+
190
+ def forward(self, x):
191
+ return self.layers(x)
192
+
193
+
194
+ class DACEncoderWrapper(nn.Module):
195
+ def __init__(self, in_channels=1, **kwargs):
196
+ super().__init__()
197
+
198
+ from dac.model.dac import Encoder as DACEncoder
199
+
200
+ latent_dim = kwargs.pop("latent_dim", None)
201
+
202
+ encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"]))
203
+ self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs)
204
+ self.latent_dim = latent_dim
205
+
206
+ # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility
207
+ self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity()
208
+
209
+ if in_channels != 1:
210
+ self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3)
211
+
212
+ def forward(self, x):
213
+ x = self.encoder(x)
214
+ x = self.proj_out(x)
215
+ return x
216
+
217
+ class DACDecoderWrapper(nn.Module):
218
+ def __init__(self, latent_dim, out_channels=1, **kwargs):
219
+ super().__init__()
220
+
221
+ from dac.model.dac import Decoder as DACDecoder
222
+
223
+ self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels)
224
+
225
+ self.latent_dim = latent_dim
226
+
227
+ def forward(self, x):
228
+ return self.decoder(x)
229
+
230
+ class AudioAutoencoder(nn.Module):
231
+ def __init__(
232
+ self,
233
+ encoder,
234
+ decoder,
235
+ latent_dim,
236
+ downsampling_ratio,
237
+ sample_rate,
238
+ io_channels=2,
239
+ bottleneck: Bottleneck = None,
240
+ pretransform: Pretransform = None,
241
+ in_channels = None,
242
+ out_channels = None,
243
+ soft_clip = False
244
+ ):
245
+ super().__init__()
246
+
247
+ self.downsampling_ratio = downsampling_ratio
248
+ self.sample_rate = sample_rate
249
+
250
+ self.latent_dim = latent_dim
251
+ self.io_channels = io_channels
252
+ self.in_channels = io_channels
253
+ self.out_channels = io_channels
254
+
255
+ self.min_length = self.downsampling_ratio
256
+
257
+ if in_channels is not None:
258
+ self.in_channels = in_channels
259
+
260
+ if out_channels is not None:
261
+ self.out_channels = out_channels
262
+
263
+ self.bottleneck = bottleneck
264
+
265
+ self.encoder = encoder
266
+
267
+ self.decoder = decoder
268
+
269
+ self.pretransform = pretransform
270
+
271
+ self.soft_clip = soft_clip
272
+
273
+ self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete
274
+
275
+ def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs):
276
+
277
+ info = {}
278
+ # import ipdb
279
+ # ipdb.set_trace()
280
+ if self.pretransform is not None and not skip_pretransform:
281
+ if self.pretransform.enable_grad:
282
+ if iterate_batch:
283
+ audios = []
284
+ for i in range(audio.shape[0]):
285
+ audios.append(self.pretransform.encode(audio[i:i+1]))
286
+ audio = torch.cat(audios, dim=0)
287
+ else:
288
+ audio = self.pretransform.encode(audio)
289
+ else:
290
+ with torch.no_grad():
291
+ if iterate_batch:
292
+ audios = []
293
+ for i in range(audio.shape[0]):
294
+ audios.append(self.pretransform.encode(audio[i:i+1]))
295
+ audio = torch.cat(audios, dim=0)
296
+ else:
297
+ audio = self.pretransform.encode(audio)
298
+
299
+ if self.encoder is not None:
300
+ if iterate_batch:
301
+ latents = []
302
+ for i in range(audio.shape[0]):
303
+ latents.append(self.encoder(audio[i:i+1]))
304
+ latents = torch.cat(latents, dim=0)
305
+ else:
306
+ latents = self.encoder(audio)
307
+ else:
308
+ latents = audio
309
+
310
+ if self.bottleneck is not None:
311
+ # TODO: Add iterate batch logic, needs to merge the info dicts
312
+ latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs)
313
+
314
+ info.update(bottleneck_info)
315
+
316
+ if return_info:
317
+ return latents, info
318
+
319
+ return latents
320
+
321
+ def decode(self, latents, iterate_batch=False, **kwargs):
322
+
323
+ if self.bottleneck is not None:
324
+ if iterate_batch:
325
+ decoded = []
326
+ for i in range(latents.shape[0]):
327
+ decoded.append(self.bottleneck.decode(latents[i:i+1]))
328
+ latents = torch.cat(decoded, dim=0)
329
+ else:
330
+ latents = self.bottleneck.decode(latents)
331
+
332
+ if iterate_batch:
333
+ decoded = []
334
+ for i in range(latents.shape[0]):
335
+ decoded.append(self.decoder(latents[i:i+1]))
336
+ decoded = torch.cat(decoded, dim=0)
337
+ else:
338
+ decoded = self.decoder(latents, **kwargs)
339
+
340
+ if self.pretransform is not None:
341
+ if self.pretransform.enable_grad:
342
+ if iterate_batch:
343
+ decodeds = []
344
+ for i in range(decoded.shape[0]):
345
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
346
+ decoded = torch.cat(decodeds, dim=0)
347
+ else:
348
+ decoded = self.pretransform.decode(decoded)
349
+ else:
350
+ with torch.no_grad():
351
+ if iterate_batch:
352
+ decodeds = []
353
+ for i in range(latents.shape[0]):
354
+ decodeds.append(self.pretransform.decode(decoded[i:i+1]))
355
+ decoded = torch.cat(decodeds, dim=0)
356
+ else:
357
+ decoded = self.pretransform.decode(decoded)
358
+
359
+ if self.soft_clip:
360
+ decoded = torch.tanh(decoded)
361
+
362
+ return decoded
363
+
364
+ def decode_tokens(self, tokens, **kwargs):
365
+ '''
366
+ Decode discrete tokens to audio
367
+ Only works with discrete autoencoders
368
+ '''
369
+
370
+ assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders"
371
+
372
+ latents = self.bottleneck.decode_tokens(tokens, **kwargs)
373
+
374
+ return self.decode(latents, **kwargs)
375
+
376
+
377
+ def preprocess_audio_for_encoder(self, audio, in_sr):
378
+ '''
379
+ Preprocess single audio tensor (Channels x Length) to be compatible with the encoder.
380
+ If the model is mono, stereo audio will be converted to mono.
381
+ Audio will be silence-padded to be a multiple of the model's downsampling ratio.
382
+ Audio will be resampled to the model's sample rate.
383
+ The output will have batch size 1 and be shape (1 x Channels x Length)
384
+ '''
385
+ return self.preprocess_audio_list_for_encoder([audio], [in_sr])
386
+
387
+ def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list):
388
+ '''
389
+ Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder.
390
+ The audio in that list can be of different lengths and channels.
391
+ in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio.
392
+ All audio will be resampled to the model's sample rate.
393
+ Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio.
394
+ If the model is mono, all audio will be converted to mono.
395
+ The output will be a tensor of shape (Batch x Channels x Length)
396
+ '''
397
+ batch_size = len(audio_list)
398
+ if isinstance(in_sr_list, int):
399
+ in_sr_list = [in_sr_list]*batch_size
400
+ assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list"
401
+ new_audio = []
402
+ max_length = 0
403
+ # resample & find the max length
404
+ for i in range(batch_size):
405
+ audio = audio_list[i]
406
+ in_sr = in_sr_list[i]
407
+ if len(audio.shape) == 3 and audio.shape[0] == 1:
408
+ # batchsize 1 was given by accident. Just squeeze it.
409
+ audio = audio.squeeze(0)
410
+ elif len(audio.shape) == 1:
411
+ # Mono signal, channel dimension is missing, unsqueeze it in
412
+ audio = audio.unsqueeze(0)
413
+ assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension"
414
+ # Resample audio
415
+ if in_sr != self.sample_rate:
416
+ resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device)
417
+ audio = resample_tf(audio)
418
+ new_audio.append(audio)
419
+ if audio.shape[-1] > max_length:
420
+ max_length = audio.shape[-1]
421
+ # Pad every audio to the same length, multiple of model's downsampling ratio
422
+ padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length
423
+ for i in range(batch_size):
424
+ # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model
425
+ new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length,
426
+ target_channels=self.in_channels, device=new_audio[i].device).squeeze(0)
427
+ # convert to tensor
428
+ return torch.stack(new_audio)
429
+
430
+ def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs):
431
+ '''
432
+ Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder.
433
+ If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap.
434
+ Overlap and chunk_size params are both measured in number of latents (not audio samples)
435
+ # and therefore you likely could use the same values with decode_audio.
436
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
437
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
438
+ You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff.
439
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
440
+ Smaller chunk_size uses less memory, but more compute.
441
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
442
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
443
+ '''
444
+ if not chunked:
445
+ # default behavior. Encode the entire audio in parallel
446
+ return self.encode(audio, **kwargs)
447
+ else:
448
+ # CHUNKED ENCODING
449
+ # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio)
450
+ # import ipdb
451
+ # ipdb.set_trace()
452
+ samples_per_latent = self.downsampling_ratio
453
+ total_size = audio.shape[2] # in samples
454
+ print(f'audio shape: {audio.shape}')
455
+ batch_size = audio.shape[0]
456
+ chunk_size *= samples_per_latent # converting metric in latents to samples
457
+ overlap *= samples_per_latent # converting metric in latents to samples
458
+ hop_size = chunk_size - overlap
459
+ chunks = []
460
+ for i in range(0, total_size - chunk_size + 1, hop_size):
461
+ chunk = audio[:,:,i:i+chunk_size]
462
+ chunks.append(chunk)
463
+ if i+chunk_size != total_size:
464
+ # Final chunk
465
+ chunk = audio[:,:,-chunk_size:]
466
+ chunks.append(chunk)
467
+ chunks = torch.stack(chunks)
468
+ num_chunks = chunks.shape[0]
469
+ # Note: y_size might be a different value from the latent length used in diffusion training
470
+ # because we can encode audio of varying lengths
471
+ # However, the audio should've been padded to a multiple of samples_per_latent by now.
472
+ y_size = total_size // samples_per_latent
473
+ # Create an empty latent, we will populate it with chunks as we encode them
474
+ y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device)
475
+ print(f'y_final shape: {y_final.shape}')
476
+ for i in range(num_chunks):
477
+ x_chunk = chunks[i,:]
478
+ # encode the chunk
479
+ y_chunk = self.encode(x_chunk)
480
+ print(f'y_chunk shape: {y_chunk.shape}')
481
+ # figure out where to put the audio along the time domain
482
+ if i == num_chunks-1:
483
+ # final chunk always goes at the end
484
+ t_end = y_size
485
+ t_start = t_end - y_chunk.shape[2]
486
+ else:
487
+ t_start = i * hop_size // samples_per_latent
488
+ t_end = t_start + chunk_size // samples_per_latent
489
+ # remove the edges of the overlaps
490
+ ol = overlap//samples_per_latent//2
491
+ chunk_start = 0
492
+ chunk_end = y_chunk.shape[2]
493
+ if i > 0:
494
+ # no overlap for the start of the first chunk
495
+ t_start += ol
496
+ chunk_start += ol
497
+ if i < num_chunks-1:
498
+ # no overlap for the end of the last chunk
499
+ t_end -= ol
500
+ chunk_end -= ol
501
+ # paste the chunked audio into our y_final output audio
502
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
503
+ return y_final
504
+
505
+ def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs):
506
+ '''
507
+ Decode latents to audio.
508
+ If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents.
509
+ A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size.
510
+ Every autoencoder will have a different receptive field size, and thus ideal overlap.
511
+ You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff.
512
+ The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks.
513
+ Smaller chunk_size uses less memory, but more compute.
514
+ The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version
515
+ For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks
516
+ '''
517
+ if not chunked:
518
+ # default behavior. Decode the entire latent in parallel
519
+ return self.decode(latents, **kwargs)
520
+ else:
521
+ # chunked decoding
522
+ hop_size = chunk_size - overlap
523
+ total_size = latents.shape[2]
524
+ batch_size = latents.shape[0]
525
+ chunks = []
526
+ for i in range(0, total_size - chunk_size + 1, hop_size):
527
+ chunk = latents[:,:,i:i+chunk_size]
528
+ chunks.append(chunk)
529
+ if i+chunk_size != total_size:
530
+ # Final chunk
531
+ chunk = latents[:,:,-chunk_size:]
532
+ chunks.append(chunk)
533
+ chunks = torch.stack(chunks)
534
+ num_chunks = chunks.shape[0]
535
+ # samples_per_latent is just the downsampling ratio
536
+ samples_per_latent = self.downsampling_ratio
537
+ # Create an empty waveform, we will populate it with chunks as decode them
538
+ y_size = total_size * samples_per_latent
539
+ y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device)
540
+ for i in range(num_chunks):
541
+ x_chunk = chunks[i,:]
542
+ # decode the chunk
543
+ y_chunk = self.decode(x_chunk)
544
+ # figure out where to put the audio along the time domain
545
+ if i == num_chunks-1:
546
+ # final chunk always goes at the end
547
+ t_end = y_size
548
+ t_start = t_end - y_chunk.shape[2]
549
+ else:
550
+ t_start = i * hop_size * samples_per_latent
551
+ t_end = t_start + chunk_size * samples_per_latent
552
+ # remove the edges of the overlaps
553
+ ol = (overlap//2) * samples_per_latent
554
+ chunk_start = 0
555
+ chunk_end = y_chunk.shape[2]
556
+ if i > 0:
557
+ # no overlap for the start of the first chunk
558
+ t_start += ol
559
+ chunk_start += ol
560
+ if i < num_chunks-1:
561
+ # no overlap for the end of the last chunk
562
+ t_end -= ol
563
+ chunk_end -= ol
564
+ # paste the chunked audio into our y_final output audio
565
+ y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end]
566
+ return y_final
567
+
568
+
569
+ class DiffusionAutoencoder(AudioAutoencoder):
570
+ def __init__(
571
+ self,
572
+ diffusion: ConditionedDiffusionModel,
573
+ diffusion_downsampling_ratio,
574
+ *args,
575
+ **kwargs
576
+ ):
577
+ super().__init__(*args, **kwargs)
578
+
579
+ self.diffusion = diffusion
580
+
581
+ self.min_length = self.downsampling_ratio * diffusion_downsampling_ratio
582
+
583
+ if self.encoder is not None:
584
+ # Shrink the initial encoder parameters to avoid saturated latents
585
+ with torch.no_grad():
586
+ for param in self.encoder.parameters():
587
+ param *= 0.5
588
+
589
+ def decode(self, latents, steps=100):
590
+
591
+ upsampled_length = latents.shape[2] * self.downsampling_ratio
592
+
593
+ if self.bottleneck is not None:
594
+ latents = self.bottleneck.decode(latents)
595
+
596
+ if self.decoder is not None:
597
+ latents = self.decode(latents)
598
+
599
+ # Upsample latents to match diffusion length
600
+ if latents.shape[2] != upsampled_length:
601
+ latents = F.interpolate(latents, size=upsampled_length, mode='nearest')
602
+
603
+ noise = torch.randn(latents.shape[0], self.io_channels, upsampled_length, device=latents.device)
604
+ decoded = sample(self.diffusion, noise, steps, 0, input_concat_cond=latents)
605
+
606
+ if self.pretransform is not None:
607
+ if self.pretransform.enable_grad:
608
+ decoded = self.pretransform.decode(decoded)
609
+ else:
610
+ with torch.no_grad():
611
+ decoded = self.pretransform.decode(decoded)
612
+
613
+ return decoded
614
+
615
+ # AE factories
616
+
617
+ def create_encoder_from_config(encoder_config: Dict[str, Any]):
618
+ encoder_type = encoder_config.get("type", None)
619
+ assert encoder_type is not None, "Encoder type must be specified"
620
+
621
+ if encoder_type == "oobleck":
622
+ encoder = OobleckEncoder(
623
+ **encoder_config["config"]
624
+ )
625
+
626
+ elif encoder_type == "seanet":
627
+ from encodec.modules import SEANetEncoder
628
+ seanet_encoder_config = encoder_config["config"]
629
+
630
+ #SEANet encoder expects strides in reverse order
631
+ seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2])))
632
+ encoder = SEANetEncoder(
633
+ **seanet_encoder_config
634
+ )
635
+ elif encoder_type == "dac":
636
+ dac_config = encoder_config["config"]
637
+
638
+ encoder = DACEncoderWrapper(**dac_config)
639
+ elif encoder_type == "local_attn":
640
+ from .local_attention import TransformerEncoder1D
641
+
642
+ local_attn_config = encoder_config["config"]
643
+
644
+ encoder = TransformerEncoder1D(
645
+ **local_attn_config
646
+ )
647
+ else:
648
+ raise ValueError(f"Unknown encoder type {encoder_type}")
649
+
650
+ requires_grad = encoder_config.get("requires_grad", True)
651
+ if not requires_grad:
652
+ for param in encoder.parameters():
653
+ param.requires_grad = False
654
+
655
+ return encoder
656
+
657
+ def create_decoder_from_config(decoder_config: Dict[str, Any]):
658
+ decoder_type = decoder_config.get("type", None)
659
+ assert decoder_type is not None, "Decoder type must be specified"
660
+
661
+ if decoder_type == "oobleck":
662
+ decoder = OobleckDecoder(
663
+ **decoder_config["config"]
664
+ )
665
+ elif decoder_type == "seanet":
666
+ from encodec.modules import SEANetDecoder
667
+
668
+ decoder = SEANetDecoder(
669
+ **decoder_config["config"]
670
+ )
671
+ elif decoder_type == "dac":
672
+ dac_config = decoder_config["config"]
673
+
674
+ decoder = DACDecoderWrapper(**dac_config)
675
+ elif decoder_type == "local_attn":
676
+ from .local_attention import TransformerDecoder1D
677
+
678
+ local_attn_config = decoder_config["config"]
679
+
680
+ decoder = TransformerDecoder1D(
681
+ **local_attn_config
682
+ )
683
+ else:
684
+ raise ValueError(f"Unknown decoder type {decoder_type}")
685
+
686
+ requires_grad = decoder_config.get("requires_grad", True)
687
+ if not requires_grad:
688
+ for param in decoder.parameters():
689
+ param.requires_grad = False
690
+
691
+ return decoder
692
+
693
+ def create_autoencoder_from_config(config: Dict[str, Any]):
694
+
695
+ ae_config = config["model"]
696
+
697
+ encoder = create_encoder_from_config(ae_config["encoder"])
698
+ decoder = create_decoder_from_config(ae_config["decoder"])
699
+
700
+ bottleneck = ae_config.get("bottleneck", None)
701
+
702
+ latent_dim = ae_config.get("latent_dim", None)
703
+ assert latent_dim is not None, "latent_dim must be specified in model config"
704
+ downsampling_ratio = ae_config.get("downsampling_ratio", None)
705
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
706
+ io_channels = ae_config.get("io_channels", None)
707
+ assert io_channels is not None, "io_channels must be specified in model config"
708
+ sample_rate = config.get("sample_rate", None)
709
+ assert sample_rate is not None, "sample_rate must be specified in model config"
710
+
711
+ in_channels = ae_config.get("in_channels", None)
712
+ out_channels = ae_config.get("out_channels", None)
713
+
714
+ pretransform = ae_config.get("pretransform", None)
715
+
716
+ if pretransform is not None:
717
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
718
+
719
+ if bottleneck is not None:
720
+ bottleneck = create_bottleneck_from_config(bottleneck)
721
+
722
+ soft_clip = ae_config["decoder"].get("soft_clip", False)
723
+
724
+ return AudioAutoencoder(
725
+ encoder,
726
+ decoder,
727
+ io_channels=io_channels,
728
+ latent_dim=latent_dim,
729
+ downsampling_ratio=downsampling_ratio,
730
+ sample_rate=sample_rate,
731
+ bottleneck=bottleneck,
732
+ pretransform=pretransform,
733
+ in_channels=in_channels,
734
+ out_channels=out_channels,
735
+ soft_clip=soft_clip
736
+ )
737
+
738
+ def create_diffAE_from_config(config: Dict[str, Any]):
739
+
740
+ diffae_config = config["model"]
741
+
742
+ if "encoder" in diffae_config:
743
+ encoder = create_encoder_from_config(diffae_config["encoder"])
744
+ else:
745
+ encoder = None
746
+
747
+ if "decoder" in diffae_config:
748
+ decoder = create_decoder_from_config(diffae_config["decoder"])
749
+ else:
750
+ decoder = None
751
+
752
+ diffusion_model_type = diffae_config["diffusion"]["type"]
753
+
754
+ if diffusion_model_type == "DAU1d":
755
+ diffusion = DAU1DCondWrapper(**diffae_config["diffusion"]["config"])
756
+ elif diffusion_model_type == "adp_1d":
757
+ diffusion = UNet1DCondWrapper(**diffae_config["diffusion"]["config"])
758
+ elif diffusion_model_type == "dit":
759
+ diffusion = DiTWrapper(**diffae_config["diffusion"]["config"])
760
+
761
+ latent_dim = diffae_config.get("latent_dim", None)
762
+ assert latent_dim is not None, "latent_dim must be specified in model config"
763
+ downsampling_ratio = diffae_config.get("downsampling_ratio", None)
764
+ assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config"
765
+ io_channels = diffae_config.get("io_channels", None)
766
+ assert io_channels is not None, "io_channels must be specified in model config"
767
+ sample_rate = config.get("sample_rate", None)
768
+ assert sample_rate is not None, "sample_rate must be specified in model config"
769
+
770
+ bottleneck = diffae_config.get("bottleneck", None)
771
+
772
+ pretransform = diffae_config.get("pretransform", None)
773
+
774
+ if pretransform is not None:
775
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
776
+
777
+ if bottleneck is not None:
778
+ bottleneck = create_bottleneck_from_config(bottleneck)
779
+
780
+ diffusion_downsampling_ratio = None,
781
+
782
+ if diffusion_model_type == "DAU1d":
783
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["strides"])
784
+ elif diffusion_model_type == "adp_1d":
785
+ diffusion_downsampling_ratio = np.prod(diffae_config["diffusion"]["config"]["factors"])
786
+ elif diffusion_model_type == "dit":
787
+ diffusion_downsampling_ratio = 1
788
+
789
+ return DiffusionAutoencoder(
790
+ encoder=encoder,
791
+ decoder=decoder,
792
+ diffusion=diffusion,
793
+ io_channels=io_channels,
794
+ sample_rate=sample_rate,
795
+ latent_dim=latent_dim,
796
+ downsampling_ratio=downsampling_ratio,
797
+ diffusion_downsampling_ratio=diffusion_downsampling_ratio,
798
+ bottleneck=bottleneck,
799
+ pretransform=pretransform
800
+ )
ThinkSound/models/blocks.py ADDED
@@ -0,0 +1,430 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ from torch import nn
6
+ from torch.nn import functional as F
7
+
8
+ from torch.backends.cuda import sdp_kernel
9
+ from packaging import version
10
+
11
+ from dac.nn.layers import Snake1d
12
+
13
+ class ResidualBlock(nn.Module):
14
+ def __init__(self, main, skip=None):
15
+ super().__init__()
16
+ self.main = nn.Sequential(*main)
17
+ self.skip = skip if skip else nn.Identity()
18
+
19
+ def forward(self, input):
20
+ return self.main(input) + self.skip(input)
21
+
22
+ class ResConvBlock(ResidualBlock):
23
+ def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False):
24
+ skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False)
25
+ super().__init__([
26
+ nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias),
27
+ nn.GroupNorm(1, c_mid),
28
+ Snake1d(c_mid) if use_snake else nn.GELU(),
29
+ nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias),
30
+ nn.GroupNorm(1, c_out) if not is_last else nn.Identity(),
31
+ (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(),
32
+ ], skip)
33
+
34
+ class SelfAttention1d(nn.Module):
35
+ def __init__(self, c_in, n_head=1, dropout_rate=0.):
36
+ super().__init__()
37
+ assert c_in % n_head == 0
38
+ self.norm = nn.GroupNorm(1, c_in)
39
+ self.n_head = n_head
40
+ self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1)
41
+ self.out_proj = nn.Conv1d(c_in, c_in, 1)
42
+ self.dropout = nn.Dropout(dropout_rate, inplace=True)
43
+
44
+ self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
45
+
46
+ if not self.use_flash:
47
+ return
48
+
49
+ device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
50
+
51
+ if device_properties.major == 8 and device_properties.minor == 0:
52
+ # Use flash attention for A100 GPUs
53
+ self.sdp_kernel_config = (True, False, False)
54
+ else:
55
+ # Don't use flash attention for other GPUs
56
+ self.sdp_kernel_config = (False, True, True)
57
+
58
+ def forward(self, input):
59
+ n, c, s = input.shape
60
+ qkv = self.qkv_proj(self.norm(input))
61
+ qkv = qkv.view(
62
+ [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3)
63
+ q, k, v = qkv.chunk(3, dim=1)
64
+ scale = k.shape[3]**-0.25
65
+
66
+ if self.use_flash:
67
+ with sdp_kernel(*self.sdp_kernel_config):
68
+ y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s])
69
+ else:
70
+ att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3)
71
+ y = (att @ v).transpose(2, 3).contiguous().view([n, c, s])
72
+
73
+
74
+ return input + self.dropout(self.out_proj(y))
75
+
76
+ class SkipBlock(nn.Module):
77
+ def __init__(self, *main):
78
+ super().__init__()
79
+ self.main = nn.Sequential(*main)
80
+
81
+ def forward(self, input):
82
+ return torch.cat([self.main(input), input], dim=1)
83
+
84
+ class FourierFeatures(nn.Module):
85
+ def __init__(self, in_features, out_features, std=1.):
86
+ super().__init__()
87
+ assert out_features % 2 == 0
88
+ self.weight = nn.Parameter(torch.randn(
89
+ [out_features // 2, in_features]) * std)
90
+
91
+ def forward(self, input):
92
+ f = 2 * math.pi * input @ self.weight.T
93
+ return torch.cat([f.cos(), f.sin()], dim=-1)
94
+
95
+ def expand_to_planes(input, shape):
96
+ return input[..., None].repeat([1, 1, shape[2]])
97
+
98
+ _kernels = {
99
+ 'linear':
100
+ [1 / 8, 3 / 8, 3 / 8, 1 / 8],
101
+ 'cubic':
102
+ [-0.01171875, -0.03515625, 0.11328125, 0.43359375,
103
+ 0.43359375, 0.11328125, -0.03515625, -0.01171875],
104
+ 'lanczos3':
105
+ [0.003689131001010537, 0.015056144446134567, -0.03399861603975296,
106
+ -0.066637322306633, 0.13550527393817902, 0.44638532400131226,
107
+ 0.44638532400131226, 0.13550527393817902, -0.066637322306633,
108
+ -0.03399861603975296, 0.015056144446134567, 0.003689131001010537]
109
+ }
110
+
111
+ class Downsample1d(nn.Module):
112
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
113
+ super().__init__()
114
+ self.pad_mode = pad_mode
115
+ kernel_1d = torch.tensor(_kernels[kernel])
116
+ self.pad = kernel_1d.shape[0] // 2 - 1
117
+ self.register_buffer('kernel', kernel_1d)
118
+ self.channels_last = channels_last
119
+
120
+ def forward(self, x):
121
+ if self.channels_last:
122
+ x = x.permute(0, 2, 1)
123
+ x = F.pad(x, (self.pad,) * 2, self.pad_mode)
124
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
125
+ indices = torch.arange(x.shape[1], device=x.device)
126
+ weight[indices, indices] = self.kernel.to(weight)
127
+ x = F.conv1d(x, weight, stride=2)
128
+ if self.channels_last:
129
+ x = x.permute(0, 2, 1)
130
+ return x
131
+
132
+
133
+ class Upsample1d(nn.Module):
134
+ def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False):
135
+ super().__init__()
136
+ self.pad_mode = pad_mode
137
+ kernel_1d = torch.tensor(_kernels[kernel]) * 2
138
+ self.pad = kernel_1d.shape[0] // 2 - 1
139
+ self.register_buffer('kernel', kernel_1d)
140
+ self.channels_last = channels_last
141
+
142
+ def forward(self, x):
143
+ if self.channels_last:
144
+ x = x.permute(0, 2, 1)
145
+ x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode)
146
+ weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]])
147
+ indices = torch.arange(x.shape[1], device=x.device)
148
+ weight[indices, indices] = self.kernel.to(weight)
149
+ x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1)
150
+ if self.channels_last:
151
+ x = x.permute(0, 2, 1)
152
+ return x
153
+
154
+ def Downsample1d_2(
155
+ in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2
156
+ ) -> nn.Module:
157
+ assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even"
158
+
159
+ return nn.Conv1d(
160
+ in_channels=in_channels,
161
+ out_channels=out_channels,
162
+ kernel_size=factor * kernel_multiplier + 1,
163
+ stride=factor,
164
+ padding=factor * (kernel_multiplier // 2),
165
+ )
166
+
167
+
168
+ def Upsample1d_2(
169
+ in_channels: int, out_channels: int, factor: int, use_nearest: bool = False
170
+ ) -> nn.Module:
171
+
172
+ if factor == 1:
173
+ return nn.Conv1d(
174
+ in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1
175
+ )
176
+
177
+ if use_nearest:
178
+ return nn.Sequential(
179
+ nn.Upsample(scale_factor=factor, mode="nearest"),
180
+ nn.Conv1d(
181
+ in_channels=in_channels,
182
+ out_channels=out_channels,
183
+ kernel_size=3,
184
+ padding=1,
185
+ ),
186
+ )
187
+ else:
188
+ return nn.ConvTranspose1d(
189
+ in_channels=in_channels,
190
+ out_channels=out_channels,
191
+ kernel_size=factor * 2,
192
+ stride=factor,
193
+ padding=factor // 2 + factor % 2,
194
+ output_padding=factor % 2,
195
+ )
196
+
197
+ def zero_init(layer):
198
+ nn.init.zeros_(layer.weight)
199
+ if layer.bias is not None:
200
+ nn.init.zeros_(layer.bias)
201
+ return layer
202
+
203
+ def rms_norm(x, scale, eps):
204
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
205
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
206
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
207
+ return x * scale.to(x.dtype)
208
+
209
+ #rms_norm = torch.compile(rms_norm)
210
+
211
+ class AdaRMSNorm(nn.Module):
212
+ def __init__(self, features, cond_features, eps=1e-6):
213
+ super().__init__()
214
+ self.eps = eps
215
+ self.linear = zero_init(nn.Linear(cond_features, features, bias=False))
216
+
217
+ def extra_repr(self):
218
+ return f"eps={self.eps},"
219
+
220
+ def forward(self, x, cond):
221
+ return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps)
222
+
223
+ def normalize(x, eps=1e-4):
224
+ dim = list(range(1, x.ndim))
225
+ n = torch.linalg.vector_norm(x, dim=dim, keepdim=True)
226
+ alpha = np.sqrt(n.numel() / x.numel())
227
+ return x / torch.add(eps, n, alpha=alpha)
228
+
229
+ class ForcedWNConv1d(nn.Module):
230
+ def __init__(self, in_channels, out_channels, kernel_size=1):
231
+ super().__init__()
232
+ self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size]))
233
+
234
+ def forward(self, x):
235
+ if self.training:
236
+ with torch.no_grad():
237
+ self.weight.copy_(normalize(self.weight))
238
+
239
+ fan_in = self.weight[0].numel()
240
+
241
+ w = normalize(self.weight) / math.sqrt(fan_in)
242
+
243
+ return F.conv1d(x, w, padding='same')
244
+
245
+ # Kernels
246
+
247
+ use_compile = True
248
+
249
+ def compile(function, *args, **kwargs):
250
+ if not use_compile:
251
+ return function
252
+ try:
253
+ return torch.compile(function, *args, **kwargs)
254
+ except RuntimeError:
255
+ return function
256
+
257
+
258
+ @compile
259
+ def linear_geglu(x, weight, bias=None):
260
+ x = x @ weight.mT
261
+ if bias is not None:
262
+ x = x + bias
263
+ x, gate = x.chunk(2, dim=-1)
264
+ return x * F.gelu(gate)
265
+
266
+
267
+ @compile
268
+ def rms_norm(x, scale, eps):
269
+ dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32))
270
+ mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
271
+ scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
272
+ return x * scale.to(x.dtype)
273
+
274
+ # Layers
275
+
276
+ class LinearGEGLU(nn.Linear):
277
+ def __init__(self, in_features, out_features, bias=True):
278
+ super().__init__(in_features, out_features * 2, bias=bias)
279
+ self.out_features = out_features
280
+
281
+ def forward(self, x):
282
+ return linear_geglu(x, self.weight, self.bias)
283
+
284
+
285
+ class RMSNorm(nn.Module):
286
+ def __init__(self, shape, fix_scale = False, eps=1e-6):
287
+ super().__init__()
288
+ self.eps = eps
289
+
290
+ if fix_scale:
291
+ self.register_buffer("scale", torch.ones(shape))
292
+ else:
293
+ self.scale = nn.Parameter(torch.ones(shape))
294
+
295
+ def extra_repr(self):
296
+ return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
297
+
298
+ def forward(self, x):
299
+ return rms_norm(x, self.scale, self.eps)
300
+
301
+ def snake_beta(x, alpha, beta):
302
+ return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2)
303
+
304
+ # try:
305
+ # snake_beta = torch.compile(snake_beta)
306
+ # except RuntimeError:
307
+ # pass
308
+
309
+ # Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license
310
+ # License available in LICENSES/LICENSE_NVIDIA.txt
311
+ class SnakeBeta(nn.Module):
312
+
313
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True):
314
+ super(SnakeBeta, self).__init__()
315
+ self.in_features = in_features
316
+
317
+ # initialize alpha
318
+ self.alpha_logscale = alpha_logscale
319
+ if self.alpha_logscale: # log scale alphas initialized to zeros
320
+ self.alpha = nn.Parameter(torch.zeros(in_features) * alpha)
321
+ self.beta = nn.Parameter(torch.zeros(in_features) * alpha)
322
+ else: # linear scale alphas initialized to ones
323
+ self.alpha = nn.Parameter(torch.ones(in_features) * alpha)
324
+ self.beta = nn.Parameter(torch.ones(in_features) * alpha)
325
+
326
+ self.alpha.requires_grad = alpha_trainable
327
+ self.beta.requires_grad = alpha_trainable
328
+
329
+ self.no_div_by_zero = 0.000000001
330
+
331
+ def forward(self, x):
332
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
333
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
334
+ if self.alpha_logscale:
335
+ alpha = torch.exp(alpha)
336
+ beta = torch.exp(beta)
337
+ x = snake_beta(x, alpha, beta)
338
+
339
+ return x
340
+
341
+ class ChannelLastConv1d(nn.Conv1d):
342
+
343
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
344
+ x = x.permute(0, 2, 1)
345
+ x = super().forward(x)
346
+ x = x.permute(0, 2, 1)
347
+ return x
348
+
349
+
350
+ # https://github.com/Stability-AI/sd3-ref
351
+ class MLP(nn.Module):
352
+
353
+ def __init__(
354
+ self,
355
+ dim: int,
356
+ hidden_dim: int,
357
+ multiple_of: int = 256,
358
+ ):
359
+ """
360
+ Initialize the FeedForward module.
361
+
362
+ Args:
363
+ dim (int): Input dimension.
364
+ hidden_dim (int): Hidden dimension of the feedforward layer.
365
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
366
+
367
+ Attributes:
368
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
369
+ w2 (RowParallelLinear): Linear transformation for the second layer.
370
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
371
+
372
+ """
373
+ super().__init__()
374
+ hidden_dim = int(2 * hidden_dim / 3)
375
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
376
+
377
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
378
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
379
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
380
+
381
+ def forward(self, x):
382
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
383
+
384
+
385
+ class ConvMLP(nn.Module):
386
+
387
+ def __init__(
388
+ self,
389
+ dim: int,
390
+ hidden_dim: int,
391
+ multiple_of: int = 256,
392
+ kernel_size: int = 3,
393
+ padding: int = 1,
394
+ ):
395
+ """
396
+ Initialize the FeedForward module.
397
+
398
+ Args:
399
+ dim (int): Input dimension.
400
+ hidden_dim (int): Hidden dimension of the feedforward layer.
401
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
402
+
403
+ Attributes:
404
+ w1 (ColumnParallelLinear): Linear transformation for the first layer.
405
+ w2 (RowParallelLinear): Linear transformation for the second layer.
406
+ w3 (ColumnParallelLinear): Linear transformation for the third layer.
407
+
408
+ """
409
+ super().__init__()
410
+ hidden_dim = int(2 * hidden_dim / 3)
411
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
412
+
413
+ self.w1 = ChannelLastConv1d(dim,
414
+ hidden_dim,
415
+ bias=False,
416
+ kernel_size=kernel_size,
417
+ padding=padding)
418
+ self.w2 = ChannelLastConv1d(hidden_dim,
419
+ dim,
420
+ bias=False,
421
+ kernel_size=kernel_size,
422
+ padding=padding)
423
+ self.w3 = ChannelLastConv1d(dim,
424
+ hidden_dim,
425
+ bias=False,
426
+ kernel_size=kernel_size,
427
+ padding=padding)
428
+
429
+ def forward(self, x):
430
+ return self.w2(F.silu(self.w1(x)) * self.w3(x))
ThinkSound/models/bottleneck.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from einops import rearrange
7
+ from vector_quantize_pytorch import ResidualVQ, FSQ
8
+ from dac.nn.quantize import ResidualVectorQuantize as DACResidualVQ
9
+
10
+ class Bottleneck(nn.Module):
11
+ def __init__(self, is_discrete: bool = False):
12
+ super().__init__()
13
+
14
+ self.is_discrete = is_discrete
15
+
16
+ def encode(self, x, return_info=False, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, x):
20
+ raise NotImplementedError
21
+
22
+ class DiscreteBottleneck(Bottleneck):
23
+ def __init__(self, num_quantizers, codebook_size, tokens_id):
24
+ super().__init__(is_discrete=True)
25
+
26
+ self.num_quantizers = num_quantizers
27
+ self.codebook_size = codebook_size
28
+ self.tokens_id = tokens_id
29
+
30
+ def decode_tokens(self, codes, **kwargs):
31
+ raise NotImplementedError
32
+
33
+ class TanhBottleneck(Bottleneck):
34
+ def __init__(self):
35
+ super().__init__(is_discrete=False)
36
+ self.tanh = nn.Tanh()
37
+
38
+ def encode(self, x, return_info=False):
39
+ info = {}
40
+
41
+ x = torch.tanh(x)
42
+
43
+ if return_info:
44
+ return x, info
45
+ else:
46
+ return x
47
+
48
+ def decode(self, x):
49
+ return x
50
+
51
+ def vae_sample(mean, scale):
52
+ stdev = nn.functional.softplus(scale) + 1e-4
53
+ var = stdev * stdev
54
+ logvar = torch.log(var)
55
+ latents = torch.randn_like(mean) * stdev + mean
56
+
57
+ kl = (mean * mean + var - logvar - 1).sum(1).mean()
58
+
59
+ return latents, kl
60
+
61
+ class VAEBottleneck(Bottleneck):
62
+ def __init__(self):
63
+ super().__init__(is_discrete=False)
64
+
65
+ def encode(self, x, return_info=False, **kwargs):
66
+ info = {}
67
+
68
+ mean, scale = x.chunk(2, dim=1)
69
+
70
+ x, kl = vae_sample(mean, scale)
71
+
72
+ info["kl"] = kl
73
+
74
+ if return_info:
75
+ return x, info
76
+ else:
77
+ return x
78
+
79
+ def decode(self, x):
80
+ return x
81
+
82
+ def compute_mean_kernel(x, y):
83
+ kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1]
84
+ return torch.exp(-kernel_input).mean()
85
+
86
+ def compute_mmd(latents):
87
+ latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1])
88
+ noise = torch.randn_like(latents_reshaped)
89
+
90
+ latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped)
91
+ noise_kernel = compute_mean_kernel(noise, noise)
92
+ latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise)
93
+
94
+ mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel
95
+ return mmd.mean()
96
+
97
+ class WassersteinBottleneck(Bottleneck):
98
+ def __init__(self, noise_augment_dim: int = 0, bypass_mmd: bool = False):
99
+ super().__init__(is_discrete=False)
100
+
101
+ self.noise_augment_dim = noise_augment_dim
102
+ self.bypass_mmd = bypass_mmd
103
+
104
+ def encode(self, x, return_info=False):
105
+ info = {}
106
+
107
+ if self.training and return_info:
108
+ if self.bypass_mmd:
109
+ mmd = torch.tensor(0.0)
110
+ else:
111
+ mmd = compute_mmd(x)
112
+
113
+ info["mmd"] = mmd
114
+
115
+ if return_info:
116
+ return x, info
117
+
118
+ return x
119
+
120
+ def decode(self, x):
121
+
122
+ if self.noise_augment_dim > 0:
123
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
124
+ x.shape[-1]).type_as(x)
125
+ x = torch.cat([x, noise], dim=1)
126
+
127
+ return x
128
+
129
+ class L2Bottleneck(Bottleneck):
130
+ def __init__(self):
131
+ super().__init__(is_discrete=False)
132
+
133
+ def encode(self, x, return_info=False):
134
+ info = {}
135
+
136
+ x = F.normalize(x, dim=1)
137
+
138
+ if return_info:
139
+ return x, info
140
+ else:
141
+ return x
142
+
143
+ def decode(self, x):
144
+ return F.normalize(x, dim=1)
145
+
146
+ class RVQBottleneck(DiscreteBottleneck):
147
+ def __init__(self, **quantizer_kwargs):
148
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
149
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
150
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
151
+
152
+ def encode(self, x, return_info=False, **kwargs):
153
+ info = {}
154
+
155
+ x = rearrange(x, "b c n -> b n c")
156
+ x, indices, loss = self.quantizer(x)
157
+ x = rearrange(x, "b n c -> b c n")
158
+
159
+ info["quantizer_indices"] = indices
160
+ info["quantizer_loss"] = loss.mean()
161
+
162
+ if return_info:
163
+ return x, info
164
+ else:
165
+ return x
166
+
167
+ def decode(self, x):
168
+ return x
169
+
170
+ def decode_tokens(self, codes, **kwargs):
171
+ latents = self.quantizer.get_outputs_from_indices(codes)
172
+
173
+ return self.decode(latents, **kwargs)
174
+
175
+ class RVQVAEBottleneck(DiscreteBottleneck):
176
+ def __init__(self, **quantizer_kwargs):
177
+ super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices")
178
+ self.quantizer = ResidualVQ(**quantizer_kwargs)
179
+ self.num_quantizers = quantizer_kwargs["num_quantizers"]
180
+
181
+ def encode(self, x, return_info=False):
182
+ info = {}
183
+
184
+ x, kl = vae_sample(*x.chunk(2, dim=1))
185
+
186
+ info["kl"] = kl
187
+
188
+ x = rearrange(x, "b c n -> b n c")
189
+ x, indices, loss = self.quantizer(x)
190
+ x = rearrange(x, "b n c -> b c n")
191
+
192
+ info["quantizer_indices"] = indices
193
+ info["quantizer_loss"] = loss.mean()
194
+
195
+ if return_info:
196
+ return x, info
197
+ else:
198
+ return x
199
+
200
+ def decode(self, x):
201
+ return x
202
+
203
+ def decode_tokens(self, codes, **kwargs):
204
+ latents = self.quantizer.get_outputs_from_indices(codes)
205
+
206
+ return self.decode(latents, **kwargs)
207
+
208
+ class DACRVQBottleneck(DiscreteBottleneck):
209
+ def __init__(self, quantize_on_decode=False, noise_augment_dim=0, **quantizer_kwargs):
210
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
211
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
212
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
213
+ self.quantize_on_decode = quantize_on_decode
214
+ self.noise_augment_dim = noise_augment_dim
215
+
216
+ def encode(self, x, return_info=False, **kwargs):
217
+ info = {}
218
+
219
+ info["pre_quantizer"] = x
220
+
221
+ if self.quantize_on_decode:
222
+ return x, info if return_info else x
223
+
224
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs)
225
+
226
+ output = {
227
+ "z": z,
228
+ "codes": codes,
229
+ "latents": latents,
230
+ "vq/commitment_loss": commitment_loss,
231
+ "vq/codebook_loss": codebook_loss,
232
+ }
233
+
234
+ output["vq/commitment_loss"] /= self.num_quantizers
235
+ output["vq/codebook_loss"] /= self.num_quantizers
236
+
237
+ info.update(output)
238
+
239
+ if return_info:
240
+ return output["z"], info
241
+
242
+ return output["z"]
243
+
244
+ def decode(self, x):
245
+
246
+ if self.quantize_on_decode:
247
+ x = self.quantizer(x)[0]
248
+
249
+ if self.noise_augment_dim > 0:
250
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
251
+ x.shape[-1]).type_as(x)
252
+ x = torch.cat([x, noise], dim=1)
253
+
254
+ return x
255
+
256
+ def decode_tokens(self, codes, **kwargs):
257
+ latents, _, _ = self.quantizer.from_codes(codes)
258
+
259
+ return self.decode(latents, **kwargs)
260
+
261
+ class DACRVQVAEBottleneck(DiscreteBottleneck):
262
+ def __init__(self, quantize_on_decode=False, **quantizer_kwargs):
263
+ super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes")
264
+ self.quantizer = DACResidualVQ(**quantizer_kwargs)
265
+ self.num_quantizers = quantizer_kwargs["n_codebooks"]
266
+ self.quantize_on_decode = quantize_on_decode
267
+
268
+ def encode(self, x, return_info=False, n_quantizers: int = None):
269
+ info = {}
270
+
271
+ mean, scale = x.chunk(2, dim=1)
272
+
273
+ x, kl = vae_sample(mean, scale)
274
+
275
+ info["pre_quantizer"] = x
276
+ info["kl"] = kl
277
+
278
+ if self.quantize_on_decode:
279
+ return x, info if return_info else x
280
+
281
+ z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers)
282
+
283
+ output = {
284
+ "z": z,
285
+ "codes": codes,
286
+ "latents": latents,
287
+ "vq/commitment_loss": commitment_loss,
288
+ "vq/codebook_loss": codebook_loss,
289
+ }
290
+
291
+ output["vq/commitment_loss"] /= self.num_quantizers
292
+ output["vq/codebook_loss"] /= self.num_quantizers
293
+
294
+ info.update(output)
295
+
296
+ if return_info:
297
+ return output["z"], info
298
+
299
+ return output["z"]
300
+
301
+ def decode(self, x):
302
+
303
+ if self.quantize_on_decode:
304
+ x = self.quantizer(x)[0]
305
+
306
+ return x
307
+
308
+ def decode_tokens(self, codes, **kwargs):
309
+ latents, _, _ = self.quantizer.from_codes(codes)
310
+
311
+ return self.decode(latents, **kwargs)
312
+
313
+ class FSQBottleneck(DiscreteBottleneck):
314
+ def __init__(self, noise_augment_dim=0, **kwargs):
315
+ super().__init__(num_quantizers = kwargs.get("num_codebooks", 1), codebook_size = np.prod(kwargs["levels"]), tokens_id = "quantizer_indices")
316
+
317
+ self.noise_augment_dim = noise_augment_dim
318
+
319
+ self.quantizer = FSQ(**kwargs, allowed_dtypes=[torch.float16, torch.float32, torch.float64])
320
+
321
+ def encode(self, x, return_info=False):
322
+ info = {}
323
+
324
+ orig_dtype = x.dtype
325
+ x = x.float()
326
+
327
+ x = rearrange(x, "b c n -> b n c")
328
+ x, indices = self.quantizer(x)
329
+ x = rearrange(x, "b n c -> b c n")
330
+
331
+ x = x.to(orig_dtype)
332
+
333
+ # Reorder indices to match the expected format
334
+ indices = rearrange(indices, "b n q -> b q n")
335
+
336
+ info["quantizer_indices"] = indices
337
+
338
+ if return_info:
339
+ return x, info
340
+ else:
341
+ return x
342
+
343
+ def decode(self, x):
344
+
345
+ if self.noise_augment_dim > 0:
346
+ noise = torch.randn(x.shape[0], self.noise_augment_dim,
347
+ x.shape[-1]).type_as(x)
348
+ x = torch.cat([x, noise], dim=1)
349
+
350
+ return x
351
+
352
+ def decode_tokens(self, tokens, **kwargs):
353
+ latents = self.quantizer.indices_to_codes(tokens)
354
+
355
+ return self.decode(latents, **kwargs)
ThinkSound/models/codebook_patterns.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/codebooks_patterns.py under MIT License
2
+ # License available in LICENSES/LICENSE_META.txt
3
+
4
+ from collections import namedtuple
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ import logging
8
+ import typing as tp
9
+
10
+ from abc import ABC, abstractmethod
11
+ import torch
12
+
13
+ LayoutCoord = namedtuple('LayoutCoord', ['t', 'q']) # (timestep, codebook index)
14
+ PatternLayout = tp.List[tp.List[LayoutCoord]] # Sequence of coordinates
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ @dataclass
19
+ class Pattern:
20
+ """Base implementation of a pattern over a sequence with multiple codebooks.
21
+
22
+ The codebook pattern consists in a layout, defining for each sequence step
23
+ the list of coordinates of each codebook timestep in the resulting interleaved sequence.
24
+ The first item of the pattern is always an empty list in order to properly insert a special token
25
+ to start with. For convenience, we also keep track of ``n_q`` the number of codebooks used for the pattern
26
+ and ``timesteps`` the number of timesteps corresponding to the original sequence.
27
+
28
+ The pattern provides convenient methods to build and revert interleaved sequences from it:
29
+ ``build_pattern_sequence`` maps a given a dense input tensor of multi-codebook sequence from [B, K, T]
30
+ to the interleaved sequence of shape [B, K, S] applying the pattern, with B being the batch size,
31
+ K being the number of codebooks, T the number of original timesteps and S the number of sequence steps
32
+ for the output sequence. The unfilled positions are replaced with a special token and the built sequence
33
+ is returned along with a mask indicating valid tokens.
34
+ ``revert_pattern_sequence`` maps back an interleaved sequence of shape [B, K, S] to the original alignment
35
+ of codebooks across timesteps to an output tensor of shape [B, K, T], using again a special token and a mask
36
+ to fill and specify invalid positions if needed.
37
+ See the dedicated methods for more details.
38
+ """
39
+ # Pattern layout, for each sequence step, we have a list of coordinates
40
+ # corresponding to the original codebook timestep and position.
41
+ # The first list is always an empty list in order to properly insert
42
+ # a special token to start with.
43
+ layout: PatternLayout
44
+ timesteps: int
45
+ n_q: int
46
+
47
+ def __post_init__(self):
48
+ assert len(self.layout) > 0
49
+ self._validate_layout()
50
+ self._build_reverted_sequence_scatter_indexes = lru_cache(100)(self._build_reverted_sequence_scatter_indexes)
51
+ self._build_pattern_sequence_scatter_indexes = lru_cache(100)(self._build_pattern_sequence_scatter_indexes)
52
+ logger.info("New pattern, time steps: %d, sequence steps: %d", self.timesteps, len(self.layout))
53
+
54
+ def _validate_layout(self):
55
+ """Runs checks on the layout to ensure a valid pattern is defined.
56
+ A pattern is considered invalid if:
57
+ - Multiple timesteps for a same codebook are defined in the same sequence step
58
+ - The timesteps for a given codebook are not in ascending order as we advance in the sequence
59
+ (this would mean that we have future timesteps before past timesteps).
60
+ """
61
+ q_timesteps = {q: 0 for q in range(self.n_q)}
62
+ for s, seq_coords in enumerate(self.layout):
63
+ if len(seq_coords) > 0:
64
+ qs = set()
65
+ for coord in seq_coords:
66
+ qs.add(coord.q)
67
+ last_q_timestep = q_timesteps[coord.q]
68
+ assert coord.t >= last_q_timestep, \
69
+ f"Past timesteps are found in the sequence for codebook = {coord.q} at step {s}"
70
+ q_timesteps[coord.q] = coord.t
71
+ # each sequence step contains at max 1 coordinate per codebook
72
+ assert len(qs) == len(seq_coords), \
73
+ f"Multiple entries for a same codebook are found at step {s}"
74
+
75
+ @property
76
+ def num_sequence_steps(self):
77
+ return len(self.layout) - 1
78
+
79
+ @property
80
+ def max_delay(self):
81
+ max_t_in_seq_coords = 0
82
+ for seq_coords in self.layout[1:]:
83
+ for coords in seq_coords:
84
+ max_t_in_seq_coords = max(max_t_in_seq_coords, coords.t + 1)
85
+ return max_t_in_seq_coords - self.timesteps
86
+
87
+ @property
88
+ def valid_layout(self):
89
+ valid_step = len(self.layout) - self.max_delay
90
+ return self.layout[:valid_step]
91
+
92
+ def starts_with_special_token(self):
93
+ return self.layout[0] == []
94
+
95
+ def get_sequence_coords_with_timestep(self, t: int, q: tp.Optional[int] = None):
96
+ """Get codebook coordinates in the layout that corresponds to the specified timestep t
97
+ and optionally to the codebook q. Coordinates are returned as a tuple with the sequence step
98
+ and the actual codebook coordinates.
99
+ """
100
+ assert t <= self.timesteps, "provided timesteps is greater than the pattern's number of timesteps"
101
+ if q is not None:
102
+ assert q <= self.n_q, "provided number of codebooks is greater than the pattern's number of codebooks"
103
+ coords = []
104
+ for s, seq_codes in enumerate(self.layout):
105
+ for code in seq_codes:
106
+ if code.t == t and (q is None or code.q == q):
107
+ coords.append((s, code))
108
+ return coords
109
+
110
+ def get_steps_with_timestep(self, t: int, q: tp.Optional[int] = None) -> tp.List[int]:
111
+ return [step for step, coords in self.get_sequence_coords_with_timestep(t, q)]
112
+
113
+ def get_first_step_with_timesteps(self, t: int, q: tp.Optional[int] = None) -> tp.Optional[int]:
114
+ steps_with_timesteps = self.get_steps_with_timestep(t, q)
115
+ return steps_with_timesteps[0] if len(steps_with_timesteps) > 0 else None
116
+
117
+ def _build_pattern_sequence_scatter_indexes(self, timesteps: int, n_q: int, keep_only_valid_steps: bool,
118
+ device: tp.Union[torch.device, str] = 'cpu'):
119
+ """Build scatter indexes corresponding to the pattern, up to the provided sequence_steps.
120
+
121
+ Args:
122
+ timesteps (int): Maximum number of timesteps steps to consider.
123
+ keep_only_valid_steps (bool): Restrict the pattern layout to match only valid steps.
124
+ device (torch.device or str): Device for created tensors.
125
+ Returns:
126
+ indexes (torch.Tensor): Indexes corresponding to the sequence, of shape [K, S].
127
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes, of shape [K, S].
128
+ """
129
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
130
+ assert timesteps <= self.timesteps, "invalid number of timesteps used to build the sequence from the pattern"
131
+ # use the proper layout based on whether we limit ourselves to valid steps only or not,
132
+ # note that using the valid_layout will result in a truncated sequence up to the valid steps
133
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
134
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
135
+ indexes = torch.zeros(n_q, len(ref_layout), dtype=torch.long).numpy()
136
+ mask = torch.zeros(n_q, len(ref_layout), dtype=torch.bool).numpy()
137
+ # fill indexes with last sequence step value that will correspond to our special token
138
+ # the last value is n_q * timesteps as we have flattened z and append special token as the last token
139
+ # which will correspond to the index: n_q * timesteps
140
+ indexes[:] = n_q * timesteps
141
+ # iterate over the pattern and fill scattered indexes and mask
142
+ for s, sequence_coords in enumerate(ref_layout):
143
+ for coords in sequence_coords:
144
+ if coords.t < timesteps:
145
+ indexes[coords.q, s] = coords.t + coords.q * timesteps
146
+ mask[coords.q, s] = 1
147
+ indexes = torch.from_numpy(indexes).to(device)
148
+ mask = torch.from_numpy(mask).to(device)
149
+ return indexes, mask
150
+
151
+ def build_pattern_sequence(self, z: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
152
+ """Build sequence corresponding to the pattern from the input tensor z.
153
+ The sequence is built using up to sequence_steps if specified, and non-pattern
154
+ coordinates are filled with the special token.
155
+
156
+ Args:
157
+ z (torch.Tensor): Input tensor of multi-codebooks sequence, of shape [B, K, T].
158
+ special_token (int): Special token used to fill non-pattern coordinates in the new sequence.
159
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
160
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
161
+ Returns:
162
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, S] with S
163
+ corresponding either to the sequence_steps if provided, otherwise to the length of the pattern.
164
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, S].
165
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, S].
166
+ """
167
+ B, K, T = z.shape
168
+ indexes, mask = self._build_pattern_sequence_scatter_indexes(
169
+ T, K, keep_only_valid_steps=keep_only_valid_steps, device=str(z.device)
170
+ )
171
+ z = z.view(B, -1)
172
+ # we append the special token as the last index of our flattened z tensor
173
+ z = torch.cat([z, torch.zeros_like(z[:, :1]) + special_token], dim=1)
174
+ values = z[:, indexes.view(-1)]
175
+ values = values.view(B, K, indexes.shape[-1])
176
+ return values, indexes, mask
177
+
178
+ def _build_reverted_sequence_scatter_indexes(self, sequence_steps: int, n_q: int,
179
+ keep_only_valid_steps: bool = False,
180
+ is_model_output: bool = False,
181
+ device: tp.Union[torch.device, str] = 'cpu'):
182
+ """Builds scatter indexes required to retrieve the original multi-codebook sequence
183
+ from interleaving pattern.
184
+
185
+ Args:
186
+ sequence_steps (int): Sequence steps.
187
+ n_q (int): Number of codebooks.
188
+ keep_only_valid_steps (bool): Build a sequence from the pattern up to valid (= fully defined) steps.
189
+ Steps that are beyond valid steps will be replaced by the special_token in that case.
190
+ is_model_output (bool): Whether to keep the sequence item corresponding to initial special token or not.
191
+ device (torch.device or str): Device for created tensors.
192
+ Returns:
193
+ indexes (torch.Tensor): Indexes for reconstructing the output, of shape [K, T].
194
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
195
+ """
196
+ ref_layout = self.valid_layout if keep_only_valid_steps else self.layout
197
+ # TODO(jade): Do we want to further truncate to only valid timesteps here as well?
198
+ timesteps = self.timesteps
199
+ assert n_q == self.n_q, f"invalid number of codebooks for the sequence and the pattern: {n_q} != {self.n_q}"
200
+ assert sequence_steps <= len(ref_layout), \
201
+ f"sequence to revert is longer than the defined pattern: {sequence_steps} > {len(ref_layout)}"
202
+
203
+ # ensure we take the appropriate indexes to keep the model output from the first special token as well
204
+ if is_model_output and self.starts_with_special_token():
205
+ ref_layout = ref_layout[1:]
206
+
207
+ # single item indexing being super slow with pytorch vs. numpy, so we use numpy here
208
+ indexes = torch.zeros(n_q, timesteps, dtype=torch.long).numpy()
209
+ mask = torch.zeros(n_q, timesteps, dtype=torch.bool).numpy()
210
+ # fill indexes with last sequence step value that will correspond to our special token
211
+ indexes[:] = n_q * sequence_steps
212
+ for s, sequence_codes in enumerate(ref_layout):
213
+ if s < sequence_steps:
214
+ for code in sequence_codes:
215
+ if code.t < timesteps:
216
+ indexes[code.q, code.t] = s + code.q * sequence_steps
217
+ mask[code.q, code.t] = 1
218
+ indexes = torch.from_numpy(indexes).to(device)
219
+ mask = torch.from_numpy(mask).to(device)
220
+ return indexes, mask
221
+
222
+ def revert_pattern_sequence(self, s: torch.Tensor, special_token: int, keep_only_valid_steps: bool = False):
223
+ """Revert a sequence built from the pattern back to the original multi-codebook sequence without interleaving.
224
+ The sequence is reverted using up to timesteps if specified, and non-pattern coordinates
225
+ are filled with the special token.
226
+
227
+ Args:
228
+ s (torch.Tensor): Interleaved sequence tensor obtained from the pattern, of shape [B, K, S].
229
+ special_token (int or float): Special token used to fill non-pattern coordinates in the new sequence.
230
+ Returns:
231
+ values (torch.Tensor): Interleaved sequence matching the pattern, of shape [B, K, T] with T
232
+ corresponding either to the timesteps if provided, or the total timesteps in pattern otherwise.
233
+ indexes (torch.Tensor): Indexes corresponding to the interleaved sequence, of shape [K, T].
234
+ mask (torch.Tensor): Mask corresponding to indexes that matches valid indexes of shape [K, T].
235
+ """
236
+ B, K, S = s.shape
237
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
238
+ S, K, keep_only_valid_steps, is_model_output=False, device=str(s.device)
239
+ )
240
+ s = s.view(B, -1)
241
+ # we append the special token as the last index of our flattened z tensor
242
+ s = torch.cat([s, torch.zeros_like(s[:, :1]) + special_token], dim=1)
243
+ values = s[:, indexes.view(-1)]
244
+ values = values.view(B, K, indexes.shape[-1])
245
+ return values, indexes, mask
246
+
247
+ def revert_pattern_logits(self, logits: torch.Tensor, special_token: float, keep_only_valid_steps: bool = False):
248
+ """Revert model logits obtained on a sequence built from the pattern
249
+ back to a tensor matching the original sequence.
250
+
251
+ This method is similar to ``revert_pattern_sequence`` with the following specificities:
252
+ 1. It is designed to work with the extra cardinality dimension
253
+ 2. We return the logits for the first sequence item that matches the special_token and
254
+ which matching target in the original sequence is the first item of the sequence,
255
+ while we skip the last logits as there is no matching target
256
+ """
257
+ B, card, K, S = logits.shape
258
+ indexes, mask = self._build_reverted_sequence_scatter_indexes(
259
+ S, K, keep_only_valid_steps, is_model_output=True, device=logits.device
260
+ )
261
+ logits = logits.reshape(B, card, -1)
262
+ # we append the special token as the last index of our flattened z tensor
263
+ logits = torch.cat([logits, torch.zeros_like(logits[:, :, :1]) + special_token], dim=-1) # [B, card, K x S]
264
+ values = logits[:, :, indexes.view(-1)]
265
+ values = values.view(B, card, K, indexes.shape[-1])
266
+ return values, indexes, mask
267
+
268
+
269
+ class CodebooksPatternProvider(ABC):
270
+ """Abstraction around providing pattern for interleaving codebooks.
271
+
272
+ The CodebooksPatternProvider abstraction allows to implement various strategies to
273
+ define interleaving pattern of sequences composed of multiple codebooks. For a given
274
+ number of codebooks `n_q`, the pattern provider can generate a specified pattern
275
+ corresponding to a sequence of `T` timesteps with `n_q` parallel codebooks. This pattern
276
+ can be used to construct a new sequence from the original codes respecting the specified
277
+ pattern. The pattern is defined as a list of list of code coordinates, code coordinate
278
+ being a tuple with the original timestep and codebook to build the new sequence.
279
+ Note that all patterns must start with an empty list that is then used to insert a first
280
+ sequence step of special tokens in the newly generated sequence.
281
+
282
+ Args:
283
+ n_q (int): number of codebooks.
284
+ cached (bool): if True, patterns for a given length are cached. In general
285
+ that should be true for efficiency reason to avoid synchronization points.
286
+ """
287
+ def __init__(self, n_q: int, cached: bool = True):
288
+ assert n_q > 0
289
+ self.n_q = n_q
290
+ self.get_pattern = lru_cache(100)(self.get_pattern) # type: ignore
291
+
292
+ @abstractmethod
293
+ def get_pattern(self, timesteps: int) -> Pattern:
294
+ """Builds pattern with specific interleaving between codebooks.
295
+
296
+ Args:
297
+ timesteps (int): Total number of timesteps.
298
+ """
299
+ raise NotImplementedError()
300
+
301
+
302
+ class DelayedPatternProvider(CodebooksPatternProvider):
303
+ """Provider for delayed pattern across delayed codebooks.
304
+ Codebooks are delayed in the sequence and sequence steps will contain codebooks
305
+ from different timesteps.
306
+
307
+ Example:
308
+ Taking timesteps=4 and n_q=3, delays=None, the multi-codebook sequence:
309
+ [[1, 2, 3, 4],
310
+ [1, 2, 3, 4],
311
+ [1, 2, 3, 4]]
312
+ The resulting sequence obtained from the returned pattern is:
313
+ [[S, 1, 2, 3, 4],
314
+ [S, S, 1, 2, 3],
315
+ [S, S, S, 1, 2]]
316
+ (with S being a special token)
317
+
318
+ Args:
319
+ n_q (int): Number of codebooks.
320
+ delays (list of int, optional): Delay for each of the codebooks.
321
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
322
+ flatten_first (int): Flatten the first N timesteps.
323
+ empty_initial (int): Prepend with N empty list of coordinates.
324
+ """
325
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None,
326
+ flatten_first: int = 0, empty_initial: int = 0):
327
+ super().__init__(n_q)
328
+ if delays is None:
329
+ delays = list(range(n_q))
330
+ self.delays = delays
331
+ self.flatten_first = flatten_first
332
+ self.empty_initial = empty_initial
333
+ assert len(self.delays) == self.n_q
334
+ assert sorted(self.delays) == self.delays
335
+
336
+ def get_pattern(self, timesteps: int) -> Pattern:
337
+ omit_special_token = self.empty_initial < 0
338
+ out: PatternLayout = [] if omit_special_token else [[]]
339
+ max_delay = max(self.delays)
340
+ if self.empty_initial:
341
+ out += [[] for _ in range(self.empty_initial)]
342
+ if self.flatten_first:
343
+ for t in range(min(timesteps, self.flatten_first)):
344
+ for q in range(self.n_q):
345
+ out.append([LayoutCoord(t, q)])
346
+ for t in range(self.flatten_first, timesteps + max_delay):
347
+ v = []
348
+ for q, delay in enumerate(self.delays):
349
+ t_for_q = t - delay
350
+ if t_for_q >= self.flatten_first:
351
+ v.append(LayoutCoord(t_for_q, q))
352
+ out.append(v)
353
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
354
+
355
+
356
+ class ParallelPatternProvider(DelayedPatternProvider):
357
+ """Provider for parallel pattern across codebooks.
358
+ This pattern provider is a special case of the delayed pattern with actually no delay,
359
+ hence delays=repeat(0, n_q).
360
+
361
+ Args:
362
+ n_q (int): Number of codebooks.
363
+ empty_initial (int): Prepend with N empty list of coordinates.
364
+ """
365
+ def __init__(self, n_q: int, empty_initial: int = 0):
366
+ super().__init__(n_q, [0] * n_q, empty_initial=empty_initial)
367
+
368
+
369
+ class UnrolledPatternProvider(CodebooksPatternProvider):
370
+ """Provider for unrolling codebooks pattern.
371
+ This pattern provider enables to represent the codebook flattened completely or only to some extend
372
+ while also specifying a given delay between the flattened codebooks representation, allowing to
373
+ unroll the codebooks in the sequence.
374
+
375
+ Example:
376
+ 1. Flattening of the codebooks.
377
+ By default, the pattern provider will fully flatten the codebooks such as flattening=range(n_q),
378
+ taking n_q = 3 and timesteps = 4:
379
+ [[1, 2, 3, 4],
380
+ [1, 2, 3, 4],
381
+ [1, 2, 3, 4]]
382
+ will result into:
383
+ [[S, S, 1, S, S, 2, S, S, 3, S, S, 4],
384
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
385
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
386
+ 2. Partial flattening of the codebooks. The ``flattening`` parameter allows to specify the inner step
387
+ for each of the codebook, allowing to define which codebook to flatten (or keep in parallel), for example
388
+ taking n_q = 3, timesteps = 4 and flattening = [0, 1, 1]:
389
+ [[1, 2, 3, 4],
390
+ [1, 2, 3, 4],
391
+ [1, 2, 3, 4]]
392
+ will result into:
393
+ [[S, 1, S, S, 2, S, S, 3, S, S, 4, S],
394
+ [S, 1, S, S, 2, S, S, 3, S, S, 4, S],
395
+ [1, S, S, 2, S, S, 3, S, S, 4, S, S]]
396
+ 3. Flattening with delay. The ``delay`` parameter allows to further unroll the sequence of codebooks
397
+ allowing to specify the delay per codebook. Note that the delay between codebooks flattened to the
398
+ same inner timestep should be coherent. For example, taking n_q = 3, timesteps = 4, flattening = [0, 1, 1]
399
+ and delays = [0, 3, 3]:
400
+ [[1, 2, 3, 4],
401
+ [1, 2, 3, 4],
402
+ [1, 2, 3, 4]]
403
+ will result into:
404
+ [[S, S, S, 1, S, 2, S, 3, S, 4],
405
+ [S, S, S, 1, S, 2, S, 3, S, 4],
406
+ [1, 2, 3, S, 4, S, 5, S, 6, S]]
407
+
408
+ Args:
409
+ n_q (int): Number of codebooks.
410
+ flattening (list of int, optional): Flattening schema over the codebooks. If not defined,
411
+ the codebooks will be flattened to 1 codebook per step, meaning that the sequence will
412
+ have n_q extra steps for each timestep.
413
+ delays (list of int, optional): Delay for each of the codebooks. If not defined,
414
+ no delay is added and therefore will default to [0] * ``n_q``.
415
+ Note that two codebooks that will be flattened to the same inner step
416
+ should have the same delay, otherwise the pattern is considered as invalid.
417
+ """
418
+ FlattenedCodebook = namedtuple('FlattenedCodebook', ['codebooks', 'delay'])
419
+
420
+ def __init__(self, n_q: int, flattening: tp.Optional[tp.List[int]] = None,
421
+ delays: tp.Optional[tp.List[int]] = None):
422
+ super().__init__(n_q)
423
+ if flattening is None:
424
+ flattening = list(range(n_q))
425
+ if delays is None:
426
+ delays = [0] * n_q
427
+ assert len(flattening) == n_q
428
+ assert len(delays) == n_q
429
+ assert sorted(flattening) == flattening
430
+ assert sorted(delays) == delays
431
+ self._flattened_codebooks = self._build_flattened_codebooks(delays, flattening)
432
+ self.max_delay = max(delays)
433
+
434
+ def _build_flattened_codebooks(self, delays: tp.List[int], flattening: tp.List[int]):
435
+ """Build a flattened codebooks representation as a dictionary of inner step
436
+ and the actual codebook indices corresponding to the flattened codebook. For convenience, we
437
+ also store the delay associated to the flattened codebook to avoid maintaining an extra mapping.
438
+ """
439
+ flattened_codebooks: dict = {}
440
+ for q, (inner_step, delay) in enumerate(zip(flattening, delays)):
441
+ if inner_step not in flattened_codebooks:
442
+ flat_codebook = UnrolledPatternProvider.FlattenedCodebook(codebooks=[q], delay=delay)
443
+ else:
444
+ flat_codebook = flattened_codebooks[inner_step]
445
+ assert flat_codebook.delay == delay, (
446
+ "Delay and flattening between codebooks is inconsistent: ",
447
+ "two codebooks flattened to the same position should have the same delay."
448
+ )
449
+ flat_codebook.codebooks.append(q)
450
+ flattened_codebooks[inner_step] = flat_codebook
451
+ return flattened_codebooks
452
+
453
+ @property
454
+ def _num_inner_steps(self):
455
+ """Number of inner steps to unroll between timesteps in order to flatten the codebooks.
456
+ """
457
+ return max([inner_step for inner_step in self._flattened_codebooks.keys()]) + 1
458
+
459
+ def num_virtual_steps(self, timesteps: int) -> int:
460
+ return timesteps * self._num_inner_steps + 1
461
+
462
+ def get_pattern(self, timesteps: int) -> Pattern:
463
+ """Builds pattern for delay across codebooks.
464
+
465
+ Args:
466
+ timesteps (int): Total number of timesteps.
467
+ """
468
+ # the PatternLayout is built as a tuple of sequence position and list of coordinates
469
+ # so that it can be reordered properly given the required delay between codebooks of given timesteps
470
+ indexed_out: list = [(-1, [])]
471
+ max_timesteps = timesteps + self.max_delay
472
+ for t in range(max_timesteps):
473
+ # for each timestep, we unroll the flattened codebooks,
474
+ # emitting the sequence step with the corresponding delay
475
+ for step in range(self._num_inner_steps):
476
+ if step in self._flattened_codebooks:
477
+ # we have codebooks at this virtual step to emit
478
+ step_codebooks = self._flattened_codebooks[step]
479
+ t_for_q = t + step_codebooks.delay
480
+ coords = [LayoutCoord(t, q) for q in step_codebooks.codebooks]
481
+ if t_for_q < max_timesteps and t < max_timesteps:
482
+ indexed_out.append((t_for_q, coords))
483
+ else:
484
+ # there is no codebook in this virtual step so we emit an empty list
485
+ indexed_out.append((t, []))
486
+ out = [coords for _, coords in sorted(indexed_out)]
487
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
488
+
489
+
490
+ class CoarseFirstPattern(CodebooksPatternProvider):
491
+ """First generates all the codebooks #1 (e.g. coarser), then the remaining ones,
492
+ potentially with delays.
493
+
494
+ ..Warning:: You must always generate the full training duration at test time, for instance,
495
+ 30 seconds, as otherwise, the fine codebooks will start being generated in an unexpected
496
+ location. This is due to the non causality of the remaining codebooks with respect to
497
+ the first ones.
498
+
499
+ Args:
500
+ n_q (int): Number of codebooks.
501
+ delays (list of int, optional): Delay for each of the codebooks.
502
+ If delays not defined, each codebook is delayed by 1 compared to the previous one.
503
+ """
504
+ def __init__(self, n_q: int, delays: tp.Optional[tp.List[int]] = None):
505
+ super().__init__(n_q)
506
+ if delays is None:
507
+ delays = [0] * (n_q - 1)
508
+ self.delays = delays
509
+ assert len(self.delays) == self.n_q - 1
510
+ assert sorted(self.delays) == self.delays
511
+
512
+ def get_pattern(self, timesteps: int) -> Pattern:
513
+ out: PatternLayout = [[]]
514
+ for t in range(timesteps):
515
+ out.append([LayoutCoord(t, 0)])
516
+ max_delay = max(self.delays)
517
+ for t in range(timesteps + max_delay):
518
+ v = []
519
+ for q, delay in enumerate(self.delays):
520
+ t_for_q = t - delay
521
+ if t_for_q >= 0:
522
+ v.append(LayoutCoord(t_for_q, q + 1))
523
+ out.append(v)
524
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
525
+
526
+
527
+ class MusicLMPattern(CodebooksPatternProvider):
528
+ """Almost MusicLM style pattern. This is equivalent to full flattening
529
+ but in a different order.
530
+
531
+ Args:
532
+ n_q (int): Number of codebooks.
533
+ group_by (int): Number of codebooks to group together.
534
+ """
535
+ def __init__(self, n_q: int, group_by: int = 2):
536
+ super().__init__(n_q)
537
+ self.group_by = group_by
538
+
539
+ def get_pattern(self, timesteps: int) -> Pattern:
540
+ out: PatternLayout = [[]]
541
+ for offset in range(0, self.n_q, self.group_by):
542
+ for t in range(timesteps):
543
+ for q in range(offset, offset + self.group_by):
544
+ out.append([LayoutCoord(t, q)])
545
+ return Pattern(out, n_q=self.n_q, timesteps=timesteps)
ThinkSound/models/conditioners.py ADDED
@@ -0,0 +1,1005 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py
2
+
3
+ import torch
4
+ import logging, warnings
5
+ import string
6
+ import typing as tp
7
+ import gc
8
+ from typing import Literal, Optional
9
+ import os
10
+ from ..inference.utils import set_audio_channels
11
+ from .factory import create_pretransform_from_config
12
+ from .pretransforms import Pretransform
13
+ from ..training.utils import copy_state_dict
14
+ from .utils import load_ckpt_state_dict
15
+ import numpy as np
16
+ from einops import rearrange
17
+ from transformers import AutoProcessor, AutoModel
18
+ from torch import nn
19
+
20
+ class Conditioner(nn.Module):
21
+ def __init__(
22
+ self,
23
+ dim: int,
24
+ output_dim: int,
25
+ project_out: bool = False
26
+ ):
27
+
28
+ super().__init__()
29
+
30
+ self.dim = dim
31
+ self.output_dim = output_dim
32
+ self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity()
33
+
34
+ def forward(self, x: tp.Any) -> tp.Any:
35
+ raise NotImplementedError()
36
+
37
+ class VideoHieraConditioner(Conditioner):
38
+ def __init__(self,
39
+ output_dim: int,
40
+ hiera_ckpt_path,
41
+ project_out: bool = False,
42
+ finetune: bool = False):
43
+ super().__init__(768, output_dim, project_out=project_out)
44
+
45
+ self.finetune = finetune
46
+
47
+ # Suppress logging from transformers
48
+ previous_level = logging.root.manager.disable
49
+ logging.disable(logging.ERROR)
50
+ with warnings.catch_warnings():
51
+ warnings.simplefilter("ignore")
52
+ try:
53
+ from hiera import Hiera
54
+ import hiera
55
+ # model = hiera.hiera_base_16x224(pretrained=True, checkpoint="useful_ckpts/hiera_base_224.mae_in1k_ft_in1k")
56
+ model = Hiera(
57
+ num_classes=400, # K400 has 400 classes
58
+ input_size=(64, 224, 224),
59
+ q_stride=[(1, 4, 4),(1,7,7),(1,2,2)],
60
+ mask_unit_size=(1, 8, 8),
61
+ patch_kernel=(3, 7, 7),
62
+ patch_stride=(2, 4, 4),
63
+ patch_padding=(1, 3, 3),
64
+ sep_pos_embed=True,
65
+ )
66
+ state_dict = torch.load(hiera_ckpt_path)['model_state']
67
+ state_dict.pop('pos_embed_temporal', None) # 如果不需要这个参数
68
+ model.load_state_dict(state_dict,strict=False)
69
+ if self.finetune:
70
+ self.model = model
71
+ else:
72
+ self.__dict__["model"] = model
73
+
74
+ state_dict = model.state_dict()
75
+ self.model.load_state_dict(state_dict, strict=False)
76
+
77
+ if self.finetune:
78
+ self.model.requires_grad_(True)
79
+ self.model.train()
80
+ else:
81
+ self.model.requires_grad_(False)
82
+ self.model.train()
83
+
84
+ finally:
85
+ logging.disable(previous_level)
86
+
87
+
88
+ gc.collect()
89
+ torch.cuda.empty_cache()
90
+
91
+ def forward(self, x: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
92
+ self.model.to(device)
93
+ import ipdb
94
+ ipdb.set_trace()
95
+ output, interm = model(x,return_intermediates=True)
96
+
97
+ video_features = interm[-1]
98
+ return [self.proj_out(video_features), torch.ones(video_features.shape[0], 1).to(device)]
99
+
100
+ class Video_Linear(Conditioner):
101
+ """ Transform the video feat encoder"""
102
+
103
+ def __init__(self, dim, output_dim):
104
+ super().__init__(dim, output_dim)
105
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
106
+
107
+ def forward(self, x, device: tp.Any = "cuda"):
108
+ # import ipdb
109
+ # ipdb.set_trace()
110
+ if not isinstance(x[0], torch.Tensor):
111
+ video_feats = []
112
+ for path in x:
113
+ if '.npy' in path:
114
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
115
+ elif '.pth' in path:
116
+ video_feats.append(torch.load(path)['metaclip_features'].to(device))
117
+ else:
118
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
119
+ x = torch.stack(video_feats, dim=0).to(device)
120
+ else:
121
+ # Revise the shape here:
122
+ x = torch.stack(x, dim=0).to(device)
123
+
124
+ x = self.embedder(x) # B x 117 x C
125
+ return [x, torch.ones(x.shape[0], 1).to(device)]
126
+
127
+ class Video_Global(Conditioner):
128
+ """ Transform the video feat encoder"""
129
+
130
+ def __init__(self, dim, output_dim, global_dim=1536):
131
+ super().__init__(dim, output_dim)
132
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
133
+ self.global_proj = nn.Sequential(nn.Linear(output_dim, global_dim))
134
+
135
+ def forward(self, x, device: tp.Any = "cuda"):
136
+ # import ipdb
137
+ # ipdb.set_trace()
138
+ if not isinstance(x[0], torch.Tensor):
139
+ video_feats = []
140
+ for path in x:
141
+ if '.npy' in path:
142
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
143
+ elif '.pth' in path:
144
+ data = torch.load(path)
145
+ video_feats.append(data['metaclip_features'].to(device))
146
+ else:
147
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
148
+ x = torch.stack(video_feats, dim=0).to(device)
149
+ else:
150
+ # Revise the shape here:
151
+ x = torch.stack(x, dim=0).to(device)
152
+
153
+ x = self.embedder(x) # B x 117 x C
154
+ global_x = self.global_proj(x.mean(dim=1))
155
+ return [x, torch.ones(x.shape[0], 1).to(device), global_x, torch.ones(global_x.shape[0], 1).to(device)]
156
+
157
+ class Video_Sync(Conditioner):
158
+ """ Transform the video feat encoder"""
159
+
160
+ def __init__(self, dim, output_dim):
161
+ super().__init__(dim, output_dim)
162
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
163
+
164
+ def forward(self, x, device: tp.Any = "cuda"):
165
+ # import ipdb
166
+ # ipdb.set_trace()
167
+ if not isinstance(x[0], torch.Tensor):
168
+ video_feats = []
169
+ for path in x:
170
+ if '.npy' in path:
171
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
172
+ elif '.pth' in path:
173
+ video_feats.append(torch.load(path)['sync_features'].to(device))
174
+ else:
175
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
176
+ x = torch.stack(video_feats, dim=0).to(device)
177
+ else:
178
+ # Revise the shape here:
179
+ x = torch.stack(x, dim=0).to(device)
180
+
181
+ x = self.embedder(x) # B x 117 x C
182
+ return [x, torch.ones(x.shape[0], 1).to(device)]
183
+
184
+ class Text_Linear(Conditioner):
185
+ """ Transform the video feat encoder"""
186
+
187
+ def __init__(self, dim, output_dim):
188
+ super().__init__(dim, output_dim)
189
+ self.embedder = nn.Sequential(nn.Linear(dim, output_dim))
190
+
191
+ def forward(self, x, device: tp.Any = "cuda"):
192
+ # import ipdb
193
+ # ipdb.set_trace()
194
+ if not isinstance(x[0], torch.Tensor):
195
+ video_feats = []
196
+ for path in x:
197
+ if '.npy' in path:
198
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
199
+ elif '.pth' in path:
200
+ video_feats.append(torch.load(path)['metaclip_text_features'].to(device))
201
+ else:
202
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
203
+ x = torch.stack(video_feats, dim=0).to(device)
204
+ else:
205
+ # Revise the shape here:
206
+ x = torch.stack(x, dim=0).to(device)
207
+
208
+ x = self.embedder(x) # B x 117 x C
209
+ return [x, torch.ones(x.shape[0], 1).to(device)]
210
+
211
+
212
+ class mm_unchang(Conditioner):
213
+ """ Transform the video feat encoder"""
214
+
215
+ def __init__(self, dim, output_dim):
216
+ super().__init__(dim, output_dim)
217
+
218
+ def forward(self, x, device: tp.Any = "cuda"):
219
+ # import ipdb
220
+ # ipdb.set_trace()
221
+ if not isinstance(x[0], torch.Tensor):
222
+ video_feats = []
223
+ for path in x:
224
+ if '.npy' in path:
225
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
226
+ elif '.pth' in path:
227
+ video_feats.append(torch.load(path)['metaclip_features'].to(device))
228
+ else:
229
+ video_feats.append(torch.from_numpy(np.load(path)['feat']).to(device))
230
+ x = torch.stack(video_feats, dim=0).to(device)
231
+ else:
232
+ # Revise the shape here:
233
+ x = torch.stack(x, dim=0).to(device)
234
+ return [x]
235
+
236
+ class CLIPConditioner(Conditioner):
237
+
238
+ CLIP_MODELS = ["metaclip-base", "metaclip-b16", "metaclip-large", "metaclip-huge"]
239
+
240
+ CLIP_MODEL_DIMS = {
241
+ "metaclip-base": 512,
242
+ "metaclip-b16": 512,
243
+ "metaclip-large": 768,
244
+ "metaclip-huge": 1024,
245
+ }
246
+
247
+ def __init__(
248
+ self,
249
+ dim: int,
250
+ output_dim: int,
251
+ clip_model_name: str = "metaclip-huge",
252
+ enable_grad: bool = False,
253
+ project_out: bool = False
254
+ ):
255
+ assert clip_model_name in self.CLIP_MODELS, f"Unknown CLIP model name: {clip_model_name}"
256
+ super().__init__(self.CLIP_MODEL_DIMS[clip_model_name], output_dim, project_out=project_out)
257
+
258
+ self.enable_grad = enable_grad
259
+ model = AutoModel.from_pretrained(f"useful_ckpts/{clip_model_name}").train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
260
+
261
+
262
+
263
+ if self.enable_grad:
264
+ self.model = model
265
+ else:
266
+ self.__dict__["model"] = model
267
+
268
+
269
+ def forward(self, images: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
270
+
271
+ self.model.to(device)
272
+ self.proj_out.to(device)
273
+ # import ipdb
274
+ # ipdb.set_trace()
275
+
276
+ self.model.eval()
277
+ if not isinstance(images[0], torch.Tensor):
278
+ video_feats = []
279
+ for path in images:
280
+ if '.npy' in path:
281
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
282
+ else:
283
+ video_feats.append(torch.from_numpy(np.load(path)).to(device))
284
+ images = torch.stack(video_feats, dim=0).to(device)
285
+ else:
286
+ images = torch.stack(images, dim=0).to(device)
287
+ bsz, t, c, h, w = images.shape
288
+ # 使用 rearrange 进行维度合并
289
+ images = rearrange(images, 'b t c h w -> (b t) c h w')
290
+ with torch.set_grad_enabled(self.enable_grad):
291
+ image_features = self.model.get_image_features(images)
292
+ image_features = rearrange(image_features, '(b t) d -> b t d', b=bsz, t=t)
293
+ image_features = self.proj_out(image_features)
294
+
295
+
296
+ return [image_features, torch.ones(image_features.shape[0], 1).to(device)]
297
+
298
+ class IntConditioner(Conditioner):
299
+ def __init__(self,
300
+ output_dim: int,
301
+ min_val: int=0,
302
+ max_val: int=512
303
+ ):
304
+ super().__init__(output_dim, output_dim)
305
+
306
+ self.min_val = min_val
307
+ self.max_val = max_val
308
+ self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True)
309
+
310
+ def forward(self, ints: tp.List[int], device=None) -> tp.Any:
311
+
312
+ #self.int_embedder.to(device)
313
+
314
+ ints = torch.tensor(ints).to(device)
315
+ ints = ints.clamp(self.min_val, self.max_val)
316
+
317
+ int_embeds = self.int_embedder(ints).unsqueeze(1)
318
+
319
+ return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)]
320
+
321
+ class NumberConditioner(Conditioner):
322
+ '''
323
+ Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings
324
+ '''
325
+ def __init__(self,
326
+ output_dim: int,
327
+ min_val: float=0,
328
+ max_val: float=1
329
+ ):
330
+ super().__init__(output_dim, output_dim)
331
+
332
+ self.min_val = min_val
333
+ self.max_val = max_val
334
+
335
+ self.embedder = NumberEmbedder(features=output_dim)
336
+
337
+ def forward(self, floats: tp.List[float], device=None) -> tp.Any:
338
+
339
+ # Cast the inputs to floats
340
+ floats = [float(x) for x in floats]
341
+
342
+ floats = torch.tensor(floats).to(device)
343
+
344
+ floats = floats.clamp(self.min_val, self.max_val)
345
+
346
+ normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val)
347
+
348
+ # Cast floats to same type as embedder
349
+ embedder_dtype = next(self.embedder.parameters()).dtype
350
+ normalized_floats = normalized_floats.to(embedder_dtype)
351
+
352
+ float_embeds = self.embedder(normalized_floats).unsqueeze(1)
353
+
354
+ return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)]
355
+
356
+ class CLAPTextConditioner(Conditioner):
357
+ def __init__(self,
358
+ output_dim: int,
359
+ clap_ckpt_path,
360
+ use_text_features = False,
361
+ feature_layer_ix: int = -1,
362
+ audio_model_type="HTSAT-base",
363
+ enable_fusion=True,
364
+ project_out: bool = False,
365
+ finetune: bool = False):
366
+ super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out)
367
+
368
+ self.use_text_features = use_text_features
369
+ self.feature_layer_ix = feature_layer_ix
370
+ self.finetune = finetune
371
+
372
+ # Suppress logging from transformers
373
+ previous_level = logging.root.manager.disable
374
+ logging.disable(logging.ERROR)
375
+ with warnings.catch_warnings():
376
+ warnings.simplefilter("ignore")
377
+ try:
378
+ import laion_clap
379
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
380
+
381
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
382
+
383
+ if self.finetune:
384
+ self.model = model
385
+ else:
386
+ self.__dict__["model"] = model
387
+
388
+ state_dict = clap_load_state_dict(clap_ckpt_path)
389
+ self.model.model.load_state_dict(state_dict, strict=False)
390
+
391
+ if self.finetune:
392
+ self.model.model.text_branch.requires_grad_(True)
393
+ self.model.model.text_branch.train()
394
+ else:
395
+ self.model.model.text_branch.requires_grad_(False)
396
+ self.model.model.text_branch.eval()
397
+
398
+ finally:
399
+ logging.disable(previous_level)
400
+
401
+ del self.model.model.audio_branch
402
+
403
+ gc.collect()
404
+ torch.cuda.empty_cache()
405
+
406
+ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
407
+ prompt_tokens = self.model.tokenizer(prompts)
408
+ attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
409
+ prompt_features = self.model.model.text_branch(
410
+ input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True),
411
+ attention_mask=attention_mask,
412
+ output_hidden_states=True
413
+ )["hidden_states"][layer_ix]
414
+
415
+ return prompt_features, attention_mask
416
+
417
+ def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
418
+ self.model.to(device)
419
+
420
+ if self.use_text_features:
421
+ if len(texts) == 1:
422
+ text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device)
423
+ text_features = text_features[:1, ...]
424
+ text_attention_mask = text_attention_mask[:1, ...]
425
+ else:
426
+ text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device)
427
+ return [self.proj_out(text_features), text_attention_mask]
428
+
429
+ # Fix for CLAP bug when only one text is passed
430
+ if len(texts) == 1:
431
+ text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...]
432
+ else:
433
+ text_embedding = self.model.get_text_embedding(texts, use_tensor=True)
434
+
435
+ text_embedding = text_embedding.unsqueeze(1).to(device)
436
+
437
+ return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)]
438
+
439
+ class CLAPAudioConditioner(Conditioner):
440
+ def __init__(self,
441
+ output_dim: int,
442
+ clap_ckpt_path,
443
+ audio_model_type="HTSAT-base",
444
+ enable_fusion=True,
445
+ project_out: bool = False):
446
+ super().__init__(512, output_dim, project_out=project_out)
447
+
448
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
449
+
450
+ # Suppress logging from transformers
451
+ previous_level = logging.root.manager.disable
452
+ logging.disable(logging.ERROR)
453
+ with warnings.catch_warnings():
454
+ warnings.simplefilter("ignore")
455
+ try:
456
+ import laion_clap
457
+ from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
458
+
459
+ model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')
460
+
461
+ if self.finetune:
462
+ self.model = model
463
+ else:
464
+ self.__dict__["model"] = model
465
+
466
+ state_dict = clap_load_state_dict(clap_ckpt_path)
467
+ self.model.model.load_state_dict(state_dict, strict=False)
468
+
469
+ if self.finetune:
470
+ self.model.model.audio_branch.requires_grad_(True)
471
+ self.model.model.audio_branch.train()
472
+ else:
473
+ self.model.model.audio_branch.requires_grad_(False)
474
+ self.model.model.audio_branch.eval()
475
+
476
+ finally:
477
+ logging.disable(previous_level)
478
+
479
+ del self.model.model.text_branch
480
+
481
+ gc.collect()
482
+ torch.cuda.empty_cache()
483
+
484
+ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
485
+
486
+ self.model.to(device)
487
+
488
+ if isinstance(audios, list) or isinstance(audios, tuple):
489
+ audios = torch.cat(audios, dim=0)
490
+
491
+ # Convert to mono
492
+ mono_audios = audios.mean(dim=1)
493
+
494
+ with torch.cuda.amp.autocast(enabled=False):
495
+ audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)
496
+
497
+ audio_embedding = audio_embedding.unsqueeze(1).to(device)
498
+
499
+ return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)]
500
+
501
+ class T5Conditioner(Conditioner):
502
+
503
+ T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b",
504
+ "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large",
505
+ "google/flan-t5-xl", "google/flan-t5-xxl", "t5-v1_1-xl", "google/t5-v1_1-xxl"]
506
+
507
+ T5_MODEL_DIMS = {
508
+ "t5-small": 512,
509
+ "t5-base": 768,
510
+ "t5-large": 1024,
511
+ "t5-3b": 1024,
512
+ "t5-11b": 1024,
513
+ "t5-v1_1-xl": 2048,
514
+ "google/t5-v1_1-xxl": 4096,
515
+ "google/flan-t5-small": 512,
516
+ "google/flan-t5-base": 768,
517
+ "google/flan-t5-large": 1024,
518
+ "google/flan-t5-3b": 1024,
519
+ "google/flan-t5-11b": 1024,
520
+ "google/flan-t5-xl": 2048,
521
+ "google/flan-t5-xxl": 4096,
522
+ }
523
+
524
+ def __init__(
525
+ self,
526
+ output_dim: int,
527
+ t5_model_name: str = "t5-base",
528
+ max_length: str = 77,
529
+ enable_grad: bool = False,
530
+ project_out: bool = False
531
+ ):
532
+ assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}"
533
+ super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out)
534
+
535
+ from transformers import T5EncoderModel, AutoTokenizer
536
+
537
+ self.max_length = max_length
538
+ self.enable_grad = enable_grad
539
+
540
+ # Suppress logging from transformers
541
+ previous_level = logging.root.manager.disable
542
+ logging.disable(logging.ERROR)
543
+ with warnings.catch_warnings():
544
+ warnings.simplefilter("ignore")
545
+ try:
546
+ # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length)
547
+ # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad)
548
+ self.tokenizer = AutoTokenizer.from_pretrained(os.path.join('useful_ckpts', t5_model_name))
549
+ model = T5EncoderModel.from_pretrained(os.path.join('useful_ckpts', t5_model_name)).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
550
+ finally:
551
+ logging.disable(previous_level)
552
+
553
+ if self.enable_grad:
554
+ self.model = model
555
+ else:
556
+ self.__dict__["model"] = model
557
+
558
+
559
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
560
+
561
+ self.model.to(device)
562
+ self.proj_out.to(device)
563
+ encoded = self.tokenizer(
564
+ texts,
565
+ truncation=True,
566
+ max_length=self.max_length,
567
+ padding="max_length",
568
+ return_tensors="pt",
569
+ )
570
+
571
+ input_ids = encoded["input_ids"].to(device)
572
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
573
+
574
+ self.model.eval()
575
+
576
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
577
+ embeddings = self.model(
578
+ input_ids=input_ids, attention_mask=attention_mask
579
+ )["last_hidden_state"]
580
+
581
+ embeddings = self.proj_out(embeddings.float())
582
+
583
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
584
+
585
+ return embeddings, attention_mask
586
+
587
+ def patch_clip(clip_model):
588
+ # a hack to make it output last hidden states
589
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
590
+ def new_encode_text(self, text, normalize: bool = False):
591
+ cast_dtype = self.transformer.get_cast_dtype()
592
+
593
+ x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model]
594
+
595
+ x = x + self.positional_embedding.to(cast_dtype)
596
+ x = self.transformer(x, attn_mask=self.attn_mask)
597
+ x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
598
+ return F.normalize(x, dim=-1) if normalize else x
599
+
600
+ clip_model.encode_text = new_encode_text.__get__(clip_model)
601
+ return clip_model
602
+
603
+ class CLIPTextConditioner(Conditioner):
604
+ def __init__(
605
+ self,
606
+ output_dim: int,
607
+ max_length: str = 77,
608
+ enable_grad: bool = False,
609
+ project_out: bool = False
610
+ ):
611
+ super().__init__(1024, output_dim, project_out=project_out)
612
+
613
+ from transformers import T5EncoderModel, AutoTokenizer
614
+ import open_clip
615
+ from open_clip import create_model_from_pretrained
616
+
617
+ self.max_length = max_length
618
+ self.enable_grad = enable_grad
619
+
620
+ # Suppress logging from transformers
621
+ previous_level = logging.root.manager.disable
622
+ logging.disable(logging.ERROR)
623
+ with warnings.catch_warnings():
624
+ warnings.simplefilter("ignore")
625
+ try:
626
+ model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384',cache_dir='useful_ckpts/DFN5B-CLIP-ViT-H-14-384',
627
+ return_transform=False).train(enable_grad).requires_grad_(enable_grad).to(torch.float16)
628
+ model = patch_clip(model)
629
+ self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14'
630
+ finally:
631
+ logging.disable(previous_level)
632
+
633
+ if self.enable_grad:
634
+ self.model = model
635
+ else:
636
+ self.__dict__["model"] = model
637
+
638
+
639
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
640
+
641
+ self.model.to(device)
642
+ self.proj_out.to(device)
643
+
644
+ encoded = self.tokenizer(
645
+ texts
646
+ ).to(device)
647
+
648
+ # input_ids = encoded["input_ids"].to(device)
649
+ # attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
650
+
651
+ self.model.eval()
652
+
653
+ with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
654
+ embeddings = self.model.encode_text(
655
+ encoded
656
+ )
657
+
658
+ embeddings = self.proj_out(embeddings.float())
659
+
660
+ # embeddings = embeddings * attention_mask.unsqueeze(-1).float()
661
+
662
+ return embeddings, torch.ones(embeddings.shape[0], 1).to(device)
663
+
664
+ def patch_clip(clip_model):
665
+ # a hack to make it output last hidden states
666
+ # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269
667
+ def new_get_text_features(self, input_ids=None, attention_mask=None, position_ids=None,
668
+ output_attentions: Optional[bool] = None,
669
+ output_hidden_states: Optional[bool] = None,
670
+ return_dict: Optional[bool] = None):
671
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
672
+ output_hidden_states = (
673
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
674
+ )
675
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
676
+
677
+ text_outputs = self.text_model(
678
+ input_ids=input_ids,
679
+ attention_mask=attention_mask,
680
+ position_ids=position_ids,
681
+ output_attentions=output_attentions,
682
+ output_hidden_states=output_hidden_states,
683
+ return_dict=return_dict,
684
+ )
685
+ last_hidden_state = text_outputs[0]
686
+ # pooled_output = text_outputs[1]
687
+ # text_features = self.text_projection(pooled_output)
688
+
689
+ return last_hidden_state
690
+
691
+ clip_model.get_text_features = new_get_text_features.__get__(clip_model)
692
+ return clip_model
693
+
694
+ class MetaCLIPTextConditioner(Conditioner):
695
+ def __init__(
696
+ self,
697
+ output_dim: int,
698
+ max_length: str = 77,
699
+ enable_grad: bool = False,
700
+ project_out: bool = False
701
+ ):
702
+ super().__init__(1024, output_dim, project_out=project_out)
703
+
704
+ from transformers import AutoModel
705
+ from transformers import AutoProcessor
706
+
707
+ self.max_length = max_length
708
+ self.enable_grad = enable_grad
709
+
710
+ # Suppress logging from transformers
711
+ previous_level = logging.root.manager.disable
712
+ logging.disable(logging.ERROR)
713
+ with warnings.catch_warnings():
714
+ warnings.simplefilter("ignore")
715
+ try:
716
+ self.model = AutoModel.from_pretrained("useful_ckpts/metaclip-huge")
717
+ self.model = patch_clip(self.model)
718
+ self.clip_processor = AutoProcessor.from_pretrained("useful_ckpts/metaclip-huge")
719
+ finally:
720
+ logging.disable(previous_level)
721
+
722
+
723
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
724
+
725
+ self.model.to(device)
726
+ self.proj_out.to(device)
727
+ encoded = self.clip_processor(text=texts, return_tensors="pt", padding=True).to(device)
728
+
729
+ # input_ids = encoded["input_ids"].to(device)
730
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
731
+
732
+ self.model.eval()
733
+
734
+ with torch.set_grad_enabled(self.enable_grad):
735
+ embeddings = self.model.get_text_features(
736
+ **encoded
737
+ )
738
+
739
+ embeddings = self.proj_out(embeddings.float())
740
+
741
+ # embeddings = embeddings * attention_mask.unsqueeze(-1).float()
742
+
743
+ return embeddings, torch.ones(embeddings.shape[0],1).to(device)
744
+
745
+ class PhonemeConditioner(Conditioner):
746
+ """
747
+ A conditioner that turns text into phonemes and embeds them using a lookup table
748
+ Only works for English text
749
+
750
+ Args:
751
+ output_dim: the dimension of the output embeddings
752
+ max_length: the maximum number of phonemes to embed
753
+ project_out: whether to add another linear projection to the output embeddings
754
+ """
755
+
756
+ def __init__(
757
+ self,
758
+ output_dim: int,
759
+ max_length: int = 1024,
760
+ project_out: bool = False,
761
+ ):
762
+ super().__init__(output_dim, output_dim, project_out=project_out)
763
+
764
+ from g2p_en import G2p
765
+
766
+ self.max_length = max_length
767
+
768
+ self.g2p = G2p()
769
+
770
+ # Reserving 0 for padding, 1 for ignored
771
+ self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim)
772
+
773
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
774
+
775
+ self.phoneme_embedder.to(device)
776
+ self.proj_out.to(device)
777
+
778
+ batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length]
779
+
780
+ phoneme_ignore = [" ", *string.punctuation]
781
+
782
+ # Remove ignored phonemes and cut to max length
783
+ batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes]
784
+
785
+ # Convert to ids
786
+ phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes]
787
+
788
+ #Pad to match longest and make a mask tensor for the padding
789
+ longest = max([len(ids) for ids in phoneme_ids])
790
+ phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids]
791
+
792
+ phoneme_ids = torch.tensor(phoneme_ids).to(device)
793
+
794
+ # Convert to embeddings
795
+ phoneme_embeds = self.phoneme_embedder(phoneme_ids)
796
+
797
+ phoneme_embeds = self.proj_out(phoneme_embeds)
798
+
799
+ return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device)
800
+
801
+ class TokenizerLUTConditioner(Conditioner):
802
+ """
803
+ A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary
804
+
805
+ Args:
806
+ tokenizer_name: the name of the tokenizer from the Hugging Face transformers library
807
+ output_dim: the dimension of the output embeddings
808
+ max_length: the maximum length of the text to embed
809
+ project_out: whether to add another linear projection to the output embeddings
810
+ """
811
+
812
+ def __init__(
813
+ self,
814
+ tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library
815
+ output_dim: int,
816
+ max_length: int = 1024,
817
+ project_out: bool = False,
818
+ ):
819
+ super().__init__(output_dim, output_dim, project_out=project_out)
820
+
821
+ from transformers import AutoTokenizer
822
+
823
+ # Suppress logging from transformers
824
+ previous_level = logging.root.manager.disable
825
+ logging.disable(logging.ERROR)
826
+ with warnings.catch_warnings():
827
+ warnings.simplefilter("ignore")
828
+ try:
829
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
830
+ finally:
831
+ logging.disable(previous_level)
832
+
833
+ self.max_length = max_length
834
+
835
+ self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim)
836
+
837
+ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
838
+ self.proj_out.to(device)
839
+
840
+ encoded = self.tokenizer(
841
+ texts,
842
+ truncation=True,
843
+ max_length=self.max_length,
844
+ padding="max_length",
845
+ return_tensors="pt",
846
+ )
847
+
848
+ input_ids = encoded["input_ids"].to(device)
849
+ attention_mask = encoded["attention_mask"].to(device).to(torch.bool)
850
+
851
+ embeddings = self.token_embedder(input_ids)
852
+
853
+ embeddings = self.proj_out(embeddings)
854
+
855
+ embeddings = embeddings * attention_mask.unsqueeze(-1).float()
856
+
857
+ return embeddings, attention_mask
858
+
859
+ class PretransformConditioner(Conditioner):
860
+ """
861
+ A conditioner that uses a pretransform's encoder for conditioning
862
+
863
+ Args:
864
+ pretransform: an instantiated pretransform to use for conditioning
865
+ output_dim: the dimension of the output embeddings
866
+ """
867
+ def __init__(self, pretransform: Pretransform, output_dim: int):
868
+ super().__init__(pretransform.encoded_channels, output_dim)
869
+
870
+ self.pretransform = pretransform
871
+
872
+ def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
873
+
874
+ self.pretransform.to(device)
875
+ self.proj_out.to(device)
876
+
877
+ if isinstance(audio, list) or isinstance(audio, tuple):
878
+ audio = torch.cat(audio, dim=0)
879
+
880
+ # Convert audio to pretransform input channels
881
+ audio = set_audio_channels(audio, self.pretransform.io_channels)
882
+
883
+ latents = self.pretransform.encode(audio)
884
+
885
+ latents = self.proj_out(latents)
886
+
887
+ return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)]
888
+
889
+ class MultiConditioner(nn.Module):
890
+ """
891
+ A module that applies multiple conditioners to an input dictionary based on the keys
892
+
893
+ Args:
894
+ conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt")
895
+ default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"})
896
+ """
897
+ def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}):
898
+ super().__init__()
899
+
900
+ self.conditioners = nn.ModuleDict(conditioners)
901
+ self.default_keys = default_keys
902
+
903
+ def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]:
904
+ output = {}
905
+
906
+ for key, conditioner in self.conditioners.items():
907
+ condition_key = key
908
+
909
+ conditioner_inputs = []
910
+
911
+ for x in batch_metadata:
912
+
913
+ if condition_key not in x:
914
+ if condition_key in self.default_keys:
915
+ condition_key = self.default_keys[condition_key]
916
+ else:
917
+ raise ValueError(f"Conditioner key {condition_key} not found in batch metadata")
918
+
919
+ #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list
920
+ if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1:
921
+ conditioner_input = x[condition_key][0]
922
+
923
+ else:
924
+ conditioner_input = x[condition_key]
925
+
926
+ conditioner_inputs.append(conditioner_input)
927
+
928
+ cond_output = conditioner(conditioner_inputs, device)
929
+ if len(cond_output) == 1:
930
+ output[key] = cond_output[0]
931
+ elif len(cond_output) == 2:
932
+ output[key] = cond_output
933
+ elif len(cond_output) == 4:
934
+ output[key] = cond_output[:2]
935
+ output[f'{key}_g'] = cond_output[2:]
936
+
937
+ return output
938
+
939
+ def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner:
940
+ """
941
+ Create a MultiConditioner from a conditioning config dictionary
942
+
943
+ Args:
944
+ config: the conditioning config dictionary
945
+ device: the device to put the conditioners on
946
+ """
947
+ conditioners = {}
948
+ cond_dim = config["cond_dim"]
949
+
950
+ default_keys = config.get("default_keys", {})
951
+
952
+ for conditioner_info in config["configs"]:
953
+ id = conditioner_info["id"]
954
+
955
+ conditioner_type = conditioner_info["type"]
956
+
957
+ conditioner_config = {"output_dim": cond_dim}
958
+
959
+ conditioner_config.update(conditioner_info["config"])
960
+ if conditioner_type == "t5":
961
+ conditioners[id] = T5Conditioner(**conditioner_config)
962
+ elif conditioner_type == "clap_text":
963
+ conditioners[id] = CLAPTextConditioner(**conditioner_config)
964
+ elif conditioner_type == "clip_text":
965
+ conditioners[id] = CLIPTextConditioner(**conditioner_config)
966
+ elif conditioner_type == "metaclip_text":
967
+ conditioners[id] = MetaCLIPTextConditioner(**conditioner_config)
968
+ elif conditioner_type == "clap_audio":
969
+ conditioners[id] = CLAPAudioConditioner(**conditioner_config)
970
+ elif conditioner_type == "video_linear":
971
+ conditioners[id] = Video_Linear(**conditioner_config)
972
+ elif conditioner_type == "video_global":
973
+ conditioners[id] = Video_Global(**conditioner_config)
974
+ elif conditioner_type == "video_sync":
975
+ conditioners[id] = Video_Sync(**conditioner_config)
976
+ elif conditioner_type == "text_linear":
977
+ conditioners[id] = Text_Linear(**conditioner_config)
978
+ elif conditioner_type == "video_clip":
979
+ conditioners[id] = CLIPConditioner(**conditioner_config)
980
+ elif conditioner_type == "video_hiera":
981
+ conditioners[id] = VideoHieraConditioner(**conditioner_config)
982
+ elif conditioner_type == "int":
983
+ conditioners[id] = IntConditioner(**conditioner_config)
984
+ elif conditioner_type == "number":
985
+ conditioners[id] = NumberConditioner(**conditioner_config)
986
+ elif conditioner_type == "phoneme":
987
+ conditioners[id] = PhonemeConditioner(**conditioner_config)
988
+ elif conditioner_type == "lut":
989
+ conditioners[id] = TokenizerLUTConditioner(**conditioner_config)
990
+ elif conditioner_type == "pretransform":
991
+ sample_rate = conditioner_config.pop("sample_rate", None)
992
+ assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners"
993
+
994
+ pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate)
995
+
996
+ if conditioner_config.get("pretransform_ckpt_path", None) is not None:
997
+ pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path")))
998
+
999
+ conditioners[id] = PretransformConditioner(pretransform, **conditioner_config)
1000
+ elif conditioner_type == "mm_unchang":
1001
+ conditioners[id] = mm_unchang(**conditioner_config)
1002
+ else:
1003
+ raise ValueError(f"Unknown conditioner type: {conditioner_type}")
1004
+
1005
+ return MultiConditioner(conditioners, default_keys=default_keys)
ThinkSound/models/diffusion.py ADDED
@@ -0,0 +1,920 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from functools import partial
5
+ import numpy as np
6
+ import typing as tp
7
+
8
+ from .blocks import ResConvBlock, FourierFeatures, Upsample1d, Upsample1d_2, Downsample1d, Downsample1d_2, SelfAttention1d, SkipBlock, expand_to_planes
9
+ from .conditioners import MultiConditioner, create_multi_conditioner_from_conditioning_config
10
+ # from .dit import DiffusionTransformer
11
+ from .mmdit import MMAudio
12
+ from .factory import create_pretransform_from_config
13
+ from .pretransforms import Pretransform
14
+ from ..inference.generation import generate_diffusion_cond
15
+
16
+ from time import time
17
+
18
+ class Profiler:
19
+
20
+ def __init__(self):
21
+ self.ticks = [[time(), None]]
22
+
23
+ def tick(self, msg):
24
+ self.ticks.append([time(), msg])
25
+
26
+ def __repr__(self):
27
+ rep = 80 * "=" + "\n"
28
+ for i in range(1, len(self.ticks)):
29
+ msg = self.ticks[i][1]
30
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
31
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
32
+ rep += 80 * "=" + "\n\n\n"
33
+ return rep
34
+
35
+ class DiffusionModel(nn.Module):
36
+ def __init__(self, *args, **kwargs):
37
+ super().__init__(*args, **kwargs)
38
+
39
+ def forward(self, x, t, **kwargs):
40
+ raise NotImplementedError()
41
+
42
+ class DiffusionModelWrapper(nn.Module):
43
+ def __init__(
44
+ self,
45
+ model: DiffusionModel,
46
+ io_channels,
47
+ sample_size,
48
+ sample_rate,
49
+ min_input_length,
50
+ pretransform: tp.Optional[Pretransform] = None,
51
+ ):
52
+ super().__init__()
53
+ self.io_channels = io_channels
54
+ self.sample_size = sample_size
55
+ self.sample_rate = sample_rate
56
+ self.min_input_length = min_input_length
57
+
58
+ self.model = model
59
+
60
+ if pretransform is not None:
61
+ self.pretransform = pretransform
62
+ else:
63
+ self.pretransform = None
64
+
65
+ def forward(self, x, t, **kwargs):
66
+ return self.model(x, t, **kwargs)
67
+
68
+ class ConditionedDiffusionModel(nn.Module):
69
+ def __init__(self,
70
+ *args,
71
+ supports_cross_attention: bool = False,
72
+ supports_input_concat: bool = False,
73
+ supports_global_cond: bool = False,
74
+ supports_prepend_cond: bool = False,
75
+ **kwargs):
76
+ super().__init__(*args, **kwargs)
77
+ self.supports_cross_attention = supports_cross_attention
78
+ self.supports_input_concat = supports_input_concat
79
+ self.supports_global_cond = supports_global_cond
80
+ self.supports_prepend_cond = supports_prepend_cond
81
+
82
+ def forward(self,
83
+ x: torch.Tensor,
84
+ t: torch.Tensor,
85
+ cross_attn_cond: torch.Tensor = None,
86
+ cross_attn_mask: torch.Tensor = None,
87
+ input_concat_cond: torch.Tensor = None,
88
+ global_embed: torch.Tensor = None,
89
+ prepend_cond: torch.Tensor = None,
90
+ prepend_cond_mask: torch.Tensor = None,
91
+ cfg_scale: float = 1.0,
92
+ cfg_dropout_prob: float = 0.0,
93
+ batch_cfg: bool = False,
94
+ rescale_cfg: bool = False,
95
+ **kwargs):
96
+ raise NotImplementedError()
97
+
98
+ class ConditionedDiffusionModelWrapper(nn.Module):
99
+ """
100
+ A diffusion model that takes in conditioning
101
+ """
102
+ def __init__(
103
+ self,
104
+ model: ConditionedDiffusionModel,
105
+ conditioner: MultiConditioner,
106
+ io_channels,
107
+ sample_rate,
108
+ min_input_length: int,
109
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
110
+ pretransform: tp.Optional[Pretransform] = None,
111
+ cross_attn_cond_ids: tp.List[str] = [],
112
+ global_cond_ids: tp.List[str] = [],
113
+ input_concat_ids: tp.List[str] = [],
114
+ prepend_cond_ids: tp.List[str] = [],
115
+ add_cond_ids: tp.List[str] = [],
116
+ ):
117
+ super().__init__()
118
+
119
+ self.model = model
120
+ self.conditioner = conditioner
121
+ self.io_channels = io_channels
122
+ self.sample_rate = sample_rate
123
+ self.diffusion_objective = diffusion_objective
124
+ self.pretransform = pretransform
125
+ self.cross_attn_cond_ids = cross_attn_cond_ids
126
+ self.global_cond_ids = global_cond_ids
127
+ self.input_concat_ids = input_concat_ids
128
+ self.prepend_cond_ids = prepend_cond_ids
129
+ self.add_cond_ids = add_cond_ids
130
+ self.min_input_length = min_input_length
131
+
132
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
133
+ cross_attention_input = None
134
+ cross_attention_masks = None
135
+ global_cond = None
136
+ input_concat_cond = None
137
+ prepend_cond = None
138
+ prepend_cond_mask = None
139
+ add_input = None
140
+
141
+ if len(self.cross_attn_cond_ids) > 0:
142
+ # Concatenate all cross-attention inputs over the sequence dimension
143
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
144
+ cross_attention_input = []
145
+ cross_attention_masks = []
146
+
147
+ for key in self.cross_attn_cond_ids:
148
+ cross_attn_in, cross_attn_mask = conditioning_tensors[key]
149
+
150
+ # Add sequence dimension if it's not there
151
+ if len(cross_attn_in.shape) == 2:
152
+ cross_attn_in = cross_attn_in.unsqueeze(1)
153
+ # cross_attn_mask = cross_attn_mask.unsqueeze(1)
154
+
155
+ cross_attention_input.append(cross_attn_in)
156
+ cross_attention_masks.append(cross_attn_mask)
157
+ # import ipdb
158
+ # ipdb.set_trace()
159
+ cross_attention_input = torch.cat(cross_attention_input, dim=1)
160
+ cross_attention_masks = torch.cat(cross_attention_masks, dim=1)
161
+
162
+ if len(self.add_cond_ids) > 0:
163
+ # Concatenate all cross-attention inputs over the sequence dimension
164
+ # Assumes that the cross-attention inputs are of shape (batch, seq, channels)
165
+ add_input = []
166
+
167
+ for key in self.add_cond_ids:
168
+ add_in, _ = conditioning_tensors[key]
169
+
170
+ # Add sequence dimension if it's not there
171
+ if len(add_in.shape) == 2:
172
+ add_in = add_in.unsqueeze(1)
173
+
174
+ add_input.append(add_in)
175
+
176
+ add_input = torch.cat(add_input, dim=1)
177
+
178
+ if len(self.global_cond_ids) > 0:
179
+ # Concatenate all global conditioning inputs over the channel dimension
180
+ # Assumes that the global conditioning inputs are of shape (batch, channels)
181
+ global_conds = []
182
+ # import ipdb
183
+ # ipdb.set_trace()
184
+ for key in self.global_cond_ids:
185
+ global_cond_input = conditioning_tensors[key][0]
186
+
187
+ global_conds.append(global_cond_input)
188
+
189
+ # Concatenate over the channel dimension
190
+ if global_conds[0].shape[-1] == 768:
191
+ global_cond = torch.cat(global_conds, dim=-1)
192
+ else:
193
+ global_cond = sum(global_conds)
194
+
195
+ # global_cond = torch.cat(global_conds, dim=-1)
196
+
197
+ if len(global_cond.shape) == 3:
198
+ global_cond = global_cond.squeeze(1)
199
+
200
+ if len(self.input_concat_ids) > 0:
201
+ # Concatenate all input concat conditioning inputs over the channel dimension
202
+ # Assumes that the input concat conditioning inputs are of shape (batch, channels, seq)
203
+ input_concat_cond = torch.cat([conditioning_tensors[key][0] for key in self.input_concat_ids], dim=1)
204
+
205
+ if len(self.prepend_cond_ids) > 0:
206
+ # Concatenate all prepend conditioning inputs over the sequence dimension
207
+ # Assumes that the prepend conditioning inputs are of shape (batch, seq, channels)
208
+ prepend_conds = []
209
+ prepend_cond_masks = []
210
+
211
+ for key in self.prepend_cond_ids:
212
+ prepend_cond_input, prepend_cond_mask = conditioning_tensors[key]
213
+ prepend_conds.append(prepend_cond_input)
214
+ prepend_cond_masks.append(prepend_cond_mask)
215
+
216
+ prepend_cond = torch.cat(prepend_conds, dim=1)
217
+ prepend_cond_mask = torch.cat(prepend_cond_masks, dim=1)
218
+
219
+ if negative:
220
+ return {
221
+ "negative_cross_attn_cond": cross_attention_input,
222
+ "negative_cross_attn_mask": cross_attention_masks,
223
+ "negative_global_cond": global_cond,
224
+ "negative_input_concat_cond": input_concat_cond
225
+ }
226
+ else:
227
+ return {
228
+ "cross_attn_cond": cross_attention_input,
229
+ "cross_attn_mask": cross_attention_masks,
230
+ "global_cond": global_cond,
231
+ "input_concat_cond": input_concat_cond,
232
+ "prepend_cond": prepend_cond,
233
+ "prepend_cond_mask": prepend_cond_mask,
234
+ "add_cond": add_input
235
+ }
236
+
237
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
238
+ return self.model(x, t, **self.get_conditioning_inputs(cond), **kwargs)
239
+
240
+ def generate(self, *args, **kwargs):
241
+ return generate_diffusion_cond(self, *args, **kwargs)
242
+
243
+ class UNetCFG1DWrapper(ConditionedDiffusionModel):
244
+ def __init__(
245
+ self,
246
+ *args,
247
+ **kwargs
248
+ ):
249
+ super().__init__(supports_cross_attention=True, supports_global_cond=True, supports_input_concat=True)
250
+
251
+ self.model = UNetCFG1d(*args, **kwargs)
252
+
253
+ with torch.no_grad():
254
+ for param in self.model.parameters():
255
+ param *= 0.5
256
+
257
+ def forward(self,
258
+ x,
259
+ t,
260
+ cross_attn_cond=None,
261
+ cross_attn_mask=None,
262
+ input_concat_cond=None,
263
+ global_cond=None,
264
+ cfg_scale=1.0,
265
+ cfg_dropout_prob: float = 0.0,
266
+ batch_cfg: bool = False,
267
+ rescale_cfg: bool = False,
268
+ negative_cross_attn_cond=None,
269
+ negative_cross_attn_mask=None,
270
+ negative_global_cond=None,
271
+ negative_input_concat_cond=None,
272
+ prepend_cond=None,
273
+ prepend_cond_mask=None,
274
+ **kwargs):
275
+ p = Profiler()
276
+
277
+ p.tick("start")
278
+
279
+ channels_list = None
280
+ if input_concat_cond is not None:
281
+ channels_list = [input_concat_cond]
282
+
283
+ outputs = self.model(
284
+ x,
285
+ t,
286
+ embedding=cross_attn_cond,
287
+ embedding_mask=cross_attn_mask,
288
+ features=global_cond,
289
+ channels_list=channels_list,
290
+ embedding_scale=cfg_scale,
291
+ embedding_mask_proba=cfg_dropout_prob,
292
+ batch_cfg=batch_cfg,
293
+ rescale_cfg=rescale_cfg,
294
+ negative_embedding=negative_cross_attn_cond,
295
+ negative_embedding_mask=negative_cross_attn_mask,
296
+ **kwargs)
297
+
298
+ p.tick("UNetCFG1D forward")
299
+
300
+ #print(f"Profiler: {p}")
301
+ return outputs
302
+
303
+ class UNet1DCondWrapper(ConditionedDiffusionModel):
304
+ def __init__(
305
+ self,
306
+ *args,
307
+ **kwargs
308
+ ):
309
+ super().__init__(supports_cross_attention=False, supports_global_cond=True, supports_input_concat=True)
310
+
311
+ self.model = UNet1d(*args, **kwargs)
312
+
313
+ with torch.no_grad():
314
+ for param in self.model.parameters():
315
+ param *= 0.5
316
+
317
+ def forward(self,
318
+ x,
319
+ t,
320
+ input_concat_cond=None,
321
+ global_cond=None,
322
+ cross_attn_cond=None,
323
+ cross_attn_mask=None,
324
+ prepend_cond=None,
325
+ prepend_cond_mask=None,
326
+ cfg_scale=1.0,
327
+ cfg_dropout_prob: float = 0.0,
328
+ batch_cfg: bool = False,
329
+ rescale_cfg: bool = False,
330
+ negative_cross_attn_cond=None,
331
+ negative_cross_attn_mask=None,
332
+ negative_global_cond=None,
333
+ negative_input_concat_cond=None,
334
+ **kwargs):
335
+
336
+ channels_list = None
337
+ if input_concat_cond is not None:
338
+
339
+ # Interpolate input_concat_cond to the same length as x
340
+ if input_concat_cond.shape[2] != x.shape[2]:
341
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest')
342
+
343
+ channels_list = [input_concat_cond]
344
+
345
+ outputs = self.model(
346
+ x,
347
+ t,
348
+ features=global_cond,
349
+ channels_list=channels_list,
350
+ **kwargs)
351
+
352
+ return outputs
353
+
354
+ class UNet1DUncondWrapper(DiffusionModel):
355
+ def __init__(
356
+ self,
357
+ in_channels,
358
+ *args,
359
+ **kwargs
360
+ ):
361
+ super().__init__()
362
+
363
+ self.model = UNet1d(in_channels=in_channels, *args, **kwargs)
364
+
365
+ self.io_channels = in_channels
366
+
367
+ with torch.no_grad():
368
+ for param in self.model.parameters():
369
+ param *= 0.5
370
+
371
+ def forward(self, x, t, **kwargs):
372
+ return self.model(x, t, **kwargs)
373
+
374
+ class DAU1DCondWrapper(ConditionedDiffusionModel):
375
+ def __init__(
376
+ self,
377
+ *args,
378
+ **kwargs
379
+ ):
380
+ super().__init__(supports_cross_attention=False, supports_global_cond=False, supports_input_concat=True)
381
+
382
+ self.model = DiffusionAttnUnet1D(*args, **kwargs)
383
+
384
+ with torch.no_grad():
385
+ for param in self.model.parameters():
386
+ param *= 0.5
387
+
388
+ def forward(self,
389
+ x,
390
+ t,
391
+ input_concat_cond=None,
392
+ cross_attn_cond=None,
393
+ cross_attn_mask=None,
394
+ global_cond=None,
395
+ cfg_scale=1.0,
396
+ cfg_dropout_prob: float = 0.0,
397
+ batch_cfg: bool = False,
398
+ rescale_cfg: bool = False,
399
+ negative_cross_attn_cond=None,
400
+ negative_cross_attn_mask=None,
401
+ negative_global_cond=None,
402
+ negative_input_concat_cond=None,
403
+ prepend_cond=None,
404
+ **kwargs):
405
+
406
+ return self.model(x, t, cond = input_concat_cond)
407
+
408
+ class DiffusionAttnUnet1D(nn.Module):
409
+ def __init__(
410
+ self,
411
+ io_channels = 2,
412
+ depth=14,
413
+ n_attn_layers = 6,
414
+ channels = [128, 128, 256, 256] + [512] * 10,
415
+ cond_dim = 0,
416
+ cond_noise_aug = False,
417
+ kernel_size = 5,
418
+ learned_resample = False,
419
+ strides = [2] * 13,
420
+ conv_bias = True,
421
+ use_snake = False
422
+ ):
423
+ super().__init__()
424
+
425
+ self.cond_noise_aug = cond_noise_aug
426
+
427
+ self.io_channels = io_channels
428
+
429
+ if self.cond_noise_aug:
430
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
431
+
432
+ self.timestep_embed = FourierFeatures(1, 16)
433
+
434
+ attn_layer = depth - n_attn_layers
435
+
436
+ strides = [1] + strides
437
+
438
+ block = nn.Identity()
439
+
440
+ conv_block = partial(ResConvBlock, kernel_size=kernel_size, conv_bias = conv_bias, use_snake=use_snake)
441
+
442
+ for i in range(depth, 0, -1):
443
+ c = channels[i - 1]
444
+ stride = strides[i-1]
445
+ if stride > 2 and not learned_resample:
446
+ raise ValueError("Must have stride 2 without learned resampling")
447
+
448
+ if i > 1:
449
+ c_prev = channels[i - 2]
450
+ add_attn = i >= attn_layer and n_attn_layers > 0
451
+ block = SkipBlock(
452
+ Downsample1d_2(c_prev, c_prev, stride) if (learned_resample or stride == 1) else Downsample1d("cubic"),
453
+ conv_block(c_prev, c, c),
454
+ SelfAttention1d(
455
+ c, c // 32) if add_attn else nn.Identity(),
456
+ conv_block(c, c, c),
457
+ SelfAttention1d(
458
+ c, c // 32) if add_attn else nn.Identity(),
459
+ conv_block(c, c, c),
460
+ SelfAttention1d(
461
+ c, c // 32) if add_attn else nn.Identity(),
462
+ block,
463
+ conv_block(c * 2 if i != depth else c, c, c),
464
+ SelfAttention1d(
465
+ c, c // 32) if add_attn else nn.Identity(),
466
+ conv_block(c, c, c),
467
+ SelfAttention1d(
468
+ c, c // 32) if add_attn else nn.Identity(),
469
+ conv_block(c, c, c_prev),
470
+ SelfAttention1d(c_prev, c_prev //
471
+ 32) if add_attn else nn.Identity(),
472
+ Upsample1d_2(c_prev, c_prev, stride) if learned_resample else Upsample1d(kernel="cubic")
473
+ )
474
+ else:
475
+ cond_embed_dim = 16 if not self.cond_noise_aug else 32
476
+ block = nn.Sequential(
477
+ conv_block((io_channels + cond_dim) + cond_embed_dim, c, c),
478
+ conv_block(c, c, c),
479
+ conv_block(c, c, c),
480
+ block,
481
+ conv_block(c * 2, c, c),
482
+ conv_block(c, c, c),
483
+ conv_block(c, c, io_channels, is_last=True),
484
+ )
485
+ self.net = block
486
+
487
+ with torch.no_grad():
488
+ for param in self.net.parameters():
489
+ param *= 0.5
490
+
491
+ def forward(self, x, t, cond=None, cond_aug_scale=None):
492
+
493
+ timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), x.shape)
494
+
495
+ inputs = [x, timestep_embed]
496
+
497
+ if cond is not None:
498
+ if cond.shape[2] != x.shape[2]:
499
+ cond = F.interpolate(cond, (x.shape[2], ), mode='linear', align_corners=False)
500
+
501
+ if self.cond_noise_aug:
502
+ # Get a random number between 0 and 1, uniformly sampled
503
+ if cond_aug_scale is None:
504
+ aug_level = self.rng.draw(cond.shape[0])[:, 0].to(cond)
505
+ else:
506
+ aug_level = torch.tensor([cond_aug_scale]).repeat([cond.shape[0]]).to(cond)
507
+
508
+ # Add noise to the conditioning signal
509
+ cond = cond + torch.randn_like(cond) * aug_level[:, None, None]
510
+
511
+ # Get embedding for noise cond level, reusing timestamp_embed
512
+ aug_level_embed = expand_to_planes(self.timestep_embed(aug_level[:, None]), x.shape)
513
+
514
+ inputs.append(aug_level_embed)
515
+
516
+ inputs.append(cond)
517
+
518
+ outputs = self.net(torch.cat(inputs, dim=1))
519
+
520
+ return outputs
521
+
522
+ class DiTWrapper(ConditionedDiffusionModel):
523
+ def __init__(
524
+ self,
525
+ *args,
526
+ **kwargs
527
+ ):
528
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
529
+
530
+ self.model = DiffusionTransformer(*args, **kwargs)
531
+
532
+ with torch.no_grad():
533
+ for param in self.model.parameters():
534
+ param *= 0.5
535
+
536
+ def forward(self,
537
+ x,
538
+ t,
539
+ cross_attn_cond=None,
540
+ cross_attn_mask=None,
541
+ negative_cross_attn_cond=None,
542
+ negative_cross_attn_mask=None,
543
+ input_concat_cond=None,
544
+ negative_input_concat_cond=None,
545
+ global_cond=None,
546
+ negative_global_cond=None,
547
+ prepend_cond=None,
548
+ prepend_cond_mask=None,
549
+ cfg_scale=1.0,
550
+ cfg_dropout_prob: float = 0.0,
551
+ batch_cfg: bool = True,
552
+ rescale_cfg: bool = False,
553
+ scale_phi: float = 0.0,
554
+ **kwargs):
555
+
556
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
557
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
558
+
559
+ return self.model(
560
+ x,
561
+ t,
562
+ cross_attn_cond=cross_attn_cond,
563
+ cross_attn_cond_mask=cross_attn_mask,
564
+ negative_cross_attn_cond=negative_cross_attn_cond,
565
+ negative_cross_attn_mask=negative_cross_attn_mask,
566
+ input_concat_cond=input_concat_cond,
567
+ prepend_cond=prepend_cond,
568
+ prepend_cond_mask=prepend_cond_mask,
569
+ cfg_scale=cfg_scale,
570
+ cfg_dropout_prob=cfg_dropout_prob,
571
+ scale_phi=scale_phi,
572
+ global_embed=global_cond,
573
+ **kwargs)
574
+
575
+ class MMDiTWrapper(ConditionedDiffusionModel):
576
+ def __init__(
577
+ self,
578
+ *args,
579
+ **kwargs
580
+ ):
581
+ super().__init__(supports_cross_attention=True, supports_global_cond=False, supports_input_concat=False)
582
+
583
+ self.model = MMAudio(*args, **kwargs)
584
+
585
+ # with torch.no_grad():
586
+ # for param in self.model.parameters():
587
+ # param *= 0.5
588
+
589
+ def forward(self,
590
+ x,
591
+ t,
592
+ clip_f,
593
+ sync_f,
594
+ text_f,
595
+ inpaint_masked_input=None,
596
+ t5_features=None,
597
+ metaclip_global_text_features=None,
598
+ cfg_scale=1.0,
599
+ cfg_dropout_prob: float = 0.0,
600
+ batch_cfg: bool = True,
601
+ rescale_cfg: bool = False,
602
+ scale_phi: float = 0.0,
603
+ **kwargs):
604
+
605
+ # breakpoint()
606
+ assert batch_cfg, "batch_cfg must be True for DiTWrapper"
607
+ #assert negative_input_concat_cond is None, "negative_input_concat_cond is not supported for DiTWrapper"
608
+
609
+ return self.model(
610
+ latent=x,
611
+ t=t,
612
+ clip_f=clip_f,
613
+ sync_f=sync_f,
614
+ text_f=text_f,
615
+ inpaint_masked_input=inpaint_masked_input,
616
+ t5_features=t5_features,
617
+ metaclip_global_text_features=metaclip_global_text_features,
618
+ cfg_scale=cfg_scale,
619
+ cfg_dropout_prob=cfg_dropout_prob,
620
+ scale_phi=scale_phi,
621
+ **kwargs)
622
+
623
+ class MMConditionedDiffusionModelWrapper(ConditionedDiffusionModel):
624
+ """
625
+ A diffusion model that takes in conditioning
626
+ """
627
+ def __init__(
628
+ self,
629
+ model: MMAudio,
630
+ conditioner: MultiConditioner,
631
+ io_channels,
632
+ sample_rate,
633
+ min_input_length: int,
634
+ diffusion_objective: tp.Literal["v", "rectified_flow"] = "v",
635
+ pretransform: tp.Optional[Pretransform] = None,
636
+ cross_attn_cond_ids: tp.List[str] = [],
637
+ global_cond_ids: tp.List[str] = [],
638
+ input_concat_ids: tp.List[str] = [],
639
+ prepend_cond_ids: tp.List[str] = [],
640
+ add_cond_ids: tp.List[str] = [],
641
+ mm_cond_ids: tp.List[str] = [],
642
+ ):
643
+ super().__init__()
644
+
645
+ self.model = model
646
+ self.conditioner = conditioner
647
+ self.io_channels = io_channels
648
+ self.sample_rate = sample_rate
649
+ self.diffusion_objective = diffusion_objective
650
+ self.pretransform = pretransform
651
+ self.cross_attn_cond_ids = cross_attn_cond_ids
652
+ self.global_cond_ids = global_cond_ids
653
+ self.input_concat_ids = input_concat_ids
654
+ self.prepend_cond_ids = prepend_cond_ids
655
+ self.add_cond_ids = add_cond_ids
656
+ self.min_input_length = min_input_length
657
+ self.mm_cond_ids = mm_cond_ids
658
+
659
+ assert len(self.cross_attn_cond_ids) == 0, "cross_attn_cond_ids is not supported for MMDiTWrapper"
660
+ assert len(self.global_cond_ids) == 0, "global_cond_ids is not supported for MMDiTWrapper"
661
+ assert len(self.input_concat_ids) == 0, "input_concat_ids is not supported for MMDiTWrapper"
662
+ assert len(self.prepend_cond_ids) == 0, "prepend_cond_ids is not supported for MMDiTWrapper"
663
+ assert len(self.add_cond_ids) == 0, "add_cond_ids is not supported for MMDiTWrapper"
664
+ assert len(self.mm_cond_ids) > 0, "mm_cond_ids must be specified for MMDiTWrapper"
665
+ assert "metaclip_features" in self.mm_cond_ids, "clip_f must be specified in mm_cond_ids for MMDiTWrapper"
666
+ assert "sync_features" in self.mm_cond_ids, "sync_features must be specified in mm_cond_ids for MMDiTWrapper"
667
+ assert "metaclip_text_features" in self.mm_cond_ids, "metaclip_text_features must be specified in mm_cond_ids for MMDiTWrapper"
668
+ # assert len(self.mm_cond_ids) == 3, "mm_cond_ids must be clip_f sync_f text_f for MMDiTWrapper"
669
+
670
+ def get_conditioning_inputs(self, conditioning_tensors: tp.Dict[str, tp.Any], negative=False):
671
+ assert negative == False, "negative conditioning is not supported for MMDiTWrapper"
672
+ cross_attention_input = None
673
+ cross_attention_masks = None
674
+ global_cond = None
675
+ input_concat_cond = None
676
+ prepend_cond = None
677
+ prepend_cond_mask = None
678
+ add_input = None
679
+ inpaint_masked_input = None
680
+ t5_features = None
681
+ metaclip_global_text_features = None
682
+ clip_f = conditioning_tensors["metaclip_features"]
683
+ sync_f = conditioning_tensors["sync_features"]
684
+ text_f = conditioning_tensors["metaclip_text_features"]
685
+ if 'inpaint_masked_input' in conditioning_tensors.keys():
686
+ inpaint_masked_input = conditioning_tensors["inpaint_masked_input"]
687
+ if 't5_features' in conditioning_tensors.keys():
688
+ t5_features = conditioning_tensors["t5_features"]
689
+ if 'metaclip_global_text_features' in conditioning_tensors.keys():
690
+ metaclip_global_text_features = conditioning_tensors["metaclip_global_text_features"]
691
+ return {
692
+ "clip_f": clip_f,
693
+ "sync_f": sync_f,
694
+ "text_f": text_f,
695
+ "inpaint_masked_input": inpaint_masked_input,
696
+ "t5_features": t5_features,
697
+ "metaclip_global_text_features": metaclip_global_text_features
698
+ }
699
+
700
+ def forward(self, x: torch.Tensor, t: torch.Tensor, cond: tp.Dict[str, tp.Any], **kwargs):
701
+ # breakpoint()
702
+ # print(kwargs)
703
+ return self.model(x=x, t=t, **self.get_conditioning_inputs(cond), **kwargs)
704
+
705
+ def generate(self, *args, **kwargs):
706
+ return generate_diffusion_cond(self, *args, **kwargs)
707
+
708
+ class DiTUncondWrapper(DiffusionModel):
709
+ def __init__(
710
+ self,
711
+ io_channels,
712
+ *args,
713
+ **kwargs
714
+ ):
715
+ super().__init__()
716
+
717
+ self.model = DiffusionTransformer(io_channels=io_channels, *args, **kwargs)
718
+
719
+ self.io_channels = io_channels
720
+
721
+ with torch.no_grad():
722
+ for param in self.model.parameters():
723
+ param *= 0.5
724
+
725
+ def forward(self, x, t, **kwargs):
726
+ return self.model(x, t, **kwargs)
727
+
728
+ def create_diffusion_uncond_from_config(config: tp.Dict[str, tp.Any]):
729
+ diffusion_uncond_config = config["model"]
730
+
731
+ model_type = diffusion_uncond_config.get('type', None)
732
+
733
+ diffusion_config = diffusion_uncond_config.get('config', {})
734
+
735
+ assert model_type is not None, "Must specify model type in config"
736
+
737
+ pretransform = diffusion_uncond_config.get("pretransform", None)
738
+
739
+ sample_size = config.get("sample_size", None)
740
+ assert sample_size is not None, "Must specify sample size in config"
741
+
742
+ sample_rate = config.get("sample_rate", None)
743
+ assert sample_rate is not None, "Must specify sample rate in config"
744
+
745
+ if pretransform is not None:
746
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
747
+ min_input_length = pretransform.downsampling_ratio
748
+ else:
749
+ min_input_length = 1
750
+
751
+ if model_type == 'DAU1d':
752
+
753
+ model = DiffusionAttnUnet1D(
754
+ **diffusion_config
755
+ )
756
+
757
+ elif model_type == "adp_uncond_1d":
758
+
759
+ model = UNet1DUncondWrapper(
760
+ **diffusion_config
761
+ )
762
+
763
+ elif model_type == "dit":
764
+ model = DiTUncondWrapper(
765
+ **diffusion_config
766
+ )
767
+
768
+ else:
769
+ raise NotImplementedError(f'Unknown model type: {model_type}')
770
+
771
+ return DiffusionModelWrapper(model,
772
+ io_channels=model.io_channels,
773
+ sample_size=sample_size,
774
+ sample_rate=sample_rate,
775
+ pretransform=pretransform,
776
+ min_input_length=min_input_length)
777
+
778
+ def create_diffusion_infill_from_config(config: tp.Dict[str, tp.Any]):
779
+ diffusion_uncond_config = config["model"]
780
+
781
+
782
+ diffusion_config = diffusion_uncond_config.get('diffusion', {})
783
+ model_type = diffusion_config.get('type', None)
784
+ model_config = diffusion_config.get("config",{})
785
+ assert model_type is not None, "Must specify model type in config"
786
+
787
+ pretransform = diffusion_uncond_config.get("pretransform", None)
788
+
789
+ sample_size = config.get("sample_size", None)
790
+ assert sample_size is not None, "Must specify sample size in config"
791
+
792
+ sample_rate = config.get("sample_rate", None)
793
+ assert sample_rate is not None, "Must specify sample rate in config"
794
+
795
+ if pretransform is not None:
796
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
797
+ min_input_length = pretransform.downsampling_ratio
798
+ else:
799
+ min_input_length = 1
800
+
801
+ if model_type == 'DAU1d':
802
+
803
+ model = DiffusionAttnUnet1D(
804
+ **model_config
805
+ )
806
+
807
+ elif model_type == "adp_uncond_1d":
808
+
809
+ model = UNet1DUncondWrapper(
810
+ io_channels = io_channels,
811
+ **model_config
812
+ )
813
+ elif model_type == "dit":
814
+ model = DiTUncondWrapper(
815
+ **model_config
816
+ )
817
+
818
+ else:
819
+ raise NotImplementedError(f'Unknown model type: {model_type}')
820
+
821
+ return DiffusionModelWrapper(model,
822
+ io_channels=model.io_channels,
823
+ sample_size=sample_size,
824
+ sample_rate=sample_rate,
825
+ pretransform=pretransform,
826
+ min_input_length=min_input_length)
827
+
828
+ def create_diffusion_cond_from_config(config: tp.Dict[str, tp.Any]):
829
+
830
+ model_config = config["model"]
831
+
832
+ model_type = config["model_type"]
833
+
834
+ diffusion_config = model_config.get('diffusion', None)
835
+ assert diffusion_config is not None, "Must specify diffusion config"
836
+
837
+ diffusion_model_type = diffusion_config.get('type', None)
838
+ assert diffusion_model_type is not None, "Must specify diffusion model type"
839
+
840
+ diffusion_model_config = diffusion_config.get('config', None)
841
+ assert diffusion_model_config is not None, "Must specify diffusion model config"
842
+
843
+ if diffusion_model_type == 'adp_cfg_1d':
844
+ diffusion_model = UNetCFG1DWrapper(**diffusion_model_config)
845
+ elif diffusion_model_type == 'adp_1d':
846
+ diffusion_model = UNet1DCondWrapper(**diffusion_model_config)
847
+ elif diffusion_model_type == 'dit':
848
+ diffusion_model = DiTWrapper(**diffusion_model_config)
849
+ elif diffusion_model_type == 'mmdit':
850
+ diffusion_model = MMDiTWrapper(**diffusion_model_config)
851
+
852
+ io_channels = model_config.get('io_channels', None)
853
+ assert io_channels is not None, "Must specify io_channels in model config"
854
+
855
+ sample_rate = config.get('sample_rate', None)
856
+ assert sample_rate is not None, "Must specify sample_rate in config"
857
+
858
+ diffusion_objective = diffusion_config.get('diffusion_objective', 'v')
859
+
860
+ conditioning_config = model_config.get('conditioning', None)
861
+
862
+ conditioner = None
863
+ if conditioning_config is not None:
864
+ conditioner = create_multi_conditioner_from_conditioning_config(conditioning_config)
865
+
866
+ cross_attention_ids = diffusion_config.get('cross_attention_cond_ids', [])
867
+ add_cond_ids = diffusion_config.get('add_cond_ids', [])
868
+ global_cond_ids = diffusion_config.get('global_cond_ids', [])
869
+ input_concat_ids = diffusion_config.get('input_concat_ids', [])
870
+ prepend_cond_ids = diffusion_config.get('prepend_cond_ids', [])
871
+ mm_cond_ids = diffusion_config.get('mm_cond_ids', [])
872
+
873
+ pretransform = model_config.get("pretransform", None)
874
+
875
+ if pretransform is not None:
876
+ pretransform = create_pretransform_from_config(pretransform, sample_rate)
877
+ min_input_length = pretransform.downsampling_ratio
878
+ else:
879
+ min_input_length = 1
880
+
881
+ if diffusion_model_type == "adp_cfg_1d" or diffusion_model_type == "adp_1d":
882
+ min_input_length *= np.prod(diffusion_model_config["factors"])
883
+ elif diffusion_model_type == "dit":
884
+ min_input_length *= diffusion_model.model.patch_size
885
+
886
+ # Get the proper wrapper class
887
+
888
+ extra_kwargs = {}
889
+
890
+ if model_type == "mm_diffusion_cond":
891
+ wrapper_fn = MMConditionedDiffusionModelWrapper
892
+ extra_kwargs["diffusion_objective"] = diffusion_objective
893
+ extra_kwargs["mm_cond_ids"] = mm_cond_ids
894
+
895
+ if model_type == "diffusion_cond" or model_type == "diffusion_cond_inpaint" or model_type == 'diffusion_infill':
896
+ wrapper_fn = ConditionedDiffusionModelWrapper
897
+ extra_kwargs["diffusion_objective"] = diffusion_objective
898
+
899
+ elif model_type == "diffusion_prior":
900
+ prior_type = model_config.get("prior_type", None)
901
+ assert prior_type is not None, "Must specify prior_type in diffusion prior model config"
902
+
903
+ if prior_type == "mono_stereo":
904
+ from .diffusion_prior import MonoToStereoDiffusionPrior
905
+ wrapper_fn = MonoToStereoDiffusionPrior
906
+
907
+ return wrapper_fn(
908
+ diffusion_model,
909
+ conditioner,
910
+ min_input_length=min_input_length,
911
+ sample_rate=sample_rate,
912
+ cross_attn_cond_ids=cross_attention_ids,
913
+ global_cond_ids=global_cond_ids,
914
+ input_concat_ids=input_concat_ids,
915
+ prepend_cond_ids=prepend_cond_ids,
916
+ add_cond_ids=add_cond_ids,
917
+ pretransform=pretransform,
918
+ io_channels=io_channels,
919
+ **extra_kwargs
920
+ )
ThinkSound/models/dit.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+ # from beartype.typing import Tuple
5
+ from einops import rearrange
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from x_transformers import ContinuousTransformerWrapper, Encoder
9
+ from .mmmodules.model.low_level import MLP, ChannelLastConv1d, ConvMLP
10
+ from .blocks import FourierFeatures
11
+ from .transformer import ContinuousTransformer
12
+ from .utils import mask_from_frac_lengths, resample
13
+ class DiffusionTransformer(nn.Module):
14
+ def __init__(self,
15
+ io_channels=32,
16
+ patch_size=1,
17
+ embed_dim=768,
18
+ cond_token_dim=0,
19
+ project_cond_tokens=True,
20
+ global_cond_dim=0,
21
+ project_global_cond=True,
22
+ input_concat_dim=0,
23
+ prepend_cond_dim=0,
24
+ cond_ctx_dim=0,
25
+ depth=12,
26
+ num_heads=8,
27
+ transformer_type: tp.Literal["x-transformers", "continuous_transformer","mm_transformer"] = "x-transformers",
28
+ global_cond_type: tp.Literal["prepend", "adaLN"] = "prepend",
29
+ frac_lengths_mask = (0.7, 1.),
30
+ ctx_drop: float = 0.1,
31
+ add_token_dim=0,
32
+ use_mlp=False,
33
+ **kwargs):
34
+
35
+ super().__init__()
36
+
37
+ self.cond_token_dim = cond_token_dim
38
+
39
+ # Timestep embeddings
40
+ timestep_features_dim = 256
41
+
42
+ self.timestep_features = FourierFeatures(1, timestep_features_dim)
43
+
44
+ self.to_timestep_embed = nn.Sequential(
45
+ nn.Linear(timestep_features_dim, embed_dim, bias=True),
46
+ nn.SiLU(),
47
+ nn.Linear(embed_dim, embed_dim, bias=True),
48
+ )
49
+ self.use_mlp = use_mlp
50
+ if cond_token_dim > 0:
51
+ # Conditioning tokens
52
+ cond_embed_dim = cond_token_dim if not project_cond_tokens else embed_dim
53
+ self.to_cond_embed = nn.Sequential(
54
+ nn.Linear(cond_token_dim, cond_embed_dim, bias=False),
55
+ nn.SiLU(),
56
+ nn.Linear(cond_embed_dim, cond_embed_dim, bias=False)
57
+ )
58
+ else:
59
+ cond_embed_dim = 0
60
+
61
+ if global_cond_dim > 0:
62
+ # Global conditioning
63
+ global_embed_dim = global_cond_dim if not project_global_cond else embed_dim
64
+ self.to_global_embed = nn.Sequential(
65
+ nn.Linear(global_cond_dim, global_embed_dim, bias=False),
66
+ nn.SiLU(),
67
+ nn.Linear(global_embed_dim, global_embed_dim, bias=False)
68
+ )
69
+
70
+ if add_token_dim > 0:
71
+ # Conditioning tokens
72
+
73
+ add_embed_dim = add_token_dim if not project_cond_tokens else embed_dim
74
+ self.to_add_embed = nn.Sequential(
75
+ nn.SiLU(),
76
+ ConvMLP(add_embed_dim, add_embed_dim * 4, kernel_size=3, padding=1),
77
+ )
78
+ else:
79
+ add_embed_dim = 0
80
+
81
+ if cond_ctx_dim > 0:
82
+ self.ctx_linear = nn.Linear(cond_ctx_dim*2, cond_ctx_dim, bias=True)
83
+ self.frac_lengths_mask = frac_lengths_mask
84
+ self.ctx_drop = ctx_drop
85
+
86
+ if prepend_cond_dim > 0:
87
+ # Prepend conditioning
88
+ self.to_prepend_embed = nn.Sequential(
89
+ nn.Linear(prepend_cond_dim, embed_dim, bias=False),
90
+ nn.SiLU(),
91
+ nn.Linear(embed_dim, embed_dim, bias=False)
92
+ )
93
+
94
+ self.input_concat_dim = input_concat_dim
95
+
96
+ dim_in = io_channels + self.input_concat_dim
97
+
98
+ self.patch_size = patch_size
99
+
100
+ # Transformer
101
+
102
+ self.transformer_type = transformer_type
103
+
104
+ self.global_cond_type = global_cond_type
105
+ print("######################")
106
+ print(f'global type: {global_cond_type}')
107
+ print("######################")
108
+ if self.transformer_type == "x-transformers":
109
+ self.transformer = ContinuousTransformerWrapper(
110
+ dim_in=dim_in * patch_size,
111
+ dim_out=io_channels * patch_size,
112
+ max_seq_len=0, #Not relevant without absolute positional embeds
113
+ attn_layers = Encoder(
114
+ dim=embed_dim,
115
+ depth=depth,
116
+ heads=num_heads,
117
+ attn_flash = True,
118
+ cross_attend = cond_token_dim > 0,
119
+ dim_context=None if cond_embed_dim == 0 else cond_embed_dim,
120
+ zero_init_branch_output=True,
121
+ use_abs_pos_emb = False,
122
+ rotary_pos_emb=True,
123
+ ff_swish = True,
124
+ ff_glu = True,
125
+ **kwargs
126
+ )
127
+ )
128
+
129
+ elif self.transformer_type == "continuous_transformer":
130
+
131
+ global_dim = None
132
+
133
+ if self.global_cond_type == "adaLN":
134
+ # The global conditioning is projected to the embed_dim already at this point
135
+ global_dim = embed_dim
136
+
137
+ self.transformer = ContinuousTransformer(
138
+ dim=embed_dim,
139
+ depth=depth,
140
+ dim_heads=embed_dim // num_heads,
141
+ dim_in=dim_in * patch_size,
142
+ dim_out=io_channels * patch_size,
143
+ cross_attend = cond_token_dim > 0,
144
+ cond_token_dim = cond_embed_dim,
145
+ global_cond_dim=global_dim,
146
+ **kwargs
147
+ )
148
+
149
+ else:
150
+ raise ValueError(f"Unknown transformer type: {self.transformer_type}")
151
+
152
+ self.preprocess_conv = nn.Conv1d(dim_in, dim_in, 1, bias=False)
153
+ nn.init.zeros_(self.preprocess_conv.weight)
154
+ self.postprocess_conv = nn.Conv1d(io_channels, io_channels, 1, bias=False)
155
+ nn.init.zeros_(self.postprocess_conv.weight)
156
+
157
+ def _forward(
158
+ self,
159
+ x,
160
+ t,
161
+ mask=None,
162
+ cross_attn_cond=None,
163
+ cross_attn_cond_mask=None,
164
+ input_concat_cond=None,
165
+ global_embed=None,
166
+ prepend_cond=None,
167
+ prepend_cond_mask=None,
168
+ add_cond=None,
169
+ add_masks=None,
170
+ # x_ctx=None,
171
+ return_info=False,
172
+ **kwargs):
173
+
174
+ if cross_attn_cond is not None:
175
+ cross_attn_cond = self.to_cond_embed(cross_attn_cond)
176
+ if global_embed is not None:
177
+ # Project the global conditioning to the embedding dimension
178
+ global_embed = self.to_global_embed(global_embed)
179
+ if len(global_embed.shape) == 3:
180
+ global_embed = torch.max(global_embed, dim=1).values
181
+
182
+ prepend_inputs = None
183
+ prepend_mask = None
184
+ prepend_length = 0
185
+ if prepend_cond is not None:
186
+ # Project the prepend conditioning to the embedding dimension
187
+ prepend_cond = self.to_prepend_embed(prepend_cond)
188
+
189
+ prepend_inputs = prepend_cond
190
+ if prepend_cond_mask is not None:
191
+ prepend_mask = prepend_cond_mask
192
+
193
+ if input_concat_cond is not None:
194
+
195
+ # Interpolate input_concat_cond to the same length as x
196
+ if input_concat_cond.shape[2] != x.shape[2]:
197
+ input_concat_cond = F.interpolate(input_concat_cond, (x.shape[2], ), mode='nearest-exact')
198
+
199
+ x = torch.cat([x, input_concat_cond], dim=1)
200
+
201
+ if add_cond is not None:
202
+ # Interpolate input_concat_cond to the same length as x
203
+
204
+ if self.use_mlp:
205
+ add_cond = self.to_add_embed(add_cond)
206
+ if add_cond.shape[1] != x.shape[2]:
207
+ # add_cond = add_cond.transpose(1,2)
208
+ # add_cond = F.interpolate(add_cond, (x.shape[2], ), mode='nearest-exact')
209
+ # add_cond = add_cond.transpose(1,2)
210
+ add_cond = resample(add_cond, x)
211
+
212
+ # Get the batch of timestep embeddings
213
+ timestep_embed = self.to_timestep_embed(self.timestep_features(t[:, None])) # (b, embed_dim)
214
+ # import ipdb
215
+ # ipdb.set_trace()
216
+ # Timestep embedding is considered a global embedding. Add to the global conditioning if it exists
217
+ if global_embed is not None:
218
+ global_embed = global_embed + timestep_embed
219
+ else:
220
+ global_embed = timestep_embed
221
+
222
+ # Add the global_embed to the prepend inputs if there is no global conditioning support in the transformer
223
+ if self.global_cond_type == "prepend":
224
+ if prepend_inputs is None:
225
+ # Prepend inputs are just the global embed, and the mask is all ones
226
+ prepend_inputs = global_embed.unsqueeze(1)
227
+ prepend_mask = torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)
228
+ else:
229
+ # Prepend inputs are the prepend conditioning + the global embed
230
+ prepend_inputs = torch.cat([prepend_inputs, global_embed.unsqueeze(1)], dim=1)
231
+ prepend_mask = torch.cat([prepend_mask, torch.ones((x.shape[0], 1), device=x.device, dtype=torch.bool)], dim=1)
232
+
233
+ prepend_length = prepend_inputs.shape[1]
234
+
235
+ x = self.preprocess_conv(x) + x
236
+ x = rearrange(x, "b c t -> b t c")
237
+
238
+
239
+ extra_args = {}
240
+
241
+ if self.global_cond_type == "adaLN":
242
+ extra_args["global_cond"] = global_embed
243
+
244
+ if self.patch_size > 1:
245
+ x = rearrange(x, "b (t p) c -> b t (c p)", p=self.patch_size)
246
+
247
+ if self.transformer_type == "x-transformers":
248
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, context_mask=cross_attn_cond_mask, add_cond=add_cond, mask=mask, prepend_mask=prepend_mask, **extra_args, **kwargs)
249
+ elif self.transformer_type == "continuous_transformer":
250
+ output = self.transformer(x, prepend_embeds=prepend_inputs, context=cross_attn_cond, add_cond=add_cond, context_mask=cross_attn_cond_mask, mask=mask, prepend_mask=prepend_mask, return_info=return_info, **extra_args, **kwargs)
251
+
252
+ if return_info:
253
+ output, info = output
254
+ elif self.transformer_type == "mm_transformer":
255
+ output = self.transformer(x, context=cross_attn_cond, mask=mask, context_mask=cross_attn_cond_mask, **extra_args, **kwargs)
256
+
257
+ output = rearrange(output, "b t c -> b c t")[:,:,prepend_length:]
258
+
259
+ if self.patch_size > 1:
260
+ output = rearrange(output, "b (c p) t -> b c (t p)", p=self.patch_size)
261
+
262
+ output = self.postprocess_conv(output) + output
263
+
264
+ if return_info:
265
+ return output, info
266
+
267
+ return output
268
+
269
+ def forward(
270
+ self,
271
+ x,
272
+ t,
273
+ cross_attn_cond=None,
274
+ cross_attn_cond_mask=None,
275
+ negative_cross_attn_cond=None,
276
+ negative_cross_attn_mask=None,
277
+ input_concat_cond=None,
278
+ global_embed=None,
279
+ negative_global_embed=None,
280
+ prepend_cond=None,
281
+ prepend_cond_mask=None,
282
+ add_cond=None,
283
+ cfg_scale=1.0,
284
+ cfg_dropout_prob=0.0,
285
+ causal=False,
286
+ scale_phi=0.0,
287
+ mask=None,
288
+ x_ctx=None,
289
+ ctx_mask=None,
290
+ return_info=False,
291
+ **kwargs):
292
+
293
+ assert causal == False, "Causal mode is not supported for DiffusionTransformer"
294
+ bsz, a, b = x.shape
295
+
296
+ if cross_attn_cond_mask is not None:
297
+ cross_attn_cond_mask = cross_attn_cond_mask.bool()
298
+
299
+ cross_attn_cond_mask = None # Temporarily disabling conditioning masks due to kernel issue for flash attention
300
+
301
+ if prepend_cond_mask is not None:
302
+ prepend_cond_mask = prepend_cond_mask.bool()
303
+
304
+ # CFG dropout
305
+ if cfg_dropout_prob > 0.0:
306
+ if cross_attn_cond is not None:
307
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
308
+ dropout_mask = torch.bernoulli(torch.full((cross_attn_cond.shape[0], 1, 1), cfg_dropout_prob, device=cross_attn_cond.device)).to(torch.bool)
309
+ cross_attn_cond = torch.where(dropout_mask, null_embed, cross_attn_cond)
310
+
311
+ if prepend_cond is not None:
312
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
313
+ dropout_mask = torch.bernoulli(torch.full((prepend_cond.shape[0], 1, 1), cfg_dropout_prob, device=prepend_cond.device)).to(torch.bool)
314
+ prepend_cond = torch.where(dropout_mask, null_embed, prepend_cond)
315
+
316
+ if add_cond is not None:
317
+ null_embed = torch.zeros_like(add_cond, device=add_cond.device)
318
+ dropout_mask = torch.bernoulli(torch.full((add_cond.shape[0], 1, 1), cfg_dropout_prob, device=add_cond.device)).to(torch.bool)
319
+ add_cond = torch.where(dropout_mask, null_embed, add_cond)
320
+
321
+ if cfg_scale != 1.0 and (cross_attn_cond is not None or prepend_cond is not None or add_cond is not None):
322
+ # Classifier-free guidance
323
+ # Concatenate conditioned and unconditioned inputs on the batch dimension
324
+ batch_inputs = torch.cat([x, x], dim=0)
325
+ batch_timestep = torch.cat([t, t], dim=0)
326
+
327
+ if global_embed is not None:
328
+ batch_global_cond = torch.cat([global_embed, global_embed], dim=0)
329
+ else:
330
+ batch_global_cond = None
331
+
332
+ if input_concat_cond is not None:
333
+ batch_input_concat_cond = torch.cat([input_concat_cond, input_concat_cond], dim=0)
334
+ else:
335
+ batch_input_concat_cond = None
336
+
337
+ batch_cond = None
338
+ batch_cond_masks = None
339
+
340
+ # Handle CFG for cross-attention conditioning
341
+ if cross_attn_cond is not None:
342
+
343
+ null_embed = torch.zeros_like(cross_attn_cond, device=cross_attn_cond.device)
344
+
345
+ # For negative cross-attention conditioning, replace the null embed with the negative cross-attention conditioning
346
+ if negative_cross_attn_cond is not None:
347
+
348
+ # If there's a negative cross-attention mask, set the masked tokens to the null embed
349
+ if negative_cross_attn_mask is not None:
350
+ negative_cross_attn_mask = negative_cross_attn_mask.to(torch.bool).unsqueeze(2)
351
+
352
+ negative_cross_attn_cond = torch.where(negative_cross_attn_mask, negative_cross_attn_cond, null_embed)
353
+
354
+ batch_cond = torch.cat([cross_attn_cond, negative_cross_attn_cond], dim=0)
355
+
356
+ else:
357
+ batch_cond = torch.cat([cross_attn_cond, null_embed], dim=0)
358
+
359
+ if cross_attn_cond_mask is not None:
360
+ batch_cond_masks = torch.cat([cross_attn_cond_mask, cross_attn_cond_mask], dim=0)
361
+
362
+ batch_prepend_cond = None
363
+ batch_prepend_cond_mask = None
364
+
365
+ if prepend_cond is not None:
366
+
367
+ null_embed = torch.zeros_like(prepend_cond, device=prepend_cond.device)
368
+
369
+ batch_prepend_cond = torch.cat([prepend_cond, null_embed], dim=0)
370
+
371
+ if prepend_cond_mask is not None:
372
+ batch_prepend_cond_mask = torch.cat([prepend_cond_mask, prepend_cond_mask], dim=0)
373
+
374
+ batch_add_cond = None
375
+
376
+ # Handle CFG for cross-attention conditioning
377
+ if add_cond is not None:
378
+
379
+ null_embed = torch.zeros_like(add_cond, device=add_cond.device)
380
+
381
+
382
+ batch_add_cond = torch.cat([add_cond, null_embed], dim=0)
383
+
384
+
385
+ if mask is not None:
386
+ batch_masks = torch.cat([mask, mask], dim=0)
387
+ else:
388
+ batch_masks = None
389
+
390
+ batch_output = self._forward(
391
+ batch_inputs,
392
+ batch_timestep,
393
+ cross_attn_cond=batch_cond,
394
+ cross_attn_cond_mask=batch_cond_masks,
395
+ mask = batch_masks,
396
+ # x_ctx=x_ctx,
397
+ input_concat_cond=batch_input_concat_cond,
398
+ global_embed = batch_global_cond,
399
+ prepend_cond = batch_prepend_cond,
400
+ prepend_cond_mask = batch_prepend_cond_mask,
401
+ add_cond = batch_add_cond,
402
+ return_info = return_info,
403
+ **kwargs)
404
+
405
+ if return_info:
406
+ batch_output, info = batch_output
407
+
408
+ cond_output, uncond_output = torch.chunk(batch_output, 2, dim=0)
409
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
410
+
411
+ # CFG Rescale
412
+ if scale_phi != 0.0:
413
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
414
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
415
+ output = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
416
+ else:
417
+ output = cfg_output
418
+
419
+ if return_info:
420
+ return output, info
421
+
422
+ return output
423
+
424
+ else:
425
+ return self._forward(
426
+ x,
427
+ t,
428
+ cross_attn_cond=cross_attn_cond,
429
+ cross_attn_cond_mask=cross_attn_cond_mask,
430
+ input_concat_cond=input_concat_cond,
431
+ global_embed=global_embed,
432
+ prepend_cond=prepend_cond,
433
+ prepend_cond_mask=prepend_cond_mask,
434
+ add_cond=add_cond,
435
+ # x_ctx=x_ctx,
436
+ mask=mask,
437
+ return_info=return_info,
438
+ **kwargs
439
+ )
ThinkSound/models/embeddings.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ # https://github.com/facebookresearch/DiT
5
+
6
+ from typing import Union
7
+
8
+ import torch
9
+ from einops import rearrange
10
+ from torch import Tensor
11
+
12
+ # Ref: https://github.com/black-forest-labs/flux/blob/main/src/flux/math.py
13
+ # Ref: https://github.com/lucidrains/rotary-embedding-torch
14
+
15
+
16
+ def compute_rope_rotations(length: int,
17
+ dim: int,
18
+ theta: int,
19
+ *,
20
+ freq_scaling: float = 1.0,
21
+ device: Union[torch.device, str] = 'cpu') -> Tensor:
22
+ assert dim % 2 == 0
23
+
24
+ with torch.amp.autocast(device_type='cuda', enabled=False):
25
+ pos = torch.arange(length, dtype=torch.float32, device=device)
26
+ freqs = 1.0 / (theta**(torch.arange(0, dim, 2, dtype=torch.float32, device=device) / dim))
27
+ freqs *= freq_scaling
28
+
29
+ rot = torch.einsum('..., f -> ... f', pos, freqs)
30
+ rot = torch.stack([torch.cos(rot), -torch.sin(rot), torch.sin(rot), torch.cos(rot)], dim=-1)
31
+ rot = rearrange(rot, 'n d (i j) -> 1 n d i j', i=2, j=2)
32
+ return rot
33
+
34
+
35
+ def apply_rope(x: Tensor, rot: Tensor) -> tuple[Tensor, Tensor]:
36
+ with torch.amp.autocast(device_type='cuda', enabled=False):
37
+ _x = x.float()
38
+ _x = _x.view(*_x.shape[:-1], -1, 1, 2)
39
+ x_out = rot[..., 0] * _x[..., 0] + rot[..., 1] * _x[..., 1]
40
+ return x_out.reshape(*x.shape).to(dtype=x.dtype)
41
+
42
+
43
+ class TimestepEmbedder(nn.Module):
44
+ """
45
+ Embeds scalar timesteps into vector representations.
46
+ """
47
+
48
+ def __init__(self, dim, frequency_embedding_size, max_period):
49
+ super().__init__()
50
+ self.mlp = nn.Sequential(
51
+ nn.Linear(frequency_embedding_size, dim),
52
+ nn.SiLU(),
53
+ nn.Linear(dim, dim),
54
+ )
55
+ self.dim = dim
56
+ self.max_period = max_period
57
+ assert dim % 2 == 0, 'dim must be even.'
58
+
59
+ with torch.autocast('cuda', enabled=False):
60
+ self.register_buffer("freqs",
61
+ 1.0 / (10000**(torch.arange(0, frequency_embedding_size, 2, dtype=torch.float32) /
62
+ frequency_embedding_size)),
63
+ persistent=False)
64
+ freq_scale = 10000 / max_period
65
+ self.freqs = freq_scale * self.freqs
66
+
67
+ def timestep_embedding(self, t):
68
+ """
69
+ Create sinusoidal timestep embeddings.
70
+ :param t: a 1-D Tensor of N indices, one per batch element.
71
+ These may be fractional.
72
+ :param dim: the dimension of the output.
73
+ :param max_period: controls the minimum frequency of the embeddings.
74
+ :return: an (N, D) Tensor of positional embeddings.
75
+ """
76
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
77
+
78
+ args = t[:, None].float() * self.freqs[None]
79
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
80
+ return embedding
81
+
82
+ def forward(self, t):
83
+ t_freq = self.timestep_embedding(t).to(t.dtype)
84
+ t_emb = self.mlp(t_freq)
85
+ return t_emb
ThinkSound/models/factory.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ def create_model_from_config(model_config):
4
+ model_type = model_config.get('model_type', None)
5
+
6
+ assert model_type is not None, 'model_type must be specified in model config'
7
+
8
+ if model_type == 'autoencoder':
9
+ from .autoencoders import create_autoencoder_from_config
10
+ return create_autoencoder_from_config(model_config)
11
+ elif model_type == 'diffusion_uncond':
12
+ from .diffusion import create_diffusion_uncond_from_config
13
+ return create_diffusion_uncond_from_config(model_config)
14
+ # elif model_type == 'diffusion_infill':
15
+ # from .diffusion import create_diffusion_infill_from_config
16
+ # return create_diffusion_infill_from_config(model_config)
17
+ elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior" or model_type == "diffusion_infill" or model_type == "mm_diffusion_cond":
18
+ from .diffusion import create_diffusion_cond_from_config
19
+ return create_diffusion_cond_from_config(model_config)
20
+ elif model_type == 'diffusion_autoencoder':
21
+ from .autoencoders import create_diffAE_from_config
22
+ return create_diffAE_from_config(model_config)
23
+ elif model_type == 'lm':
24
+ from .lm import create_audio_lm_from_config
25
+ return create_audio_lm_from_config(model_config)
26
+ else:
27
+ raise NotImplementedError(f'Unknown model type: {model_type}')
28
+
29
+ def create_model_from_config_path(model_config_path):
30
+ with open(model_config_path) as f:
31
+ model_config = json.load(f)
32
+
33
+ return create_model_from_config(model_config)
34
+
35
+ def create_pretransform_from_config(pretransform_config, sample_rate):
36
+ pretransform_type = pretransform_config.get('type', None)
37
+
38
+ assert pretransform_type is not None, 'type must be specified in pretransform config'
39
+
40
+ if pretransform_type == 'autoencoder':
41
+ from .autoencoders import create_autoencoder_from_config
42
+ from .pretransforms import AutoencoderPretransform
43
+
44
+ # Create fake top-level config to pass sample rate to autoencoder constructor
45
+ # This is a bit of a hack but it keeps us from re-defining the sample rate in the config
46
+ autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]}
47
+ autoencoder = create_autoencoder_from_config(autoencoder_config)
48
+
49
+ scale = pretransform_config.get("scale", 1.0)
50
+ model_half = pretransform_config.get("model_half", False)
51
+ iterate_batch = pretransform_config.get("iterate_batch", False)
52
+ chunked = pretransform_config.get("chunked", False)
53
+
54
+ pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked)
55
+ elif pretransform_type == 'wavelet':
56
+ from .pretransforms import WaveletPretransform
57
+
58
+ wavelet_config = pretransform_config["config"]
59
+ channels = wavelet_config["channels"]
60
+ levels = wavelet_config["levels"]
61
+ wavelet = wavelet_config["wavelet"]
62
+
63
+ pretransform = WaveletPretransform(channels, levels, wavelet)
64
+ elif pretransform_type == 'pqmf':
65
+ from .pretransforms import PQMFPretransform
66
+ pqmf_config = pretransform_config["config"]
67
+ pretransform = PQMFPretransform(**pqmf_config)
68
+ elif pretransform_type == 'dac_pretrained':
69
+ from .pretransforms import PretrainedDACPretransform
70
+ pretrained_dac_config = pretransform_config["config"]
71
+ pretransform = PretrainedDACPretransform(**pretrained_dac_config)
72
+ elif pretransform_type == "audiocraft_pretrained":
73
+ from .pretransforms import AudiocraftCompressionPretransform
74
+
75
+ audiocraft_config = pretransform_config["config"]
76
+ pretransform = AudiocraftCompressionPretransform(**audiocraft_config)
77
+ else:
78
+ raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}')
79
+
80
+ enable_grad = pretransform_config.get('enable_grad', False)
81
+ pretransform.enable_grad = enable_grad
82
+
83
+ pretransform.eval().requires_grad_(pretransform.enable_grad)
84
+
85
+ return pretransform
86
+
87
+ def create_bottleneck_from_config(bottleneck_config):
88
+ bottleneck_type = bottleneck_config.get('type', None)
89
+
90
+ assert bottleneck_type is not None, 'type must be specified in bottleneck config'
91
+
92
+ if bottleneck_type == 'tanh':
93
+ from .bottleneck import TanhBottleneck
94
+ bottleneck = TanhBottleneck()
95
+ elif bottleneck_type == 'vae':
96
+ from .bottleneck import VAEBottleneck
97
+ bottleneck = VAEBottleneck()
98
+ elif bottleneck_type == 'rvq':
99
+ from .bottleneck import RVQBottleneck
100
+
101
+ quantizer_params = {
102
+ "dim": 128,
103
+ "codebook_size": 1024,
104
+ "num_quantizers": 8,
105
+ "decay": 0.99,
106
+ "kmeans_init": True,
107
+ "kmeans_iters": 50,
108
+ "threshold_ema_dead_code": 2,
109
+ }
110
+
111
+ quantizer_params.update(bottleneck_config["config"])
112
+
113
+ bottleneck = RVQBottleneck(**quantizer_params)
114
+ elif bottleneck_type == "dac_rvq":
115
+ from .bottleneck import DACRVQBottleneck
116
+
117
+ bottleneck = DACRVQBottleneck(**bottleneck_config["config"])
118
+
119
+ elif bottleneck_type == 'rvq_vae':
120
+ from .bottleneck import RVQVAEBottleneck
121
+
122
+ quantizer_params = {
123
+ "dim": 128,
124
+ "codebook_size": 1024,
125
+ "num_quantizers": 8,
126
+ "decay": 0.99,
127
+ "kmeans_init": True,
128
+ "kmeans_iters": 50,
129
+ "threshold_ema_dead_code": 2,
130
+ }
131
+
132
+ quantizer_params.update(bottleneck_config["config"])
133
+
134
+ bottleneck = RVQVAEBottleneck(**quantizer_params)
135
+
136
+ elif bottleneck_type == 'dac_rvq_vae':
137
+ from .bottleneck import DACRVQVAEBottleneck
138
+ bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"])
139
+ elif bottleneck_type == 'l2_norm':
140
+ from .bottleneck import L2Bottleneck
141
+ bottleneck = L2Bottleneck()
142
+ elif bottleneck_type == "wasserstein":
143
+ from .bottleneck import WassersteinBottleneck
144
+ bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {}))
145
+ elif bottleneck_type == "fsq":
146
+ from .bottleneck import FSQBottleneck
147
+ bottleneck = FSQBottleneck(**bottleneck_config["config"])
148
+ else:
149
+ raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}')
150
+
151
+ requires_grad = bottleneck_config.get('requires_grad', True)
152
+ if not requires_grad:
153
+ for param in bottleneck.parameters():
154
+ param.requires_grad = False
155
+
156
+ return bottleneck
ThinkSound/models/local_attention.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from einops import rearrange
4
+ from torch import nn
5
+
6
+ from .blocks import AdaRMSNorm
7
+ from .transformer import Attention, FeedForward, RotaryEmbedding, LayerNorm
8
+
9
+ def checkpoint(function, *args, **kwargs):
10
+ kwargs.setdefault("use_reentrant", False)
11
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
12
+
13
+ # Adapted from https://github.com/lucidrains/local-attention/blob/master/local_attention/transformer.py
14
+ class ContinuousLocalTransformer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ *,
18
+ dim,
19
+ depth,
20
+ dim_in = None,
21
+ dim_out = None,
22
+ causal = False,
23
+ local_attn_window_size = 64,
24
+ heads = 8,
25
+ ff_mult = 2,
26
+ cond_dim = 0,
27
+ cross_attn_cond_dim = 0,
28
+ **kwargs
29
+ ):
30
+ super().__init__()
31
+
32
+ dim_head = dim//heads
33
+
34
+ self.layers = nn.ModuleList([])
35
+
36
+ self.project_in = nn.Linear(dim_in, dim) if dim_in is not None else nn.Identity()
37
+
38
+ self.project_out = nn.Linear(dim, dim_out) if dim_out is not None else nn.Identity()
39
+
40
+ self.local_attn_window_size = local_attn_window_size
41
+
42
+ self.cond_dim = cond_dim
43
+
44
+ self.cross_attn_cond_dim = cross_attn_cond_dim
45
+
46
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_head // 2, 32))
47
+
48
+ for _ in range(depth):
49
+
50
+ self.layers.append(nn.ModuleList([
51
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
52
+ Attention(
53
+ dim=dim,
54
+ dim_heads=dim_head,
55
+ causal=causal,
56
+ zero_init_output=True,
57
+ natten_kernel_size=local_attn_window_size,
58
+ ),
59
+ Attention(
60
+ dim=dim,
61
+ dim_heads=dim_head,
62
+ dim_context = cross_attn_cond_dim,
63
+ zero_init_output=True
64
+ ) if self.cross_attn_cond_dim > 0 else nn.Identity(),
65
+ AdaRMSNorm(dim, cond_dim, eps=1e-8) if cond_dim > 0 else LayerNorm(dim),
66
+ FeedForward(dim = dim, mult = ff_mult, no_bias=True)
67
+ ]))
68
+
69
+ def forward(self, x, mask = None, cond = None, cross_attn_cond = None, cross_attn_cond_mask = None, prepend_cond = None):
70
+
71
+ x = checkpoint(self.project_in, x)
72
+
73
+ if prepend_cond is not None:
74
+ x = torch.cat([prepend_cond, x], dim=1)
75
+
76
+ pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
77
+
78
+ for attn_norm, attn, xattn, ff_norm, ff in self.layers:
79
+
80
+ residual = x
81
+ if cond is not None:
82
+ x = checkpoint(attn_norm, x, cond)
83
+ else:
84
+ x = checkpoint(attn_norm, x)
85
+
86
+ x = checkpoint(attn, x, mask = mask, rotary_pos_emb=pos_emb) + residual
87
+
88
+ if cross_attn_cond is not None:
89
+ x = checkpoint(xattn, x, context=cross_attn_cond, context_mask=cross_attn_cond_mask) + x
90
+
91
+ residual = x
92
+
93
+ if cond is not None:
94
+ x = checkpoint(ff_norm, x, cond)
95
+ else:
96
+ x = checkpoint(ff_norm, x)
97
+
98
+ x = checkpoint(ff, x) + residual
99
+
100
+ return checkpoint(self.project_out, x)
101
+
102
+ class TransformerDownsampleBlock1D(nn.Module):
103
+ def __init__(
104
+ self,
105
+ in_channels,
106
+ embed_dim = 768,
107
+ depth = 3,
108
+ heads = 12,
109
+ downsample_ratio = 2,
110
+ local_attn_window_size = 64,
111
+ **kwargs
112
+ ):
113
+ super().__init__()
114
+
115
+ self.downsample_ratio = downsample_ratio
116
+
117
+ self.transformer = ContinuousLocalTransformer(
118
+ dim=embed_dim,
119
+ depth=depth,
120
+ heads=heads,
121
+ local_attn_window_size=local_attn_window_size,
122
+ **kwargs
123
+ )
124
+
125
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
126
+
127
+ self.project_down = nn.Linear(embed_dim * self.downsample_ratio, embed_dim, bias=False)
128
+
129
+
130
+ def forward(self, x):
131
+
132
+ x = checkpoint(self.project_in, x)
133
+
134
+ # Compute
135
+ x = self.transformer(x)
136
+
137
+ # Trade sequence length for channels
138
+ x = rearrange(x, "b (n r) c -> b n (c r)", r=self.downsample_ratio)
139
+
140
+ # Project back to embed dim
141
+ x = checkpoint(self.project_down, x)
142
+
143
+ return x
144
+
145
+ class TransformerUpsampleBlock1D(nn.Module):
146
+ def __init__(
147
+ self,
148
+ in_channels,
149
+ embed_dim,
150
+ depth = 3,
151
+ heads = 12,
152
+ upsample_ratio = 2,
153
+ local_attn_window_size = 64,
154
+ **kwargs
155
+ ):
156
+ super().__init__()
157
+
158
+ self.upsample_ratio = upsample_ratio
159
+
160
+ self.transformer = ContinuousLocalTransformer(
161
+ dim=embed_dim,
162
+ depth=depth,
163
+ heads=heads,
164
+ local_attn_window_size = local_attn_window_size,
165
+ **kwargs
166
+ )
167
+
168
+ self.project_in = nn.Linear(in_channels, embed_dim, bias=False) if in_channels != embed_dim else nn.Identity()
169
+
170
+ self.project_up = nn.Linear(embed_dim, embed_dim * self.upsample_ratio, bias=False)
171
+
172
+ def forward(self, x):
173
+
174
+ # Project to embed dim
175
+ x = checkpoint(self.project_in, x)
176
+
177
+ # Project to increase channel dim
178
+ x = checkpoint(self.project_up, x)
179
+
180
+ # Trade channels for sequence length
181
+ x = rearrange(x, "b n (c r) -> b (n r) c", r=self.upsample_ratio)
182
+
183
+ # Compute
184
+ x = self.transformer(x)
185
+
186
+ return x
187
+
188
+
189
+ class TransformerEncoder1D(nn.Module):
190
+ def __init__(
191
+ self,
192
+ in_channels,
193
+ out_channels,
194
+ embed_dims = [96, 192, 384, 768],
195
+ heads = [12, 12, 12, 12],
196
+ depths = [3, 3, 3, 3],
197
+ ratios = [2, 2, 2, 2],
198
+ local_attn_window_size = 64,
199
+ **kwargs
200
+ ):
201
+ super().__init__()
202
+
203
+ layers = []
204
+
205
+ for layer in range(len(depths)):
206
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
207
+
208
+ layers.append(
209
+ TransformerDownsampleBlock1D(
210
+ in_channels = prev_dim,
211
+ embed_dim = embed_dims[layer],
212
+ heads = heads[layer],
213
+ depth = depths[layer],
214
+ downsample_ratio = ratios[layer],
215
+ local_attn_window_size = local_attn_window_size,
216
+ **kwargs
217
+ )
218
+ )
219
+
220
+ self.layers = nn.Sequential(*layers)
221
+
222
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
223
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
224
+
225
+ def forward(self, x):
226
+ x = rearrange(x, "b c n -> b n c")
227
+ x = checkpoint(self.project_in, x)
228
+ x = self.layers(x)
229
+ x = checkpoint(self.project_out, x)
230
+ x = rearrange(x, "b n c -> b c n")
231
+
232
+ return x
233
+
234
+
235
+ class TransformerDecoder1D(nn.Module):
236
+ def __init__(
237
+ self,
238
+ in_channels,
239
+ out_channels,
240
+ embed_dims = [768, 384, 192, 96],
241
+ heads = [12, 12, 12, 12],
242
+ depths = [3, 3, 3, 3],
243
+ ratios = [2, 2, 2, 2],
244
+ local_attn_window_size = 64,
245
+ **kwargs
246
+ ):
247
+
248
+ super().__init__()
249
+
250
+ layers = []
251
+
252
+ for layer in range(len(depths)):
253
+ prev_dim = embed_dims[layer - 1] if layer > 0 else embed_dims[0]
254
+
255
+ layers.append(
256
+ TransformerUpsampleBlock1D(
257
+ in_channels = prev_dim,
258
+ embed_dim = embed_dims[layer],
259
+ heads = heads[layer],
260
+ depth = depths[layer],
261
+ upsample_ratio = ratios[layer],
262
+ local_attn_window_size = local_attn_window_size,
263
+ **kwargs
264
+ )
265
+ )
266
+
267
+ self.layers = nn.Sequential(*layers)
268
+
269
+ self.project_in = nn.Linear(in_channels, embed_dims[0], bias=False)
270
+ self.project_out = nn.Linear(embed_dims[-1], out_channels, bias=False)
271
+
272
+ def forward(self, x):
273
+ x = rearrange(x, "b c n -> b n c")
274
+ x = checkpoint(self.project_in, x)
275
+ x = self.layers(x)
276
+ x = checkpoint(self.project_out, x)
277
+ x = rearrange(x, "b n c -> b c n")
278
+ return x
ThinkSound/models/mmdit.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from dataclasses import dataclass
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import sys
9
+ from .embeddings import compute_rope_rotations
10
+ from .embeddings import TimestepEmbedder
11
+ from .blocks import MLP, ChannelLastConv1d, ConvMLP
12
+ from .transformer_layers import (FinalBlock, JointBlock, MMDitSingleBlock)
13
+ from .utils import resample
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ @dataclass
19
+ class PreprocessedConditions:
20
+ clip_f: torch.Tensor
21
+ sync_f: torch.Tensor
22
+ text_f: torch.Tensor
23
+ clip_f_c: torch.Tensor
24
+ text_f_c: torch.Tensor
25
+
26
+
27
+ class MMAudio(nn.Module):
28
+
29
+ def __init__(self,
30
+ *,
31
+ latent_dim: int,
32
+ clip_dim: int,
33
+ sync_dim: int,
34
+ text_dim: int,
35
+ hidden_dim: int,
36
+ depth: int,
37
+ fused_depth: int,
38
+ num_heads: int,
39
+ mlp_ratio: float = 4.0,
40
+ latent_seq_len: int,
41
+ clip_seq_len: int,
42
+ sync_seq_len: int,
43
+ text_seq_len: int = 77,
44
+ latent_mean: Optional[torch.Tensor] = None,
45
+ latent_std: Optional[torch.Tensor] = None,
46
+ empty_string_feat: Optional[torch.Tensor] = None,
47
+ v2: bool = False,
48
+ kernel_size: int = 7,
49
+ sync_kernel: int = 7,
50
+ use_inpaint: bool = False,
51
+ use_mlp: bool = False,
52
+ cross_attend: bool = False,
53
+ add_video: bool = False,
54
+ triple_fusion: bool = False,
55
+ gated_video: bool = False) -> None:
56
+ super().__init__()
57
+
58
+ self.v2 = v2
59
+ self.latent_dim = latent_dim
60
+ self._latent_seq_len = latent_seq_len
61
+ self._clip_seq_len = clip_seq_len
62
+ self._sync_seq_len = sync_seq_len
63
+ self._text_seq_len = text_seq_len
64
+ self.hidden_dim = hidden_dim
65
+ self.num_heads = num_heads
66
+ self.cross_attend = cross_attend
67
+ self.add_video = add_video
68
+ self.gated_video = gated_video
69
+ self.triple_fusion = triple_fusion
70
+ self.use_inpaint = use_inpaint
71
+ if self.gated_video:
72
+ self.gated_mlp = nn.Sequential(
73
+ nn.LayerNorm(hidden_dim * 2),
74
+ nn.Linear(hidden_dim*2, hidden_dim * 4, bias=False),
75
+ nn.SiLU(),
76
+ nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
77
+ nn.Sigmoid()
78
+ )
79
+ # 初始化最后一层权重为零,促进初始均匀融合
80
+ nn.init.zeros_(self.gated_mlp[3].weight)
81
+ if self.triple_fusion:
82
+ self.gated_mlp_v = nn.Sequential(
83
+ nn.LayerNorm(hidden_dim * 3),
84
+ nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False),
85
+ nn.SiLU(),
86
+ nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
87
+ nn.Sigmoid()
88
+ )
89
+ self.gated_mlp_t = nn.Sequential(
90
+ nn.LayerNorm(hidden_dim * 3),
91
+ nn.Linear(hidden_dim*3, hidden_dim * 4, bias=False),
92
+ nn.SiLU(),
93
+ nn.Linear(hidden_dim * 4, hidden_dim, bias=False),
94
+ nn.Sigmoid()
95
+ )
96
+ nn.init.zeros_(self.gated_mlp_v[3].weight)
97
+ nn.init.zeros_(self.gated_mlp_t[3].weight)
98
+ if v2:
99
+ padding_size = (kernel_size - 1) // 2
100
+ if use_inpaint:
101
+ self.audio_input_proj = nn.Sequential(
102
+ ChannelLastConv1d(latent_dim*2, hidden_dim, kernel_size=kernel_size, padding=padding_size),
103
+ nn.SiLU(),
104
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size),
105
+ )
106
+ else:
107
+ self.audio_input_proj = nn.Sequential(
108
+ ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=kernel_size, padding=padding_size),
109
+ nn.SiLU(),
110
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=kernel_size, padding=padding_size),
111
+ )
112
+
113
+ self.clip_input_proj = nn.Sequential(
114
+ nn.Linear(clip_dim, hidden_dim),
115
+ nn.SiLU(),
116
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
117
+ )
118
+ sync_pad = (sync_kernel - 1) // 2
119
+ self.sync_input_proj = nn.Sequential(
120
+ ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=sync_kernel, padding=sync_pad),
121
+ nn.SiLU(),
122
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
123
+ )
124
+
125
+ self.text_input_proj = nn.Sequential(
126
+ nn.Linear(text_dim, hidden_dim),
127
+ nn.SiLU(),
128
+ MLP(hidden_dim, hidden_dim * 4),
129
+ )
130
+ else:
131
+ self.audio_input_proj = nn.Sequential(
132
+ ChannelLastConv1d(latent_dim, hidden_dim, kernel_size=7, padding=3),
133
+ nn.SELU(),
134
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=7, padding=3),
135
+ )
136
+
137
+ self.clip_input_proj = nn.Sequential(
138
+ nn.Linear(clip_dim, hidden_dim),
139
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
140
+ )
141
+
142
+ self.sync_input_proj = nn.Sequential(
143
+ ChannelLastConv1d(sync_dim, hidden_dim, kernel_size=7, padding=3),
144
+ nn.SELU(),
145
+ ConvMLP(hidden_dim, hidden_dim * 4, kernel_size=3, padding=1),
146
+ )
147
+
148
+ self.text_input_proj = nn.Sequential(
149
+ nn.Linear(text_dim, hidden_dim),
150
+ MLP(hidden_dim, hidden_dim * 4),
151
+ )
152
+
153
+ self.clip_cond_proj = nn.Linear(hidden_dim, hidden_dim)
154
+ if use_mlp:
155
+ self.text_cond_proj = nn.Sequential(
156
+ nn.Linear(1024, hidden_dim),
157
+ MLP(hidden_dim, hidden_dim * 4),
158
+ )
159
+ else:
160
+ self.text_cond_proj = nn.Linear(1024, hidden_dim)
161
+ self.global_cond_mlp = MLP(hidden_dim, hidden_dim * 4)
162
+ # each synchformer output segment has 8 feature frames
163
+ self.sync_pos_emb = nn.Parameter(torch.zeros((1, 1, 8, sync_dim)))
164
+
165
+ self.final_layer = FinalBlock(hidden_dim, latent_dim)
166
+
167
+ if v2:
168
+ self.t_embed = TimestepEmbedder(hidden_dim,
169
+ frequency_embedding_size=hidden_dim,
170
+ max_period=1)
171
+ else:
172
+ self.t_embed = TimestepEmbedder(hidden_dim,
173
+ frequency_embedding_size=256,
174
+ max_period=10000)
175
+ self.joint_blocks = nn.ModuleList([
176
+ JointBlock(hidden_dim,
177
+ num_heads,
178
+ mlp_ratio=mlp_ratio,
179
+ pre_only=(i == depth - fused_depth - 1)) for i in range(depth - fused_depth)
180
+ ])
181
+
182
+ self.fused_blocks = nn.ModuleList([
183
+ MMDitSingleBlock(hidden_dim, num_heads, mlp_ratio=mlp_ratio, kernel_size=kernel_size, padding=padding_size, cross_attend=cross_attend)
184
+ for i in range(fused_depth)
185
+ ])
186
+
187
+ if empty_string_feat is None:
188
+ empty_string_feat = torch.zeros((77, 1024))
189
+
190
+ empty_t5_feat = torch.zeros((77, 2048))
191
+
192
+ self.empty_string_feat = nn.Parameter(empty_string_feat, requires_grad=False)
193
+ self.empty_t5_feat = nn.Parameter(empty_t5_feat, requires_grad=False)
194
+ self.empty_clip_feat = nn.Parameter(torch.zeros(1, clip_dim), requires_grad=True)
195
+ self.empty_sync_feat = nn.Parameter(torch.zeros(1, sync_dim), requires_grad=True)
196
+
197
+ self.initialize_weights()
198
+ self.initialize_rotations()
199
+
200
+ def initialize_rotations(self):
201
+ base_freq = 1.0
202
+ latent_rot = compute_rope_rotations(self._latent_seq_len,
203
+ self.hidden_dim // self.num_heads,
204
+ 10000,
205
+ freq_scaling=base_freq,
206
+ device=self.device)
207
+ clip_rot = compute_rope_rotations(self._clip_seq_len,
208
+ self.hidden_dim // self.num_heads,
209
+ 10000,
210
+ freq_scaling=base_freq * self._latent_seq_len /
211
+ self._clip_seq_len,
212
+ device=self.device)
213
+
214
+ self.register_buffer("latent_rot", latent_rot, persistent=False)
215
+ self.register_buffer("clip_rot", clip_rot, persistent=False)
216
+
217
+ def update_seq_lengths(self, latent_seq_len: int, clip_seq_len: int, sync_seq_len: int) -> None:
218
+ self._latent_seq_len = latent_seq_len
219
+ self._clip_seq_len = clip_seq_len
220
+ self._sync_seq_len = sync_seq_len
221
+ self.initialize_rotations()
222
+
223
+ def initialize_weights(self):
224
+
225
+ def _basic_init(module):
226
+ if isinstance(module, nn.Linear):
227
+ torch.nn.init.xavier_uniform_(module.weight)
228
+ if module.bias is not None:
229
+ nn.init.constant_(module.bias, 0)
230
+
231
+ self.apply(_basic_init)
232
+
233
+ # Initialize timestep embedding MLP:
234
+ nn.init.normal_(self.t_embed.mlp[0].weight, std=0.02)
235
+ nn.init.normal_(self.t_embed.mlp[2].weight, std=0.02)
236
+
237
+ # Zero-out adaLN modulation layers in DiT blocks:
238
+ for block in self.joint_blocks:
239
+ nn.init.constant_(block.latent_block.adaLN_modulation[-1].weight, 0)
240
+ nn.init.constant_(block.latent_block.adaLN_modulation[-1].bias, 0)
241
+ nn.init.constant_(block.clip_block.adaLN_modulation[-1].weight, 0)
242
+ nn.init.constant_(block.clip_block.adaLN_modulation[-1].bias, 0)
243
+ nn.init.constant_(block.text_block.adaLN_modulation[-1].weight, 0)
244
+ nn.init.constant_(block.text_block.adaLN_modulation[-1].bias, 0)
245
+ for block in self.fused_blocks:
246
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
247
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
248
+
249
+ # Zero-out output layers:
250
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
251
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
252
+ nn.init.constant_(self.final_layer.conv.weight, 0)
253
+ nn.init.constant_(self.final_layer.conv.bias, 0)
254
+
255
+ # empty string feat shall be initialized by a CLIP encoder
256
+ nn.init.constant_(self.sync_pos_emb, 0)
257
+ nn.init.constant_(self.empty_clip_feat, 0)
258
+ nn.init.constant_(self.empty_sync_feat, 0)
259
+
260
+ def preprocess_conditions(self, clip_f: torch.Tensor, sync_f: torch.Tensor,
261
+ text_f: torch.Tensor, t5_features: torch.Tensor, metaclip_global_text_features: torch.Tensor) -> PreprocessedConditions:
262
+ """
263
+ cache computations that do not depend on the latent/time step
264
+ i.e., the features are reused over steps during inference
265
+ """
266
+ # breakpoint()
267
+ assert clip_f.shape[1] == self._clip_seq_len, f'{clip_f.shape=} {self._clip_seq_len=}'
268
+ assert sync_f.shape[1] == self._sync_seq_len, f'{sync_f.shape=} {self._sync_seq_len=}'
269
+ assert text_f.shape[1] == self._text_seq_len, f'{text_f.shape=} {self._text_seq_len=}'
270
+
271
+ bs = clip_f.shape[0]
272
+
273
+ # B * num_segments (24) * 8 * 768
274
+ num_sync_segments = self._sync_seq_len // 8
275
+ sync_f = sync_f.view(bs, num_sync_segments, 8, -1) + self.sync_pos_emb
276
+ sync_f = sync_f.flatten(1, 2) # (B, VN, D)
277
+
278
+ # extend vf to match x
279
+ clip_f = self.clip_input_proj(clip_f) # (B, VN, D)
280
+ sync_f = self.sync_input_proj(sync_f) # (B, VN, D)
281
+
282
+ if t5_features is not None:
283
+
284
+ if metaclip_global_text_features is not None:
285
+ text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D)
286
+ else:
287
+ text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D)
288
+ # 计算填充长度
289
+ padding_size = t5_features.size(2) - text_f.size(2) # 渴望填充的数量
290
+ # 当确实需要填充的时候,确保填充是正数
291
+ if padding_size > 0:
292
+ # 填充 text_f 的特征维度两侧
293
+ text_f = F.pad(text_f, pad=(0, padding_size), mode='constant', value=0) # 在最后一个维度上进行填充
294
+ else:
295
+ text_f = text_f # 如果填充长度不是正数,则不需要填充
296
+ text_concat = torch.cat((text_f, t5_features), dim=1)
297
+ text_f = self.text_input_proj(text_concat) # (B, VN, D)
298
+ else:
299
+ text_f = self.text_input_proj(text_f) # (B, VN, D)
300
+ if metaclip_global_text_features is not None:
301
+ text_f_c = self.text_cond_proj(metaclip_global_text_features) # (B, D)
302
+ else:
303
+ text_f_c = self.text_cond_proj(text_f.mean(dim=1)) # (B, D)
304
+
305
+ # upsample the sync features to match the audio
306
+ sync_f = sync_f.transpose(1, 2) # (B, D, VN)
307
+ # sync_f = resample(sync_f, self._latent_seq_len)
308
+ sync_f = F.interpolate(sync_f, size=self._latent_seq_len, mode='nearest-exact')
309
+ sync_f = sync_f.transpose(1, 2) # (B, N, D)
310
+
311
+ # get conditional features from the clip side
312
+ clip_f_c = self.clip_cond_proj(clip_f.mean(dim=1)) # (B, D)
313
+
314
+ return PreprocessedConditions(clip_f=clip_f,
315
+ sync_f=sync_f,
316
+ text_f=text_f,
317
+ clip_f_c=clip_f_c,
318
+ text_f_c=text_f_c)
319
+
320
+ def predict_flow(self, latent: torch.Tensor, t: torch.Tensor,
321
+ conditions: PreprocessedConditions, inpaint_masked_input=None, cfg_scale:float=1.0,cfg_dropout_prob:float=0.0,scale_phi:float=0.0
322
+ ) -> torch.Tensor:
323
+ """
324
+ for non-cacheable computations
325
+ """
326
+ # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}')
327
+ assert latent.shape[1] == self._latent_seq_len, f'{latent.shape=} {self._latent_seq_len=}'
328
+ empty_conditions = None
329
+ if inpaint_masked_input is not None:
330
+ inpaint_masked_input = inpaint_masked_input.transpose(1,2)
331
+ clip_f = conditions.clip_f
332
+ sync_f = conditions.sync_f
333
+ text_f = conditions.text_f
334
+ clip_f_c = conditions.clip_f_c
335
+ text_f_c = conditions.text_f_c
336
+
337
+ # breakpoint()
338
+ if inpaint_masked_input is not None:
339
+ latent = torch.cat([latent,inpaint_masked_input],dim=2)
340
+ latent = self.audio_input_proj(latent) # (B, N, D)
341
+ global_c = self.global_cond_mlp(clip_f_c + text_f_c) # (B, D)
342
+ # global_c = text_f_c
343
+ global_c = self.t_embed(t).unsqueeze(1) + global_c.unsqueeze(1) # (B, D)
344
+ extended_c = global_c + sync_f
345
+
346
+ for block in self.joint_blocks:
347
+ latent, clip_f, text_f = block(latent, clip_f, text_f, global_c, extended_c,
348
+ self.latent_rot, self.clip_rot) # (B, N, D)
349
+ if self.add_video:
350
+ if clip_f.shape[1] != latent.shape[1]:
351
+ clip_f = resample(clip_f, latent)
352
+
353
+ if self.triple_fusion:
354
+ text_f = torch.mean(text_f, dim=1, keepdim=True) # (bsz, 1, D)
355
+ text_f = text_f.expand(-1,latent.shape[1], -1) # (T_audio, D)
356
+ fusion = torch.concat((latent, clip_f, text_f),dim=-1)
357
+ gate_v = self.gated_mlp_v(fusion)
358
+ gate_t = self.gated_mlp_t(fusion)
359
+ # modulated_latent = gate * latent # 非对称设计
360
+ latent = latent + gate_v * clip_f + gate_t * text_f
361
+ elif self.gated_video:
362
+ fusion = torch.concat((latent, clip_f),dim=-1)
363
+ gate = self.gated_mlp(fusion)
364
+ modulated_latent = gate * latent # 非对称设计
365
+ latent = latent + modulated_latent
366
+ else:
367
+ latent = latent + clip_f
368
+
369
+ for block in self.fused_blocks:
370
+ if self.cross_attend:
371
+ latent = block(latent, extended_c, self.latent_rot, context=text_f)
372
+ else:
373
+ latent = block(latent, extended_c, self.latent_rot)
374
+
375
+ # should be extended_c; this is a minor implementation error #55
376
+ flow = self.final_layer(latent, extended_c) # (B, N, out_dim), remove t
377
+ return flow
378
+
379
+ def forward(self, latent: torch.Tensor, t: torch.Tensor, clip_f: torch.Tensor, sync_f: torch.Tensor,
380
+ text_f: torch.Tensor, inpaint_masked_input, t5_features, metaclip_global_text_features, cfg_scale:float,cfg_dropout_prob:float,scale_phi:float) -> torch.Tensor:
381
+ """
382
+ latent: (B, N, C)
383
+ vf: (B, T, C_V)
384
+ t: (B,)
385
+ """
386
+ # breakpoint()
387
+ # print(f'cfg_scale: {cfg_scale}, cfg_dropout_prob: {cfg_dropout_prob}, scale_phi: {scale_phi}')
388
+ if self.use_inpaint and inpaint_masked_input is None:
389
+ inpaint_masked_input = torch.zeros_like(latent, device=latent.device)
390
+ latent = latent.permute(0, 2, 1)
391
+
392
+ if cfg_dropout_prob > 0.0:
393
+ if inpaint_masked_input is not None:
394
+ null_embed = torch.zeros_like(inpaint_masked_input,device=latent.device)
395
+ dropout_mask = torch.bernoulli(torch.full((inpaint_masked_input.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
396
+ inpaint_masked_input = torch.where(dropout_mask, null_embed, inpaint_masked_input)
397
+
398
+ null_embed = torch.zeros_like(clip_f,device=latent.device)
399
+ dropout_mask = torch.bernoulli(torch.full((clip_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
400
+ # clip_f = torch.where(dropout_mask, null_embed, clip_f)
401
+ clip_f = torch.where(dropout_mask, self.empty_clip_feat, clip_f)
402
+ null_embed = torch.zeros_like(sync_f,device=latent.device)
403
+ dropout_mask = torch.bernoulli(torch.full((sync_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
404
+ # sync_f = torch.where(dropout_mask, null_embed, sync_f)
405
+ sync_f = torch.where(dropout_mask, self.empty_sync_feat, sync_f)
406
+ null_embed = torch.zeros_like(text_f,device=latent.device)
407
+ dropout_mask = torch.bernoulli(torch.full((text_f.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
408
+ # text_f = torch.where(dropout_mask, null_embed, text_f)
409
+ text_f = torch.where(dropout_mask, self.empty_string_feat, text_f)
410
+ if t5_features is not None:
411
+ null_embed = torch.zeros_like(t5_features,device=latent.device)
412
+ dropout_mask = torch.bernoulli(torch.full((t5_features.shape[0], 1, 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
413
+ # t5_features = torch.where(dropout_mask, null_embed, t5_features)
414
+ t5_features = torch.where(dropout_mask, self.empty_t5_feat, t5_features)
415
+ if metaclip_global_text_features is not None:
416
+ null_embed = torch.zeros_like(metaclip_global_text_features,device=latent.device)
417
+ dropout_mask = torch.bernoulli(torch.full((metaclip_global_text_features.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
418
+ metaclip_global_text_features = torch.where(dropout_mask, null_embed, metaclip_global_text_features)
419
+ # null_embed = torch.zeros_like(clip_f_c,device=latent.device)
420
+ # dropout_mask = torch.bernoulli(torch.full((clip_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
421
+ # clip_f_c = torch.where(dropout_mask, null_embed, clip_f_c)
422
+ # null_embed = torch.zeros_like(text_f_c,device=latent.device)
423
+ # dropout_mask = torch.bernoulli(torch.full((text_f_c.shape[0], 1), cfg_dropout_prob, device=latent.device)).to(torch.bool)
424
+ # text_f_c = torch.where(dropout_mask, null_embed, text_f_c)
425
+
426
+ if cfg_scale != 1.0:
427
+ # empty_conditions = self.get_empty_conditions(latent.shape[0])
428
+ # breakpoint()
429
+ bsz = latent.shape[0]
430
+ latent = torch.cat([latent,latent], dim=0)
431
+ if inpaint_masked_input is not None:
432
+ empty_inpaint_masked_input = torch.zeros_like(inpaint_masked_input, device=latent.device)
433
+ inpaint_masked_input = torch.cat([inpaint_masked_input,empty_inpaint_masked_input], dim=0)
434
+ t = torch.cat([t, t], dim=0)
435
+ empty_clip_f = torch.zeros_like(clip_f, device=latent.device)
436
+ empty_sync_f = torch.zeros_like(sync_f, device=latent.device)
437
+ empty_text_f = torch.zeros_like(text_f, device=latent.device)
438
+
439
+ # clip_f = torch.cat([clip_f,empty_clip_f], dim=0)
440
+ # sync_f = torch.cat([sync_f,empty_sync_f], dim=0)
441
+ # text_f = torch.cat([text_f,empty_text_f], dim=0)
442
+ clip_f = safe_cat(clip_f,self.get_empty_clip_sequence(bsz), dim=0, match_dim=1)
443
+ sync_f = safe_cat(sync_f,self.get_empty_sync_sequence(bsz), dim=0, match_dim=1)
444
+ text_f = safe_cat(text_f,self.get_empty_string_sequence(bsz), dim=0, match_dim=1)
445
+ if t5_features is not None:
446
+ empty_t5_features = torch.zeros_like(t5_features, device=latent.device)
447
+ # t5_features = torch.cat([t5_features,empty_t5_features], dim=0)
448
+ t5_features = torch.cat([t5_features,self.get_empty_t5_sequence(bsz)], dim=0)
449
+ if metaclip_global_text_features is not None:
450
+ empty_metaclip_global_text_features = torch.zeros_like(metaclip_global_text_features, device=latent.device)
451
+ metaclip_global_text_features = torch.cat([metaclip_global_text_features,empty_metaclip_global_text_features], dim=0)
452
+ # metaclip_global_text_features = torch.cat([metaclip_global_text_features,metaclip_global_text_features], dim=0)
453
+ # clip_f_c = torch.cat([clip_f_c,empty_clip_f_c], dim=0)
454
+ # text_f_c = torch.cat([text_f_c,empty_text_f_c], dim=0)
455
+
456
+ conditions = self.preprocess_conditions(clip_f, sync_f, text_f, t5_features, metaclip_global_text_features)
457
+ flow = self.predict_flow(latent, t, conditions, inpaint_masked_input, cfg_scale,cfg_dropout_prob,scale_phi)
458
+ if cfg_scale != 1.0:
459
+ cond_output, uncond_output = torch.chunk(flow, 2, dim=0)
460
+ cfg_output = uncond_output + (cond_output - uncond_output) * cfg_scale
461
+ if scale_phi != 0.0:
462
+ cond_out_std = cond_output.std(dim=1, keepdim=True)
463
+ out_cfg_std = cfg_output.std(dim=1, keepdim=True)
464
+ flow = scale_phi * (cfg_output * (cond_out_std/out_cfg_std)) + (1-scale_phi) * cfg_output
465
+ else:
466
+ flow = cfg_output
467
+ flow = flow.permute(0, 2, 1)
468
+ return flow
469
+
470
+ def get_empty_string_sequence(self, bs: int) -> torch.Tensor:
471
+ return self.empty_string_feat.unsqueeze(0).expand(bs, -1, -1)
472
+
473
+ def get_empty_t5_sequence(self, bs: int) -> torch.Tensor:
474
+ return self.empty_t5_feat.unsqueeze(0).expand(bs, -1, -1)
475
+
476
+ def get_empty_clip_sequence(self, bs: int) -> torch.Tensor:
477
+ return self.empty_clip_feat.unsqueeze(0).expand(bs, self._clip_seq_len, -1)
478
+
479
+ def get_empty_sync_sequence(self, bs: int) -> torch.Tensor:
480
+ return self.empty_sync_feat.unsqueeze(0).expand(bs, self._sync_seq_len, -1)
481
+
482
+ def get_empty_conditions(
483
+ self,
484
+ bs: int,
485
+ *,
486
+ negative_text_features: Optional[torch.Tensor] = None) -> PreprocessedConditions:
487
+ if negative_text_features is not None:
488
+ empty_text = negative_text_features
489
+ else:
490
+ empty_text = self.get_empty_string_sequence(1)
491
+
492
+ empty_clip = self.get_empty_clip_sequence(1)
493
+ empty_sync = self.get_empty_sync_sequence(1)
494
+ conditions = self.preprocess_conditions(empty_clip, empty_sync, empty_text)
495
+ conditions.clip_f = conditions.clip_f.expand(bs, -1, -1)
496
+ conditions.sync_f = conditions.sync_f.expand(bs, -1, -1)
497
+ conditions.clip_f_c = conditions.clip_f_c.expand(bs, -1)
498
+ if negative_text_features is None:
499
+ conditions.text_f = conditions.text_f.expand(bs, -1, -1)
500
+ conditions.text_f_c = conditions.text_f_c.expand(bs, -1)
501
+
502
+ return conditions
503
+
504
+ def load_weights(self, src_dict) -> None:
505
+ if 't_embed.freqs' in src_dict:
506
+ del src_dict['t_embed.freqs']
507
+ if 'latent_rot' in src_dict:
508
+ del src_dict['latent_rot']
509
+ if 'clip_rot' in src_dict:
510
+ del src_dict['clip_rot']
511
+
512
+ self.load_state_dict(src_dict, strict=True)
513
+
514
+ @property
515
+ def device(self) -> torch.device:
516
+ return self.empty_clip_feat.device
517
+
518
+ @property
519
+ def latent_seq_len(self) -> int:
520
+ return self._latent_seq_len
521
+
522
+ @property
523
+ def clip_seq_len(self) -> int:
524
+ return self._clip_seq_len
525
+
526
+ @property
527
+ def sync_seq_len(self) -> int:
528
+ return self._sync_seq_len
529
+
530
+
531
+
532
+
533
+
534
+
535
+
536
+
537
+
538
+
539
+
540
+
541
+
542
+
543
+
544
+
545
+
546
+ def truncate_to_target(tensor, target_size, dim=1):
547
+ current_size = tensor.size(dim)
548
+ if current_size > target_size:
549
+ slices = [slice(None)] * tensor.dim()
550
+ slices[dim] = slice(0, target_size)
551
+ return tensor[slices]
552
+ return tensor
553
+
554
+ def pad_to_target(tensor, target_size, dim=1, pad_value=0):
555
+ current_size = tensor.size(dim)
556
+ if current_size < target_size:
557
+ pad_size = target_size - current_size
558
+
559
+ pad_config = [0, 0] * tensor.dim()
560
+ pad_index = 2 * (tensor.dim() - dim - 1) + 1
561
+ pad_config[pad_index] = pad_size
562
+
563
+ return torch.nn.functional.pad(tensor, pad_config, value=pad_value)
564
+ return tensor
565
+
566
+
567
+ def safe_cat(tensor1, tensor2, dim=0, match_dim=1):
568
+
569
+ target_size = tensor2.size(match_dim)
570
+
571
+ if tensor1.size(match_dim) > target_size:
572
+ tensor1 = truncate_to_target(tensor1, target_size, match_dim)
573
+
574
+ else:
575
+ tensor1 = pad_to_target(tensor1, target_size, match_dim)
576
+
577
+ return torch.cat([tensor1, tensor2], dim=dim)
578
+
ThinkSound/models/pretrained.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from .factory import create_model_from_config
4
+ from .utils import load_ckpt_state_dict
5
+
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ def get_pretrained_model(name: str):
9
+
10
+ model_config_path = hf_hub_download(name, filename="model_config.json", repo_type='model')
11
+
12
+ with open(model_config_path) as f:
13
+ model_config = json.load(f)
14
+
15
+ model = create_model_from_config(model_config)
16
+
17
+ # Try to download the model.safetensors file first, if it doesn't exist, download the model.ckpt file
18
+ try:
19
+ model_ckpt_path = hf_hub_download(name, filename="model.safetensors", repo_type='model')
20
+ except Exception as e:
21
+ model_ckpt_path = hf_hub_download(name, filename="model.ckpt", repo_type='model')
22
+
23
+ model.load_state_dict(load_ckpt_state_dict(model_ckpt_path))
24
+
25
+ return model, model_config
ThinkSound/models/pretransforms.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from einops import rearrange
3
+ from torch import nn
4
+
5
+ class Pretransform(nn.Module):
6
+ def __init__(self, enable_grad, io_channels, is_discrete):
7
+ super().__init__()
8
+
9
+ self.is_discrete = is_discrete
10
+ self.io_channels = io_channels
11
+ self.encoded_channels = None
12
+ self.downsampling_ratio = None
13
+
14
+ self.enable_grad = enable_grad
15
+
16
+ def encode(self, x):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, z):
20
+ raise NotImplementedError
21
+
22
+ def tokenize(self, x):
23
+ raise NotImplementedError
24
+
25
+ def decode_tokens(self, tokens):
26
+ raise NotImplementedError
27
+
28
+ class AutoencoderPretransform(Pretransform):
29
+ def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False):
30
+ super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete)
31
+ self.model = model
32
+ self.model.requires_grad_(False).eval()
33
+ self.scale=scale
34
+ self.downsampling_ratio = model.downsampling_ratio
35
+ self.io_channels = model.io_channels
36
+ self.sample_rate = model.sample_rate
37
+
38
+ self.model_half = model_half
39
+ self.iterate_batch = iterate_batch
40
+
41
+ self.encoded_channels = model.latent_dim
42
+
43
+ self.chunked = chunked
44
+ self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None
45
+ self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None
46
+
47
+ if self.model_half:
48
+ self.model.half()
49
+
50
+ def encode(self, x, **kwargs):
51
+
52
+ if self.model_half:
53
+ x = x.half()
54
+ self.model.to(torch.float16)
55
+
56
+ encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
57
+
58
+ if self.model_half:
59
+ encoded = encoded.float()
60
+
61
+ return encoded / self.scale
62
+
63
+ def decode(self, z, **kwargs):
64
+ z = z * self.scale
65
+
66
+ if self.model_half:
67
+ z = z.half()
68
+ self.model.to(torch.float16)
69
+
70
+ decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs)
71
+
72
+ if self.model_half:
73
+ decoded = decoded.float()
74
+
75
+ return decoded
76
+
77
+ def tokenize(self, x, **kwargs):
78
+ assert self.model.is_discrete, "Cannot tokenize with a continuous model"
79
+
80
+ _, info = self.model.encode(x, return_info = True, **kwargs)
81
+
82
+ return info[self.model.bottleneck.tokens_id]
83
+
84
+ def decode_tokens(self, tokens, **kwargs):
85
+ assert self.model.is_discrete, "Cannot decode tokens with a continuous model"
86
+
87
+ return self.model.decode_tokens(tokens, **kwargs)
88
+
89
+ def load_state_dict(self, state_dict, strict=True):
90
+ self.model.load_state_dict(state_dict, strict=strict)
91
+
92
+ class WaveletPretransform(Pretransform):
93
+ def __init__(self, channels, levels, wavelet):
94
+ super().__init__(enable_grad=False, io_channels=channels, is_discrete=False)
95
+
96
+ from .wavelets import WaveletEncode1d, WaveletDecode1d
97
+
98
+ self.encoder = WaveletEncode1d(channels, levels, wavelet)
99
+ self.decoder = WaveletDecode1d(channels, levels, wavelet)
100
+
101
+ self.downsampling_ratio = 2 ** levels
102
+ self.io_channels = channels
103
+ self.encoded_channels = channels * self.downsampling_ratio
104
+
105
+ def encode(self, x):
106
+ return self.encoder(x)
107
+
108
+ def decode(self, z):
109
+ return self.decoder(z)
110
+
111
+ class PQMFPretransform(Pretransform):
112
+ def __init__(self, attenuation=100, num_bands=16):
113
+ # TODO: Fix PQMF to take in in-channels
114
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=False)
115
+ from .pqmf import PQMF
116
+ self.pqmf = PQMF(attenuation, num_bands)
117
+
118
+
119
+ def encode(self, x):
120
+ # x is (Batch x Channels x Time)
121
+ x = self.pqmf.forward(x)
122
+ # pqmf.forward returns (Batch x Channels x Bands x Time)
123
+ # but Pretransform needs Batch x Channels x Time
124
+ # so concatenate channels and bands into one axis
125
+ return rearrange(x, "b c n t -> b (c n) t")
126
+
127
+ def decode(self, x):
128
+ # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time)
129
+ x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands)
130
+ # returns (Batch x Channels x Time)
131
+ return self.pqmf.inverse(x)
132
+
133
+ class PretrainedDACPretransform(Pretransform):
134
+ def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True):
135
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
136
+
137
+ import dac
138
+
139
+ model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate)
140
+
141
+ self.model = dac.DAC.load(model_path)
142
+
143
+ self.quantize_on_decode = quantize_on_decode
144
+
145
+ if model_type == "44khz":
146
+ self.downsampling_ratio = 512
147
+ else:
148
+ self.downsampling_ratio = 320
149
+
150
+ self.io_channels = 1
151
+
152
+ self.scale = scale
153
+
154
+ self.chunked = chunked
155
+
156
+ self.encoded_channels = self.model.latent_dim
157
+
158
+ self.num_quantizers = self.model.n_codebooks
159
+
160
+ self.codebook_size = self.model.codebook_size
161
+
162
+ def encode(self, x):
163
+
164
+ latents = self.model.encoder(x)
165
+
166
+ if self.quantize_on_decode:
167
+ output = latents
168
+ else:
169
+ z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
170
+ output = z
171
+
172
+ if self.scale != 1.0:
173
+ output = output / self.scale
174
+
175
+ return output
176
+
177
+ def decode(self, z):
178
+
179
+ if self.scale != 1.0:
180
+ z = z * self.scale
181
+
182
+ if self.quantize_on_decode:
183
+ z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
184
+
185
+ return self.model.decode(z)
186
+
187
+ def tokenize(self, x):
188
+ return self.model.encode(x)[1]
189
+
190
+ def decode_tokens(self, tokens):
191
+ latents = self.model.quantizer.from_codes(tokens)
192
+ return self.model.decode(latents)
193
+
194
+ class AudiocraftCompressionPretransform(Pretransform):
195
+ def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True):
196
+ super().__init__(enable_grad=False, io_channels=1, is_discrete=True)
197
+
198
+ try:
199
+ from audiocraft.models import CompressionModel
200
+ except ImportError:
201
+ raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.")
202
+
203
+ self.model = CompressionModel.get_pretrained(model_type)
204
+
205
+ self.quantize_on_decode = quantize_on_decode
206
+
207
+ self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate)
208
+
209
+ self.sample_rate = self.model.sample_rate
210
+
211
+ self.io_channels = self.model.channels
212
+
213
+ self.scale = scale
214
+
215
+ #self.encoded_channels = self.model.latent_dim
216
+
217
+ self.num_quantizers = self.model.num_codebooks
218
+
219
+ self.codebook_size = self.model.cardinality
220
+
221
+ self.model.to(torch.float16).eval().requires_grad_(False)
222
+
223
+ def encode(self, x):
224
+
225
+ assert False, "Audiocraft compression models do not support continuous encoding"
226
+
227
+ # latents = self.model.encoder(x)
228
+
229
+ # if self.quantize_on_decode:
230
+ # output = latents
231
+ # else:
232
+ # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks)
233
+ # output = z
234
+
235
+ # if self.scale != 1.0:
236
+ # output = output / self.scale
237
+
238
+ # return output
239
+
240
+ def decode(self, z):
241
+
242
+ assert False, "Audiocraft compression models do not support continuous decoding"
243
+
244
+ # if self.scale != 1.0:
245
+ # z = z * self.scale
246
+
247
+ # if self.quantize_on_decode:
248
+ # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks)
249
+
250
+ # return self.model.decode(z)
251
+
252
+ def tokenize(self, x):
253
+ with torch.cuda.amp.autocast(enabled=False):
254
+ return self.model.encode(x.to(torch.float16))[0]
255
+
256
+ def decode_tokens(self, tokens):
257
+ with torch.cuda.amp.autocast(enabled=False):
258
+ return self.model.decode(tokens)
ThinkSound/models/transformer.py ADDED
@@ -0,0 +1,821 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import reduce, partial
2
+ from packaging import version
3
+
4
+ from einops import rearrange, repeat
5
+ from einops.layers.torch import Rearrange
6
+ import torch
7
+ import torch.nn.functional as F
8
+ from torch import nn, einsum
9
+ from torch.cuda.amp import autocast
10
+ from typing import Callable, Literal
11
+
12
+ try:
13
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
14
+ except ImportError as e:
15
+ print(e)
16
+ print('flash_attn not installed, disabling Flash Attention')
17
+ flash_attn_kvpacked_func = None
18
+ flash_attn_func = None
19
+
20
+ try:
21
+ import natten
22
+ except ImportError:
23
+ natten = None
24
+
25
+ def checkpoint(function, *args, **kwargs):
26
+ kwargs.setdefault("use_reentrant", False)
27
+ return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
28
+
29
+
30
+ # Copied and modified from https://github.com/lucidrains/x-transformers/blob/main/x_transformers/attend.py under MIT License
31
+ # License can be found in LICENSES/LICENSE_XTRANSFORMERS.txt
32
+
33
+ def create_causal_mask(i, j, device):
34
+ return torch.ones((i, j), device = device, dtype = torch.bool).triu(j - i + 1)
35
+
36
+ def or_reduce(masks):
37
+ head, *body = masks
38
+ for rest in body:
39
+ head = head | rest
40
+ return head
41
+
42
+ # positional embeddings
43
+
44
+ class AbsolutePositionalEmbedding(nn.Module):
45
+ def __init__(self, dim, max_seq_len):
46
+ super().__init__()
47
+ self.scale = dim ** -0.5
48
+ self.max_seq_len = max_seq_len
49
+ self.emb = nn.Embedding(max_seq_len, dim)
50
+
51
+ def forward(self, x, pos = None, seq_start_pos = None):
52
+ seq_len, device = x.shape[1], x.device
53
+ assert seq_len <= self.max_seq_len, f'you are passing in a sequence length of {seq_len} but your absolute positional embedding has a max sequence length of {self.max_seq_len}'
54
+
55
+ if pos is None:
56
+ pos = torch.arange(seq_len, device = device)
57
+
58
+ if seq_start_pos is not None:
59
+ pos = (pos - seq_start_pos[..., None]).clamp(min = 0)
60
+
61
+ pos_emb = self.emb(pos)
62
+ pos_emb = pos_emb * self.scale
63
+ return pos_emb
64
+
65
+ class ScaledSinusoidalEmbedding(nn.Module):
66
+ def __init__(self, dim, theta = 10000):
67
+ super().__init__()
68
+ assert (dim % 2) == 0, 'dimension must be divisible by 2'
69
+ self.scale = nn.Parameter(torch.ones(1) * dim ** -0.5)
70
+
71
+ half_dim = dim // 2
72
+ freq_seq = torch.arange(half_dim).float() / half_dim
73
+ inv_freq = theta ** -freq_seq
74
+ self.register_buffer('inv_freq', inv_freq, persistent = False)
75
+
76
+ def forward(self, x, pos = None, seq_start_pos = None):
77
+ seq_len, device = x.shape[1], x.device
78
+
79
+ if pos is None:
80
+ pos = torch.arange(seq_len, device = device)
81
+
82
+ if seq_start_pos is not None:
83
+ pos = pos - seq_start_pos[..., None]
84
+
85
+ emb = einsum('i, j -> i j', pos, self.inv_freq)
86
+ emb = torch.cat((emb.sin(), emb.cos()), dim = -1)
87
+ return emb * self.scale
88
+
89
+ class RotaryEmbedding(nn.Module):
90
+ def __init__(
91
+ self,
92
+ dim,
93
+ use_xpos = False,
94
+ scale_base = 512,
95
+ interpolation_factor = 1.,
96
+ base = 10000,
97
+ base_rescale_factor = 1.
98
+ ):
99
+ super().__init__()
100
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
101
+ # has some connection to NTK literature
102
+ # https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
103
+ base *= base_rescale_factor ** (dim / (dim - 2))
104
+
105
+ inv_freq = 1. / (base ** (torch.arange(0, dim, 2).float() / dim))
106
+ self.register_buffer('inv_freq', inv_freq)
107
+
108
+ assert interpolation_factor >= 1.
109
+ self.interpolation_factor = interpolation_factor
110
+
111
+ if not use_xpos:
112
+ self.register_buffer('scale', None)
113
+ return
114
+
115
+ scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
116
+
117
+ self.scale_base = scale_base
118
+ self.register_buffer('scale', scale)
119
+
120
+ def forward_from_seq_len(self, seq_len):
121
+ device = self.inv_freq.device
122
+
123
+ t = torch.arange(seq_len, device = device)
124
+ return self.forward(t)
125
+
126
+ @autocast(enabled = False)
127
+ def forward(self, t):
128
+ device = self.inv_freq.device
129
+
130
+ t = t.to(torch.float32)
131
+
132
+ t = t / self.interpolation_factor
133
+
134
+ freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
135
+ freqs = torch.cat((freqs, freqs), dim = -1)
136
+
137
+ if self.scale is None:
138
+ return freqs, 1.
139
+
140
+ power = (torch.arange(seq_len, device = device) - (seq_len // 2)) / self.scale_base
141
+ scale = self.scale ** rearrange(power, 'n -> n 1')
142
+ scale = torch.cat((scale, scale), dim = -1)
143
+
144
+ return freqs, scale
145
+
146
+ def rotate_half(x):
147
+ x = rearrange(x, '... (j d) -> ... j d', j = 2)
148
+ x1, x2 = x.unbind(dim = -2)
149
+ return torch.cat((-x2, x1), dim = -1)
150
+
151
+ @autocast(enabled = False)
152
+ def apply_rotary_pos_emb(t, freqs, scale = 1):
153
+ out_dtype = t.dtype
154
+
155
+ # cast to float32 if necessary for numerical stability
156
+ dtype = reduce(torch.promote_types, (t.dtype, freqs.dtype, torch.float32))
157
+ rot_dim, seq_len = freqs.shape[-1], t.shape[-2]
158
+ freqs, t = freqs.to(dtype), t.to(dtype)
159
+ freqs = freqs[-seq_len:, :]
160
+
161
+ if t.ndim == 4 and freqs.ndim == 3:
162
+ freqs = rearrange(freqs, 'b n d -> b 1 n d')
163
+
164
+ # partial rotary embeddings, Wang et al. GPT-J
165
+ t, t_unrotated = t[..., :rot_dim], t[..., rot_dim:]
166
+ t = (t * freqs.cos() * scale) + (rotate_half(t) * freqs.sin() * scale)
167
+
168
+ t, t_unrotated = t.to(out_dtype), t_unrotated.to(out_dtype)
169
+
170
+ return torch.cat((t, t_unrotated), dim = -1)
171
+
172
+ # norms
173
+ class LayerNorm(nn.Module):
174
+ def __init__(self, dim, bias=False, fix_scale=False):
175
+ """
176
+ bias-less layernorm has been shown to be more stable. most newer models have moved towards rmsnorm, also bias-less
177
+ """
178
+ super().__init__()
179
+
180
+ if fix_scale:
181
+ self.register_buffer("gamma", torch.ones(dim))
182
+ else:
183
+ self.gamma = nn.Parameter(torch.ones(dim))
184
+
185
+ if bias:
186
+ self.beta = nn.Parameter(torch.zeros(dim))
187
+ else:
188
+ self.register_buffer("beta", torch.zeros(dim))
189
+
190
+
191
+ def forward(self, x):
192
+ return F.layer_norm(x, x.shape[-1:], weight=self.gamma, bias=self.beta)
193
+
194
+ # feedforward
195
+
196
+ class GLU(nn.Module):
197
+ def __init__(
198
+ self,
199
+ dim_in,
200
+ dim_out,
201
+ activation: Callable,
202
+ use_conv = False,
203
+ conv_kernel_size = 3,
204
+ ):
205
+ super().__init__()
206
+ self.act = activation
207
+ self.proj = nn.Linear(dim_in, dim_out * 2) if not use_conv else nn.Conv1d(dim_in, dim_out * 2, conv_kernel_size, padding = (conv_kernel_size // 2))
208
+ self.use_conv = use_conv
209
+
210
+ def forward(self, x):
211
+ if self.use_conv:
212
+ x = rearrange(x, 'b n d -> b d n')
213
+ x = self.proj(x)
214
+ x = rearrange(x, 'b d n -> b n d')
215
+ else:
216
+ x = self.proj(x)
217
+
218
+ x, gate = x.chunk(2, dim = -1)
219
+ return x * self.act(gate)
220
+
221
+ class FeedForward(nn.Module):
222
+ def __init__(
223
+ self,
224
+ dim,
225
+ dim_out = None,
226
+ mult = 4,
227
+ no_bias = False,
228
+ glu = True,
229
+ use_conv = False,
230
+ conv_kernel_size = 3,
231
+ zero_init_output = True,
232
+ ):
233
+ super().__init__()
234
+ inner_dim = int(dim * mult)
235
+
236
+ # Default to SwiGLU
237
+
238
+ activation = nn.SiLU()
239
+
240
+ dim_out = dim if dim_out is None else dim_out
241
+
242
+ if glu:
243
+ linear_in = GLU(dim, inner_dim, activation)
244
+ else:
245
+ linear_in = nn.Sequential(
246
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
247
+ nn.Linear(dim, inner_dim, bias = not no_bias) if not use_conv else nn.Conv1d(dim, inner_dim, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias),
248
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
249
+ activation
250
+ )
251
+
252
+ linear_out = nn.Linear(inner_dim, dim_out, bias = not no_bias) if not use_conv else nn.Conv1d(inner_dim, dim_out, conv_kernel_size, padding = (conv_kernel_size // 2), bias = not no_bias)
253
+
254
+ # init last linear layer to 0
255
+ if zero_init_output:
256
+ nn.init.zeros_(linear_out.weight)
257
+ if not no_bias:
258
+ nn.init.zeros_(linear_out.bias)
259
+
260
+
261
+ self.ff = nn.Sequential(
262
+ linear_in,
263
+ Rearrange('b d n -> b n d') if use_conv else nn.Identity(),
264
+ linear_out,
265
+ Rearrange('b n d -> b d n') if use_conv else nn.Identity(),
266
+ )
267
+
268
+ def forward(self, x):
269
+ return self.ff(x)
270
+
271
+ class Attention(nn.Module):
272
+ def __init__(
273
+ self,
274
+ dim,
275
+ dim_heads = 64,
276
+ dim_context = None,
277
+ causal = False,
278
+ zero_init_output=True,
279
+ qk_norm: Literal['l2', 'ln', 'none'] = 'none',
280
+ natten_kernel_size = None
281
+ ):
282
+ super().__init__()
283
+ self.dim = dim
284
+ self.dim_heads = dim_heads
285
+ self.causal = causal
286
+
287
+ dim_kv = dim_context if dim_context is not None else dim
288
+
289
+ self.num_heads = dim // dim_heads
290
+ self.kv_heads = dim_kv // dim_heads
291
+
292
+ if dim_context is not None:
293
+ self.to_q = nn.Linear(dim, dim, bias=False)
294
+ self.to_kv = nn.Linear(dim_kv, dim_kv * 2, bias=False)
295
+ else:
296
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
297
+
298
+ self.to_out = nn.Linear(dim, dim, bias=False)
299
+
300
+ if zero_init_output:
301
+ nn.init.zeros_(self.to_out.weight)
302
+
303
+ self.qk_norm = qk_norm
304
+
305
+ if self.qk_norm == "ln":
306
+ self.q_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
307
+ self.k_norm = nn.LayerNorm(dim_heads, elementwise_affine=True, eps=1.0e-6)
308
+ elif self.qk_norm == 'rns':
309
+ self.q_norm = nn.RMSNorm(dim_heads)
310
+ self.k_norm = nn.RMSNorm(dim_heads)
311
+
312
+ # Using 1d neighborhood attention
313
+ self.natten_kernel_size = natten_kernel_size
314
+ if natten_kernel_size is not None:
315
+ return
316
+
317
+ self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
318
+
319
+ self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
320
+
321
+ self.sdp_kwargs = dict(
322
+ enable_flash = True,
323
+ enable_math = True,
324
+ enable_mem_efficient = True
325
+ )
326
+
327
+ def flash_attn(
328
+ self,
329
+ q,
330
+ k,
331
+ v,
332
+ mask = None,
333
+ causal = None
334
+ ):
335
+ batch, heads, q_len, _, k_len, device = *q.shape, k.shape[-2], q.device
336
+ kv_heads = k.shape[1]
337
+ # Recommended for multi-query single-key-value attention by Tri Dao
338
+ # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
339
+
340
+ if heads != kv_heads:
341
+ # Repeat interleave kv_heads to match q_heads
342
+ heads_per_kv_head = heads // kv_heads
343
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
344
+
345
+ if k.ndim == 3:
346
+ k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
347
+
348
+ if v.ndim == 3:
349
+ v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
350
+
351
+ causal = self.causal if causal is None else causal
352
+
353
+ if q_len == 1 and causal:
354
+ causal = False
355
+
356
+ if mask is not None:
357
+ assert mask.ndim == 4
358
+ mask = mask.expand(batch, heads, q_len, k_len)
359
+
360
+ # handle kv cache - this should be bypassable in updated flash attention 2
361
+
362
+ if k_len > q_len and causal:
363
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
364
+ if mask is None:
365
+ mask = ~causal_mask
366
+ else:
367
+ mask = mask & ~causal_mask
368
+ causal = False
369
+
370
+ # manually handle causal mask, if another mask was given
371
+
372
+ row_is_entirely_masked = None
373
+
374
+ if mask is not None and causal:
375
+ causal_mask = self.create_causal_mask(q_len, k_len, device = device)
376
+ mask = mask & ~causal_mask
377
+
378
+ # protect against an entire row being masked out
379
+
380
+ row_is_entirely_masked = ~mask.any(dim = -1)
381
+ mask[..., 0] = mask[..., 0] | row_is_entirely_masked
382
+
383
+ causal = False
384
+
385
+ with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):
386
+ out = F.scaled_dot_product_attention(
387
+ q, k, v,
388
+ attn_mask = mask,
389
+ is_causal = causal
390
+ )
391
+
392
+ # for a row that is entirely masked out, should zero out the output of that row token
393
+
394
+ if row_is_entirely_masked is not None:
395
+ out = out.masked_fill(row_is_entirely_masked[..., None], 0.)
396
+
397
+ return out
398
+
399
+ def forward(
400
+ self,
401
+ x,
402
+ context = None,
403
+ mask = None,
404
+ context_mask = None,
405
+ rotary_pos_emb = None,
406
+ causal = None
407
+ ):
408
+ h, kv_h, has_context = self.num_heads, self.kv_heads, context is not None
409
+ kv_input = context if has_context else x
410
+
411
+ if hasattr(self, 'to_q'):
412
+ # Use separate linear projections for q and k/v
413
+ q = self.to_q(x)
414
+ q = rearrange(q, 'b n (h d) -> b h n d', h = h)
415
+
416
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
417
+
418
+ k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = kv_h), (k, v))
419
+ else:
420
+ # Use fused linear projection
421
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
422
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
423
+
424
+ # Normalize q and k for cosine sim attention
425
+ if self.qk_norm == "l2":
426
+ q = F.normalize(q, dim=-1)
427
+ k = F.normalize(k, dim=-1)
428
+ elif self.qk_norm == "ln":
429
+ q = self.q_norm(q)
430
+ k = self.k_norm(k)
431
+ elif self.qk_norm == "rns":
432
+ q = self.q_norm(q)
433
+ k = self.k_norm(k)
434
+
435
+ if rotary_pos_emb is not None and not has_context:
436
+ freqs, _ = rotary_pos_emb
437
+
438
+ q_dtype = q.dtype
439
+ k_dtype = k.dtype
440
+
441
+ q = q.to(torch.float32)
442
+ k = k.to(torch.float32)
443
+ freqs = freqs.to(torch.float32)
444
+
445
+ q = apply_rotary_pos_emb(q, freqs)
446
+ k = apply_rotary_pos_emb(k, freqs)
447
+
448
+ q = q.to(q_dtype)
449
+ k = k.to(k_dtype)
450
+
451
+ input_mask = context_mask
452
+
453
+ if input_mask is None and not has_context:
454
+ input_mask = mask
455
+
456
+ # determine masking
457
+ masks = []
458
+ final_attn_mask = None # The mask that will be applied to the attention matrix, taking all masks into account
459
+
460
+ if input_mask is not None:
461
+ input_mask = rearrange(input_mask, 'b j -> b 1 1 j')
462
+ masks.append(~input_mask)
463
+
464
+ # Other masks will be added here later
465
+
466
+ if len(masks) > 0:
467
+ final_attn_mask = ~or_reduce(masks)
468
+
469
+ n, device = q.shape[-2], q.device
470
+
471
+ causal = self.causal if causal is None else causal
472
+
473
+ if n == 1 and causal:
474
+ causal = False
475
+
476
+ if self.natten_kernel_size is not None:
477
+ if natten is None:
478
+ raise ImportError('natten not installed, please install natten to use neighborhood attention')
479
+
480
+ dtype_in = q.dtype
481
+ q, k, v = map(lambda t: t.to(torch.float32), (q, k, v))
482
+
483
+ attn = natten.functional.natten1dqk(q, k, kernel_size = self.natten_kernel_size, dilation=1)
484
+
485
+ if final_attn_mask is not None:
486
+ attn = attn.masked_fill(final_attn_mask, -torch.finfo(attn.dtype).max)
487
+
488
+ attn = F.softmax(attn, dim=-1, dtype=torch.float32)
489
+
490
+ out = natten.functional.natten1dav(attn, v, kernel_size = self.natten_kernel_size, dilation=1).to(dtype_in)
491
+
492
+ # Prioritize Flash Attention 2
493
+ elif self.use_fa_flash:
494
+ assert final_attn_mask is None, 'masking not yet supported for Flash Attention 2'
495
+ # Flash Attention 2 requires FP16 inputs
496
+ fa_dtype_in = q.dtype
497
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.float16), (q, k, v))
498
+
499
+ out = flash_attn_func(q, k, v, causal = causal)
500
+
501
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b h n d')
502
+
503
+ # Fall back to PyTorch implementation
504
+ elif self.use_pt_flash:
505
+ out = self.flash_attn(q, k, v, causal = causal, mask = final_attn_mask)
506
+
507
+ else:
508
+ # Fall back to custom implementation
509
+
510
+ if h != kv_h:
511
+ # Repeat interleave kv_heads to match q_heads
512
+ heads_per_kv_head = h // kv_h
513
+ k, v = map(lambda t: t.repeat_interleave(heads_per_kv_head, dim = 1), (k, v))
514
+
515
+ scale = 1. / (q.shape[-1] ** 0.5)
516
+
517
+ kv_einsum_eq = 'b j d' if k.ndim == 3 else 'b h j d'
518
+
519
+ dots = einsum(f'b h i d, {kv_einsum_eq} -> b h i j', q, k) * scale
520
+
521
+ i, j, dtype = *dots.shape[-2:], dots.dtype
522
+
523
+ mask_value = -torch.finfo(dots.dtype).max
524
+
525
+ if final_attn_mask is not None:
526
+ dots = dots.masked_fill(~final_attn_mask, mask_value)
527
+
528
+ if causal:
529
+ causal_mask = self.create_causal_mask(i, j, device = device)
530
+ dots = dots.masked_fill(causal_mask, mask_value)
531
+
532
+ attn = F.softmax(dots, dim=-1, dtype=torch.float32)
533
+ attn = attn.type(dtype)
534
+
535
+ out = einsum(f'b h i j, {kv_einsum_eq} -> b h i d', attn, v)
536
+
537
+ # merge heads
538
+ out = rearrange(out, ' b h n d -> b n (h d)')
539
+
540
+ # Communicate between heads
541
+
542
+ # with autocast(enabled = False):
543
+ # out_dtype = out.dtype
544
+ # out = out.to(torch.float32)
545
+ # out = self.to_out(out).to(out_dtype)
546
+ out = self.to_out(out)
547
+
548
+ if mask is not None:
549
+ mask = rearrange(mask, 'b n -> b n 1')
550
+ out = out.masked_fill(~mask, 0.)
551
+
552
+ return out
553
+
554
+ class ConformerModule(nn.Module):
555
+ def __init__(
556
+ self,
557
+ dim,
558
+ norm_kwargs = {},
559
+ ):
560
+
561
+ super().__init__()
562
+
563
+ self.dim = dim
564
+
565
+ self.in_norm = LayerNorm(dim, **norm_kwargs)
566
+ self.pointwise_conv = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
567
+ self.glu = GLU(dim, dim, nn.SiLU())
568
+ self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=17, groups=dim, padding=8, bias=False)
569
+ self.mid_norm = LayerNorm(dim, **norm_kwargs) # This is a batch norm in the original but I don't like batch norm
570
+ self.swish = nn.SiLU()
571
+ self.pointwise_conv_2 = nn.Conv1d(dim, dim, kernel_size=1, bias=False)
572
+
573
+ def forward(self, x):
574
+ x = self.in_norm(x)
575
+ x = rearrange(x, 'b n d -> b d n')
576
+ x = self.pointwise_conv(x)
577
+ x = rearrange(x, 'b d n -> b n d')
578
+ x = self.glu(x)
579
+ x = rearrange(x, 'b n d -> b d n')
580
+ x = self.depthwise_conv(x)
581
+ x = rearrange(x, 'b d n -> b n d')
582
+ x = self.mid_norm(x)
583
+ x = self.swish(x)
584
+ x = rearrange(x, 'b n d -> b d n')
585
+ x = self.pointwise_conv_2(x)
586
+ x = rearrange(x, 'b d n -> b n d')
587
+
588
+ return x
589
+
590
+ class TransformerBlock(nn.Module):
591
+ def __init__(
592
+ self,
593
+ dim,
594
+ dim_heads = 64,
595
+ cross_attend = False,
596
+ dim_context = None,
597
+ global_cond_dim = None,
598
+ causal = False,
599
+ zero_init_branch_outputs = True,
600
+ conformer = False,
601
+ layer_ix = -1,
602
+ remove_norms = False,
603
+ attn_kwargs = {},
604
+ ff_kwargs = {},
605
+ norm_kwargs = {}
606
+ ):
607
+
608
+ super().__init__()
609
+ self.dim = dim
610
+ self.dim_heads = dim_heads
611
+ self.cross_attend = cross_attend
612
+ self.dim_context = dim_context
613
+ self.causal = causal
614
+
615
+ self.pre_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
616
+
617
+ self.self_attn = Attention(
618
+ dim,
619
+ dim_heads = dim_heads,
620
+ causal = causal,
621
+ zero_init_output=zero_init_branch_outputs,
622
+ **attn_kwargs
623
+ )
624
+
625
+ if cross_attend:
626
+ self.cross_attend_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
627
+ self.cross_attn = Attention(
628
+ dim,
629
+ dim_heads = dim_heads,
630
+ dim_context=dim_context,
631
+ causal = causal,
632
+ zero_init_output=zero_init_branch_outputs,
633
+ **attn_kwargs
634
+ )
635
+
636
+ self.ff_norm = LayerNorm(dim, **norm_kwargs) if not remove_norms else nn.Identity()
637
+ self.ff = FeedForward(dim, zero_init_output=zero_init_branch_outputs, **ff_kwargs)
638
+
639
+ self.layer_ix = layer_ix
640
+
641
+ self.conformer = ConformerModule(dim, norm_kwargs=norm_kwargs) if conformer else None
642
+
643
+ self.global_cond_dim = global_cond_dim
644
+
645
+ if global_cond_dim is not None:
646
+ self.to_scale_shift_gate = nn.Sequential(
647
+ nn.SiLU(),
648
+ nn.Linear(global_cond_dim, dim * 6, bias=False)
649
+ )
650
+
651
+ nn.init.zeros_(self.to_scale_shift_gate[1].weight)
652
+ #nn.init.zeros_(self.to_scale_shift_gate_self[1].bias)
653
+
654
+ def forward(
655
+ self,
656
+ x,
657
+ context = None,
658
+ global_cond=None,
659
+ mask = None,
660
+ context_mask = None,
661
+ rotary_pos_emb = None
662
+ ):
663
+ if self.global_cond_dim is not None and self.global_cond_dim > 0 and global_cond is not None:
664
+
665
+ scale_self, shift_self, gate_self, scale_ff, shift_ff, gate_ff = self.to_scale_shift_gate(global_cond).unsqueeze(1).chunk(6, dim = -1)
666
+
667
+ # self-attention with adaLN
668
+ residual = x
669
+ x = self.pre_norm(x)
670
+ x = x * (1 + scale_self) + shift_self
671
+ x = self.self_attn(x, mask = mask, rotary_pos_emb = rotary_pos_emb)
672
+ x = x * torch.sigmoid(1 - gate_self)
673
+ x = x + residual
674
+
675
+ if context is not None:
676
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
677
+
678
+ if self.conformer is not None:
679
+ x = x + self.conformer(x)
680
+
681
+ # feedforward with adaLN
682
+ residual = x
683
+ x = self.ff_norm(x)
684
+ x = x * (1 + scale_ff) + shift_ff
685
+ x = self.ff(x)
686
+ x = x * torch.sigmoid(1 - gate_ff)
687
+ x = x + residual
688
+
689
+ else:
690
+ x = x + self.self_attn(self.pre_norm(x), mask = mask, rotary_pos_emb = rotary_pos_emb)
691
+
692
+ if context is not None:
693
+ x = x + self.cross_attn(self.cross_attend_norm(x), context = context, context_mask = context_mask)
694
+
695
+ if self.conformer is not None:
696
+ x = x + self.conformer(x)
697
+
698
+ x = x + self.ff(self.ff_norm(x))
699
+
700
+ return x
701
+
702
+ class ContinuousTransformer(nn.Module):
703
+ def __init__(
704
+ self,
705
+ dim,
706
+ depth,
707
+ *,
708
+ dim_in = None,
709
+ dim_out = None,
710
+ dim_heads = 64,
711
+ cross_attend=False,
712
+ cond_token_dim=None,
713
+ global_cond_dim=None,
714
+ causal=False,
715
+ rotary_pos_emb=True,
716
+ zero_init_branch_outputs=True,
717
+ conformer=False,
718
+ use_sinusoidal_emb=False,
719
+ use_abs_pos_emb=False,
720
+ abs_pos_emb_max_length=10000,
721
+ **kwargs
722
+ ):
723
+
724
+ super().__init__()
725
+
726
+ self.dim = dim
727
+ self.depth = depth
728
+ self.causal = causal
729
+ self.layers = nn.ModuleList([])
730
+
731
+ self.project_in = nn.Linear(dim_in, dim, bias=False) if dim_in is not None else nn.Identity()
732
+ self.project_out = nn.Linear(dim, dim_out, bias=False) if dim_out is not None else nn.Identity()
733
+
734
+ if rotary_pos_emb:
735
+ self.rotary_pos_emb = RotaryEmbedding(max(dim_heads // 2, 32))
736
+ else:
737
+ self.rotary_pos_emb = None
738
+
739
+ self.use_sinusoidal_emb = use_sinusoidal_emb
740
+ if use_sinusoidal_emb:
741
+ self.pos_emb = ScaledSinusoidalEmbedding(dim)
742
+
743
+ self.use_abs_pos_emb = use_abs_pos_emb
744
+ if use_abs_pos_emb:
745
+ self.pos_emb = AbsolutePositionalEmbedding(dim, abs_pos_emb_max_length)
746
+
747
+ for i in range(depth):
748
+ self.layers.append(
749
+ TransformerBlock(
750
+ dim,
751
+ dim_heads = dim_heads,
752
+ cross_attend = cross_attend,
753
+ dim_context = cond_token_dim,
754
+ global_cond_dim = global_cond_dim,
755
+ causal = causal,
756
+ zero_init_branch_outputs = zero_init_branch_outputs,
757
+ conformer=conformer,
758
+ layer_ix=i,
759
+ **kwargs
760
+ )
761
+ )
762
+
763
+ def forward(
764
+ self,
765
+ x,
766
+ mask = None,
767
+ prepend_embeds = None,
768
+ prepend_mask = None,
769
+ add_cond = None,
770
+ global_cond = None,
771
+ return_info = False,
772
+ **kwargs
773
+ ):
774
+ batch, seq, device = *x.shape[:2], x.device
775
+
776
+ info = {
777
+ "hidden_states": [],
778
+ }
779
+
780
+ x = self.project_in(x)
781
+ if add_cond is not None:
782
+ x = x + add_cond
783
+
784
+ if prepend_embeds is not None:
785
+ prepend_length, prepend_dim = prepend_embeds.shape[1:]
786
+
787
+ assert prepend_dim == x.shape[-1], 'prepend dimension must match sequence dimension'
788
+
789
+ x = torch.cat((prepend_embeds, x), dim = -2)
790
+
791
+ if prepend_mask is not None or mask is not None:
792
+ mask = mask if mask is not None else torch.ones((batch, seq), device = device, dtype = torch.bool)
793
+ prepend_mask = prepend_mask if prepend_mask is not None else torch.ones((batch, prepend_length), device = device, dtype = torch.bool)
794
+
795
+ mask = torch.cat((prepend_mask, mask), dim = -1)
796
+
797
+
798
+ # Attention layers
799
+
800
+ if self.rotary_pos_emb is not None:
801
+ rotary_pos_emb = self.rotary_pos_emb.forward_from_seq_len(x.shape[1])
802
+ else:
803
+ rotary_pos_emb = None
804
+
805
+ if self.use_sinusoidal_emb or self.use_abs_pos_emb:
806
+ x = x + self.pos_emb(x)
807
+
808
+ # Iterate over the transformer layers
809
+ for layer in self.layers:
810
+ #x = layer(x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
811
+ x = checkpoint(layer, x, rotary_pos_emb = rotary_pos_emb, global_cond=global_cond, **kwargs)
812
+
813
+ if return_info:
814
+ info["hidden_states"].append(x)
815
+
816
+ x = self.project_out(x)
817
+
818
+ if return_info:
819
+ return x, info
820
+
821
+ return x
ThinkSound/models/transformer_layers.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from einops import rearrange
7
+ from einops.layers.torch import Rearrange
8
+
9
+ from .embeddings import apply_rope
10
+ from .blocks import MLP, ChannelLastConv1d, ConvMLP
11
+ try:
12
+ from flash_attn import flash_attn_func, flash_attn_kvpacked_func
13
+ print('flash_attn installed, using Flash Attention')
14
+ except ImportError as e:
15
+ print(e)
16
+ print('flash_attn not installed, disabling Flash Attention')
17
+ flash_attn_kvpacked_func = None
18
+ flash_attn_func = None
19
+
20
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor):
21
+ return x * (1 + scale) + shift
22
+
23
+
24
+ def attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
25
+ # training will crash without these contiguous calls and the CUDNN limitation
26
+ # I believe this is related to https://github.com/pytorch/pytorch/issues/133974
27
+ # unresolved at the time of writing
28
+ fa_dtype_in = q.dtype
29
+
30
+ q = q.contiguous()
31
+ k = k.contiguous()
32
+ v = v.contiguous()
33
+ out = F.scaled_dot_product_attention(q, k, v)
34
+ out = rearrange(out, 'b h n d -> b n (h d)').contiguous()
35
+ return out
36
+ q, k, v = map(lambda t: rearrange(t, 'b h n d -> b n h d').to(torch.bfloat16), (q, k, v))
37
+ # print(f"q dtype: {q.dtype}")
38
+ # print(f"k dtype: {k.dtype}")
39
+ # print(f"v dtype: {v.dtype}")
40
+ # breakpoint()
41
+ out = flash_attn_func(q, k, v)
42
+ out = rearrange(out.to(fa_dtype_in), 'b n h d -> b n (h d)')
43
+ # out = rearrange(out.to(fa_dtype_in), 'b h n d -> b n (h d)').contiguous()
44
+ return out
45
+
46
+
47
+ class SelfAttention(nn.Module):
48
+
49
+ def __init__(self, dim: int, nheads: int):
50
+ super().__init__()
51
+ self.dim = dim
52
+ self.nheads = nheads
53
+
54
+ self.qkv = nn.Linear(dim, dim * 3, bias=True)
55
+ self.q_norm = nn.RMSNorm(dim // nheads)
56
+ self.k_norm = nn.RMSNorm(dim // nheads)
57
+
58
+ self.split_into_heads = Rearrange('b n (h d j) -> b h n d j',
59
+ h=nheads,
60
+ d=dim // nheads,
61
+ j=3)
62
+
63
+ def pre_attention(
64
+ self, x: torch.Tensor,
65
+ rot: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
66
+ # x: batch_size * n_tokens * n_channels
67
+ qkv = self.qkv(x)
68
+ q, k, v = self.split_into_heads(qkv).chunk(3, dim=-1)
69
+ q = q.squeeze(-1)
70
+ k = k.squeeze(-1)
71
+ v = v.squeeze(-1)
72
+ q = self.q_norm(q)
73
+ k = self.k_norm(k)
74
+
75
+ if rot is not None:
76
+ q = apply_rope(q, rot)
77
+ k = apply_rope(k, rot)
78
+
79
+ return q, k, v
80
+
81
+ def forward(
82
+ self,
83
+ x: torch.Tensor, # batch_size * n_tokens * n_channels
84
+ ) -> torch.Tensor:
85
+ q, v, k = self.pre_attention(x)
86
+ out = attention(q, k, v)
87
+ return out
88
+
89
+ class CrossAttention(nn.Module):
90
+
91
+ def __init__(self, dim: int, nheads: int):
92
+ super().__init__()
93
+ self.dim = dim
94
+ self.nheads = nheads
95
+
96
+ self.to_q = nn.Linear(dim, dim, bias=False)
97
+ self.to_kv = nn.Linear(dim, dim * 2, bias=False)
98
+ self.q_norm = nn.RMSNorm(dim // nheads)
99
+ self.k_norm = nn.RMSNorm(dim // nheads)
100
+
101
+ self.split_q_into_heads = Rearrange('b n (h d) -> b h n d',
102
+ h=nheads,
103
+ d=dim // nheads)
104
+ self.split_kv_into_heads = Rearrange('b n (h d j) -> b h n d j',
105
+ h=nheads,
106
+ d=dim // nheads,
107
+ j=2)
108
+
109
+ def pre_attention(
110
+ self, x: torch.Tensor,
111
+ context: Optional[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
112
+ # x: batch_size * n_tokens * n_channels
113
+ q = self.to_q(x)
114
+ kv = self.to_kv(context)
115
+ q = self.split_q_into_heads(q)
116
+ k, v = self.split_kv_into_heads(kv).chunk(2, dim=-1)
117
+ q = q.squeeze(-1)
118
+ k = k.squeeze(-1)
119
+ v = v.squeeze(-1)
120
+ q = self.q_norm(q)
121
+ k = self.k_norm(k)
122
+
123
+
124
+ return q, k, v
125
+
126
+ def forward(
127
+ self,
128
+ x: torch.Tensor, context=None
129
+ ) -> torch.Tensor:
130
+ q, v, k = self.pre_attention(x, context=context)
131
+ out = attention(q, k, v)
132
+ return out
133
+
134
+
135
+ class MMDitSingleBlock(nn.Module):
136
+
137
+ def __init__(self,
138
+ dim: int,
139
+ nhead: int,
140
+ mlp_ratio: float = 4.0,
141
+ pre_only: bool = False,
142
+ kernel_size: int = 7,
143
+ padding: int = 3,
144
+ cross_attend: bool = False):
145
+ super().__init__()
146
+ self.norm1 = nn.LayerNorm(dim, elementwise_affine=False)
147
+ self.attn = SelfAttention(dim, nhead)
148
+ if cross_attend:
149
+ self.cross_attn = CrossAttention(dim, nhead)
150
+ self.pre_only = pre_only
151
+ if pre_only:
152
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
153
+ else:
154
+ if kernel_size == 1:
155
+ self.linear1 = nn.Linear(dim, dim)
156
+ else:
157
+ self.linear1 = ChannelLastConv1d(dim, dim, kernel_size=kernel_size, padding=padding)
158
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False)
159
+
160
+ if kernel_size == 1:
161
+ self.ffn = MLP(dim, int(dim * mlp_ratio))
162
+ else:
163
+ self.ffn = ConvMLP(dim,
164
+ int(dim * mlp_ratio),
165
+ kernel_size=kernel_size,
166
+ padding=padding)
167
+
168
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 6 * dim, bias=True))
169
+
170
+ def pre_attention(self, x: torch.Tensor, c: torch.Tensor, rot: Optional[torch.Tensor]):
171
+ # x: BS * N * D
172
+ # cond: BS * D
173
+ modulation = self.adaLN_modulation(c)
174
+ if self.pre_only:
175
+ (shift_msa, scale_msa) = modulation.chunk(2, dim=-1)
176
+ gate_msa = shift_mlp = scale_mlp = gate_mlp = None
177
+ else:
178
+ (shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp,
179
+ gate_mlp) = modulation.chunk(6, dim=-1)
180
+
181
+ x = modulate(self.norm1(x), shift_msa, scale_msa)
182
+ q, k, v = self.attn.pre_attention(x, rot)
183
+ return (q, k, v), (gate_msa, shift_mlp, scale_mlp, gate_mlp)
184
+
185
+ def post_attention(self, x: torch.Tensor, attn_out: torch.Tensor, c: tuple[torch.Tensor], context=None):
186
+ if self.pre_only:
187
+ return x
188
+
189
+ (gate_msa, shift_mlp, scale_mlp, gate_mlp) = c
190
+ x = x + self.linear1(attn_out) * gate_msa
191
+
192
+ if context is not None:
193
+ x = x + self.cross_attn(x, context=context)
194
+
195
+ r = modulate(self.norm2(x), shift_mlp, scale_mlp)
196
+ x = x + self.ffn(r) * gate_mlp
197
+
198
+ return x
199
+
200
+ def forward(self, x: torch.Tensor, cond: torch.Tensor,
201
+ rot: Optional[torch.Tensor], context: torch.Tensor = None) -> torch.Tensor:
202
+ # x: BS * N * D
203
+ # cond: BS * D
204
+ x_qkv, x_conditions = self.pre_attention(x, cond, rot)
205
+ attn_out = attention(*x_qkv)
206
+ x = self.post_attention(x, attn_out, x_conditions, context = context)
207
+
208
+ return x
209
+
210
+
211
+ class JointBlock(nn.Module):
212
+
213
+ def __init__(self, dim: int, nhead: int, mlp_ratio: float = 4.0, pre_only: bool = False):
214
+ super().__init__()
215
+ self.pre_only = pre_only
216
+ self.latent_block = MMDitSingleBlock(dim,
217
+ nhead,
218
+ mlp_ratio,
219
+ pre_only=False,
220
+ kernel_size=3,
221
+ padding=1)
222
+ self.clip_block = MMDitSingleBlock(dim,
223
+ nhead,
224
+ mlp_ratio,
225
+ pre_only=pre_only,
226
+ kernel_size=3,
227
+ padding=1)
228
+ self.text_block = MMDitSingleBlock(dim, nhead, mlp_ratio, pre_only=pre_only, kernel_size=1)
229
+
230
+ def forward(self, latent: torch.Tensor, clip_f: torch.Tensor, text_f: torch.Tensor,
231
+ global_c: torch.Tensor, extended_c: torch.Tensor, latent_rot: torch.Tensor,
232
+ clip_rot: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
233
+ # latent: BS * N1 * D
234
+ # clip_f: BS * N2 * D
235
+ # c: BS * (1/N) * D
236
+ x_qkv, x_mod = self.latent_block.pre_attention(latent, extended_c, latent_rot)
237
+ c_qkv, c_mod = self.clip_block.pre_attention(clip_f, global_c, clip_rot)
238
+ t_qkv, t_mod = self.text_block.pre_attention(text_f, global_c, rot=None)
239
+
240
+ latent_len = latent.shape[1]
241
+ clip_len = clip_f.shape[1]
242
+ text_len = text_f.shape[1]
243
+
244
+ joint_qkv = [torch.cat([x_qkv[i], c_qkv[i], t_qkv[i]], dim=2) for i in range(3)]
245
+
246
+ attn_out = attention(*joint_qkv)
247
+ x_attn_out = attn_out[:, :latent_len]
248
+ c_attn_out = attn_out[:, latent_len:latent_len + clip_len]
249
+ t_attn_out = attn_out[:, latent_len + clip_len:]
250
+
251
+ latent = self.latent_block.post_attention(latent, x_attn_out, x_mod)
252
+ if not self.pre_only:
253
+ clip_f = self.clip_block.post_attention(clip_f, c_attn_out, c_mod)
254
+ text_f = self.text_block.post_attention(text_f, t_attn_out, t_mod)
255
+
256
+ return latent, clip_f, text_f
257
+
258
+
259
+ class FinalBlock(nn.Module):
260
+
261
+ def __init__(self, dim, out_dim):
262
+ super().__init__()
263
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim, 2 * dim, bias=True))
264
+ self.norm = nn.LayerNorm(dim, elementwise_affine=False)
265
+ self.conv = ChannelLastConv1d(dim, out_dim, kernel_size=7, padding=3)
266
+
267
+ def forward(self, latent, c):
268
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
269
+ latent = modulate(self.norm(latent), shift, scale)
270
+ latent = self.conv(latent)
271
+ return latent
ThinkSound/models/utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from safetensors.torch import load_file
3
+ from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
4
+ from torchcubicspline import natural_cubic_spline_coeffs, NaturalCubicSpline
5
+ from torch.nn.utils import remove_weight_norm
6
+
7
+ def load_ckpt_state_dict(ckpt_path, prefix=None):
8
+ if ckpt_path.endswith(".safetensors"):
9
+ state_dict = load_file(ckpt_path)
10
+ else:
11
+ state_dict = torch.load(ckpt_path, map_location="cpu")["state_dict"]
12
+
13
+ # 过滤特定前缀的state_dict
14
+ filtered_state_dict = {k.replace(f'{prefix}',''): v for k, v in state_dict.items() if k.startswith(prefix)} if prefix is not None else state_dict
15
+
16
+ return filtered_state_dict
17
+
18
+ def remove_weight_norm_from_model(model):
19
+ for module in model.modules():
20
+ if hasattr(module, "weight"):
21
+ print(f"Removing weight norm from {module}")
22
+ remove_weight_norm(module)
23
+
24
+ return model
25
+
26
+ # Sampling functions copied from https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/utils/utils.py under MIT license
27
+ # License can be found in LICENSES/LICENSE_META.txt
28
+
29
+ def multinomial(input: torch.Tensor, num_samples: int, replacement=False, *, generator=None):
30
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
31
+
32
+ Args:
33
+ input (torch.Tensor): The input tensor containing probabilities.
34
+ num_samples (int): Number of samples to draw.
35
+ replacement (bool): Whether to draw with replacement or not.
36
+ Keywords args:
37
+ generator (torch.Generator): A pseudorandom number generator for sampling.
38
+ Returns:
39
+ torch.Tensor: Last dimension contains num_samples indices
40
+ sampled from the multinomial probability distribution
41
+ located in the last dimension of tensor input.
42
+ """
43
+
44
+ if num_samples == 1:
45
+ q = torch.empty_like(input).exponential_(1, generator=generator)
46
+ return torch.argmax(input / q, dim=-1, keepdim=True).to(torch.int64)
47
+
48
+ input_ = input.reshape(-1, input.shape[-1])
49
+ output_ = torch.multinomial(input_, num_samples=num_samples, replacement=replacement, generator=generator)
50
+ output = output_.reshape(*list(input.shape[:-1]), -1)
51
+ return output
52
+
53
+
54
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
55
+ """Sample next token from top K values along the last dimension of the input probs tensor.
56
+
57
+ Args:
58
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
59
+ k (int): The k in “top-k”.
60
+ Returns:
61
+ torch.Tensor: Sampled tokens.
62
+ """
63
+ top_k_value, _ = torch.topk(probs, k, dim=-1)
64
+ min_value_top_k = top_k_value[..., [-1]]
65
+ probs *= (probs >= min_value_top_k).float()
66
+ probs.div_(probs.sum(dim=-1, keepdim=True))
67
+ next_token = multinomial(probs, num_samples=1)
68
+ return next_token
69
+
70
+
71
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
72
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
73
+
74
+ Args:
75
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
76
+ p (int): The p in “top-p”.
77
+ Returns:
78
+ torch.Tensor: Sampled tokens.
79
+ """
80
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
81
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
82
+ mask = probs_sum - probs_sort > p
83
+ probs_sort *= (~mask).float()
84
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
85
+ next_token = multinomial(probs_sort, num_samples=1)
86
+ next_token = torch.gather(probs_idx, -1, next_token)
87
+ return next_token
88
+
89
+ def next_power_of_two(n):
90
+ return 2 ** (n - 1).bit_length()
91
+
92
+ def next_multiple_of_64(n):
93
+ return ((n + 63) // 64) * 64
94
+
95
+
96
+ # mask construction helpers
97
+
98
+ def mask_from_start_end_indices(
99
+ seq_len: int,
100
+ start: Tensor,
101
+ end: Tensor
102
+ ):
103
+ assert start.shape == end.shape
104
+ device = start.device
105
+
106
+ seq = torch.arange(seq_len, device = device, dtype = torch.long)
107
+ seq = seq.reshape(*((-1,) * start.ndim), seq_len)
108
+ seq = seq.expand(*start.shape, seq_len)
109
+
110
+ mask = seq >= start[..., None].long()
111
+ mask &= seq < end[..., None].long()
112
+ return mask
113
+
114
+ def mask_from_frac_lengths(
115
+ seq_len: int,
116
+ frac_lengths: Tensor
117
+ ):
118
+ device = frac_lengths.device
119
+
120
+ lengths = (frac_lengths * seq_len).long()
121
+ max_start = seq_len - lengths
122
+
123
+ rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
124
+ start = (max_start * rand).clamp(min = 0)
125
+ end = start + lengths
126
+
127
+ return mask_from_start_end_indices(seq_len, start, end)
128
+
129
+ def _build_spline(video_feat, video_t, target_t):
130
+ # 三次样条插值核心实现
131
+ coeffs = natural_cubic_spline_coeffs(video_t, video_feat.permute(0,2,1))
132
+ spline = NaturalCubicSpline(coeffs)
133
+ return spline.evaluate(target_t).permute(0,2,1)
134
+
135
+ def resample(video_feat, audio_latent):
136
+ """
137
+ 9s
138
+ video_feat: [B, 72, D]
139
+ audio_latent: [B, D', 194] or int
140
+ """
141
+ B, Tv, D = video_feat.shape
142
+
143
+ if isinstance(audio_latent, torch.Tensor):
144
+ # audio_latent is a tensor
145
+ if audio_latent.shape[1] != D:
146
+ Ta = audio_latent.shape[1]
147
+ else:
148
+ Ta = audio_latent.shape[2]
149
+ elif isinstance(audio_latent, int):
150
+ # audio_latent is an int
151
+ Ta = audio_latent
152
+ else:
153
+ raise TypeError("audio_latent must be either a tensor or an int")
154
+
155
+ # 构建时间戳 (关键改进点)
156
+ video_time = torch.linspace(0, 9, Tv, device=video_feat.device)
157
+ audio_time = torch.linspace(0, 9, Ta, device=video_feat.device)
158
+
159
+ # 三维化处理 (Batch, Feature, Time)
160
+ video_feat = video_feat.permute(0, 2, 1) # [B, D, Tv]
161
+
162
+ # 三次样条插值
163
+ aligned_video = _build_spline(video_feat, video_time, audio_time) # [B, D, Ta]
164
+ return aligned_video.permute(0, 2, 1) # [B, Ta, D]
ThinkSound/training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .factory import create_training_wrapper_from_config, create_demo_callback_from_config
ThinkSound/training/autoencoders.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import wandb
4
+ from einops import rearrange
5
+ from safetensors.torch import save_file, save_model
6
+ from ema_pytorch import EMA
7
+ from .losses.auraloss import SumAndDifferenceSTFTLoss, MultiResolutionSTFTLoss, SpatialSTFTLoss
8
+ # import pytorch_lightning as pl
9
+ import lightning as L
10
+ from lightning.pytorch.callbacks import Callback
11
+ from ..models.autoencoders import AudioAutoencoder
12
+ from ..models.bottleneck import VAEBottleneck, RVQBottleneck, DACRVQBottleneck, DACRVQVAEBottleneck, RVQVAEBottleneck, WassersteinBottleneck
13
+ from .losses import MultiLoss, AuralossLoss, ValueLoss, L1Loss
14
+ from .utils import create_optimizer_from_config, create_scheduler_from_config
15
+
16
+
17
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
18
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
19
+
20
+ class AutoencoderTrainingWrapper(L.LightningModule):
21
+ def __init__(
22
+ self,
23
+ autoencoder: AudioAutoencoder,
24
+ lr: float = 1e-4,
25
+ warmup_steps: int = 0,
26
+ encoder_freeze_on_warmup: bool = False,
27
+ sample_rate=48000,
28
+ loss_config: dict = None,
29
+ optimizer_configs: dict = None,
30
+ use_ema: bool = True,
31
+ ema_copy = None,
32
+ force_input_mono = False,
33
+ latent_mask_ratio = 0.0,
34
+ teacher_model: AudioAutoencoder = None
35
+ ):
36
+ super().__init__()
37
+
38
+ self.automatic_optimization = False
39
+
40
+ self.autoencoder = autoencoder
41
+
42
+ self.warmed_up = False
43
+ self.warmup_steps = warmup_steps
44
+ self.encoder_freeze_on_warmup = encoder_freeze_on_warmup
45
+ self.lr = lr
46
+
47
+ self.force_input_mono = force_input_mono
48
+
49
+ self.teacher_model = teacher_model
50
+
51
+ if optimizer_configs is None:
52
+ optimizer_configs ={
53
+ "autoencoder": {
54
+ "optimizer": {
55
+ "type": "AdamW",
56
+ "config": {
57
+ "lr": lr,
58
+ "betas": (.8, .99)
59
+ }
60
+ }
61
+ },
62
+ "discriminator": {
63
+ "optimizer": {
64
+ "type": "AdamW",
65
+ "config": {
66
+ "lr": lr,
67
+ "betas": (.8, .99)
68
+ }
69
+ }
70
+ }
71
+
72
+ }
73
+
74
+ self.optimizer_configs = optimizer_configs
75
+
76
+ if loss_config is None:
77
+ scales = [2048, 1024, 512, 256, 128, 64, 32]
78
+ hop_sizes = []
79
+ win_lengths = []
80
+ overlap = 0.75
81
+ for s in scales:
82
+ hop_sizes.append(int(s * (1 - overlap)))
83
+ win_lengths.append(s)
84
+
85
+ loss_config = {
86
+ "discriminator": {
87
+ "type": "encodec",
88
+ "config": {
89
+ "n_ffts": scales,
90
+ "hop_lengths": hop_sizes,
91
+ "win_lengths": win_lengths,
92
+ "filters": 32
93
+ },
94
+ "weights": {
95
+ "adversarial": 0.1,
96
+ "feature_matching": 5.0,
97
+ }
98
+ },
99
+ "spectral": {
100
+ "type": "mrstft",
101
+ "config": {
102
+ "fft_sizes": scales,
103
+ "hop_sizes": hop_sizes,
104
+ "win_lengths": win_lengths,
105
+ "perceptual_weighting": True
106
+ },
107
+ "weights": {
108
+ "mrstft": 1.0,
109
+ }
110
+ },
111
+ "time": {
112
+ "type": "l1",
113
+ "config": {},
114
+ "weights": {
115
+ "l1": 0.0,
116
+ }
117
+ }
118
+ }
119
+
120
+ self.loss_config = loss_config
121
+
122
+ # Spectral reconstruction loss
123
+
124
+ stft_loss_args = loss_config['spectral']['config']
125
+
126
+ if self.autoencoder.out_channels == 2:
127
+ self.sdstft = SumAndDifferenceSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
128
+ self.lrstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
129
+ elif self.autoencoder.out_channels == 4:
130
+ # self.sdstft = SpatialSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
131
+ self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
132
+ else:
133
+ self.sdstft = MultiResolutionSTFTLoss(sample_rate=sample_rate, **stft_loss_args)
134
+
135
+ # Discriminator
136
+
137
+ if loss_config['discriminator']['type'] == 'oobleck':
138
+ self.discriminator = OobleckDiscriminator(**loss_config['discriminator']['config'])
139
+ elif loss_config['discriminator']['type'] == 'encodec':
140
+ self.discriminator = EncodecDiscriminator(in_channels=self.autoencoder.out_channels, **loss_config['discriminator']['config'])
141
+ elif loss_config['discriminator']['type'] == 'dac':
142
+ self.discriminator = DACGANLoss(channels=self.autoencoder.out_channels, sample_rate=sample_rate, **loss_config['discriminator']['config'])
143
+
144
+ self.gen_loss_modules = []
145
+
146
+ # Adversarial and feature matching losses
147
+ self.gen_loss_modules += [
148
+ ValueLoss(key='loss_adv', weight=self.loss_config['discriminator']['weights']['adversarial'], name='loss_adv'),
149
+ ValueLoss(key='feature_matching_distance', weight=self.loss_config['discriminator']['weights']['feature_matching'], name='feature_matching'),
150
+ ]
151
+
152
+ if self.teacher_model is not None:
153
+ # Distillation losses
154
+
155
+ stft_loss_weight = self.loss_config['spectral']['weights']['mrstft'] * 0.25
156
+ self.gen_loss_modules += [
157
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=stft_loss_weight), # Reconstruction loss
158
+ AuralossLoss(self.sdstft, 'decoded', 'teacher_decoded', name='mrstft_loss_distill', weight=stft_loss_weight), # Distilled model's decoder is compatible with teacher's decoder
159
+ AuralossLoss(self.sdstft, 'reals', 'own_latents_teacher_decoded', name='mrstft_loss_own_latents_teacher', weight=stft_loss_weight), # Distilled model's encoder is compatible with teacher's decoder
160
+ AuralossLoss(self.sdstft, 'reals', 'teacher_latents_own_decoded', name='mrstft_loss_teacher_latents_own', weight=stft_loss_weight) # Teacher's encoder is compatible with distilled model's decoder
161
+ ]
162
+
163
+ else:
164
+
165
+ # Reconstruction loss
166
+ self.gen_loss_modules += [
167
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
168
+ ]
169
+
170
+ if self.autoencoder.out_channels == 2:
171
+
172
+ # Add left and right channel reconstruction losses in addition to the sum and difference
173
+ self.gen_loss_modules += [
174
+ AuralossLoss(self.lrstft, 'reals_left', 'decoded_left', name='stft_loss_left', weight=self.loss_config['spectral']['weights']['mrstft']/2),
175
+ AuralossLoss(self.lrstft, 'reals_right', 'decoded_right', name='stft_loss_right', weight=self.loss_config['spectral']['weights']['mrstft']/2),
176
+ ]
177
+ elif self.autoencoder.out_channels == 4:
178
+ # self.gen_loss_modules += [
179
+ # AuralossLoss(self.lrstft, 'reals', 'decoded', name='stft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
180
+ # ]
181
+ # Add left and right channel reconstruction losses in addition to the sum and difference
182
+ self.gen_loss_modules += [
183
+ AuralossLoss(self.sdstft, 'reals_w', 'decoded_w', name='stft_loss_w', weight=self.loss_config['spectral']['weights']['mrstft']/4),
184
+ AuralossLoss(self.sdstft, 'reals_x', 'decoded_x', name='stft_loss_x', weight=self.loss_config['spectral']['weights']['mrstft']/4),
185
+ AuralossLoss(self.sdstft, 'reals_y', 'decoded_y', name='stft_loss_y', weight=self.loss_config['spectral']['weights']['mrstft']/4),
186
+ AuralossLoss(self.sdstft, 'reals_z', 'decoded_z', name='stft_loss_z', weight=self.loss_config['spectral']['weights']['mrstft']/4),
187
+ ]
188
+
189
+ self.gen_loss_modules += [
190
+ AuralossLoss(self.sdstft, 'reals', 'decoded', name='mrstft_loss', weight=self.loss_config['spectral']['weights']['mrstft']),
191
+ ]
192
+
193
+ if self.loss_config['time']['weights']['l1'] > 0.0:
194
+ self.gen_loss_modules.append(L1Loss(key_a='reals', key_b='decoded', weight=self.loss_config['time']['weights']['l1'], name='l1_time_loss'))
195
+
196
+ if self.autoencoder.bottleneck is not None:
197
+ self.gen_loss_modules += create_loss_modules_from_bottleneck(self.autoencoder.bottleneck, self.loss_config)
198
+
199
+ self.losses_gen = MultiLoss(self.gen_loss_modules)
200
+
201
+ self.disc_loss_modules = [
202
+ ValueLoss(key='loss_dis', weight=1.0, name='discriminator_loss'),
203
+ ]
204
+
205
+ self.losses_disc = MultiLoss(self.disc_loss_modules)
206
+
207
+ # Set up EMA for model weights
208
+ self.autoencoder_ema = None
209
+
210
+ self.use_ema = use_ema
211
+
212
+ if self.use_ema:
213
+ self.autoencoder_ema = EMA(
214
+ self.autoencoder,
215
+ ema_model=ema_copy,
216
+ beta=0.9999,
217
+ power=3/4,
218
+ update_every=1,
219
+ update_after_step=1
220
+ )
221
+
222
+ self.latent_mask_ratio = latent_mask_ratio
223
+
224
+ def configure_optimizers(self):
225
+
226
+ opt_gen = create_optimizer_from_config(self.optimizer_configs['autoencoder']['optimizer'], self.autoencoder.parameters())
227
+ opt_disc = create_optimizer_from_config(self.optimizer_configs['discriminator']['optimizer'], self.discriminator.parameters())
228
+
229
+ if "scheduler" in self.optimizer_configs['autoencoder'] and "scheduler" in self.optimizer_configs['discriminator']:
230
+ sched_gen = create_scheduler_from_config(self.optimizer_configs['autoencoder']['scheduler'], opt_gen)
231
+ sched_disc = create_scheduler_from_config(self.optimizer_configs['discriminator']['scheduler'], opt_disc)
232
+ return [opt_gen, opt_disc], [sched_gen, sched_disc]
233
+
234
+ return [opt_gen, opt_disc]
235
+
236
+ def training_step(self, batch, batch_idx):
237
+ reals, _ = batch
238
+
239
+ # Remove extra dimension added by WebDataset
240
+ if reals.ndim == 4 and reals.shape[0] == 1:
241
+ reals = reals[0]
242
+
243
+ if self.global_step >= self.warmup_steps:
244
+ self.warmed_up = True
245
+
246
+ loss_info = {}
247
+
248
+ loss_info["reals"] = reals
249
+
250
+ encoder_input = reals
251
+
252
+ if self.force_input_mono and encoder_input.shape[1] > 1:
253
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
254
+
255
+ loss_info["encoder_input"] = encoder_input
256
+
257
+ data_std = encoder_input.std()
258
+
259
+ if self.warmed_up and self.encoder_freeze_on_warmup:
260
+ with torch.no_grad():
261
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
262
+ else:
263
+ latents, encoder_info = self.autoencoder.encode(encoder_input, return_info=True)
264
+
265
+ loss_info["latents"] = latents
266
+
267
+ loss_info.update(encoder_info)
268
+
269
+ # Encode with teacher model for distillation
270
+ if self.teacher_model is not None:
271
+ with torch.no_grad():
272
+ teacher_latents = self.teacher_model.encode(encoder_input, return_info=False)
273
+ loss_info['teacher_latents'] = teacher_latents
274
+
275
+ # Optionally mask out some latents for noise resistance
276
+ if self.latent_mask_ratio > 0.0:
277
+ mask = torch.rand_like(latents) < self.latent_mask_ratio
278
+ latents = torch.where(mask, torch.zeros_like(latents), latents)
279
+ decoded = self.autoencoder.decode(latents)
280
+
281
+ loss_info["decoded"] = decoded
282
+
283
+ if self.autoencoder.out_channels == 2:
284
+ loss_info["decoded_left"] = decoded[:, 0:1, :]
285
+ loss_info["decoded_right"] = decoded[:, 1:2, :]
286
+ loss_info["reals_left"] = reals[:, 0:1, :]
287
+ loss_info["reals_right"] = reals[:, 1:2, :]
288
+ elif self.autoencoder.out_channels == 4:
289
+ loss_info["decoded_w"] = decoded[:, 0:1, :]
290
+ loss_info["decoded_x"] = decoded[:, 1:2, :]
291
+ loss_info["decoded_y"] = decoded[:, 2:3, :]
292
+ loss_info["decoded_z"] = decoded[:, 3:4, :]
293
+ loss_info["reals_w"] = reals[:, 0:1, :]
294
+ loss_info["reals_x"] = reals[:, 1:2, :]
295
+ loss_info["reals_y"] = reals[:, 2:3, :]
296
+ loss_info["reals_z"] = reals[:, 3:4, :]
297
+
298
+ # Distillation
299
+ if self.teacher_model is not None:
300
+ with torch.no_grad():
301
+ teacher_decoded = self.teacher_model.decode(teacher_latents)
302
+ own_latents_teacher_decoded = self.teacher_model.decode(latents) #Distilled model's latents decoded by teacher
303
+ teacher_latents_own_decoded = self.autoencoder.decode(teacher_latents) #Teacher's latents decoded by distilled model
304
+
305
+ loss_info['teacher_decoded'] = teacher_decoded
306
+ loss_info['own_latents_teacher_decoded'] = own_latents_teacher_decoded
307
+ loss_info['teacher_latents_own_decoded'] = teacher_latents_own_decoded
308
+
309
+
310
+ if self.warmed_up:
311
+ loss_dis, loss_adv, feature_matching_distance = self.discriminator.loss(reals, decoded)
312
+ else:
313
+ loss_dis = torch.tensor(0.).to(reals)
314
+ loss_adv = torch.tensor(0.).to(reals)
315
+ feature_matching_distance = torch.tensor(0.).to(reals)
316
+
317
+ loss_info["loss_dis"] = loss_dis
318
+ loss_info["loss_adv"] = loss_adv
319
+ loss_info["feature_matching_distance"] = feature_matching_distance
320
+
321
+ opt_gen, opt_disc = self.optimizers()
322
+
323
+ lr_schedulers = self.lr_schedulers()
324
+
325
+ sched_gen = None
326
+ sched_disc = None
327
+
328
+ if lr_schedulers is not None:
329
+ sched_gen, sched_disc = lr_schedulers
330
+
331
+ # Train the discriminator
332
+ if self.global_step % 2 and self.warmed_up:
333
+ loss, losses = self.losses_disc(loss_info)
334
+
335
+ log_dict = {
336
+ 'train/disc_lr': opt_disc.param_groups[0]['lr']
337
+ }
338
+
339
+ opt_disc.zero_grad()
340
+ self.manual_backward(loss)
341
+
342
+
343
+ opt_disc.step()
344
+
345
+ if sched_disc is not None:
346
+ # sched step every step
347
+ sched_disc.step()
348
+
349
+ # Train the generator
350
+ else:
351
+
352
+ # import ipdb
353
+ # ipdb.set_trace()
354
+ loss, losses = self.losses_gen(loss_info)
355
+
356
+ if self.use_ema:
357
+ self.autoencoder_ema.update()
358
+
359
+ opt_gen.zero_grad()
360
+ self.manual_backward(loss)
361
+ opt_gen.step()
362
+
363
+ if sched_gen is not None:
364
+ # scheduler step every step
365
+ sched_gen.step()
366
+
367
+ log_dict = {
368
+ 'train/loss': loss.detach(),
369
+ 'train/latent_std': latents.std().detach(),
370
+ 'train/data_std': data_std.detach(),
371
+ 'train/gen_lr': opt_gen.param_groups[0]['lr']
372
+ }
373
+
374
+ for loss_name, loss_value in losses.items():
375
+ log_dict[f'train/{loss_name}'] = loss_value.detach()
376
+
377
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
378
+
379
+ return loss
380
+
381
+ def export_model(self, path, use_safetensors=False):
382
+ if self.autoencoder_ema is not None:
383
+ model = self.autoencoder_ema.ema_model
384
+ else:
385
+ model = self.autoencoder
386
+
387
+ if use_safetensors:
388
+ save_model(model, path)
389
+ else:
390
+ torch.save({"state_dict": model.state_dict()}, path)
391
+
392
+
393
+ class AutoencoderDemoCallback(Callback):
394
+ def __init__(
395
+ self,
396
+ demo_dl,
397
+ demo_every=2000,
398
+ sample_size=65536,
399
+ sample_rate=48000
400
+ ):
401
+ super().__init__()
402
+ self.demo_every = demo_every
403
+ self.demo_samples = sample_size
404
+ self.demo_dl = iter(demo_dl)
405
+ self.sample_rate = sample_rate
406
+ self.last_demo_step = -1
407
+
408
+ @rank_zero_only
409
+ @torch.no_grad()
410
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
411
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
412
+ return
413
+
414
+ self.last_demo_step = trainer.global_step
415
+
416
+ module.eval()
417
+
418
+ try:
419
+ demo_reals, _ = next(self.demo_dl)
420
+
421
+ # Remove extra dimension added by WebDataset
422
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
423
+ demo_reals = demo_reals[0]
424
+
425
+ encoder_input = demo_reals
426
+
427
+ encoder_input = encoder_input.to(module.device)
428
+
429
+ if module.force_input_mono:
430
+ encoder_input = encoder_input.mean(dim=1, keepdim=True)
431
+
432
+ demo_reals = demo_reals.to(module.device)
433
+
434
+ with torch.no_grad():
435
+ if module.use_ema:
436
+
437
+ latents = module.autoencoder_ema.ema_model.encode(encoder_input)
438
+
439
+ fakes = module.autoencoder_ema.ema_model.decode(latents)
440
+ else:
441
+ latents = module.autoencoder.encode(encoder_input)
442
+
443
+ fakes = module.autoencoder.decode(latents)
444
+
445
+ #Interleave reals and fakes
446
+ reals_fakes = rearrange([demo_reals, fakes], 'i b d n -> (b i) d n')
447
+
448
+ # Put the demos together
449
+ reals_fakes = rearrange(reals_fakes, 'b d n -> d (b n)')
450
+
451
+ log_dict = {}
452
+
453
+ filename = f'demos/recon_{trainer.global_step:08}.wav'
454
+ reals_fakes = reals_fakes.to(torch.float32).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
455
+ torchaudio.save(filename, reals_fakes, self.sample_rate)
456
+
457
+ log_dict[f'recon'] = wandb.Audio(filename,
458
+ sample_rate=self.sample_rate,
459
+ caption=f'Reconstructed')
460
+
461
+ log_dict[f'embeddings_3dpca'] = pca_point_cloud(latents)
462
+ log_dict[f'embeddings_spec'] = wandb.Image(tokens_spectrogram_image(latents))
463
+
464
+ log_dict[f'recon_melspec_left'] = wandb.Image(audio_spectrogram_image(reals_fakes))
465
+
466
+ trainer.logger.experiment.log(log_dict)
467
+ except Exception as e:
468
+ print(f'{type(e).__name__}: {e}')
469
+ raise e
470
+ finally:
471
+ module.train()
472
+
473
+ def create_loss_modules_from_bottleneck(bottleneck, loss_config):
474
+ losses = []
475
+
476
+ if isinstance(bottleneck, VAEBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
477
+ try:
478
+ kl_weight = loss_config['bottleneck']['weights']['kl']
479
+ except:
480
+ kl_weight = 1e-6
481
+
482
+ kl_loss = ValueLoss(key='kl', weight=kl_weight, name='kl_loss')
483
+ losses.append(kl_loss)
484
+
485
+ if isinstance(bottleneck, RVQBottleneck) or isinstance(bottleneck, RVQVAEBottleneck):
486
+ quantizer_loss = ValueLoss(key='quantizer_loss', weight=1.0, name='quantizer_loss')
487
+ losses.append(quantizer_loss)
488
+
489
+ if isinstance(bottleneck, DACRVQBottleneck) or isinstance(bottleneck, DACRVQVAEBottleneck):
490
+ codebook_loss = ValueLoss(key='vq/codebook_loss', weight=1.0, name='codebook_loss')
491
+ commitment_loss = ValueLoss(key='vq/commitment_loss', weight=0.25, name='commitment_loss')
492
+ losses.append(codebook_loss)
493
+ losses.append(commitment_loss)
494
+
495
+ if isinstance(bottleneck, WassersteinBottleneck):
496
+ try:
497
+ mmd_weight = loss_config['bottleneck']['weights']['mmd']
498
+ except:
499
+ mmd_weight = 100
500
+
501
+ mmd_loss = ValueLoss(key='mmd', weight=mmd_weight, name='mmd_loss')
502
+ losses.append(mmd_loss)
503
+
504
+ return losses
ThinkSound/training/diffusion.py ADDED
@@ -0,0 +1,1076 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import pytorch_lightning as pl
2
+ import lightning as L
3
+ from lightning.pytorch.callbacks import Callback
4
+ import sys, gc
5
+ import random
6
+ import torch
7
+ import torchaudio
8
+ import typing as tp
9
+ import wandb
10
+ # from beartype.typing import Tuple
11
+ from aeiou.viz import pca_point_cloud, audio_spectrogram_image, tokens_spectrogram_image
12
+ import auraloss
13
+ from ema_pytorch import EMA
14
+ from einops import rearrange
15
+ from safetensors.torch import save_file
16
+ from torch import optim
17
+ from torch.nn import functional as F
18
+ from pytorch_lightning.utilities.rank_zero import rank_zero_only
19
+
20
+ from ..inference.sampling import get_alphas_sigmas, sample, sample_discrete_euler
21
+ from ..models.diffusion import DiffusionModelWrapper, ConditionedDiffusionModelWrapper
22
+ from ..models.autoencoders import DiffusionAutoencoder
23
+ from .autoencoders import create_loss_modules_from_bottleneck
24
+ from .losses import AuralossLoss, MSELoss, MultiLoss
25
+ from .utils import create_optimizer_from_config, create_scheduler_from_config, mask_from_frac_lengths, generate_mask, generate_channel_mask
26
+ import os
27
+ from pathlib import Path
28
+ from time import time
29
+ import numpy as np
30
+
31
+ class Profiler:
32
+
33
+ def __init__(self):
34
+ self.ticks = [[time(), None]]
35
+
36
+ def tick(self, msg):
37
+ self.ticks.append([time(), msg])
38
+
39
+ def __repr__(self):
40
+ rep = 80 * "=" + "\n"
41
+ for i in range(1, len(self.ticks)):
42
+ msg = self.ticks[i][1]
43
+ ellapsed = self.ticks[i][0] - self.ticks[i - 1][0]
44
+ rep += msg + f": {ellapsed*1000:.2f}ms\n"
45
+ rep += 80 * "=" + "\n\n\n"
46
+ return rep
47
+
48
+ class DiffusionUncondTrainingWrapper(L.LightningModule):
49
+ '''
50
+ Wrapper for training an unconditional audio diffusion model (like Dance Diffusion).
51
+ '''
52
+ def __init__(
53
+ self,
54
+ model: DiffusionModelWrapper,
55
+ lr: float = 1e-4,
56
+ pre_encoded: bool = False
57
+ ):
58
+ super().__init__()
59
+
60
+ self.diffusion = model
61
+
62
+ self.diffusion_ema = EMA(
63
+ self.diffusion.model,
64
+ beta=0.9999,
65
+ power=3/4,
66
+ update_every=1,
67
+ update_after_step=1
68
+ )
69
+
70
+ self.lr = lr
71
+
72
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
73
+
74
+ loss_modules = [
75
+ MSELoss("v",
76
+ "targets",
77
+ weight=1.0,
78
+ name="mse_loss"
79
+ )
80
+ ]
81
+
82
+ self.losses = MultiLoss(loss_modules)
83
+
84
+ self.pre_encoded = pre_encoded
85
+
86
+ def configure_optimizers(self):
87
+ return optim.Adam([*self.diffusion.parameters()], lr=self.lr)
88
+
89
+ def training_step(self, batch, batch_idx):
90
+ reals = batch[0]
91
+
92
+ if reals.ndim == 4 and reals.shape[0] == 1:
93
+ reals = reals[0]
94
+
95
+ diffusion_input = reals
96
+
97
+ loss_info = {}
98
+
99
+ if not self.pre_encoded:
100
+ loss_info["audio_reals"] = diffusion_input
101
+
102
+ if self.diffusion.pretransform is not None:
103
+ if not self.pre_encoded:
104
+ with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
105
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
106
+ else:
107
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
108
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
109
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
110
+
111
+ loss_info["reals"] = diffusion_input
112
+
113
+ # Draw uniformly distributed continuous timesteps
114
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
115
+
116
+ # Calculate the noise schedule parameters for those timesteps
117
+ alphas, sigmas = get_alphas_sigmas(t)
118
+
119
+ # Combine the ground truth data and the noise
120
+ alphas = alphas[:, None, None]
121
+ sigmas = sigmas[:, None, None]
122
+ noise = torch.randn_like(diffusion_input)
123
+ noised_inputs = diffusion_input * alphas + noise * sigmas
124
+ targets = noise * alphas - diffusion_input * sigmas
125
+
126
+ with torch.amp.autocast('cuda'):
127
+ v = self.diffusion(noised_inputs, t)
128
+
129
+ loss_info.update({
130
+ "v": v,
131
+ "targets": targets
132
+ })
133
+
134
+ loss, losses = self.losses(loss_info)
135
+
136
+ log_dict = {
137
+ 'train/loss': loss.detach(),
138
+ 'train/std_data': diffusion_input.std(),
139
+ }
140
+
141
+ for loss_name, loss_value in losses.items():
142
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
143
+
144
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
145
+ return loss
146
+
147
+ def on_before_zero_grad(self, *args, **kwargs):
148
+ self.diffusion_ema.update()
149
+
150
+ def export_model(self, path, use_safetensors=False):
151
+
152
+ self.diffusion.model = self.diffusion_ema.ema_model
153
+
154
+ if use_safetensors:
155
+ save_file(self.diffusion.state_dict(), path)
156
+ else:
157
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
158
+
159
+ class DiffusionUncondDemoCallback(Callback):
160
+ def __init__(self,
161
+ demo_every=2000,
162
+ num_demos=8,
163
+ demo_steps=250,
164
+ sample_rate=48000
165
+ ):
166
+ super().__init__()
167
+
168
+ self.demo_every = demo_every
169
+ self.num_demos = num_demos
170
+ self.demo_steps = demo_steps
171
+ self.sample_rate = sample_rate
172
+ self.last_demo_step = -1
173
+
174
+ @rank_zero_only
175
+ @torch.no_grad()
176
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
177
+
178
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
179
+ return
180
+
181
+ self.last_demo_step = trainer.global_step
182
+
183
+ demo_samples = module.diffusion.sample_size
184
+
185
+ if module.diffusion.pretransform is not None:
186
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
187
+
188
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
189
+
190
+ try:
191
+ with torch.amp.autocast('cuda'):
192
+ fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
193
+
194
+ if module.diffusion.pretransform is not None:
195
+ fakes = module.diffusion.pretransform.decode(fakes)
196
+
197
+ # Put the demos together
198
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
199
+
200
+ log_dict = {}
201
+
202
+ filename = f'demo_{trainer.global_step:08}.wav'
203
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
204
+ torchaudio.save(filename, fakes, self.sample_rate)
205
+
206
+ log_dict[f'demo'] = wandb.Audio(filename,
207
+ sample_rate=self.sample_rate,
208
+ caption=f'Reconstructed')
209
+
210
+ log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
211
+
212
+ trainer.logger.experiment.log(log_dict)
213
+
214
+ del fakes
215
+
216
+ except Exception as e:
217
+ print(f'{type(e).__name__}: {e}')
218
+ finally:
219
+ gc.collect()
220
+ torch.cuda.empty_cache()
221
+
222
+ class DiffusionInfillTrainingWrapper(L.LightningModule):
223
+ '''
224
+ Wrapper for training an unconditional audio diffusion model (like Dance Diffusion).
225
+ '''
226
+ def __init__(
227
+ self,
228
+ model: ConditionedDiffusionModelWrapper,
229
+ lr: float = 1e-4,
230
+ optimizer_configs: dict = None,
231
+ pre_encoded: bool = False,
232
+ frac_lengths_mask = (0.7, 1.),
233
+ min_span_len = 10,
234
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
235
+ diffusion_objective = 'rectified_flow',
236
+ ctx_drop: float = 0.1,
237
+ r_drop: float = 0.0,
238
+ ):
239
+ super().__init__()
240
+
241
+ self.diffusion = model
242
+
243
+ self.diffusion_ema = EMA(
244
+ self.diffusion.model,
245
+ beta=0.9999,
246
+ power=3/4,
247
+ update_every=1,
248
+ update_after_step=1
249
+ )
250
+
251
+ if optimizer_configs is None:
252
+ optimizer_configs = {
253
+ "diffusion": {
254
+ "optimizer": {
255
+ "type": "Adam",
256
+ "config": {
257
+ "lr": lr
258
+ }
259
+ }
260
+ }
261
+ }
262
+ else:
263
+ if lr is not None:
264
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
265
+
266
+ self.optimizer_configs = optimizer_configs
267
+
268
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
269
+ self.frac_lengths_mask = frac_lengths_mask
270
+ self.min_span_len = min_span_len
271
+ self.timestep_sampler = timestep_sampler
272
+ self.ctx_drop = ctx_drop
273
+ self.r_drop = r_drop
274
+ self.diffusion_objective = diffusion_objective
275
+ print(f'Training in the {diffusion_objective} formulation')
276
+ loss_modules = [
277
+ MSELoss("v",
278
+ "targets",
279
+ weight=1.0,
280
+ name="mse_loss",
281
+ mask_key="mask"
282
+ )
283
+ ]
284
+
285
+ self.losses = MultiLoss(loss_modules)
286
+
287
+ self.pre_encoded = pre_encoded
288
+
289
+ def configure_optimizers(self):
290
+ diffusion_opt_config = self.optimizer_configs['diffusion']
291
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
292
+
293
+ if "scheduler" in diffusion_opt_config:
294
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
295
+ sched_diff_config = {
296
+ "scheduler": sched_diff,
297
+ "interval": "step"
298
+ }
299
+ return [opt_diff], [sched_diff_config]
300
+
301
+ return [opt_diff]
302
+
303
+ def training_step(self, batch, batch_idx):
304
+ reals, metadata = batch
305
+ if reals.ndim == 4 and reals.shape[0] == 1:
306
+ reals = reals[0]
307
+ # import ipdb
308
+ # ipdb.set_trace()
309
+ p_drop = torch.rand(1).item()
310
+ # r_drop = torch.rand(1).item()
311
+ # if p_drop >= self.ctx_drop and self.r_drop > 0.0 and r_drop < self.r_drop:
312
+ # generate_channel_mask(reals)
313
+
314
+ diffusion_input = reals
315
+ assert torch.all(torch.isfinite(diffusion_input)), "Non-finite values detected in diffusion_input"
316
+ p = Profiler()
317
+ loss_info = {}
318
+ if not self.pre_encoded:
319
+ loss_info["audio_reals"] = diffusion_input
320
+
321
+ p.tick("setup")
322
+
323
+ conditioning = {}
324
+
325
+ p.tick("conditioning")
326
+ if self.diffusion.pretransform is not None:
327
+ if not self.pre_encoded:
328
+ with torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
329
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
330
+ else:
331
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
332
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
333
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
334
+
335
+ loss_info["reals"] = diffusion_input
336
+
337
+ if self.timestep_sampler == "uniform":
338
+ # Draw uniformly distributed continuous timesteps
339
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
340
+ elif self.timestep_sampler == "logit_normal":
341
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
342
+
343
+ # # Calculate the noise schedule parameters for those timesteps
344
+ # alphas, sigmas = get_alphas_sigmas(t)
345
+ # Calculate the noise schedule parameters for those timesteps
346
+ if self.diffusion_objective == "v":
347
+ alphas, sigmas = get_alphas_sigmas(t)
348
+ elif self.diffusion_objective == "rectified_flow":
349
+ alphas, sigmas = 1-t, t
350
+
351
+ # Combine the ground truth data and the noise
352
+ alphas = alphas[:, None, None]
353
+ sigmas = sigmas[:, None, None]
354
+ noise = torch.randn_like(diffusion_input)
355
+ noised_inputs = diffusion_input * alphas + noise * sigmas
356
+ # x_ctx = diffusion_input.detach().clone().transpose(1,2)
357
+ bsz, dim, seq_len = diffusion_input.shape
358
+
359
+
360
+ if p_drop < self.ctx_drop:
361
+ ctx_mask = torch.ones((bsz, seq_len), device = diffusion_input.device, dtype = torch.bool)
362
+ # elif self.r_drop > 0.0 and r_drop < self.r_drop:
363
+ # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool)
364
+ else:
365
+ # 计算 frac_lengths 提前使用
366
+ frac_lengths = torch.zeros((bsz,), device=diffusion_input.device).uniform_(*self.frac_lengths_mask)
367
+ # if self.r_drop > 0.0 and r_drop < self.r_drop:
368
+ # import ipdb
369
+ # ipdb.set_trace()
370
+
371
+ # ctx_mask = torch.zeros((bsz, seq_len), device=diffusion_input.device, dtype=torch.bool)
372
+ # else:
373
+ ctx_mask = generate_mask(bsz, seq_len, frac_lengths, self.min_span_len)
374
+
375
+ if ctx_mask.dim() == 2:
376
+ ctx_mask = ctx_mask.unsqueeze(1)
377
+ masked_sequence = diffusion_input * ~ctx_mask
378
+ conditioning['x_ctx'] = [masked_sequence]
379
+ if self.diffusion_objective == "v":
380
+ targets = noise * alphas - diffusion_input * sigmas
381
+ elif self.diffusion_objective == "rectified_flow":
382
+ targets = noise - diffusion_input
383
+ with torch.amp.autocast('cuda'):
384
+ p.tick("amp")
385
+ v = self.diffusion(noised_inputs, t, cond=conditioning)
386
+ p.tick("diffusion")
387
+ loss_info.update({
388
+ "v": v,
389
+ "targets": targets,
390
+ "mask": ctx_mask.squeeze(-1)
391
+ })
392
+ # import ipdb
393
+ # ipdb.set_trace()
394
+ loss, losses = self.losses(loss_info)
395
+
396
+ log_dict = {
397
+ 'train/loss': loss.detach(),
398
+ 'train/std_data': diffusion_input.std(),
399
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
400
+ }
401
+
402
+ for loss_name, loss_value in losses.items():
403
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
404
+
405
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
406
+ p.tick("log")
407
+ return loss
408
+
409
+ def on_before_zero_grad(self, *args, **kwargs):
410
+ self.diffusion_ema.update()
411
+
412
+ def export_model(self, path, use_safetensors=False):
413
+
414
+ self.diffusion.model = self.diffusion_ema.ema_model
415
+
416
+ if use_safetensors:
417
+ save_file(self.diffusion.state_dict(), path)
418
+ else:
419
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
420
+
421
+ class DiffusionInfillDemoCallback(Callback):
422
+ def __init__(self,
423
+ demo_dl,
424
+ demo_every=2000,
425
+ num_demos=8,
426
+ demo_steps=250,
427
+ sample_rate=48000
428
+ ):
429
+ super().__init__()
430
+ self.demo_dl = iter(demo_dl)
431
+ self.demo_every = demo_every
432
+ self.num_demos = num_demos
433
+ self.demo_steps = demo_steps
434
+ self.sample_rate = sample_rate
435
+ self.last_demo_step = -1
436
+
437
+ @rank_zero_only
438
+ @torch.no_grad()
439
+ def on_train_batch_end(self, trainer, module, outputs, batch, batch_idx):
440
+
441
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
442
+ return
443
+
444
+ self.last_demo_step = trainer.global_step
445
+
446
+
447
+ try:
448
+ demo_reals, _ = next(self.demo_dl)
449
+ # Remove extra dimension added by WebDataset
450
+ if demo_reals.ndim == 4 and demo_reals.shape[0] == 1:
451
+ demo_reals = demo_reals[0]
452
+
453
+ demo_reals = demo_reals.to(module.device)
454
+ reals = demo_reals
455
+ log_dict = {}
456
+
457
+ if not module.pre_encoded:
458
+ # Log the real audio
459
+ log_dict[f'demo_reals_melspec_left'] = wandb.Image(audio_spectrogram_image(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu()))
460
+ # log_dict[f'demo_reals'] = wandb.Audio(rearrange(demo_reals, "b d n -> d (b n)").mul(32767).to(torch.int16).cpu(), sample_rate=self.sample_rate, caption="demo reals")
461
+
462
+ if module.diffusion.pretransform is not None:
463
+ module.diffusion.pretransform.to(module.device)
464
+ with torch.amp.autocast('cuda'):
465
+ demo_reals = module.diffusion.pretransform.encode(demo_reals)
466
+
467
+ demo_samples = demo_reals.shape[2]
468
+
469
+ # Get conditioning
470
+ conditioning = {}
471
+
472
+ noise = torch.randn([demo_reals.shape[0], module.diffusion.io_channels, demo_samples]).to(module.device)
473
+ frac_lengths = torch.zeros((demo_reals.shape[0],), device = module.device).uniform_(*(0.3,0.5))
474
+ ctx_mask = generate_mask(demo_reals.shape[0],demo_reals.shape[2], frac_lengths, module.min_span_len)
475
+ # x_ctx = (demo_reals * ~ctx_mask.unsqueeze(1)).transpose(1,2)
476
+ x_ctx = demo_reals * ~ctx_mask.unsqueeze(1)
477
+
478
+ conditioning['x_ctx'] = [x_ctx]
479
+ # x_ctx_mask = x_ctx * ~ctx_mask.unsqueeze(-1)
480
+ if module.diffusion.pretransform is not None:
481
+ log_dict[f'demo_masked_input'] = wandb.Image(tokens_spectrogram_image(x_ctx.cpu()))
482
+ else:
483
+ log_dict[f'demo_masked_input'] = wandb.Image(audio_spectrogram_image(rearrange(x_ctx, "b c t -> c (b t)").mul(32767).to(torch.int16).cpu()))
484
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
485
+ with torch.amp.autocast('cuda'):
486
+ if module.diffusion_objective == "v":
487
+ fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
488
+ elif module.diffusion_objective == "rectified_flow":
489
+ fakes = sample_discrete_euler(module.diffusion_ema, noise, self.demo_steps, **cond_inputs)
490
+ # fakes = sample(module.diffusion_ema, noise, self.demo_steps, 0)
491
+
492
+ if module.diffusion.pretransform is not None:
493
+ fakes = module.diffusion.pretransform.decode(fakes)
494
+
495
+ # #Interleave reals and fakes
496
+ # reals_fakes = rearrange([reals, fakes], 'i b d n -> (b i) d n')
497
+ # Put the demos together
498
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
499
+
500
+
501
+ filename = f'results/audio_ssl/demo_ssl_{trainer.global_step:08}.wav'
502
+ os.makedirs(Path(filename).parent,exist_ok=True)
503
+ fakes = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
504
+ torchaudio.save(filename, fakes, self.sample_rate)
505
+
506
+ log_dict[f'demo'] = wandb.Audio(filename,
507
+ sample_rate=self.sample_rate,
508
+ caption=f'Reconstructed')
509
+
510
+ log_dict[f'demo_melspec_left'] = wandb.Image(audio_spectrogram_image(fakes))
511
+
512
+ trainer.logger.experiment.log(log_dict)
513
+
514
+ del fakes
515
+
516
+ except Exception as e:
517
+ print(f'{type(e).__name__}: {e}')
518
+ finally:
519
+ gc.collect()
520
+ torch.cuda.empty_cache()
521
+
522
+ class DiffusionCondTrainingWrapper(L.LightningModule):
523
+ '''
524
+ Wrapper for training a conditional audio diffusion model.
525
+ '''
526
+ def __init__(
527
+ self,
528
+ model: ConditionedDiffusionModelWrapper,
529
+ lr: float = None,
530
+ mask_padding: bool = False,
531
+ mask_padding_dropout: float = 0.0,
532
+ use_ema: bool = True,
533
+ log_loss_info: bool = False,
534
+ optimizer_configs: dict = None,
535
+ diffusion_objective: tp.Literal["rectified_flow", "v"] = "v",
536
+ pre_encoded: bool = False,
537
+ cfg_dropout_prob = 0.1,
538
+ timestep_sampler: tp.Literal["uniform", "logit_normal"] = "uniform",
539
+ max_mask_segments = 0,
540
+ ):
541
+ super().__init__()
542
+
543
+ self.diffusion = model
544
+
545
+ if use_ema:
546
+ self.diffusion_ema = EMA(
547
+ self.diffusion.model,
548
+ beta=0.9999,
549
+ power=3/4,
550
+ update_every=1,
551
+ update_after_step=1,
552
+ include_online_model=False
553
+ )
554
+ else:
555
+ self.diffusion_ema = None
556
+
557
+ self.mask_padding = mask_padding
558
+ self.mask_padding_dropout = mask_padding_dropout
559
+
560
+ self.cfg_dropout_prob = cfg_dropout_prob
561
+
562
+ self.rng = torch.quasirandom.SobolEngine(1, scramble=True)
563
+
564
+ self.timestep_sampler = timestep_sampler
565
+
566
+ self.diffusion_objective = model.diffusion_objective
567
+ print(f'Training in the {self.diffusion_objective} formulation with timestep sampler: {timestep_sampler}')
568
+
569
+ self.max_mask_segments = max_mask_segments
570
+
571
+ self.loss_modules = [
572
+ MSELoss("output",
573
+ "targets",
574
+ weight=1.0,
575
+ mask_key="padding_mask" if self.mask_padding else None,
576
+ name="mse_loss"
577
+ )
578
+ ]
579
+
580
+ self.losses = MultiLoss(self.loss_modules)
581
+
582
+ self.log_loss_info = log_loss_info
583
+
584
+ assert lr is not None or optimizer_configs is not None, "Must specify either lr or optimizer_configs in training config"
585
+
586
+ if optimizer_configs is None:
587
+ optimizer_configs = {
588
+ "diffusion": {
589
+ "optimizer": {
590
+ "type": "Adam",
591
+ "config": {
592
+ "lr": lr
593
+ }
594
+ }
595
+ }
596
+ }
597
+ else:
598
+ if lr is not None:
599
+ print(f"WARNING: learning_rate and optimizer_configs both specified in config. Ignoring learning_rate and using optimizer_configs.")
600
+
601
+ self.optimizer_configs = optimizer_configs
602
+
603
+ self.pre_encoded = pre_encoded
604
+
605
+ def configure_optimizers(self):
606
+ diffusion_opt_config = self.optimizer_configs['diffusion']
607
+ opt_diff = create_optimizer_from_config(diffusion_opt_config['optimizer'], self.diffusion.parameters())
608
+
609
+ if "scheduler" in diffusion_opt_config:
610
+ sched_diff = create_scheduler_from_config(diffusion_opt_config['scheduler'], opt_diff)
611
+ sched_diff_config = {
612
+ "scheduler": sched_diff,
613
+ "interval": "step"
614
+ }
615
+ return [opt_diff], [sched_diff_config]
616
+
617
+ return [opt_diff]
618
+
619
+ def training_step(self, batch, batch_idx):
620
+ reals, metadata = batch
621
+ # import ipdb
622
+ # ipdb.set_trace()
623
+ p = Profiler()
624
+ if reals.ndim == 4 and reals.shape[0] == 1:
625
+ reals = reals[0]
626
+
627
+ loss_info = {}
628
+
629
+ diffusion_input = reals
630
+ if not self.pre_encoded:
631
+ loss_info["audio_reals"] = diffusion_input
632
+
633
+ p.tick("setup")
634
+
635
+ with torch.amp.autocast('cuda'):
636
+
637
+ conditioning = self.diffusion.conditioner(metadata, self.device)
638
+
639
+
640
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
641
+ conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat
642
+ conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat
643
+ # If mask_padding is on, randomly drop the padding masks to allow for learning silence padding
644
+ use_padding_mask = self.mask_padding and random.random() > self.mask_padding_dropout
645
+
646
+ # Create batch tensor of attention masks from the "mask" field of the metadata array
647
+ if use_padding_mask:
648
+ padding_masks = torch.stack([md["padding_mask"][0] for md in metadata], dim=0).to(self.device) # Shape (batch_size, sequence_length)
649
+
650
+ p.tick("conditioning")
651
+
652
+ if self.diffusion.pretransform is not None:
653
+ self.diffusion.pretransform.to(self.device)
654
+
655
+ if not self.pre_encoded:
656
+ with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
657
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
658
+
659
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
660
+ p.tick("pretransform")
661
+
662
+ # If mask_padding is on, interpolate the padding masks to the size of the pretransformed input
663
+ if use_padding_mask:
664
+ padding_masks = F.interpolate(padding_masks.unsqueeze(1).float(), size=diffusion_input.shape[2], mode="nearest").squeeze(1).bool()
665
+ else:
666
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
667
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
668
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
669
+
670
+ if self.max_mask_segments > 0:
671
+ # Max mask size is the full sequence length
672
+ max_mask_length = diffusion_input.shape[2]
673
+
674
+ # Create a mask of random length for a random slice of the input
675
+ masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
676
+
677
+ conditioning['inpaint_mask'] = [mask]
678
+ conditioning['inpaint_masked_input'] = masked_input
679
+
680
+ if self.timestep_sampler == "uniform":
681
+ # Draw uniformly distributed continuous timesteps
682
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
683
+ elif self.timestep_sampler == "logit_normal":
684
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
685
+ # import ipdb
686
+ # ipdb.set_trace()
687
+ # Calculate the noise schedule parameters for those timesteps
688
+ if self.diffusion_objective == "v":
689
+ alphas, sigmas = get_alphas_sigmas(t)
690
+ elif self.diffusion_objective == "rectified_flow":
691
+ alphas, sigmas = 1-t, t
692
+
693
+ # Combine the ground truth data and the noise
694
+ alphas = alphas[:, None, None]
695
+ sigmas = sigmas[:, None, None]
696
+ noise = torch.randn_like(diffusion_input)
697
+ noised_inputs = diffusion_input * alphas + noise * sigmas
698
+
699
+ if self.diffusion_objective == "v":
700
+ targets = noise * alphas - diffusion_input * sigmas
701
+ elif self.diffusion_objective == "rectified_flow":
702
+ targets = noise - diffusion_input
703
+
704
+ p.tick("noise")
705
+
706
+ extra_args = {}
707
+
708
+ if use_padding_mask:
709
+ extra_args["mask"] = padding_masks
710
+
711
+ with torch.amp.autocast('cuda'):
712
+ p.tick("amp")
713
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = self.cfg_dropout_prob, **extra_args)
714
+ p.tick("diffusion")
715
+
716
+ loss_info.update({
717
+ "output": output,
718
+ "targets": targets,
719
+ "padding_mask": padding_masks if use_padding_mask else None,
720
+ })
721
+
722
+ loss, losses = self.losses(loss_info)
723
+
724
+ p.tick("loss")
725
+
726
+ if self.log_loss_info:
727
+ # Loss debugging logs
728
+ num_loss_buckets = 10
729
+ bucket_size = 1 / num_loss_buckets
730
+ loss_all = F.mse_loss(output, targets, reduction="none")
731
+
732
+ sigmas = rearrange(self.all_gather(sigmas), "w b c n -> (w b) c n").squeeze()
733
+
734
+ # gather loss_all across all GPUs
735
+ loss_all = rearrange(self.all_gather(loss_all), "w b c n -> (w b) c n")
736
+
737
+ # Bucket loss values based on corresponding sigma values, bucketing sigma values by bucket_size
738
+ loss_all = torch.stack([loss_all[(sigmas >= i) & (sigmas < i + bucket_size)].mean() for i in torch.arange(0, 1, bucket_size).to(self.device)])
739
+
740
+ # Log bucketed losses with corresponding sigma bucket values, if it's not NaN
741
+ debug_log_dict = {
742
+ f"model/loss_all_{i/num_loss_buckets:.1f}": loss_all[i].detach() for i in range(num_loss_buckets) if not torch.isnan(loss_all[i])
743
+ }
744
+
745
+ self.log_dict(debug_log_dict)
746
+
747
+
748
+ log_dict = {
749
+ 'train/loss': loss.detach(),
750
+ 'train/std_data': diffusion_input.std(),
751
+ 'train/lr': self.trainer.optimizers[0].param_groups[0]['lr']
752
+ }
753
+
754
+ for loss_name, loss_value in losses.items():
755
+ log_dict[f"train/{loss_name}"] = loss_value.detach()
756
+
757
+ self.log_dict(log_dict, prog_bar=True, on_step=True)
758
+ p.tick("log")
759
+ #print(f"Profiler: {p}")
760
+ return loss
761
+
762
+ def validation_step(self, batch, batch_idx):
763
+ reals, metadata = batch
764
+ # breakpoint()
765
+ if reals.ndim == 4 and reals.shape[0] == 1:
766
+ reals = reals[0]
767
+
768
+ loss_info = {}
769
+
770
+ diffusion_input = reals
771
+
772
+ if not self.pre_encoded:
773
+ loss_info["audio_reals"] = diffusion_input
774
+
775
+
776
+ with torch.amp.autocast('cuda'):
777
+
778
+ conditioning = self.diffusion.conditioner(metadata, self.device)
779
+
780
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
781
+ conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat
782
+ conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat
783
+
784
+ if self.diffusion.pretransform is not None:
785
+
786
+ if not self.pre_encoded:
787
+ self.diffusion.pretransform.to(self.device)
788
+ with torch.amp.autocast('cuda') and torch.set_grad_enabled(self.diffusion.pretransform.enable_grad):
789
+ self.diffusion.pretransform.train(self.diffusion.pretransform.enable_grad)
790
+
791
+ diffusion_input = self.diffusion.pretransform.encode(diffusion_input)
792
+ else:
793
+ # Apply scale to pre-encoded latents if needed, as the pretransform encode function will not be run
794
+ if hasattr(self.diffusion.pretransform, "scale") and self.diffusion.pretransform.scale != 1.0:
795
+ diffusion_input = diffusion_input / self.diffusion.pretransform.scale
796
+ if self.max_mask_segments > 0:
797
+ # Max mask size is the full sequence length
798
+ max_mask_length = diffusion_input.shape[2]
799
+
800
+ # Create a mask of random length for a random slice of the input
801
+ masked_input, mask = self.random_mask(diffusion_input, max_mask_length)
802
+
803
+ conditioning['inpaint_mask'] = [mask]
804
+ conditioning['inpaint_masked_input'] = masked_input
805
+ if self.timestep_sampler == "uniform":
806
+ # Draw uniformly distributed continuous timesteps
807
+ t = self.rng.draw(reals.shape[0])[:, 0].to(self.device)
808
+ elif self.timestep_sampler == "logit_normal":
809
+ t = torch.sigmoid(torch.randn(reals.shape[0], device=self.device))
810
+
811
+ # Calculate the noise schedule parameters for those timesteps
812
+ if self.diffusion_objective == "v":
813
+ alphas, sigmas = get_alphas_sigmas(t)
814
+ elif self.diffusion_objective == "rectified_flow":
815
+ alphas, sigmas = 1-t, t
816
+
817
+ # Combine the ground truth data and the noise
818
+ alphas = alphas[:, None, None]
819
+ sigmas = sigmas[:, None, None]
820
+ noise = torch.randn_like(diffusion_input)
821
+ noised_inputs = diffusion_input * alphas + noise * sigmas
822
+
823
+ if self.diffusion_objective == "v":
824
+ targets = noise * alphas - diffusion_input * sigmas
825
+ elif self.diffusion_objective == "rectified_flow":
826
+ targets = noise - diffusion_input
827
+
828
+
829
+ with torch.amp.autocast('cuda'):
830
+ output = self.diffusion(noised_inputs, t, cond=conditioning, cfg_dropout_prob = 0.0)
831
+
832
+ loss_info.update({
833
+ "output": output,
834
+ "targets": targets,
835
+ })
836
+
837
+ loss, losses = self.losses(loss_info)
838
+
839
+
840
+ log_dict = {
841
+ 'val_loss': loss.detach(),
842
+ }
843
+
844
+ self.log_dict(log_dict, prog_bar=True, batch_size=diffusion_input.size(0))
845
+
846
+ def predict_step(self, batch, batch_idx):
847
+ reals, metadata = batch
848
+ ids = [item['id'] for item in metadata]
849
+ batch_size, length = reals.shape[0], reals.shape[2]
850
+ print(f"Predicting {batch_size} samples with length {length} for ids: {ids}")
851
+ with torch.amp.autocast('cuda'):
852
+ conditioning = self.diffusion.conditioner(metadata, self.device)
853
+
854
+ video_exist = torch.stack([item['video_exist'] for item in metadata],dim=0)
855
+ conditioning['metaclip_features'][~video_exist] = self.diffusion.model.model.empty_clip_feat
856
+ conditioning['sync_features'][~video_exist] = self.diffusion.model.model.empty_sync_feat
857
+
858
+ cond_inputs = self.diffusion.get_conditioning_inputs(conditioning)
859
+ if batch_size > 1:
860
+ noise_list = []
861
+ for _ in range(batch_size):
862
+ noise_1 = torch.randn([1, self.diffusion.io_channels, length]).to(self.device) # 每次生成推进RNG状态
863
+ noise_list.append(noise_1)
864
+ noise = torch.cat(noise_list, dim=0)
865
+ else:
866
+ noise = torch.randn([batch_size, self.diffusion.io_channels, length]).to(self.device)
867
+ with torch.amp.autocast('cuda'):
868
+
869
+ model = self.diffusion.model
870
+ if self.diffusion_objective == "v":
871
+ fakes = sample(model, noise, 24, 0, **cond_inputs, cfg_scale=5, batch_cfg=True)
872
+ elif self.diffusion_objective == "rectified_flow":
873
+ import time
874
+ start_time = time.time()
875
+ fakes = sample_discrete_euler(model, noise, 24, **cond_inputs, cfg_scale=5, batch_cfg=True)
876
+ end_time = time.time()
877
+ execution_time = end_time - start_time
878
+ print(f"执行时间: {execution_time:.2f} 秒")
879
+ if self.diffusion.pretransform is not None:
880
+ fakes = self.diffusion.pretransform.decode(fakes)
881
+
882
+ audios = fakes.to(torch.float32).div(torch.max(torch.abs(fakes))).clamp(-1, 1).mul(32767).to(torch.int16).cpu()
883
+ return audios
884
+ # # Put the demos together
885
+ # fakes = rearrange(fakes, 'b d n -> d (b n)')
886
+
887
+ def random_mask(self, sequence, max_mask_length):
888
+ b, _, sequence_length = sequence.size()
889
+
890
+ # Create a mask tensor for each batch element
891
+ masks = []
892
+
893
+ for i in range(b):
894
+ mask_type = random.randint(0, 2)
895
+
896
+ if mask_type == 0: # Random mask with multiple segments
897
+ num_segments = random.randint(1, self.max_mask_segments)
898
+ max_segment_length = max_mask_length // num_segments
899
+
900
+ segment_lengths = random.sample(range(1, max_segment_length + 1), num_segments)
901
+
902
+ mask = torch.ones((1, 1, sequence_length))
903
+ for length in segment_lengths:
904
+ mask_start = random.randint(0, sequence_length - length)
905
+ mask[:, :, mask_start:mask_start + length] = 0
906
+
907
+ elif mask_type == 1: # Full mask
908
+ mask = torch.zeros((1, 1, sequence_length))
909
+
910
+ elif mask_type == 2: # Causal mask
911
+ mask = torch.ones((1, 1, sequence_length))
912
+ mask_length = random.randint(1, max_mask_length)
913
+ mask[:, :, -mask_length:] = 0
914
+
915
+ mask = mask.to(sequence.device)
916
+ masks.append(mask)
917
+
918
+ # Concatenate the mask tensors into a single tensor
919
+ mask = torch.cat(masks, dim=0).to(sequence.device)
920
+
921
+ # Apply the mask to the sequence tensor for each batch element
922
+ masked_sequence = sequence * mask
923
+
924
+ return masked_sequence, mask
925
+
926
+ def on_before_zero_grad(self, *args, **kwargs):
927
+ if self.diffusion_ema is not None:
928
+ self.diffusion_ema.update()
929
+
930
+ def export_model(self, path, use_safetensors=False):
931
+ if self.diffusion_ema is not None:
932
+ self.diffusion.model = self.diffusion_ema.ema_model
933
+
934
+ if use_safetensors:
935
+ save_file(self.diffusion.state_dict(), path)
936
+ else:
937
+ torch.save({"state_dict": self.diffusion.state_dict()}, path)
938
+
939
+ class DiffusionCondDemoCallback(Callback):
940
+ def __init__(self,
941
+ demo_every=2000,
942
+ num_demos=8,
943
+ sample_size=65536,
944
+ demo_steps=250,
945
+ sample_rate=48000,
946
+ demo_conditioning: tp.Optional[tp.Dict[str, tp.Any]] = {},
947
+ demo_cfg_scales: tp.Optional[tp.List[int]] = [3, 5, 7],
948
+ demo_cond_from_batch: bool = False,
949
+ display_audio_cond: bool = False
950
+ ):
951
+ super().__init__()
952
+
953
+ self.demo_every = demo_every
954
+ self.num_demos = num_demos
955
+ self.demo_samples = sample_size
956
+ self.demo_steps = demo_steps
957
+ self.sample_rate = sample_rate
958
+ self.last_demo_step = -1
959
+ self.demo_conditioning = demo_conditioning
960
+ self.demo_cfg_scales = demo_cfg_scales
961
+
962
+ # If true, the callback will use the metadata from the batch to generate the demo conditioning
963
+ self.demo_cond_from_batch = demo_cond_from_batch
964
+
965
+ # If true, the callback will display the audio conditioning
966
+ self.display_audio_cond = display_audio_cond
967
+
968
+ @rank_zero_only
969
+ @torch.no_grad()
970
+ def on_train_batch_end(self, trainer, module: DiffusionCondTrainingWrapper, outputs, batch, batch_idx):
971
+
972
+ if (trainer.global_step - 1) % self.demo_every != 0 or self.last_demo_step == trainer.global_step:
973
+ return
974
+
975
+ module.eval()
976
+
977
+ print(f"Generating demo")
978
+ self.last_demo_step = trainer.global_step
979
+
980
+ demo_samples = self.demo_samples
981
+
982
+ demo_cond = self.demo_conditioning
983
+
984
+ if self.demo_cond_from_batch:
985
+ # Get metadata from the batch
986
+ demo_cond = batch[1][:self.num_demos]
987
+
988
+ if '.pth' in demo_cond[0]:
989
+ demo_cond_data = []
990
+ for path in demo_cond:
991
+ # info = {}
992
+ data = torch.load(path, weights_only=True)
993
+ if 'caption_t5' not in data.keys():
994
+ data['caption_t5'] = data['caption']
995
+ data['seconds_start'] = 0
996
+ data['seconds_total'] = 10
997
+ demo_cond_data.append(data)
998
+ demo_cond = demo_cond_data
999
+ elif '.npz' in demo_cond[0]:
1000
+ demo_cond_data = []
1001
+ for path in demo_cond:
1002
+ # info = {}
1003
+ npz_data = np.load(path,allow_pickle=True)
1004
+ data = {key: npz_data[key] for key in npz_data.files}
1005
+ for key in data.keys():
1006
+ # print(key)
1007
+ if isinstance(data[key], np.ndarray) and np.issubdtype(data[key].dtype, np.number):
1008
+ data[key] = torch.from_numpy(data[key])
1009
+
1010
+ demo_cond_data.append(data)
1011
+ demo_cond = demo_cond_data
1012
+ if module.diffusion.pretransform is not None:
1013
+ demo_samples = demo_samples // module.diffusion.pretransform.downsampling_ratio
1014
+
1015
+ noise = torch.randn([self.num_demos, module.diffusion.io_channels, demo_samples]).to(module.device)
1016
+
1017
+ try:
1018
+ print("Getting conditioning")
1019
+ with torch.amp.autocast('cuda'):
1020
+ conditioning = module.diffusion.conditioner(demo_cond, module.device)
1021
+
1022
+ cond_inputs = module.diffusion.get_conditioning_inputs(conditioning)
1023
+
1024
+ log_dict = {}
1025
+
1026
+ if self.display_audio_cond:
1027
+ audio_inputs = torch.cat([cond["audio"] for cond in demo_cond], dim=0)
1028
+ audio_inputs = rearrange(audio_inputs, 'b d n -> d (b n)')
1029
+
1030
+ filename = f'demo_audio_cond_{trainer.global_step:08}.wav'
1031
+ audio_inputs = audio_inputs.to(torch.float32).mul(32767).to(torch.int16).cpu()
1032
+ torchaudio.save(filename, audio_inputs, self.sample_rate)
1033
+ log_dict[f'demo_audio_cond'] = wandb.Audio(filename, sample_rate=self.sample_rate, caption="Audio conditioning")
1034
+ log_dict[f"demo_audio_cond_melspec_left"] = wandb.Image(audio_spectrogram_image(audio_inputs))
1035
+ trainer.logger.experiment.log(log_dict)
1036
+
1037
+ for cfg_scale in self.demo_cfg_scales:
1038
+
1039
+ print(f"Generating demo for cfg scale {cfg_scale}")
1040
+
1041
+ with torch.amp.autocast('cuda'):
1042
+ # model = module.diffusion_ema.model if module.diffusion_ema is not None else module.diffusion.model
1043
+ model = module.diffusion.model
1044
+
1045
+ if module.diffusion_objective == "v":
1046
+ fakes = sample(model, noise, self.demo_steps, 0, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1047
+ elif module.diffusion_objective == "rectified_flow":
1048
+ fakes = sample_discrete_euler(model, noise, self.demo_steps, **cond_inputs, cfg_scale=cfg_scale, batch_cfg=True)
1049
+
1050
+ if module.diffusion.pretransform is not None:
1051
+ fakes = module.diffusion.pretransform.decode(fakes)
1052
+
1053
+ # Put the demos together
1054
+ fakes = rearrange(fakes, 'b d n -> d (b n)')
1055
+
1056
+ log_dict = {}
1057
+
1058
+ filename = f'demos/demo_cfg_{cfg_scale}_{trainer.global_step:08}.wav'
1059
+ fakes = fakes.div(torch.max(torch.abs(fakes))).mul(32767).to(torch.int16).cpu()
1060
+ torchaudio.save(filename, fakes, self.sample_rate)
1061
+
1062
+ log_dict[f'demo_cfg_{cfg_scale}'] = wandb.Audio(filename,
1063
+ sample_rate=self.sample_rate,
1064
+ caption=f'Reconstructed')
1065
+
1066
+ log_dict[f'demo_melspec_left_cfg_{cfg_scale}'] = wandb.Image(audio_spectrogram_image(fakes))
1067
+ trainer.logger.experiment.log(log_dict)
1068
+
1069
+ del fakes
1070
+
1071
+ except Exception as e:
1072
+ raise e
1073
+ finally:
1074
+ gc.collect()
1075
+ torch.cuda.empty_cache()
1076
+ module.train()
ThinkSound/training/factory.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import Parameter
3
+ from ..models.factory import create_model_from_config
4
+
5
+ def create_training_wrapper_from_config(model_config, model):
6
+ model_type = model_config.get('model_type', None)
7
+ assert model_type is not None, 'model_type must be specified in model config'
8
+
9
+ training_config = model_config.get('training', None)
10
+ assert training_config is not None, 'training config must be specified in model config'
11
+ if model_type == 'autoencoder':
12
+ from .autoencoders import AutoencoderTrainingWrapper
13
+
14
+ ema_copy = None
15
+
16
+ if training_config.get("use_ema", False):
17
+ ema_copy = create_model_from_config(model_config)
18
+ ema_copy = create_model_from_config(model_config) # I don't know why this needs to be called twice but it broke when I called it once
19
+ # Copy each weight to the ema copy
20
+ for name, param in model.state_dict().items():
21
+ if isinstance(param, Parameter):
22
+ # backwards compatibility for serialized parameters
23
+ param = param.data
24
+ ema_copy.state_dict()[name].copy_(param)
25
+
26
+ use_ema = training_config.get("use_ema", False)
27
+
28
+ latent_mask_ratio = training_config.get("latent_mask_ratio", 0.0)
29
+
30
+ teacher_model = training_config.get("teacher_model", None)
31
+ if teacher_model is not None:
32
+ teacher_model = create_model_from_config(teacher_model)
33
+ teacher_model = teacher_model.eval().requires_grad_(False)
34
+
35
+ teacher_model_ckpt = training_config.get("teacher_model_ckpt", None)
36
+ if teacher_model_ckpt is not None:
37
+ teacher_model.load_state_dict(torch.load(teacher_model_ckpt)["state_dict"])
38
+ else:
39
+ raise ValueError("teacher_model_ckpt must be specified if teacher_model is specified")
40
+
41
+ return AutoencoderTrainingWrapper(
42
+ model,
43
+ lr=training_config["learning_rate"],
44
+ warmup_steps=training_config.get("warmup_steps", 0),
45
+ encoder_freeze_on_warmup=training_config.get("encoder_freeze_on_warmup", False),
46
+ sample_rate=model_config["sample_rate"],
47
+ loss_config=training_config.get("loss_configs", None),
48
+ optimizer_configs=training_config.get("optimizer_configs", None),
49
+ use_ema=use_ema,
50
+ ema_copy=ema_copy if use_ema else None,
51
+ force_input_mono=training_config.get("force_input_mono", False),
52
+ latent_mask_ratio=latent_mask_ratio,
53
+ teacher_model=teacher_model
54
+ )
55
+ elif model_type == 'diffusion_uncond':
56
+ from .diffusion import DiffusionUncondTrainingWrapper
57
+ return DiffusionUncondTrainingWrapper(
58
+ model,
59
+ lr=training_config["learning_rate"],
60
+ pre_encoded=training_config.get("pre_encoded", False),
61
+ )
62
+ elif model_type == 'diffusion_infill':
63
+ from .diffusion import DiffusionInfillTrainingWrapper
64
+ return DiffusionInfillTrainingWrapper(
65
+ model,
66
+ lr=training_config.get("learning_rate", None),
67
+ optimizer_configs=training_config.get("optimizer_configs", None),
68
+ pre_encoded=training_config.get("pre_encoded", False),
69
+ frac_lengths_mask=training_config.get("frac_lengths_mask", (0.7, 1.)),
70
+ min_span_len=training_config.get("min_span_len", 10),
71
+ timestep_sampler = training_config.get("timestep_sampler", "uniform"),
72
+ ctx_drop = training_config.get("ctx_drop", 0.1),
73
+ r_drop = training_config.get("r_drop", 0.0)
74
+ )
75
+ elif model_type == 'diffusion_cond' or model_type == 'mm_diffusion_cond':
76
+ from .diffusion import DiffusionCondTrainingWrapper
77
+ return DiffusionCondTrainingWrapper(
78
+ model,
79
+ lr=training_config.get("learning_rate", None),
80
+ mask_padding=training_config.get("mask_padding", False),
81
+ mask_padding_dropout=training_config.get("mask_padding_dropout", 0.0),
82
+ use_ema = training_config.get("use_ema", True),
83
+ log_loss_info=training_config.get("log_loss_info", False),
84
+ optimizer_configs=training_config.get("optimizer_configs", None),
85
+ pre_encoded=training_config.get("pre_encoded", False),
86
+ diffusion_objective=training_config.get("diffusion_objective","v"),
87
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
88
+ timestep_sampler = training_config.get("timestep_sampler", "uniform"),
89
+ max_mask_segments = training_config.get("max_mask_segments", 0)
90
+ )
91
+ elif model_type == 'diffusion_prior':
92
+ from .diffusion import DiffusionPriorTrainingWrapper
93
+ from ..models.diffusion_prior import PriorType
94
+
95
+ ema_copy = create_model_from_config(model_config)
96
+
97
+ # Copy each weight to the ema copy
98
+ for name, param in model.state_dict().items():
99
+ if isinstance(param, Parameter):
100
+ # backwards compatibility for serialized parameters
101
+ param = param.data
102
+ ema_copy.state_dict()[name].copy_(param)
103
+
104
+ prior_type = training_config.get("prior_type", "mono_stereo")
105
+
106
+ if prior_type == "mono_stereo":
107
+ prior_type_enum = PriorType.MonoToStereo
108
+ else:
109
+ raise ValueError(f"Unknown prior type: {prior_type}")
110
+
111
+ return DiffusionPriorTrainingWrapper(
112
+ model,
113
+ lr=training_config["learning_rate"],
114
+ ema_copy=ema_copy,
115
+ prior_type=prior_type_enum,
116
+ log_loss_info=training_config.get("log_loss_info", False),
117
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False),
118
+ )
119
+ elif model_type == 'diffusion_cond_inpaint':
120
+ from .diffusion import DiffusionCondInpaintTrainingWrapper
121
+ return DiffusionCondInpaintTrainingWrapper(
122
+ model,
123
+ lr=training_config.get("learning_rate", None),
124
+ max_mask_segments = training_config.get("max_mask_segments", 10),
125
+ log_loss_info=training_config.get("log_loss_info", False),
126
+ optimizer_configs=training_config.get("optimizer_configs", None),
127
+ use_ema=training_config.get("use_ema", True),
128
+ pre_encoded=training_config.get("pre_encoded", False),
129
+ cfg_dropout_prob = training_config.get("cfg_dropout_prob", 0.1),
130
+ timestep_sampler = training_config.get("timestep_sampler", "uniform")
131
+ )
132
+ elif model_type == 'diffusion_autoencoder' :
133
+ from .diffusion import DiffusionAutoencoderTrainingWrapper
134
+
135
+ ema_copy = create_model_from_config(model_config)
136
+
137
+ # Copy each weight to the ema copy
138
+ for name, param in model.state_dict().items():
139
+ if isinstance(param, Parameter):
140
+ # backwards compatibility for serialized parameters
141
+ param = param.data
142
+ ema_copy.state_dict()[name].copy_(param)
143
+
144
+ return DiffusionAutoencoderTrainingWrapper(
145
+ model,
146
+ ema_copy=ema_copy,
147
+ lr=training_config["learning_rate"],
148
+ use_reconstruction_loss=training_config.get("use_reconstruction_loss", False)
149
+ )
150
+ elif model_type == 'lm':
151
+ from .lm import AudioLanguageModelTrainingWrapper
152
+
153
+ ema_copy = create_model_from_config(model_config)
154
+
155
+ for name, param in model.state_dict().items():
156
+ if isinstance(param, Parameter):
157
+ # backwards compatibility for serialized parameters
158
+ param = param.data
159
+ ema_copy.state_dict()[name].copy_(param)
160
+
161
+ return AudioLanguageModelTrainingWrapper(
162
+ model,
163
+ ema_copy=ema_copy,
164
+ lr=training_config.get("learning_rate", None),
165
+ use_ema=training_config.get("use_ema", False),
166
+ optimizer_configs=training_config.get("optimizer_configs", None),
167
+ pre_encoded=training_config.get("pre_encoded", False),
168
+ )
169
+
170
+ else:
171
+ raise NotImplementedError(f'Unknown model type: {model_type}')
172
+
173
+ def create_demo_callback_from_config(model_config, **kwargs):
174
+ model_type = model_config.get('model_type', None)
175
+ assert model_type is not None, 'model_type must be specified in model config'
176
+
177
+ training_config = model_config.get('training', None)
178
+ assert training_config is not None, 'training config must be specified in model config'
179
+
180
+ demo_config = training_config.get("demo", {})
181
+
182
+ if model_type == 'autoencoder':
183
+ from .autoencoders import AutoencoderDemoCallback
184
+ return AutoencoderDemoCallback(
185
+ demo_every=demo_config.get("demo_every", 2000),
186
+ sample_size=model_config["sample_size"],
187
+ sample_rate=model_config["sample_rate"],
188
+ **kwargs
189
+ )
190
+ elif model_type == 'diffusion_uncond':
191
+ from .diffusion import DiffusionUncondDemoCallback
192
+ return DiffusionUncondDemoCallback(
193
+ demo_every=demo_config.get("demo_every", 2000),
194
+ demo_steps=demo_config.get("demo_steps", 250),
195
+ sample_rate=model_config["sample_rate"]
196
+ )
197
+ elif model_type == 'diffusion_infill':
198
+ from .diffusion import DiffusionInfillDemoCallback
199
+ return DiffusionInfillDemoCallback(
200
+ demo_every=demo_config.get("demo_every", 2000),
201
+ demo_steps=demo_config.get("demo_steps", 250),
202
+ sample_rate=model_config["sample_rate"],
203
+ **kwargs
204
+ )
205
+ elif model_type == "diffusion_autoencoder":
206
+ from .diffusion import DiffusionAutoencoderDemoCallback
207
+ return DiffusionAutoencoderDemoCallback(
208
+ demo_every=demo_config.get("demo_every", 2000),
209
+ demo_steps=demo_config.get("demo_steps", 250),
210
+ sample_size=model_config["sample_size"],
211
+ sample_rate=model_config["sample_rate"],
212
+ **kwargs
213
+ )
214
+ elif model_type == "diffusion_prior":
215
+ from .diffusion import DiffusionPriorDemoCallback
216
+ return DiffusionPriorDemoCallback(
217
+ demo_every=demo_config.get("demo_every", 2000),
218
+ demo_steps=demo_config.get("demo_steps", 250),
219
+ sample_size=model_config["sample_size"],
220
+ sample_rate=model_config["sample_rate"],
221
+ **kwargs
222
+ )
223
+ elif model_type == "diffusion_cond" or model_type == 'mm_diffusion_cond':
224
+ from .diffusion import DiffusionCondDemoCallback
225
+
226
+ return DiffusionCondDemoCallback(
227
+ demo_every=demo_config.get("demo_every", 2000),
228
+ sample_size=model_config["sample_size"],
229
+ sample_rate=model_config["sample_rate"],
230
+ demo_steps=demo_config.get("demo_steps", 250),
231
+ num_demos=demo_config["num_demos"],
232
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
233
+ demo_conditioning=demo_config.get("demo_cond", {}),
234
+ demo_cond_from_batch=demo_config.get("demo_cond_from_batch", False),
235
+ display_audio_cond=demo_config.get("display_audio_cond", False),
236
+ )
237
+ elif model_type == "diffusion_cond_inpaint":
238
+ from .diffusion import DiffusionCondInpaintDemoCallback
239
+
240
+ return DiffusionCondInpaintDemoCallback(
241
+ demo_every=demo_config.get("demo_every", 2000),
242
+ sample_size=model_config["sample_size"],
243
+ sample_rate=model_config["sample_rate"],
244
+ demo_steps=demo_config.get("demo_steps", 250),
245
+ demo_cfg_scales=demo_config["demo_cfg_scales"],
246
+ **kwargs
247
+ )
248
+
249
+ elif model_type == "lm":
250
+ from .lm import AudioLanguageModelDemoCallback
251
+
252
+ return AudioLanguageModelDemoCallback(
253
+ demo_every=demo_config.get("demo_every", 2000),
254
+ sample_size=model_config["sample_size"],
255
+ sample_rate=model_config["sample_rate"],
256
+ demo_cfg_scales=demo_config.get("demo_cfg_scales", [1]),
257
+ demo_conditioning=demo_config.get("demo_cond", None),
258
+ num_demos=demo_config.get("num_demos", 8),
259
+ **kwargs
260
+ )
261
+ else:
262
+ raise NotImplementedError(f'Unknown model type: {model_type}')
ThinkSound/training/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .losses import *
ThinkSound/training/losses/auraloss.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied and modified from https://github.com/csteinmetz1/auraloss/blob/main/auraloss/freq.py under Apache License 2.0
2
+ # You can find the license at LICENSES/LICENSE_AURALOSS.txt
3
+
4
+ import torch
5
+ import numpy as np
6
+ from typing import List, Any
7
+ import scipy.signal
8
+
9
+ def apply_reduction(losses, reduction="none"):
10
+ """Apply reduction to collection of losses."""
11
+ if reduction == "mean":
12
+ losses = losses.mean()
13
+ elif reduction == "sum":
14
+ losses = losses.sum()
15
+ return losses
16
+
17
+ def compute_direction(w, x, y, z):
18
+ # 计算各个声道的权重
19
+ phi = torch.atan2(y, x)
20
+ theta = torch.atan2(torch.sqrt(x**2 + y**2), z)
21
+ return phi.unsqueeze(1), theta.unsqueeze(1)
22
+
23
+ def get_window(win_type: str, win_length: int):
24
+ """Return a window function.
25
+
26
+ Args:
27
+ win_type (str): Window type. Can either be one of the window function provided in PyTorch
28
+ ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
29
+ or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
30
+ win_length (int): Window length
31
+
32
+ Returns:
33
+ win: The window as a 1D torch tensor
34
+ """
35
+
36
+ try:
37
+ win = getattr(torch, win_type)(win_length)
38
+ except:
39
+ win = torch.from_numpy(scipy.signal.windows.get_window(win_type, win_length))
40
+
41
+ return win
42
+
43
+ class SumAndDifference(torch.nn.Module):
44
+ """Sum and difference signal extraction module."""
45
+
46
+ def __init__(self):
47
+ """Initialize sum and difference extraction module."""
48
+ super(SumAndDifference, self).__init__()
49
+
50
+ def forward(self, x):
51
+ """Calculate forward propagation.
52
+
53
+ Args:
54
+ x (Tensor): Predicted signal (B, #channels, #samples).
55
+ Returns:
56
+ Tensor: Sum signal.
57
+ Tensor: Difference signal.
58
+ """
59
+ if not (x.size(1) == 2): # inputs must be stereo
60
+ raise ValueError(f"Input must be stereo: {x.size(1)} channel(s).")
61
+
62
+ sum_sig = self.sum(x).unsqueeze(1)
63
+ diff_sig = self.diff(x).unsqueeze(1)
64
+
65
+ return sum_sig, diff_sig
66
+
67
+ @staticmethod
68
+ def sum(x):
69
+ return x[:, 0, :] + x[:, 1, :]
70
+
71
+ @staticmethod
72
+ def diff(x):
73
+ return x[:, 0, :] - x[:, 1, :]
74
+
75
+
76
+ class FIRFilter(torch.nn.Module):
77
+ """FIR pre-emphasis filtering module.
78
+
79
+ Args:
80
+ filter_type (str): Shape of the desired FIR filter ("hp", "fd", "aw"). Default: "hp"
81
+ coef (float): Coefficient value for the filter tap (only applicable for "hp" and "fd"). Default: 0.85
82
+ ntaps (int): Number of FIR filter taps for constructing A-weighting filters. Default: 101
83
+ plot (bool): Plot the magnitude respond of the filter. Default: False
84
+
85
+ Based upon the perceptual loss pre-empahsis filters proposed by
86
+ [Wright & Välimäki, 2019](https://arxiv.org/abs/1911.08922).
87
+
88
+ A-weighting filter - "aw"
89
+ First-order highpass - "hp"
90
+ Folded differentiator - "fd"
91
+
92
+ Note that the default coefficeint value of 0.85 is optimized for
93
+ a sampling rate of 44.1 kHz, considering adjusting this value at differnt sampling rates.
94
+ """
95
+
96
+ def __init__(self, filter_type="hp", coef=0.85, fs=44100, ntaps=101, plot=False):
97
+ """Initilize FIR pre-emphasis filtering module."""
98
+ super(FIRFilter, self).__init__()
99
+ self.filter_type = filter_type
100
+ self.coef = coef
101
+ self.fs = fs
102
+ self.ntaps = ntaps
103
+ self.plot = plot
104
+
105
+ import scipy.signal
106
+
107
+ if ntaps % 2 == 0:
108
+ raise ValueError(f"ntaps must be odd (ntaps={ntaps}).")
109
+
110
+ if filter_type == "hp":
111
+ self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
112
+ self.fir.weight.requires_grad = False
113
+ self.fir.weight.data = torch.tensor([1, -coef, 0]).view(1, 1, -1)
114
+ elif filter_type == "fd":
115
+ self.fir = torch.nn.Conv1d(1, 1, kernel_size=3, bias=False, padding=1)
116
+ self.fir.weight.requires_grad = False
117
+ self.fir.weight.data = torch.tensor([1, 0, -coef]).view(1, 1, -1)
118
+ elif filter_type == "aw":
119
+ # Definition of analog A-weighting filter according to IEC/CD 1672.
120
+ f1 = 20.598997
121
+ f2 = 107.65265
122
+ f3 = 737.86223
123
+ f4 = 12194.217
124
+ A1000 = 1.9997
125
+
126
+ NUMs = [(2 * np.pi * f4) ** 2 * (10 ** (A1000 / 20)), 0, 0, 0, 0]
127
+ DENs = np.polymul(
128
+ [1, 4 * np.pi * f4, (2 * np.pi * f4) ** 2],
129
+ [1, 4 * np.pi * f1, (2 * np.pi * f1) ** 2],
130
+ )
131
+ DENs = np.polymul(
132
+ np.polymul(DENs, [1, 2 * np.pi * f3]), [1, 2 * np.pi * f2]
133
+ )
134
+
135
+ # convert analog filter to digital filter
136
+ b, a = scipy.signal.bilinear(NUMs, DENs, fs=fs)
137
+
138
+ # compute the digital filter frequency response
139
+ w_iir, h_iir = scipy.signal.freqz(b, a, worN=512, fs=fs)
140
+
141
+ # then we fit to 101 tap FIR filter with least squares
142
+ taps = scipy.signal.firls(ntaps, w_iir, abs(h_iir), fs=fs)
143
+
144
+ # now implement this digital FIR filter as a Conv1d layer
145
+ self.fir = torch.nn.Conv1d(
146
+ 1, 1, kernel_size=ntaps, bias=False, padding=ntaps // 2
147
+ )
148
+ self.fir.weight.requires_grad = False
149
+ self.fir.weight.data = torch.tensor(taps.astype("float32")).view(1, 1, -1)
150
+
151
+ if plot:
152
+ from .plotting import compare_filters
153
+ compare_filters(b, a, taps, fs=fs)
154
+
155
+ def forward(self, input, target):
156
+ """Calculate forward propagation.
157
+ Args:
158
+ input (Tensor): Predicted signal (B, #channels, #samples).
159
+ target (Tensor): Groundtruth signal (B, #channels, #samples).
160
+ Returns:
161
+ Tensor: Filtered signal.
162
+ """
163
+ input = torch.nn.functional.conv1d(
164
+ input, self.fir.weight.data, padding=self.ntaps // 2
165
+ )
166
+ target = torch.nn.functional.conv1d(
167
+ target, self.fir.weight.data, padding=self.ntaps // 2
168
+ )
169
+ return input, target
170
+
171
+ class SpectralConvergenceLoss(torch.nn.Module):
172
+ """Spectral convergence loss module.
173
+
174
+ See [Arik et al., 2018](https://arxiv.org/abs/1808.06719).
175
+ """
176
+
177
+ def __init__(self):
178
+ super(SpectralConvergenceLoss, self).__init__()
179
+
180
+ def forward(self, x_mag, y_mag):
181
+ return (torch.norm(y_mag - x_mag, p="fro", dim=[-1, -2]) / torch.norm(y_mag, p="fro", dim=[-1, -2])).mean()
182
+
183
+ class STFTMagnitudeLoss(torch.nn.Module):
184
+ """STFT magnitude loss module.
185
+
186
+ See [Arik et al., 2018](https://arxiv.org/abs/1808.06719)
187
+ and [Engel et al., 2020](https://arxiv.org/abs/2001.04643v1)
188
+
189
+ Log-magnitudes are calculated with `log(log_fac*x + log_eps)`, where `log_fac` controls the
190
+ compression strength (larger value results in more compression), and `log_eps` can be used
191
+ to control the range of the compressed output values (e.g., `log_eps>=1` ensures positive
192
+ output values). The default values `log_fac=1` and `log_eps=0` correspond to plain log-compression.
193
+
194
+ Args:
195
+ log (bool, optional): Log-scale the STFT magnitudes,
196
+ or use linear scale. Default: True
197
+ log_eps (float, optional): Constant value added to the magnitudes before evaluating the logarithm.
198
+ Default: 0.0
199
+ log_fac (float, optional): Constant multiplication factor for the magnitudes before evaluating the logarithm.
200
+ Default: 1.0
201
+ distance (str, optional): Distance function ["L1", "L2"]. Default: "L1"
202
+ reduction (str, optional): Reduction of the loss elements. Default: "mean"
203
+ """
204
+
205
+ def __init__(self, log=True, log_eps=0.0, log_fac=1.0, distance="L1", reduction="mean"):
206
+ super(STFTMagnitudeLoss, self).__init__()
207
+
208
+ self.log = log
209
+ self.log_eps = log_eps
210
+ self.log_fac = log_fac
211
+
212
+ if distance == "L1":
213
+ self.distance = torch.nn.L1Loss(reduction=reduction)
214
+ elif distance == "L2":
215
+ self.distance = torch.nn.MSELoss(reduction=reduction)
216
+ else:
217
+ raise ValueError(f"Invalid distance: '{distance}'.")
218
+
219
+ def forward(self, x_mag, y_mag):
220
+ if self.log:
221
+ x_mag = torch.log(self.log_fac * x_mag + self.log_eps)
222
+ y_mag = torch.log(self.log_fac * y_mag + self.log_eps)
223
+ return self.distance(x_mag, y_mag)
224
+
225
+
226
+ class STFTLoss(torch.nn.Module):
227
+ """STFT loss module.
228
+
229
+ See [Yamamoto et al. 2019](https://arxiv.org/abs/1904.04472).
230
+
231
+ Args:
232
+ fft_size (int, optional): FFT size in samples. Default: 1024
233
+ hop_size (int, optional): Hop size of the FFT in samples. Default: 256
234
+ win_length (int, optional): Length of the FFT analysis window. Default: 1024
235
+ window (str, optional): Window to apply before FFT, can either be one of the window function provided in PyTorch
236
+ ['hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
237
+ or any of the windows provided by [SciPy](https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.windows.get_window.html).
238
+ Default: 'hann_window'
239
+ w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
240
+ w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
241
+ w_lin_mag_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
242
+ w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
243
+ sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
244
+ scale (str, optional): Optional frequency scaling method, options include:
245
+ ['mel', 'chroma']
246
+ Default: None
247
+ n_bins (int, optional): Number of scaling frequency bins. Default: None.
248
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
249
+ scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
250
+ eps (float, optional): Small epsilon value for stablity. Default: 1e-8
251
+ output (str, optional): Format of the loss returned.
252
+ 'loss' : Return only the raw, aggregate loss term.
253
+ 'full' : Return the raw loss, plus intermediate loss terms.
254
+ Default: 'loss'
255
+ reduction (str, optional): Specifies the reduction to apply to the output:
256
+ 'none': no reduction will be applied,
257
+ 'mean': the sum of the output will be divided by the number of elements in the output,
258
+ 'sum': the output will be summed.
259
+ Default: 'mean'
260
+ mag_distance (str, optional): Distance function ["L1", "L2"] for the magnitude loss terms.
261
+ device (str, optional): Place the filterbanks on specified device. Default: None
262
+
263
+ Returns:
264
+ loss:
265
+ Aggreate loss term. Only returned if output='loss'. By default.
266
+ loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss:
267
+ Aggregate and intermediate loss terms. Only returned if output='full'.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ fft_size: int = 1024,
273
+ hop_size: int = 256,
274
+ win_length: int = 1024,
275
+ window: str = "hann_window",
276
+ w_sc: float = 1.0,
277
+ w_log_mag: float = 1.0,
278
+ w_lin_mag: float = 0.0,
279
+ w_phs: float = 0.0,
280
+ sample_rate: float = None,
281
+ scale: str = None,
282
+ n_bins: int = None,
283
+ perceptual_weighting: bool = False,
284
+ scale_invariance: bool = False,
285
+ eps: float = 1e-8,
286
+ output: str = "loss",
287
+ reduction: str = "mean",
288
+ mag_distance: str = "L1",
289
+ device: Any = None,
290
+ **kwargs
291
+ ):
292
+ super().__init__()
293
+ self.fft_size = fft_size
294
+ self.hop_size = hop_size
295
+ self.win_length = win_length
296
+ self.window = get_window(window, win_length)
297
+ self.w_sc = w_sc
298
+ self.w_log_mag = w_log_mag
299
+ self.w_lin_mag = w_lin_mag
300
+ self.w_phs = w_phs
301
+ self.sample_rate = sample_rate
302
+ self.scale = scale
303
+ self.n_bins = n_bins
304
+ self.perceptual_weighting = perceptual_weighting
305
+ self.scale_invariance = scale_invariance
306
+ self.eps = eps
307
+ self.output = output
308
+ self.reduction = reduction
309
+ self.mag_distance = mag_distance
310
+ self.device = device
311
+
312
+ self.phs_used = bool(self.w_phs)
313
+
314
+ self.spectralconv = SpectralConvergenceLoss()
315
+ self.logstft = STFTMagnitudeLoss(
316
+ log=True,
317
+ reduction=reduction,
318
+ distance=mag_distance,
319
+ **kwargs
320
+ )
321
+ self.linstft = STFTMagnitudeLoss(
322
+ log=False,
323
+ reduction=reduction,
324
+ distance=mag_distance,
325
+ **kwargs
326
+ )
327
+
328
+ # setup mel filterbank
329
+ if scale is not None:
330
+ try:
331
+ import librosa.filters
332
+ except Exception as e:
333
+ print(e)
334
+ print("Try `pip install auraloss[all]`.")
335
+
336
+ if self.scale == "mel":
337
+ assert sample_rate != None # Must set sample rate to use mel scale
338
+ assert n_bins <= fft_size # Must be more FFT bins than Mel bins
339
+ fb = librosa.filters.mel(sr=sample_rate, n_fft=fft_size, n_mels=n_bins)
340
+ fb = torch.tensor(fb).unsqueeze(0)
341
+
342
+ elif self.scale == "chroma":
343
+ assert sample_rate != None # Must set sample rate to use chroma scale
344
+ assert n_bins <= fft_size # Must be more FFT bins than chroma bins
345
+ fb = librosa.filters.chroma(
346
+ sr=sample_rate, n_fft=fft_size, n_chroma=n_bins
347
+ )
348
+
349
+ else:
350
+ raise ValueError(
351
+ f"Invalid scale: {self.scale}. Must be 'mel' or 'chroma'."
352
+ )
353
+
354
+ self.register_buffer("fb", fb)
355
+
356
+ if scale is not None and device is not None:
357
+ self.fb = self.fb.to(self.device) # move filterbank to device
358
+
359
+ if self.perceptual_weighting:
360
+ if sample_rate is None:
361
+ raise ValueError(
362
+ f"`sample_rate` must be supplied when `perceptual_weighting = True`."
363
+ )
364
+ self.prefilter = FIRFilter(filter_type="aw", fs=sample_rate)
365
+
366
+ def stft(self, x):
367
+ """Perform STFT.
368
+ Args:
369
+ x (Tensor): Input signal tensor (B, T).
370
+
371
+ Returns:
372
+ Tensor: x_mag, x_phs
373
+ Magnitude and phase spectra (B, fft_size // 2 + 1, frames).
374
+ """
375
+ x_stft = torch.stft(
376
+ x,
377
+ self.fft_size,
378
+ self.hop_size,
379
+ self.win_length,
380
+ self.window,
381
+ return_complex=True,
382
+ )
383
+ x_mag = torch.sqrt(
384
+ torch.clamp((x_stft.real**2) + (x_stft.imag**2), min=self.eps)
385
+ )
386
+
387
+ # torch.angle is expensive, so it is only evaluated if the values are used in the loss
388
+ if self.phs_used:
389
+ x_phs = torch.angle(x_stft)
390
+ else:
391
+ x_phs = None
392
+
393
+ return x_mag, x_phs
394
+
395
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
396
+ bs, chs, seq_len = input.size()
397
+
398
+ if self.perceptual_weighting: # apply optional A-weighting via FIR filter
399
+ # since FIRFilter only support mono audio we will move channels to batch dim
400
+ input = input.view(bs * chs, 1, -1)
401
+ target = target.view(bs * chs, 1, -1)
402
+
403
+ # now apply the filter to both
404
+ self.prefilter.to(input.device)
405
+ input, target = self.prefilter(input, target)
406
+
407
+ # now move the channels back
408
+ input = input.view(bs, chs, -1)
409
+ target = target.view(bs, chs, -1)
410
+
411
+ # compute the magnitude and phase spectra of input and target
412
+ self.window = self.window.to(input.device)
413
+
414
+ x_mag, x_phs = self.stft(input.view(-1, input.size(-1)))
415
+ y_mag, y_phs = self.stft(target.view(-1, target.size(-1)))
416
+
417
+ # apply relevant transforms
418
+ if self.scale is not None:
419
+ self.fb = self.fb.to(input.device)
420
+ x_mag = torch.matmul(self.fb, x_mag)
421
+ y_mag = torch.matmul(self.fb, y_mag)
422
+
423
+ # normalize scales
424
+ if self.scale_invariance:
425
+ alpha = (x_mag * y_mag).sum([-2, -1]) / ((y_mag**2).sum([-2, -1]))
426
+ y_mag = y_mag * alpha.unsqueeze(-1)
427
+
428
+ # compute loss terms
429
+ sc_mag_loss = self.spectralconv(x_mag, y_mag) if self.w_sc else 0.0
430
+ log_mag_loss = self.logstft(x_mag, y_mag) if self.w_log_mag else 0.0
431
+ lin_mag_loss = self.linstft(x_mag, y_mag) if self.w_lin_mag else 0.0
432
+ phs_loss = torch.nn.functional.mse_loss(x_phs, y_phs) if self.phs_used else 0.0
433
+
434
+ # combine loss terms
435
+ loss = (
436
+ (self.w_sc * sc_mag_loss)
437
+ + (self.w_log_mag * log_mag_loss)
438
+ + (self.w_lin_mag * lin_mag_loss)
439
+ + (self.w_phs * phs_loss)
440
+ )
441
+
442
+ loss = apply_reduction(loss, reduction=self.reduction)
443
+
444
+ if self.output == "loss":
445
+ return loss
446
+ elif self.output == "full":
447
+ return loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
448
+
449
+ class MultiResolutionSTFTLoss(torch.nn.Module):
450
+ """Multi resolution STFT loss module.
451
+
452
+ See [Yamamoto et al., 2019](https://arxiv.org/abs/1910.11480)
453
+
454
+ Args:
455
+ fft_sizes (list): List of FFT sizes.
456
+ hop_sizes (list): List of hop sizes.
457
+ win_lengths (list): List of window lengths.
458
+ window (str, optional): Window to apply before FFT, options include:
459
+ 'hann_window', 'bartlett_window', 'blackman_window', 'hamming_window', 'kaiser_window']
460
+ Default: 'hann_window'
461
+ w_sc (float, optional): Weight of the spectral convergence loss term. Default: 1.0
462
+ w_log_mag (float, optional): Weight of the log magnitude loss term. Default: 1.0
463
+ w_lin_mag (float, optional): Weight of the linear magnitude loss term. Default: 0.0
464
+ w_phs (float, optional): Weight of the spectral phase loss term. Default: 0.0
465
+ sample_rate (int, optional): Sample rate. Required when scale = 'mel'. Default: None
466
+ scale (str, optional): Optional frequency scaling method, options include:
467
+ ['mel', 'chroma']
468
+ Default: None
469
+ n_bins (int, optional): Number of mel frequency bins. Required when scale = 'mel'. Default: None.
470
+ scale_invariance (bool, optional): Perform an optimal scaling of the target. Default: False
471
+ """
472
+
473
+ def __init__(
474
+ self,
475
+ fft_sizes: List[int] = [1024, 2048, 512],
476
+ hop_sizes: List[int] = [120, 240, 50],
477
+ win_lengths: List[int] = [600, 1200, 240],
478
+ window: str = "hann_window",
479
+ w_sc: float = 1.0,
480
+ w_log_mag: float = 1.0,
481
+ w_lin_mag: float = 0.0,
482
+ w_phs: float = 0.0,
483
+ sample_rate: float = None,
484
+ scale: str = None,
485
+ n_bins: int = None,
486
+ perceptual_weighting: bool = False,
487
+ scale_invariance: bool = False,
488
+ **kwargs,
489
+ ):
490
+ super().__init__()
491
+ assert len(fft_sizes) == len(hop_sizes) == len(win_lengths) # must define all
492
+ self.fft_sizes = fft_sizes
493
+ self.hop_sizes = hop_sizes
494
+ self.win_lengths = win_lengths
495
+
496
+ self.stft_losses = torch.nn.ModuleList()
497
+ for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
498
+ self.stft_losses += [
499
+ STFTLoss(
500
+ fs,
501
+ ss,
502
+ wl,
503
+ window,
504
+ w_sc,
505
+ w_log_mag,
506
+ w_lin_mag,
507
+ w_phs,
508
+ sample_rate,
509
+ scale,
510
+ n_bins,
511
+ perceptual_weighting,
512
+ scale_invariance,
513
+ **kwargs,
514
+ )
515
+ ]
516
+
517
+ def forward(self, x, y):
518
+ mrstft_loss = 0.0
519
+ sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss = [], [], [], []
520
+ # import ipdb
521
+ # ipdb.set_trace()
522
+ for f in self.stft_losses:
523
+ if f.output == "full": # extract just first term
524
+ tmp_loss = f(x, y)
525
+ mrstft_loss += tmp_loss[0]
526
+ sc_mag_loss.append(tmp_loss[1])
527
+ log_mag_loss.append(tmp_loss[2])
528
+ lin_mag_loss.append(tmp_loss[3])
529
+ phs_loss.append(tmp_loss[4])
530
+ else:
531
+ mrstft_loss += f(x, y)
532
+
533
+ mrstft_loss /= len(self.stft_losses)
534
+
535
+ if f.output == "loss":
536
+ return mrstft_loss
537
+ else:
538
+ return mrstft_loss, sc_mag_loss, log_mag_loss, lin_mag_loss, phs_loss
539
+
540
+
541
+ class SumAndDifferenceSTFTLoss(torch.nn.Module):
542
+ """Sum and difference sttereo STFT loss module.
543
+
544
+ See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
545
+
546
+ Args:
547
+ fft_sizes (List[int]): List of FFT sizes.
548
+ hop_sizes (List[int]): List of hop sizes.
549
+ win_lengths (List[int]): List of window lengths.
550
+ window (str, optional): Window function type.
551
+ w_sum (float, optional): Weight of the sum loss component. Default: 1.0
552
+ w_diff (float, optional): Weight of the difference loss component. Default: 1.0
553
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
554
+ mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
555
+ n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
556
+ sample_rate (float, optional): Audio sample rate. Default: None
557
+ output (str, optional): Format of the loss returned.
558
+ 'loss' : Return only the raw, aggregate loss term.
559
+ 'full' : Return the raw loss, plus intermediate loss terms.
560
+ Default: 'loss'
561
+ """
562
+
563
+ def __init__(
564
+ self,
565
+ fft_sizes: List[int],
566
+ hop_sizes: List[int],
567
+ win_lengths: List[int],
568
+ window: str = "hann_window",
569
+ w_sum: float = 1.0,
570
+ w_diff: float = 1.0,
571
+ output: str = "loss",
572
+ **kwargs,
573
+ ):
574
+ super().__init__()
575
+ self.sd = SumAndDifference()
576
+ self.w_sum = w_sum
577
+ self.w_diff = w_diff
578
+ self.output = output
579
+ self.mrstft = MultiResolutionSTFTLoss(
580
+ fft_sizes,
581
+ hop_sizes,
582
+ win_lengths,
583
+ window,
584
+ **kwargs,
585
+ )
586
+
587
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
588
+ """This loss function assumes batched input of stereo audio in the time domain.
589
+
590
+ Args:
591
+ input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
592
+ target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
593
+
594
+ Returns:
595
+ loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
596
+ loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
597
+ Aggregate and intermediate loss terms. Only returned if output='full'.
598
+ """
599
+ assert input.shape == target.shape # must have same shape
600
+ bs, chs, seq_len = input.size()
601
+
602
+ # compute sum and difference signals for both
603
+ input_sum, input_diff = self.sd(input)
604
+ target_sum, target_diff = self.sd(target)
605
+
606
+ # compute error in STFT domain
607
+ sum_loss = self.mrstft(input_sum, target_sum)
608
+ diff_loss = self.mrstft(input_diff, target_diff)
609
+ loss = ((self.w_sum * sum_loss) + (self.w_diff * diff_loss)) / 2
610
+
611
+ if self.output == "loss":
612
+ return loss
613
+ elif self.output == "full":
614
+ return loss, sum_loss, diff_loss
615
+
616
+ class SpatialSTFTLoss(torch.nn.Module):
617
+ """Sum and difference sttereo STFT loss module.
618
+
619
+ See [Steinmetz et al., 2020](https://arxiv.org/abs/2010.10291)
620
+
621
+ Args:
622
+ fft_sizes (List[int]): List of FFT sizes.
623
+ hop_sizes (List[int]): List of hop sizes.
624
+ win_lengths (List[int]): List of window lengths.
625
+ window (str, optional): Window function type.
626
+ w_sum (float, optional): Weight of the sum loss component. Default: 1.0
627
+ w_diff (float, optional): Weight of the difference loss component. Default: 1.0
628
+ perceptual_weighting (bool, optional): Apply perceptual A-weighting (Sample rate must be supplied). Default: False
629
+ mel_stft (bool, optional): Use Multi-resoltuion mel spectrograms. Default: False
630
+ n_mel_bins (int, optional): Number of mel bins to use when mel_stft = True. Default: 128
631
+ sample_rate (float, optional): Audio sample rate. Default: None
632
+ output (str, optional): Format of the loss returned.
633
+ 'loss' : Return only the raw, aggregate loss term.
634
+ 'full' : Return the raw loss, plus intermediate loss terms.
635
+ Default: 'loss'
636
+ """
637
+
638
+ def __init__(
639
+ self,
640
+ fft_sizes: List[int],
641
+ hop_sizes: List[int],
642
+ win_lengths: List[int],
643
+ window: str = "hann_window",
644
+ w_phi: float = 1.0,
645
+ w_theta: float = 1.0,
646
+ output: str = "loss",
647
+ **kwargs,
648
+ ):
649
+ super().__init__()
650
+ self.w_phi = w_phi
651
+ self.w_theta = w_theta
652
+ self.output = output
653
+ self.mrstft = MultiResolutionSTFTLoss(
654
+ fft_sizes,
655
+ hop_sizes,
656
+ win_lengths,
657
+ window,
658
+ **kwargs,
659
+ )
660
+
661
+
662
+ def forward(self, input: torch.Tensor, target: torch.Tensor):
663
+ """This loss function assumes batched input of stereo audio in the time domain.
664
+
665
+ Args:
666
+ input (torch.Tensor): Input tensor with shape (batch size, 2, seq_len).
667
+ target (torch.Tensor): Target tensor with shape (batch size, 2, seq_len).
668
+
669
+ Returns:
670
+ loss (torch.Tensor): Aggreate loss term. Only returned if output='loss'.
671
+ loss (torch.Tensor), sum_loss (torch.Tensor), diff_loss (torch.Tensor):
672
+ Aggregate and intermediate loss terms. Only returned if output='full'.
673
+ """
674
+ assert input.shape == target.shape # must have same shape
675
+ bs, chs, seq_len = input.size()
676
+
677
+ w_o, x_o, y_o, z_o = input[:, 0], input[:, 1], input[:, 2], input[:, 3]
678
+ w_r, x_r, y_r, z_r = target[:, 0], target[:, 1], target[:, 2], target[:, 3]
679
+
680
+ phi_o, theta_o = compute_direction(w_o, x_o, y_o, z_o)
681
+ phi_r, theta_r = compute_direction(w_r, x_r, y_r, z_r)
682
+
683
+ # compute error in STFT domain
684
+ phi_loss = self.mrstft(phi_o, phi_r)
685
+ theta_loss = self.mrstft(theta_o, theta_r)
686
+ loss = ((self.w_phi * phi_loss) + (self.w_theta * theta_loss)) / 2
687
+
688
+ if self.output == "loss":
689
+ return loss
690
+ elif self.output == "full":
691
+ return loss, sum_loss, diff_loss
ThinkSound/training/losses/losses.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ from torch.nn import functional as F
4
+ from torch import nn
5
+
6
+ class LossModule(nn.Module):
7
+ def __init__(self, name: str, weight: float = 1.0):
8
+ super().__init__()
9
+
10
+ self.name = name
11
+ self.weight = weight
12
+
13
+ def forward(self, info, *args, **kwargs):
14
+ raise NotImplementedError
15
+
16
+ class ValueLoss(LossModule):
17
+ def __init__(self, key: str, name, weight: float = 1.0):
18
+ super().__init__(name=name, weight=weight)
19
+
20
+ self.key = key
21
+
22
+ def forward(self, info):
23
+ return self.weight * info[self.key]
24
+
25
+ class L1Loss(LossModule):
26
+ def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'l1_loss'):
27
+ super().__init__(name=name, weight=weight)
28
+
29
+ self.key_a = key_a
30
+ self.key_b = key_b
31
+
32
+ self.mask_key = mask_key
33
+
34
+ def forward(self, info):
35
+ mse_loss = F.l1_loss(info[self.key_a], info[self.key_b], reduction='none')
36
+
37
+ if self.mask_key is not None and self.mask_key in info:
38
+ mse_loss = mse_loss[info[self.mask_key]]
39
+
40
+ mse_loss = mse_loss.mean()
41
+
42
+ return self.weight * mse_loss
43
+
44
+ class MSELoss(LossModule):
45
+ def __init__(self, key_a: str, key_b: str, weight: float = 1.0, mask_key: str = None, name: str = 'mse_loss'):
46
+ super().__init__(name=name, weight=weight)
47
+
48
+ self.key_a = key_a
49
+ self.key_b = key_b
50
+
51
+ self.mask_key = mask_key
52
+
53
+ def forward(self, info):
54
+ mse_loss = F.mse_loss(info[self.key_a], info[self.key_b], reduction='none')
55
+ if self.mask_key is not None and self.mask_key in info and info[self.mask_key] is not None:
56
+ mask = info[self.mask_key]
57
+
58
+ if mask.ndim == 2 and mse_loss.ndim == 3:
59
+ mask = mask.unsqueeze(1)
60
+
61
+ if mask.shape[1] != mse_loss.shape[1]:
62
+ mask = mask.repeat(1, mse_loss.shape[1], 1)
63
+
64
+ mse_loss = mse_loss[mask]
65
+
66
+ mse_loss = mse_loss.mean()
67
+
68
+ return self.weight * mse_loss
69
+
70
+ class AuralossLoss(LossModule):
71
+ def __init__(self, auraloss_module, input_key: str, target_key: str, name: str, weight: float = 1):
72
+ super().__init__(name, weight)
73
+
74
+ self.auraloss_module = auraloss_module
75
+
76
+ self.input_key = input_key
77
+ self.target_key = target_key
78
+
79
+ def forward(self, info):
80
+ loss = self.auraloss_module(info[self.input_key], info[self.target_key])
81
+
82
+ return self.weight * loss
83
+
84
+ class MultiLoss(nn.Module):
85
+ def __init__(self, losses: tp.List[LossModule]):
86
+ super().__init__()
87
+
88
+ self.losses = nn.ModuleList(losses)
89
+
90
+ def forward(self, info):
91
+ total_loss = 0
92
+
93
+ losses = {}
94
+
95
+ for loss_module in self.losses:
96
+ module_loss = loss_module(info)
97
+ total_loss += module_loss
98
+ losses[loss_module.name] = module_loss
99
+
100
+ return total_loss, losses
ThinkSound/training/utils.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor
4
+ import random
5
+
6
+
7
+
8
+ def get_rank():
9
+ """Get rank of current process."""
10
+
11
+ print(os.environ.keys())
12
+
13
+ if "SLURM_PROCID" in os.environ:
14
+ return int(os.environ["SLURM_PROCID"])
15
+
16
+ if not torch.distributed.is_available() or not torch.distributed.is_initialized():
17
+ return 0
18
+
19
+ return torch.distributed.get_rank()
20
+
21
+ class InverseLR(torch.optim.lr_scheduler._LRScheduler):
22
+ """Implements an inverse decay learning rate schedule with an optional exponential
23
+ warmup. When last_epoch=-1, sets initial lr as lr.
24
+ inv_gamma is the number of steps/epochs required for the learning rate to decay to
25
+ (1 / 2)**power of its original value.
26
+ Args:
27
+ optimizer (Optimizer): Wrapped optimizer.
28
+ inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
29
+ power (float): Exponential factor of learning rate decay. Default: 1.
30
+ warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
31
+ Default: 0.
32
+ final_lr (float): The final learning rate. Default: 0.
33
+ last_epoch (int): The index of last epoch. Default: -1.
34
+ verbose (bool): If ``True``, prints a message to stdout for
35
+ each update. Default: ``False``.
36
+ """
37
+
38
+ def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., final_lr=0.,
39
+ last_epoch=-1, verbose=False):
40
+ self.inv_gamma = inv_gamma
41
+ self.power = power
42
+ if not 0. <= warmup < 1:
43
+ raise ValueError('Invalid value for warmup')
44
+ self.warmup = warmup
45
+ self.final_lr = final_lr
46
+ super().__init__(optimizer, last_epoch, verbose)
47
+
48
+ def get_lr(self):
49
+ if not self._get_lr_called_within_step:
50
+ import warnings
51
+ warnings.warn("To get the last learning rate computed by the scheduler, "
52
+ "please use `get_last_lr()`.")
53
+
54
+ return self._get_closed_form_lr()
55
+
56
+ def _get_closed_form_lr(self):
57
+ warmup = 1 - self.warmup ** (self.last_epoch + 1)
58
+ lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
59
+ return [warmup * max(self.final_lr, base_lr * lr_mult)
60
+ for base_lr in self.base_lrs]
61
+
62
+ def copy_state_dict(model, state_dict):
63
+ """Load state_dict to model, but only for keys that match exactly.
64
+
65
+ Args:
66
+ model (nn.Module): model to load state_dict.
67
+ state_dict (OrderedDict): state_dict to load.
68
+ """
69
+ model_state_dict = model.state_dict()
70
+
71
+ # 创建一个列表存储不匹配的参数
72
+ missing_keys = []
73
+ unexpected_keys = []
74
+ # 手动加载并检查不匹配的参数
75
+ for key in state_dict:
76
+ if key not in model_state_dict:
77
+ unexpected_keys.append(key)
78
+ elif state_dict[key].shape != model_state_dict[key].shape:
79
+ unexpected_keys.append(key)
80
+
81
+ for key in model_state_dict:
82
+ if key not in state_dict:
83
+ missing_keys.append(key)
84
+
85
+ # 打印不匹配的参数
86
+ print("Missing keys in state_dict:", missing_keys)
87
+ print("Unexpected keys in state_dict:", unexpected_keys)
88
+ for key in state_dict:
89
+ if key in model_state_dict and state_dict[key].shape == model_state_dict[key].shape:
90
+ if isinstance(state_dict[key], torch.nn.Parameter):
91
+ # backwards compatibility for serialized parameters
92
+ state_dict[key] = state_dict[key].data
93
+ model_state_dict[key] = state_dict[key]
94
+
95
+ model.load_state_dict(model_state_dict, strict=False)
96
+
97
+ def create_optimizer_from_config(optimizer_config, parameters):
98
+ """Create optimizer from config.
99
+
100
+ Args:
101
+ parameters (iterable): parameters to optimize.
102
+ optimizer_config (dict): optimizer config.
103
+
104
+ Returns:
105
+ torch.optim.Optimizer: optimizer.
106
+ """
107
+
108
+ optimizer_type = optimizer_config["type"]
109
+
110
+ if optimizer_type == "FusedAdam":
111
+ from deepspeed.ops.adam import FusedAdam
112
+ optimizer = FusedAdam(parameters, **optimizer_config["config"])
113
+ else:
114
+ optimizer_fn = getattr(torch.optim, optimizer_type)
115
+ optimizer = optimizer_fn(parameters, **optimizer_config["config"])
116
+ return optimizer
117
+
118
+ def create_scheduler_from_config(scheduler_config, optimizer):
119
+ """Create scheduler from config.
120
+
121
+ Args:
122
+ scheduler_config (dict): scheduler config.
123
+ optimizer (torch.optim.Optimizer): optimizer.
124
+
125
+ Returns:
126
+ torch.optim.lr_scheduler._LRScheduler: scheduler.
127
+ """
128
+ if scheduler_config["type"] == "InverseLR":
129
+ scheduler_fn = InverseLR
130
+ else:
131
+ scheduler_fn = getattr(torch.optim.lr_scheduler, scheduler_config["type"])
132
+ scheduler = scheduler_fn(optimizer, **scheduler_config["config"])
133
+ return scheduler
134
+
135
+ # mask construction helpers
136
+
137
+ def mask_from_start_end_indices(
138
+ seq_len: int,
139
+ start: Tensor,
140
+ end: Tensor
141
+ ):
142
+ assert start.shape == end.shape
143
+ device = start.device
144
+
145
+ seq = torch.arange(seq_len, device = device, dtype = torch.long)
146
+ seq = seq.reshape(*((-1,) * start.ndim), seq_len)
147
+ seq = seq.expand(*start.shape, seq_len)
148
+
149
+ mask = seq >= start[..., None].long()
150
+ mask &= seq < end[..., None].long()
151
+ return mask
152
+
153
+ def mask_from_frac_lengths(
154
+ seq_len: int,
155
+ frac_lengths: Tensor
156
+ ):
157
+ device = frac_lengths.device
158
+
159
+ lengths = (frac_lengths * seq_len).long()
160
+ max_start = seq_len - lengths
161
+
162
+ rand = torch.zeros_like(frac_lengths, device = device).float().uniform_(0, 1)
163
+ start = (max_start * rand).clamp(min = 0)
164
+ end = start + lengths
165
+
166
+ return mask_from_start_end_indices(seq_len, start, end)
167
+
168
+ def generate_mask(batch_size, seq_len, frac_lengths, min_span_len):
169
+ # 计算需要掩盖的起始数量
170
+ n_mask = (frac_lengths * seq_len // min_span_len).long() # 每个 span 为 10
171
+ # 初始化掩码张量,初始为全 0(未掩盖)
172
+ mask_tensor = torch.zeros((batch_size, seq_len), device=frac_lengths.device, dtype=torch.bool)
173
+
174
+ for b in range(batch_size):
175
+ # 随机挑选起始帧
176
+ start_frames = random.sample(range(0, seq_len - min_span_len + 1), n_mask[b]) # 0 到 seq_len-10 的范围
177
+
178
+ for start in start_frames:
179
+ # 将 span 为 10 的区域标记为 1(掩盖)
180
+ mask_tensor[b, start:start + 10] = 1.0
181
+
182
+ return mask_tensor
183
+
184
+ def generate_channel_mask(diffusion_input):
185
+
186
+ # 如果 r_drop 小于 threshold,则对每个样本选择一个随机声道进行完全 mask
187
+ batchsize, num_channels, dim = diffusion_input.shape
188
+ for i in range(batchsize):
189
+ channel_means = torch.mean(torch.abs(diffusion_input[i]), dim=1) # Mean of the absolute values for each channel
190
+ # Determine if any channel is 'small enough'
191
+ if torch.all(channel_means > 0.01):
192
+ # If all channels are not 'small enough', apply the mask
193
+ channel = torch.randint(num_channels, (1,)).item()
194
+ diffusion_input[i, channel, :] = 1e-8 # Mask the channel by setting its values
195
+ else:
196
+ # Optionally log that at least one channel is 'small enough' and no mask is applied
197
+ print(f"Sample {i}: At least one channel is 'small enough', skipping masking.")
198
+
199
+ return diffusion_input
200
+