junxiliu commited on
Commit
3a1da90
·
1 Parent(s): 6ec9214

add needed model with proper LFS tracking

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. LICENSE +21 -0
  3. MeanAudio +0 -1
  4. config/__init__.py +0 -0
  5. config/base_config.yaml +65 -0
  6. config/data/t5_clap.yaml +58 -0
  7. config/eval_config.yaml +23 -0
  8. config/hydra/job_logging/custom-eval.yaml +32 -0
  9. config/hydra/job_logging/custom-no-rank.yaml +32 -0
  10. config/hydra/job_logging/custom-simplest.yaml +26 -0
  11. config/hydra/job_logging/custom.yaml +33 -0
  12. config/train_config.yaml +46 -0
  13. data/.gitkeep +0 -0
  14. eval.py +151 -0
  15. infer.py +143 -0
  16. meanaudio/__init__.py +0 -0
  17. meanaudio/data/__init__.py +0 -0
  18. meanaudio/data/av_utils.py +162 -0
  19. meanaudio/data/data_setup.py +137 -0
  20. meanaudio/data/eval/__init__.py +0 -0
  21. meanaudio/data/eval/audiocaps.py +39 -0
  22. meanaudio/data/eval/moviegen.py +131 -0
  23. meanaudio/data/eval/video_dataset.py +197 -0
  24. meanaudio/data/extracted_audio.py +175 -0
  25. meanaudio/data/extraction/__init__.py +0 -0
  26. meanaudio/data/extraction/vgg_sound.py +195 -0
  27. meanaudio/data/extraction/wav_dataset.py +153 -0
  28. meanaudio/data/mm_dataset.py +50 -0
  29. meanaudio/data/utils.py +148 -0
  30. meanaudio/eval_utils.py +167 -0
  31. meanaudio/ext/__init__.py +1 -0
  32. meanaudio/ext/autoencoder/__init__.py +1 -0
  33. meanaudio/ext/autoencoder/autoencoder.py +52 -0
  34. meanaudio/ext/autoencoder/edm2_utils.py +168 -0
  35. meanaudio/ext/autoencoder/vae.py +369 -0
  36. meanaudio/ext/autoencoder/vae_modules.py +117 -0
  37. meanaudio/ext/bigvgan/LICENSE +21 -0
  38. meanaudio/ext/bigvgan/__init__.py +1 -0
  39. meanaudio/ext/bigvgan/activations.py +120 -0
  40. meanaudio/ext/bigvgan/alias_free_torch/__init__.py +6 -0
  41. meanaudio/ext/bigvgan/alias_free_torch/act.py +28 -0
  42. meanaudio/ext/bigvgan/alias_free_torch/filter.py +95 -0
  43. meanaudio/ext/bigvgan/alias_free_torch/resample.py +49 -0
  44. meanaudio/ext/bigvgan/bigvgan.py +32 -0
  45. meanaudio/ext/bigvgan/bigvgan_vocoder.yml +63 -0
  46. meanaudio/ext/bigvgan/env.py +18 -0
  47. meanaudio/ext/bigvgan/incl_licenses/LICENSE_1 +21 -0
  48. meanaudio/ext/bigvgan/incl_licenses/LICENSE_2 +21 -0
  49. meanaudio/ext/bigvgan/incl_licenses/LICENSE_3 +201 -0
  50. meanaudio/ext/bigvgan/incl_licenses/LICENSE_4 +29 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pdf filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Sony Research Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
MeanAudio DELETED
@@ -1 +0,0 @@
1
- Subproject commit 5f221b4b30ba3f89e8711c54961461c48d4999b8
 
 
config/__init__.py ADDED
File without changes
config/base_config.yaml ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: t5_clap # chenge here to load different data in testing (data.AudioCaps_test)
3
+ - override hydra/job_logging: custom-simplest
4
+ - _self_
5
+
6
+ hydra:
7
+ run:
8
+ dir: ./exps/${exp_id}
9
+ output_subdir: ${now:%Y-%m-%d_%H-%M-%S}-hydra
10
+
11
+ enable_email: False
12
+
13
+ ## model
14
+ model: meanaudio_mf
15
+ text_encoder_name: t5_clap # [t5, clip, t5_clap, t5_clap_cat]: change here for different feature utils (only for runner-FeatureUtils/infer, not used for using pre-computed dataset)
16
+ concat_text_fc: False
17
+
18
+ exp_id: default
19
+ debug: False
20
+ cudnn_benchmark: True
21
+ compile: False # set compile to false by default
22
+ amp: True
23
+ weights: null
24
+ # weights: null
25
+
26
+ checkpoint: null
27
+
28
+ seed: 14159265
29
+ num_workers: 10 # per-GPU
30
+ pin_memory: False # set to True if your system can handle it, i.e., have enough memory
31
+
32
+ # NOTE: This DOSE NOT affect the model during inference in any way
33
+ # they are just for the dataloader to fill in the missing data in multi-modal loading
34
+ # to change the sequence length for the model, see networks.py
35
+ data_dim:
36
+ text_seq_len: 77
37
+ text_dim: 1024
38
+ text_c_dim: 512 # 1024 for pooled T5, 512 for CLAP
39
+
40
+ # ema configuration
41
+ ema:
42
+ enable: True
43
+ sigma_rels: [0.05, 0.1]
44
+ update_every: 1
45
+ checkpoint_every: 10_000
46
+ checkpoint_folder: ${hydra:run.dir}/ema_ckpts
47
+ default_output_sigma: 0.05
48
+
49
+
50
+ # sampling, only for flow matching
51
+ sampling:
52
+ mean: 0.0
53
+ scale: 1.0
54
+ min_sigma: 0.0
55
+ method: euler
56
+ num_steps: 25
57
+
58
+ # classifier-free guidance
59
+ null_condition_probability: 0.1
60
+ cfg_strength: 1
61
+
62
+ # checkpoint paths to external modules
63
+ vae_16k_ckpt: ./weights/v1-16.pth
64
+ vae_44k_ckpt: ./weights/v1-44.pth
65
+ bigvgan_vocoder_ckpt: ./weights/best_netG.pt
config/data/t5_clap.yaml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AudioCaps
2
+ AudioCaps_npz:
3
+ tag: train
4
+ tsv: data/audiocaps/train-memmap.tsv
5
+ npz_dir: data/audiocaps/train-npz-t5-clap
6
+ output_subdir: null
7
+ repa_npz_dir: null
8
+
9
+ AudioCaps_val_npz:
10
+ tag: val
11
+ tsv: data/audiocaps/val-memmap.tsv
12
+ npz_dir: data/audiocaps/val-npz-t5-clap
13
+ output_subdir: null
14
+ repa_npz_dir: null
15
+ gt_cache: data/audiocaps/val-features
16
+
17
+ AudioCaps_test_npz:
18
+ tag: test
19
+ tsv: data/audiocaps/test-memmap.tsv
20
+ npz_dir: data/audiocaps/test-npz-t5-clap
21
+ output_subdir: null
22
+ repa_npz_dir: null
23
+ gt_cache: data/audiocaps/test-features
24
+
25
+ latent_mean: 'sets/latent_mean.pt'
26
+ latent_std: 'sets/latent_std.pt'
27
+
28
+ # Clotho
29
+ Clotho_npz:
30
+ tsv: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/clotho/dev-memmap-t5-clap.tsv
31
+ npz_dir: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/clotho/dev-npz-t5-clap
32
+ repa_npz_dir: null
33
+
34
+ # WavCaps
35
+ AudioSetSL_npz:
36
+ tsv: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/wavcaps/audioset-sl-memmap-t5-clap.tsv
37
+ npz_dir: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/wavcaps/audioset-sl-npz-t5-clap
38
+ repa_npz_dir: null
39
+
40
+ BBCSound_npz:
41
+ tsv: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/wavcaps/bbc-sound-effects-memmap-t5-clap.tsv
42
+ npz_dir: /hpc_stor03/sjtu_home/xiquan.li/data/MMAudio/wavcaps/bbc-sound-effects-npz-t5-clap
43
+ repa_npz_dir: null
44
+
45
+ FreeSound1_npz:
46
+ tsv: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-memmap-t5-clap-1.tsv
47
+ npz_dir: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-npz-t5-clap-1
48
+ repa_npz_dir: null
49
+
50
+ FreeSound2_npz:
51
+ tsv: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-memmap-t5-clap-2.tsv
52
+ npz_dir: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-npz-t5-clap-2
53
+ repa_npz_dir: null
54
+
55
+ FreeSound3_npz:
56
+ tsv: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-memmap-t5-clap-3.tsv
57
+ npz_dir: /hpc_stor03/sjtu_home/junxi.liu/shared/freesound-npz-t5-clap-3
58
+ repa_npz_dir: null
config/eval_config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## This config fire is no longer used
2
+ ## We pass everything by train_config to ensure training/eval consistency
3
+
4
+ defaults:
5
+ - base_config_at
6
+ - override hydra/job_logging: custom-simplest
7
+ - _self_
8
+
9
+ hydra:
10
+ run:
11
+ dir: ./exps/${exp_id}
12
+ output_subdir: eval-${now:%Y-%m-%d_%H-%M-%S}-hydra
13
+
14
+ exp_id: ${model}
15
+ dataset: audiocaps
16
+ duration_s: 10.0
17
+
18
+ # for inference, this is the per-GPU batch size
19
+ batch_size: 16 # eval batch size
20
+
21
+ output_name: null
22
+
23
+ enable_grad_scaler: False
config/hydra/job_logging/custom-eval.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/eval-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-no-rank.yaml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ file:
23
+ class: logging.FileHandler
24
+ formatter: simple
25
+ # absolute file path
26
+ filename: ${hydra.runtime.output_dir}/${now:%Y-%m-%d_%H-%M-%S}-eval.log
27
+ mode: w
28
+ root:
29
+ level: INFO
30
+ handlers: [console, file]
31
+
32
+ disable_existing_loggers: false
config/hydra/job_logging/custom-simplest.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # python logging configuration for tasks
2
+ version: 1
3
+ formatters:
4
+ simple:
5
+ format: '[%(asctime)s][%(levelname)s] - %(message)s'
6
+ datefmt: '%Y-%m-%d %H:%M:%S'
7
+ colorlog:
8
+ '()': 'colorlog.ColoredFormatter'
9
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
10
+ datefmt: '%Y-%m-%d %H:%M:%S'
11
+ log_colors:
12
+ DEBUG: purple
13
+ INFO: green
14
+ WARNING: yellow
15
+ ERROR: red
16
+ CRITICAL: red
17
+ handlers:
18
+ console:
19
+ class: logging.StreamHandler
20
+ formatter: colorlog
21
+ stream: ext://sys.stdout
22
+ root:
23
+ level: INFO
24
+ handlers: [console]
25
+
26
+ disable_existing_loggers: false
config/hydra/job_logging/custom.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # @package hydra.job_logging
2
+ # python logging configuration for tasks
3
+ version: 1
4
+ formatters:
5
+ simple:
6
+ format: '[%(asctime)s][%(levelname)s][r${oc.env:LOCAL_RANK}] - %(message)s'
7
+ datefmt: '%Y-%m-%d %H:%M:%S'
8
+ colorlog:
9
+ '()': 'colorlog.ColoredFormatter'
10
+ format: '[%(cyan)s%(asctime)s%(reset)s][%(blue)sr${oc.env:LOCAL_RANK}%(reset)s][%(log_color)s%(levelname)s%(reset)s] - %(message)s'
11
+ datefmt: '%Y-%m-%d %H:%M:%S'
12
+ log_colors:
13
+ DEBUG: purple
14
+ INFO: green
15
+ WARNING: yellow
16
+ ERROR: red
17
+ CRITICAL: red
18
+ handlers:
19
+ console:
20
+ class: logging.StreamHandler
21
+ formatter: colorlog
22
+ stream: ext://sys.stdout
23
+ file:
24
+ class: logging.FileHandler
25
+ formatter: simple
26
+ # absolute file path
27
+ filename: ${hydra.runtime.output_dir}/train-${now:%Y-%m-%d_%H-%M-%S}-rank${oc.env:LOCAL_RANK}.log
28
+ mode: w
29
+ root:
30
+ level: INFO
31
+ handlers: [console, file]
32
+
33
+ disable_existing_loggers: false
config/train_config.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - base_config
3
+ - override data: t5_clap # change here for loading different text features in training/evaluation
4
+ - override hydra/job_logging: custom
5
+ - _self_
6
+
7
+ hydra:
8
+ run:
9
+ dir: ./exps/${exp_id}
10
+ output_subdir: train-${now:%Y-%m-%d_%H-%M-%S}-hydra
11
+
12
+ ema:
13
+ start: 0
14
+
15
+ mini_train: False
16
+ example_train: False
17
+ enable_grad_scaler: True
18
+ ac_oversample_rate: 5
19
+
20
+ log_text_interval: 50
21
+ log_extra_interval: 10_000
22
+ val_interval: 10_000
23
+ eval_interval: 10_000
24
+ save_eval_interval: 10_000
25
+ save_weights_interval: 5_000
26
+ save_checkpoint_interval: 10_000
27
+ save_copy_iterations: []
28
+
29
+ batch_size: 128
30
+ eval_batch_size: 4
31
+
32
+ num_iterations: 100_000
33
+ learning_rate: 1e-4
34
+ linear_warmup_steps: 1_000
35
+
36
+ lr_schedule: step
37
+ lr_schedule_steps: [40_000, 45_000] # this is not used, lr_schedule_steps will be determined by the number of iterations
38
+ lr_schedule_gamma: 0.1
39
+
40
+ clip_grad_norm: 1.0
41
+ weight_decay: 1.0e-6
42
+
43
+ output_name: null # for eval
44
+
45
+ use_meanflow: True
46
+ use_repa: False
data/.gitkeep ADDED
File without changes
eval.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from argparse import ArgumentParser
3
+ from pathlib import Path
4
+ import os
5
+ import torch
6
+ import torchaudio
7
+ import csv
8
+ from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_fm, generate_mf, setup_eval_logging)
9
+ from meanaudio.model.flow_matching import FlowMatching
10
+ from meanaudio.model.mean_flow import MeanFlow
11
+ from meanaudio.model.networks import MeanAudio, get_mean_audio
12
+ from meanaudio.model.utils.features_utils import FeaturesUtils
13
+
14
+ torch.backends.cuda.matmul.allow_tf32 = True
15
+ torch.backends.cudnn.allow_tf32 = True
16
+
17
+ from tqdm import tqdm
18
+ log = logging.getLogger()
19
+
20
+
21
+ @torch.inference_mode()
22
+ def main():
23
+ setup_eval_logging()
24
+
25
+ parser = ArgumentParser()
26
+ parser.add_argument('--variant',
27
+ type=str,
28
+ default='meanaudio_mf',
29
+ help='meanaudio_mf, fluxaudio_fm')
30
+
31
+ parser.add_argument('--audio_path', type=str, help='Input audio', default='')
32
+ parser.add_argument('--duration', type=float, default=9.975) # for 312 latents, seq_config should has a duration of 9.975s
33
+ parser.add_argument('--cfg_strength', type=float, default=4.5,
34
+ help='If you use meanflow, CFG is integrated in model training. So simply set this <1 to avoid an additional unconditional infer.')
35
+ parser.add_argument('--num_steps', type=int, default=25)
36
+ parser.add_argument('--output', type=Path, help='Output directory', default='./output')
37
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
38
+ parser.add_argument('--full_precision', action='store_true')
39
+ parser.add_argument('--model_path', type=str, help='Ckpt path of trained model')
40
+ parser.add_argument('--encoder_name', choices=['clip', 't5', 't5_clap'], type=str, help='text encoder name')
41
+ parser.add_argument('--use_rope', action='store_true', help='Whether or not use position embedding for model')
42
+ parser.add_argument('--text_c_dim', type=int, default=512,
43
+ help='Dim of the text_features_c, 1024 for pooled T5 and 512 for CLAP')
44
+ parser.add_argument('--debug', action='store_true')
45
+ parser.add_argument('--use_meanflow', action='store_true', help='Whether or not use mean flow for inference')
46
+ args = parser.parse_args()
47
+
48
+ if args.debug:
49
+ import debugpy
50
+ debugpy.listen(6665)
51
+ print("Waiting for debugger attach (rank 0)...")
52
+ debugpy.wait_for_client()
53
+
54
+ if args.variant not in all_model_cfg:
55
+ raise ValueError(f'Unknown model variant: {args.variant}')
56
+ model: ModelConfig = all_model_cfg[args.variant] # model is just the model config
57
+ # model.download_if_needed()
58
+ seq_cfg = model.seq_cfg
59
+
60
+ negative_prompt: str = ''
61
+ output_dir: str = args.output.expanduser()
62
+ seed: int = args.seed
63
+ num_steps: int = args.num_steps
64
+ duration: float = args.duration
65
+ cfg_strength: float = args.cfg_strength
66
+
67
+ device = 'cpu'
68
+ if torch.cuda.is_available():
69
+ device = 'cuda'
70
+ elif torch.backends.mps.is_available():
71
+ device = 'mps'
72
+ else:
73
+ log.warning('CUDA/MPS are not available, running on CPU')
74
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
75
+
76
+ output_dir.mkdir(parents=True, exist_ok=True)
77
+ print(model.model_name)
78
+ # load a pretrained model
79
+ net: MeanAudio = get_mean_audio(model.model_name,
80
+ use_rope=args.use_rope,
81
+ text_c_dim=args.text_c_dim).to(device, dtype).eval()
82
+ net.load_weights(torch.load(args.model_path, map_location=device, weights_only=True))
83
+ log.info(f'Loaded weights from {args.model_path}')
84
+
85
+ # misc setup
86
+ rng = torch.Generator(device=device)
87
+ rng.manual_seed(seed)
88
+ if args.use_meanflow:
89
+ mf = MeanFlow(steps=num_steps)
90
+ else:
91
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
92
+
93
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
94
+ enable_conditions=True,
95
+ encoder_name=args.encoder_name,
96
+ mode=model.mode,
97
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
98
+ need_vae_encoder=False)
99
+ feature_utils = feature_utils.to(device, dtype).eval()
100
+
101
+ seq_cfg.duration = duration
102
+ net.update_seq_lengths(seq_cfg.latent_seq_len)
103
+
104
+ eval_file = './sets/test-audiocaps.tsv'
105
+ audio_ids=[]
106
+ text_prompts=[]
107
+ with open(eval_file, 'r') as f:
108
+ reader = csv.DictReader(f, delimiter='\t')
109
+ for row in reader:
110
+ audio_ids.append(row['id'])
111
+ text_prompts.append(row['caption'])
112
+
113
+ for k in tqdm(range(0, len(text_prompts))):
114
+ prompt = text_prompts[k]
115
+ if args.use_meanflow:
116
+ log.info(f'Prompt: {prompt}')
117
+ log.info(f'Negative prompt: {negative_prompt}')
118
+ audios = generate_mf([prompt],
119
+ negative_text=[negative_prompt],
120
+ feature_utils=feature_utils,
121
+ net=net,
122
+ mf=mf,
123
+ rng=rng,
124
+ cfg_strength=cfg_strength)
125
+ audio = audios.float().cpu()[0]
126
+ save_paths = output_dir / f'{audio_ids[k]}.wav'
127
+ torchaudio.save(save_paths, audio, seq_cfg.sampling_rate)
128
+ log.info(f'Audio saved to {save_paths}')
129
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
130
+
131
+ else:
132
+ prompt = text_prompts[k]
133
+ log.info(f'Prompt: {prompt}')
134
+ log.info(f'Negative prompt: {negative_prompt}')
135
+ audios = generate_fm([prompt],
136
+ negative_text=[negative_prompt],
137
+ feature_utils=feature_utils,
138
+ net=net,
139
+ fm=fm,
140
+ rng=rng,
141
+ cfg_strength=cfg_strength)
142
+ audio = audios.float().cpu()[0]
143
+
144
+ save_paths = output_dir / f'{audio_ids[k]}.wav'
145
+ torchaudio.save(save_paths, audio, seq_cfg.sampling_rate)
146
+ log.info(f'Audio saved to {save_paths}')
147
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
148
+
149
+
150
+ if __name__ == '__main__':
151
+ main()
infer.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings("ignore", category=FutureWarning)
3
+
4
+ import logging
5
+ from argparse import ArgumentParser
6
+ from pathlib import Path
7
+ import torch
8
+ import torchaudio
9
+ from meanaudio.eval_utils import (ModelConfig, all_model_cfg, generate_mf, generate_fm, setup_eval_logging)
10
+ from meanaudio.model.flow_matching import FlowMatching
11
+ from meanaudio.model.mean_flow import MeanFlow
12
+ from meanaudio.model.networks import MeanAudio, get_mean_audio
13
+ from meanaudio.model.utils.features_utils import FeaturesUtils
14
+
15
+ torch.backends.cuda.matmul.allow_tf32 = True
16
+ torch.backends.cudnn.allow_tf32 = True
17
+ from tqdm import tqdm
18
+ log = logging.getLogger()
19
+
20
+
21
+ @torch.inference_mode()
22
+ def main():
23
+ setup_eval_logging()
24
+
25
+ parser = ArgumentParser()
26
+ parser.add_argument('--variant',
27
+ type=str,
28
+ default='small_16k_mf',
29
+ help='small_16k_mf, small_16k_fm')
30
+
31
+ parser.add_argument('--prompt', type=str, help='Input prompt', default='')
32
+ parser.add_argument('--negative_prompt', type=str, help='Negative prompt', default='')
33
+ parser.add_argument('--duration', type=float, default=9.975) # for 312 latents, seq_config should has a duration of 9.975s
34
+ parser.add_argument('--cfg_strength', type=float, default=4.5)
35
+ parser.add_argument('--num_steps', type=int, default=25)
36
+
37
+ parser.add_argument('--output', type=Path, help='Output directory', default='./output')
38
+ parser.add_argument('--seed', type=int, help='Random seed', default=42)
39
+ parser.add_argument('--full_precision', action='store_true')
40
+ parser.add_argument('--model_path', type=str, help='Ckpt path of trained model')
41
+ parser.add_argument('--encoder_name', choices=['clip', 't5', 't5_clap'], type=str, help='text encoder name')
42
+ parser.add_argument('--use_rope', action='store_true', help='Whether or not use position embedding for model')
43
+ parser.add_argument('--text_c_dim', type=int, default=512,
44
+ help='Dim of the text_features_c, 1024 for pooled T5 and 512 for CLAP')
45
+ parser.add_argument('--debug', action='store_true')
46
+ parser.add_argument('--use_meanflow', action='store_true', help='Whether or not use mean flow for inference')
47
+ args = parser.parse_args()
48
+
49
+ if args.debug:
50
+ import debugpy
51
+ debugpy.listen(6666)
52
+ print("Waiting for debugger attach (rank 0)...")
53
+ debugpy.wait_for_client()
54
+
55
+ if args.variant not in all_model_cfg:
56
+ raise ValueError(f'Unknown model variant: {args.variant}')
57
+ model: ModelConfig = all_model_cfg[args.variant] # model is just the model config
58
+ seq_cfg = model.seq_cfg
59
+
60
+ negative_prompt: str = args.negative_prompt
61
+ output_dir: str = args.output.expanduser()
62
+ seed: int = args.seed
63
+ num_steps: int = args.num_steps
64
+ duration: float = args.duration
65
+ cfg_strength: float = args.cfg_strength
66
+
67
+ device = 'cpu'
68
+ if torch.cuda.is_available():
69
+ device = 'cuda'
70
+ elif torch.backends.mps.is_available():
71
+ device = 'mps'
72
+ else:
73
+ log.warning('CUDA/MPS are not available, running on CPU')
74
+ dtype = torch.float32 if args.full_precision else torch.bfloat16
75
+
76
+ output_dir.mkdir(parents=True, exist_ok=True)
77
+ # load a pretrained model
78
+ net: MeanAudio = get_mean_audio(model.model_name,
79
+ use_rope=args.use_rope,
80
+ text_c_dim=args.text_c_dim).to(device, dtype).eval()
81
+ net.load_weights(torch.load(args.model_path, map_location=device, weights_only=True))
82
+ log.info(f'Loaded weights from {args.model_path}')
83
+
84
+ # misc setup
85
+ rng = torch.Generator(device=device)
86
+ rng.manual_seed(seed)
87
+ if args.use_meanflow:
88
+ mf = MeanFlow(steps=num_steps)
89
+ else:
90
+ fm = FlowMatching(min_sigma=0, inference_mode='euler', num_steps=num_steps)
91
+
92
+ feature_utils = FeaturesUtils(tod_vae_ckpt=model.vae_path,
93
+ enable_conditions=True,
94
+ encoder_name=args.encoder_name,
95
+ mode=model.mode,
96
+ bigvgan_vocoder_ckpt=model.bigvgan_16k_path,
97
+ need_vae_encoder=False)
98
+ feature_utils = feature_utils.to(device, dtype).eval()
99
+
100
+ seq_cfg.duration = duration
101
+ net.update_seq_lengths(seq_cfg.latent_seq_len)
102
+ prompts: str = [args.prompt]
103
+
104
+
105
+ if args.use_meanflow:
106
+ for prompt in tqdm(prompts):
107
+ log.info(f'Prompt: {prompt}')
108
+ log.info(f'Negative prompt: {negative_prompt}')
109
+ audios = generate_mf([prompt],
110
+ negative_text=[negative_prompt],
111
+ feature_utils=feature_utils,
112
+ net=net,
113
+ mf=mf,
114
+ rng=rng,
115
+ cfg_strength=cfg_strength)
116
+ audio = audios.float().cpu()[0]
117
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
118
+ save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{args.seed}.wav'
119
+ torchaudio.save( save_path, audio, seq_cfg.sampling_rate)
120
+ log.info(f'Audio saved to {save_path}')
121
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
122
+ else:
123
+ for prompt in tqdm(prompts):
124
+ log.info(f'Prompt: {prompt}')
125
+ log.info(f'Negative prompt: {negative_prompt}')
126
+ audios = generate_fm([prompt],
127
+ negative_text=[negative_prompt],
128
+ feature_utils=feature_utils,
129
+ net=net,
130
+ fm=fm,
131
+ rng=rng,
132
+ cfg_strength=cfg_strength)
133
+ audio = audios.float().cpu()[0]
134
+ safe_filename = prompt.replace(' ', '_').replace('/', '_').replace('.', '')
135
+ save_path = output_dir / f'{safe_filename}--numsteps{num_steps}--seed{args.seed}.wav'
136
+ torchaudio.save(save_path, audio, seq_cfg.sampling_rate)
137
+
138
+ log.info(f'Audio saved to {save_path}')
139
+ log.info('Memory usage: %.2f GB', torch.cuda.max_memory_allocated() / (2**30))
140
+
141
+
142
+ if __name__ == '__main__':
143
+ main()
meanaudio/__init__.py ADDED
File without changes
meanaudio/data/__init__.py ADDED
File without changes
meanaudio/data/av_utils.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from fractions import Fraction
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import av
7
+ import numpy as np
8
+ import torch
9
+ from av import AudioFrame
10
+
11
+
12
+ @dataclass
13
+ class VideoInfo:
14
+ duration_sec: float
15
+ fps: Fraction
16
+ clip_frames: torch.Tensor
17
+ sync_frames: torch.Tensor
18
+ all_frames: Optional[list[np.ndarray]]
19
+
20
+ @property
21
+ def height(self):
22
+ return self.all_frames[0].shape[0]
23
+
24
+ @property
25
+ def width(self):
26
+ return self.all_frames[0].shape[1]
27
+
28
+ @classmethod
29
+ def from_image_info(cls, image_info: 'ImageInfo', duration_sec: float,
30
+ fps: Fraction) -> 'VideoInfo':
31
+ num_frames = int(duration_sec * fps)
32
+ all_frames = [image_info.original_frame] * num_frames
33
+ return cls(duration_sec=duration_sec,
34
+ fps=fps,
35
+ clip_frames=image_info.clip_frames,
36
+ sync_frames=image_info.sync_frames,
37
+ all_frames=all_frames)
38
+
39
+
40
+ @dataclass
41
+ class ImageInfo:
42
+ clip_frames: torch.Tensor
43
+ sync_frames: torch.Tensor
44
+ original_frame: Optional[np.ndarray]
45
+
46
+ @property
47
+ def height(self):
48
+ return self.original_frame.shape[0]
49
+
50
+ @property
51
+ def width(self):
52
+ return self.original_frame.shape[1]
53
+
54
+
55
+ def read_frames(video_path: Path, list_of_fps: list[float], start_sec: float, end_sec: float,
56
+ need_all_frames: bool) -> tuple[list[np.ndarray], list[np.ndarray], Fraction]:
57
+ output_frames = [[] for _ in list_of_fps]
58
+ next_frame_time_for_each_fps = [0.0 for _ in list_of_fps]
59
+ time_delta_for_each_fps = [1 / fps for fps in list_of_fps]
60
+ all_frames = []
61
+
62
+ # container = av.open(video_path)
63
+ with av.open(video_path) as container:
64
+ stream = container.streams.video[0]
65
+ fps = stream.guessed_rate
66
+ stream.thread_type = 'AUTO'
67
+ for packet in container.demux(stream):
68
+ for frame in packet.decode():
69
+ frame_time = frame.time
70
+ if frame_time < start_sec:
71
+ continue
72
+ if frame_time > end_sec:
73
+ break
74
+
75
+ frame_np = None
76
+ if need_all_frames:
77
+ frame_np = frame.to_ndarray(format='rgb24')
78
+ all_frames.append(frame_np)
79
+
80
+ for i, _ in enumerate(list_of_fps):
81
+ this_time = frame_time
82
+ while this_time >= next_frame_time_for_each_fps[i]:
83
+ if frame_np is None:
84
+ frame_np = frame.to_ndarray(format='rgb24')
85
+
86
+ output_frames[i].append(frame_np)
87
+ next_frame_time_for_each_fps[i] += time_delta_for_each_fps[i]
88
+
89
+ output_frames = [np.stack(frames) for frames in output_frames]
90
+ return output_frames, all_frames, fps
91
+
92
+
93
+ def reencode_with_audio(video_info: VideoInfo, output_path: Path, audio: torch.Tensor,
94
+ sampling_rate: int):
95
+ container = av.open(output_path, 'w')
96
+ output_video_stream = container.add_stream('h264', video_info.fps)
97
+ output_video_stream.codec_context.bit_rate = 10 * 1e6 # 10 Mbps
98
+ output_video_stream.width = video_info.width
99
+ output_video_stream.height = video_info.height
100
+ output_video_stream.pix_fmt = 'yuv420p'
101
+
102
+ output_audio_stream = container.add_stream('aac', sampling_rate)
103
+
104
+ # encode video
105
+ for image in video_info.all_frames:
106
+ image = av.VideoFrame.from_ndarray(image)
107
+ packet = output_video_stream.encode(image)
108
+ container.mux(packet)
109
+
110
+ for packet in output_video_stream.encode():
111
+ container.mux(packet)
112
+
113
+ # convert float tensor audio to numpy array
114
+ audio_np = audio.numpy().astype(np.float32)
115
+ audio_frame = AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
116
+ audio_frame.sample_rate = sampling_rate
117
+
118
+ for packet in output_audio_stream.encode(audio_frame):
119
+ container.mux(packet)
120
+
121
+ for packet in output_audio_stream.encode():
122
+ container.mux(packet)
123
+
124
+ container.close()
125
+
126
+
127
+ def remux_with_audio(video_path: Path, audio: torch.Tensor, output_path: Path, sampling_rate: int):
128
+ """
129
+ NOTE: I don't think we can get the exact video duration right without re-encoding
130
+ so we are not using this but keeping it here for reference
131
+ """
132
+ video = av.open(video_path)
133
+ output = av.open(output_path, 'w')
134
+ input_video_stream = video.streams.video[0]
135
+ output_video_stream = output.add_stream(template=input_video_stream)
136
+ output_audio_stream = output.add_stream('aac', sampling_rate)
137
+
138
+ duration_sec = audio.shape[-1] / sampling_rate
139
+
140
+ for packet in video.demux(input_video_stream):
141
+ # We need to skip the "flushing" packets that `demux` generates.
142
+ if packet.dts is None:
143
+ continue
144
+ # We need to assign the packet to the new stream.
145
+ packet.stream = output_video_stream
146
+ output.mux(packet)
147
+
148
+ # convert float tensor audio to numpy array
149
+ audio_np = audio.numpy().astype(np.float32)
150
+ audio_frame = av.AudioFrame.from_ndarray(audio_np, format='flt', layout='mono')
151
+ audio_frame.sample_rate = sampling_rate
152
+
153
+ for packet in output_audio_stream.encode(audio_frame):
154
+ output.mux(packet)
155
+
156
+ for packet in output_audio_stream.encode():
157
+ output.mux(packet)
158
+
159
+ video.close()
160
+ output.close()
161
+
162
+ output.close()
meanaudio/data/data_setup.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from omegaconf import DictConfig
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from torch.utils.data.dataloader import default_collate
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from meanaudio.data.extracted_audio import ExtractedAudio
12
+ from meanaudio.data.mm_dataset import MultiModalDataset
13
+ from meanaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+
18
+ # Re-seed randomness every time we start a worker
19
+ def worker_init_fn(worker_id: int):
20
+ worker_seed = torch.initial_seed() % (2**31) + worker_id + local_rank * 1000
21
+ np.random.seed(worker_seed)
22
+ random.seed(worker_seed)
23
+ log.debug(f'Worker {worker_id} re-seeded with seed {worker_seed} in rank {local_rank}')
24
+
25
+
26
+ def load_audio_data(cfg: DictConfig, data_cfg: DictConfig) -> Dataset:
27
+ dataset = ExtractedAudio(tsv_path=data_cfg.tsv,
28
+ concat_text_fc=cfg.concat_text_fc, # FIX here we determine usage of concat based on global config
29
+ data_dim=cfg.data_dim,
30
+ npz_dir=data_cfg.npz_dir,
31
+ repa_npz_dir=data_cfg.repa_npz_dir,
32
+ exclude_cls=cfg.get('exclude_cls', False),
33
+ repa_version=cfg.get('repa_version', 1))
34
+ return dataset
35
+
36
+
37
+ def setup_training_datasets(cfg: DictConfig) -> tuple[Dataset, DistributedSampler, DataLoader]:
38
+
39
+ if cfg.mini_train:
40
+ audiocaps_mini = load_audio_data(cfg, cfg.data.AudioCaps_val_npz) # use val set as the miniset
41
+ dataset = MultiModalDataset([],
42
+ [audiocaps_mini])
43
+
44
+ else:
45
+
46
+ audiocaps_npz = load_audio_data(cfg, cfg.data.AudioCaps_npz)
47
+ # !TODO: think of a better way to handle different datasets
48
+
49
+ # freesound1_npz = load_audio_data_npz(cfg, cfg.data.FreeSound1_npz)
50
+ # freesound2_npz = load_audio_data_npz(cfg, cfg.data.FreeSound2_npz)
51
+ # freesound3_npz = load_audio_data_npz(cfg, cfg.data.FreeSound3_npz)
52
+
53
+ # audioset_sl_npz = load_audio_data_npz(cfg, cfg.data.AudioSetSL_npz)
54
+ # bbcsound_npz = load_audio_data_npz(cfg, cfg.data.BBCSound_npz)
55
+ # clotho_npz = load_audio_data_npz(cfg, cfg.data.Clotho_npz)
56
+
57
+ dataset = MultiModalDataset([], [audiocaps_npz])
58
+ # dataset = MultiModalDataset([], [audiocaps_npz]*cfg.ac_oversample_rate + [audioset_sl_npz, bbcsound_npz, clotho_npz,
59
+ # freesound1_npz, freesound2_npz, freesound3_npz])
60
+
61
+
62
+ batch_size = cfg.batch_size # per-gpu batch size
63
+ num_workers = cfg.num_workers
64
+ pin_memory = cfg.pin_memory
65
+ sampler, loader = construct_loader(dataset,
66
+ batch_size,
67
+ num_workers,
68
+ shuffle=True,
69
+ drop_last=True,
70
+ pin_memory=pin_memory)
71
+
72
+ return dataset, sampler, loader
73
+
74
+
75
+ def setup_test_datasets(cfg): # used in sample
76
+ dataset = load_audio_data(cfg, cfg.data.AudioCaps_test_npz) # ALL with NPZ format
77
+
78
+ batch_size = cfg.eval_batch_size # FIX: from train config
79
+ num_workers = cfg.num_workers
80
+ pin_memory = cfg.pin_memory
81
+ sampler, loader = construct_loader(dataset,
82
+ batch_size,
83
+ num_workers,
84
+ shuffle=False,
85
+ drop_last=False,
86
+ pin_memory=pin_memory)
87
+
88
+ return dataset, sampler, loader
89
+
90
+
91
+ def setup_val_datasets(cfg: DictConfig) -> tuple[Dataset, DataLoader, DataLoader]:
92
+ dataset = load_audio_data(cfg, cfg.data.AudioCaps_val_npz)
93
+
94
+ val_batch_size = cfg.batch_size
95
+ val_eval_batch_size = cfg.eval_batch_size
96
+ num_workers = cfg.num_workers
97
+ pin_memory = cfg.pin_memory
98
+ _, val_loader = construct_loader(dataset,
99
+ val_batch_size,
100
+ num_workers,
101
+ shuffle=False,
102
+ drop_last=False,
103
+ pin_memory=pin_memory)
104
+ _, eval_loader = construct_loader(dataset,
105
+ val_eval_batch_size,
106
+ num_workers,
107
+ shuffle=False,
108
+ drop_last=False,
109
+ pin_memory=pin_memory)
110
+
111
+ return dataset, val_loader, eval_loader
112
+
113
+
114
+ def error_avoidance_collate(batch):
115
+ batch = list(filter(lambda x: x is not None, batch)) # batch = [x for x in batch if x is not None]
116
+ return default_collate(batch)
117
+
118
+
119
+ def construct_loader(dataset: Dataset,
120
+ batch_size: int,
121
+ num_workers: int,
122
+ *,
123
+ shuffle: bool = True,
124
+ drop_last: bool = True,
125
+ pin_memory: bool = False,
126
+ error_avoidance: bool = False) -> tuple[DistributedSampler, DataLoader]:
127
+ train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=shuffle)
128
+ train_loader = DataLoader(dataset,
129
+ batch_size,
130
+ sampler=train_sampler,
131
+ num_workers=num_workers,
132
+ worker_init_fn=worker_init_fn,
133
+ drop_last=drop_last,
134
+ persistent_workers=num_workers > 0,
135
+ pin_memory=pin_memory,
136
+ collate_fn=error_avoidance_collate if error_avoidance else None)
137
+ return train_sampler, train_loader
meanaudio/data/eval/__init__.py ADDED
File without changes
meanaudio/data/eval/audiocaps.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from collections import defaultdict
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+
11
+ log = logging.getLogger()
12
+
13
+
14
+ class AudioCapsData(Dataset):
15
+
16
+ def __init__(self, audio_path: Union[str, Path], csv_path: Union[str, Path]):
17
+ df = pd.read_csv(csv_path).to_dict(orient='records')
18
+
19
+ audio_files = sorted(os.listdir(audio_path))
20
+ audio_files = set(
21
+ [Path(f).stem for f in audio_files if f.endswith('.wav') or f.endswith('.flac')])
22
+
23
+ self.data = []
24
+ for row in df:
25
+ self.data.append({
26
+ 'name': row['name'],
27
+ 'caption': row['caption'],
28
+ })
29
+
30
+ self.audio_path = Path(audio_path)
31
+ self.csv_path = Path(csv_path)
32
+
33
+ log.info(f'Found {len(self.data)} matching audio files in {self.audio_path}')
34
+
35
+ def __getitem__(self, idx: int) -> torch.Tensor:
36
+ return self.data[idx]
37
+
38
+ def __len__(self):
39
+ return len(self.data)
meanaudio/data/eval/moviegen.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import torch
8
+ from torch.utils.data.dataset import Dataset
9
+ from torchvision.transforms import v2
10
+ from torio.io import StreamingMediaDecoder
11
+
12
+ from mmaudio.utils.dist_utils import local_rank
13
+
14
+ log = logging.getLogger()
15
+
16
+ _CLIP_SIZE = 384
17
+ _CLIP_FPS = 8.0
18
+
19
+ _SYNC_SIZE = 224
20
+ _SYNC_FPS = 25.0
21
+
22
+
23
+ class MovieGenData(Dataset):
24
+
25
+ def __init__(
26
+ self,
27
+ video_root: Union[str, Path],
28
+ sync_root: Union[str, Path],
29
+ jsonl_root: Union[str, Path],
30
+ *,
31
+ duration_sec: float = 10.0,
32
+ read_clip: bool = True,
33
+ ):
34
+ self.video_root = Path(video_root)
35
+ self.sync_root = Path(sync_root)
36
+ self.jsonl_root = Path(jsonl_root)
37
+ self.read_clip = read_clip
38
+
39
+ videos = sorted(os.listdir(self.video_root))
40
+ videos = [v[:-4] for v in videos] # remove extensions
41
+ self.captions = {}
42
+
43
+ for v in videos:
44
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
45
+ data = json.load(f)
46
+ self.captions[v] = data['audio_prompt']
47
+
48
+ if local_rank == 0:
49
+ log.info(f'{len(videos)} videos found in {video_root}')
50
+
51
+ self.duration_sec = duration_sec
52
+
53
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
54
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
55
+
56
+ self.clip_augment = v2.Compose([
57
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
58
+ v2.ToImage(),
59
+ v2.ToDtype(torch.float32, scale=True),
60
+ ])
61
+
62
+ self.sync_augment = v2.Compose([
63
+ v2.Resize((_SYNC_SIZE, _SYNC_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
64
+ v2.CenterCrop(_SYNC_SIZE),
65
+ v2.ToImage(),
66
+ v2.ToDtype(torch.float32, scale=True),
67
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
68
+ ])
69
+
70
+ self.videos = videos
71
+
72
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
73
+ video_id = self.videos[idx]
74
+ caption = self.captions[video_id]
75
+
76
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
77
+ reader.add_basic_video_stream(
78
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
79
+ frame_rate=_CLIP_FPS,
80
+ format='rgb24',
81
+ )
82
+ reader.add_basic_video_stream(
83
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
84
+ frame_rate=_SYNC_FPS,
85
+ format='rgb24',
86
+ )
87
+
88
+ reader.fill_buffer()
89
+ data_chunk = reader.pop_chunks()
90
+
91
+ clip_chunk = data_chunk[0]
92
+ sync_chunk = data_chunk[1]
93
+ if clip_chunk is None:
94
+ raise RuntimeError(f'CLIP video returned None {video_id}')
95
+ if clip_chunk.shape[0] < self.clip_expected_length:
96
+ raise RuntimeError(f'CLIP video too short {video_id}')
97
+
98
+ if sync_chunk is None:
99
+ raise RuntimeError(f'Sync video returned None {video_id}')
100
+ if sync_chunk.shape[0] < self.sync_expected_length:
101
+ raise RuntimeError(f'Sync video too short {video_id}')
102
+
103
+ # truncate the video
104
+ clip_chunk = clip_chunk[:self.clip_expected_length]
105
+ if clip_chunk.shape[0] != self.clip_expected_length:
106
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
107
+ f'expected {self.clip_expected_length}, '
108
+ f'got {clip_chunk.shape[0]}')
109
+ clip_chunk = self.clip_augment(clip_chunk)
110
+
111
+ sync_chunk = sync_chunk[:self.sync_expected_length]
112
+ if sync_chunk.shape[0] != self.sync_expected_length:
113
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
114
+ f'expected {self.sync_expected_length}, '
115
+ f'got {sync_chunk.shape[0]}')
116
+ sync_chunk = self.sync_augment(sync_chunk)
117
+
118
+ data = {
119
+ 'name': video_id,
120
+ 'caption': caption,
121
+ 'clip_video': clip_chunk,
122
+ 'sync_video': sync_chunk,
123
+ }
124
+
125
+ return data
126
+
127
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
128
+ return self.sample(idx)
129
+
130
+ def __len__(self):
131
+ return len(self.captions)
meanaudio/data/eval/video_dataset.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import os
4
+ from pathlib import Path
5
+ from typing import Union
6
+
7
+ import pandas as pd
8
+ import torch
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VideoDataset(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ video_root: Union[str, Path],
29
+ *,
30
+ duration_sec: float = 8.0,
31
+ ):
32
+ self.video_root = Path(video_root)
33
+
34
+ self.duration_sec = duration_sec
35
+
36
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
37
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
38
+
39
+ self.clip_transform = v2.Compose([
40
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
41
+ v2.ToImage(),
42
+ v2.ToDtype(torch.float32, scale=True),
43
+ ])
44
+
45
+ self.sync_transform = v2.Compose([
46
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
47
+ v2.CenterCrop(_SYNC_SIZE),
48
+ v2.ToImage(),
49
+ v2.ToDtype(torch.float32, scale=True),
50
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
51
+ ])
52
+
53
+ # to be implemented by subclasses
54
+ self.captions = {}
55
+ self.videos = sorted(list(self.captions.keys()))
56
+
57
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
58
+ video_id = self.videos[idx]
59
+ caption = self.captions[video_id]
60
+
61
+ reader = StreamingMediaDecoder(self.video_root / (video_id + '.mp4'))
62
+ reader.add_basic_video_stream(
63
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
64
+ frame_rate=_CLIP_FPS,
65
+ format='rgb24',
66
+ )
67
+ reader.add_basic_video_stream(
68
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
69
+ frame_rate=_SYNC_FPS,
70
+ format='rgb24',
71
+ )
72
+
73
+ reader.fill_buffer()
74
+ data_chunk = reader.pop_chunks()
75
+
76
+ clip_chunk = data_chunk[0]
77
+ sync_chunk = data_chunk[1]
78
+ if clip_chunk is None:
79
+ raise RuntimeError(f'CLIP video returned None {video_id}')
80
+ if clip_chunk.shape[0] < self.clip_expected_length:
81
+ raise RuntimeError(
82
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
83
+ )
84
+
85
+ if sync_chunk is None:
86
+ raise RuntimeError(f'Sync video returned None {video_id}')
87
+ if sync_chunk.shape[0] < self.sync_expected_length:
88
+ raise RuntimeError(
89
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
90
+ )
91
+
92
+ # truncate the video
93
+ clip_chunk = clip_chunk[:self.clip_expected_length]
94
+ if clip_chunk.shape[0] != self.clip_expected_length:
95
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
96
+ f'expected {self.clip_expected_length}, '
97
+ f'got {clip_chunk.shape[0]}')
98
+ clip_chunk = self.clip_transform(clip_chunk)
99
+
100
+ sync_chunk = sync_chunk[:self.sync_expected_length]
101
+ if sync_chunk.shape[0] != self.sync_expected_length:
102
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
103
+ f'expected {self.sync_expected_length}, '
104
+ f'got {sync_chunk.shape[0]}')
105
+ sync_chunk = self.sync_transform(sync_chunk)
106
+
107
+ data = {
108
+ 'name': video_id,
109
+ 'caption': caption,
110
+ 'clip_video': clip_chunk,
111
+ 'sync_video': sync_chunk,
112
+ }
113
+
114
+ return data
115
+
116
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
117
+ try:
118
+ return self.sample(idx)
119
+ except Exception as e:
120
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
121
+ return None
122
+
123
+ def __len__(self):
124
+ return len(self.captions)
125
+
126
+
127
+ class VGGSound(VideoDataset):
128
+
129
+ def __init__(
130
+ self,
131
+ video_root: Union[str, Path],
132
+ csv_path: Union[str, Path],
133
+ *,
134
+ duration_sec: float = 8.0,
135
+ ):
136
+ super().__init__(video_root, duration_sec=duration_sec)
137
+ self.video_root = Path(video_root)
138
+ self.csv_path = Path(csv_path)
139
+
140
+ videos = sorted(os.listdir(self.video_root))
141
+ if local_rank == 0:
142
+ log.info(f'{len(videos)} videos found in {video_root}')
143
+ self.captions = {}
144
+
145
+ df = pd.read_csv(csv_path, header=None, names=['id', 'sec', 'caption',
146
+ 'split']).to_dict(orient='records')
147
+
148
+ videos_no_found = []
149
+ for row in df:
150
+ if row['split'] == 'test':
151
+ start_sec = int(row['sec'])
152
+ video_id = str(row['id'])
153
+ # this is how our videos are named
154
+ video_name = f'{video_id}_{start_sec:06d}'
155
+ if video_name + '.mp4' not in videos:
156
+ videos_no_found.append(video_name)
157
+ continue
158
+
159
+ self.captions[video_name] = row['caption']
160
+
161
+ if local_rank == 0:
162
+ log.info(f'{len(videos)} videos found in {video_root}')
163
+ log.info(f'{len(self.captions)} useable videos found')
164
+ if videos_no_found:
165
+ log.info(f'{len(videos_no_found)} found in {csv_path} but not in {video_root}')
166
+ log.info(
167
+ 'A small amount is expected, as not all videos are still available on YouTube')
168
+
169
+ self.videos = sorted(list(self.captions.keys()))
170
+
171
+
172
+ class MovieGen(VideoDataset):
173
+
174
+ def __init__(
175
+ self,
176
+ video_root: Union[str, Path],
177
+ jsonl_root: Union[str, Path],
178
+ *,
179
+ duration_sec: float = 10.0,
180
+ ):
181
+ super().__init__(video_root, duration_sec=duration_sec)
182
+ self.video_root = Path(video_root)
183
+ self.jsonl_root = Path(jsonl_root)
184
+
185
+ videos = sorted(os.listdir(self.video_root))
186
+ videos = [v[:-4] for v in videos] # remove extensions
187
+ self.captions = {}
188
+
189
+ for v in videos:
190
+ with open(self.jsonl_root / (v + '.jsonl')) as f:
191
+ data = json.load(f)
192
+ self.captions[v] = data['audio_prompt']
193
+
194
+ if local_rank == 0:
195
+ log.info(f'{len(videos)} videos found in {video_root}')
196
+
197
+ self.videos = videos
meanaudio/data/extracted_audio.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from pathlib import Path
3
+ from typing import Union, Optional
4
+
5
+ import pandas as pd
6
+ import torch
7
+ from tensordict import TensorDict
8
+ from torch.utils.data.dataset import Dataset
9
+ from torch.utils.data import DataLoader
10
+
11
+ from meanaudio.utils.dist_utils import local_rank
12
+ import numpy as np
13
+ import glob
14
+ import torch.nn.functional as F
15
+ log = logging.getLogger()
16
+
17
+
18
+ class ExtractedAudio(Dataset):
19
+ def __init__(
20
+ self,
21
+ tsv_path: Union[str, Path],
22
+ *,
23
+ concat_text_fc: bool,
24
+ npz_dir: Union[str, Path],
25
+ data_dim: dict[str, int],
26
+ repa_npz_dir: Optional[Union[str, Path]], # if passed, repa features (zs) would be returned
27
+ exclude_cls: Optional[bool],
28
+ repa_version: Optional[int],
29
+ ):
30
+ super().__init__()
31
+ self.data_dim = data_dim
32
+ self.df_list = pd.read_csv(tsv_path, sep='\t').to_dict('records') # id, caption
33
+ self.ids = [str(d['id']) for d in self.df_list]
34
+ npz_files = glob.glob(f"{npz_dir}/*.npz")
35
+ self.concat_text_fc = concat_text_fc
36
+ self.exclude_cls = exclude_cls
37
+ self.repa_version = repa_version
38
+
39
+ if self.concat_text_fc:
40
+ log.info(f'We will concat the pooled text_features and text_features_c for text condition')
41
+
42
+ # dimension check
43
+ sample = np.load(f'{npz_dir}/0.npz')
44
+ mean_s = [len(npz_files)] + list(sample['mean'].shape)
45
+ std_s = [len(npz_files)] + list(sample['std'].shape)
46
+ text_features_s = [len(npz_files)] + list(sample['text_features'].shape)
47
+ text_features_c_s = [len(npz_files)] + list(sample['text_features_c'].shape)
48
+ if self.concat_text_fc:
49
+ text_features_c_s[-1] = text_features_c_s[-1] + text_features_s[-1]
50
+
51
+ log.info(f'Loading {len(npz_files)} npz files from {npz_dir}')
52
+ log.info(f'Loaded mean: {mean_s}.')
53
+ log.info(f'Loaded std: {std_s}.')
54
+ log.info(f'Loaded text features: {text_features_s}.')
55
+ log.info(f'Loaded text features_c: {text_features_c_s}.')
56
+
57
+ assert len(npz_files) == len(self.df_list), 'Number mismatch between npz files and tsv items'
58
+ assert mean_s[1] == self.data_dim['latent_seq_len'], \
59
+ f'{mean_s[1]} != {self.data_dim["latent_seq_len"]}'
60
+ assert std_s[1] == self.data_dim['latent_seq_len'], \
61
+ f'{std_s[1]} != {self.data_dim["latent_seq_len"]}'
62
+ assert text_features_s[1] == self.data_dim['text_seq_len'], \
63
+ f'{text_features_s[1]} != {self.data_dim["text_seq_len"]}'
64
+ assert text_features_s[-1] == self.data_dim['text_dim'], \
65
+ f'{text_features_s[-1]} != {self.data_dim["text_dim"]}'
66
+ assert text_features_c_s[-1] == self.data_dim['text_c_dim'], \
67
+ f'{text_features_c_s[-1]} != {self.data_dim["text_c_dim"]}'
68
+
69
+ self.npz_dir = npz_dir
70
+ if repa_npz_dir != None:
71
+ self.repa_npz_dir = repa_npz_dir
72
+ sample = np.load(f'{repa_npz_dir}/0.npz')
73
+ repa_npz_files = glob.glob(f"{repa_npz_dir}/*.npz")
74
+ log.info(f'Loading {len(repa_npz_files)} npz representations from {repa_npz_dir}')
75
+ es_s = [len(repa_npz_files)] + list(sample['es'].shape)
76
+ if self.repa_version == 2:
77
+ es_s[1] = 65 # ad-hoc 8x downsampling for EAT
78
+ elif self.repa_version == 3:
79
+ es_s[1] = 1 # we only use cls token for alignment
80
+ else:
81
+ if self.exclude_cls:
82
+ es_s[1] = es_s[1] - 1
83
+
84
+ log.info(f'Loaded es: {es_s}')
85
+ assert len(repa_npz_files) == len(npz_files), 'Number mismatch between repa npz files and latent npz files'
86
+ assert es_s[1] == self.data_dim['repa_seq_len'], \
87
+ f'{es_s[1]} != {self.data_dim["repa_seq_len"]}'
88
+ assert es_s[-1] == self.data_dim['repa_seq_dim'], \
89
+ f'{es_s[-1]} != {self.data_dim["repa_seq_dim"]}'
90
+ else:
91
+ self.repa_npz_dir = None
92
+
93
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
94
+ # !TODO here we may consider load pre-computed latent mean & std
95
+ raise NotImplementedError('Please manually compute latent stats outside. ')
96
+
97
+ def __getitem__(self, idx):
98
+ npz_path = f'{self.npz_dir}/{idx}.npz'
99
+ np_data = np.load(npz_path)
100
+ text_features = torch.from_numpy(np_data['text_features'])
101
+ text_features_c = torch.from_numpy(np_data['text_features_c'])
102
+ if self.concat_text_fc:
103
+ text_features_c = torch.cat([text_features.mean(dim=-2),
104
+ text_features_c], dim=-1) # [b, d+d_c]
105
+
106
+ out_dict = {
107
+ 'id': str(self.df_list[idx]['id']),
108
+ 'a_mean': torch.from_numpy(np_data['mean']),
109
+ 'a_std': torch.from_numpy(np_data['std']),
110
+ 'text_features': text_features,
111
+ 'text_features_c': text_features_c,
112
+ 'caption': self.df_list[idx]['caption'],
113
+ }
114
+ if self.repa_npz_dir != None:
115
+ repa_npz_path = f'{self.repa_npz_dir}/{idx}.npz'
116
+ repa_np_data = np.load(repa_npz_path)
117
+ zs = torch.from_numpy(repa_np_data['es'])
118
+
119
+ if self.repa_version == 1:
120
+ if self.exclude_cls:
121
+ zs = zs[1:,:]
122
+ if self.repa_version == 2:
123
+ z_cls = zs[0] # (dim)
124
+ # zs = zs[1:,:].view(64, 8, 768)
125
+ zs = F.avg_pool2d(zs[1:,:].unsqueeze(0),
126
+ kernel_size=(8, 1),
127
+ stride=(8, 1)).squeeze() # (64, 768)
128
+ zs = torch.cat((z_cls.unsqueeze(0), zs), dim=0)
129
+ elif self.repa_version == 3: # cls token
130
+ zs = zs[0].unsqueeze(0)
131
+
132
+ out_dict['zs'] = zs #!TODO Here field is WRONG for eat features (should be zs)
133
+
134
+ return out_dict
135
+
136
+ def __len__(self):
137
+ return len(self.ids)
138
+
139
+
140
+ if __name__ == '__main__':
141
+
142
+ from meanaudio.utils.dist_utils import info_if_rank_zero, local_rank, world_size
143
+ import torch.distributed as distributed
144
+ from datetime import timedelta
145
+ from torch.utils.data.distributed import DistributedSampler
146
+
147
+
148
+ def distributed_setup():
149
+ distributed.init_process_group(backend="nccl", timeout=timedelta(hours=2))
150
+ log.info(f'Initialized: local_rank={local_rank}, world_size={world_size}')
151
+ return local_rank, world_size
152
+
153
+ distributed_setup()
154
+
155
+ tsv_path = '/hpc_stor03/sjtu_home/xiquan.li/TTA/MMAudio/training/audiocaps/train-memmap-t5-clap.tsv'
156
+
157
+ data_dim = {'latent_seq_len': 312,
158
+ 'text_seq_len': 77,
159
+ 'text_dim': 1024,
160
+ 'text_c_dim': 512}
161
+
162
+ dataset = ExtractedAudio(tsv_path=tsv_path,
163
+ npz_dir=npz_dir,
164
+ data_dim=data_dim)
165
+ loader = DataLoader(dataset,
166
+ 16,
167
+ num_workers=8,
168
+ persistent_workers=8,
169
+ pin_memory=False)
170
+ train_sampler = DistributedSampler(dataset, rank=local_rank, shuffle=True)
171
+
172
+
173
+ for b in loader:
174
+ print(b['a_mean'].shape)
175
+ break
meanaudio/data/extraction/__init__.py ADDED
File without changes
meanaudio/data/extraction/vgg_sound.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Optional, Union
5
+
6
+ import pandas as pd
7
+ import torch
8
+ import torchaudio
9
+ from torch.utils.data.dataset import Dataset
10
+ from torchvision.transforms import v2
11
+ from torio.io import StreamingMediaDecoder
12
+
13
+ from mmaudio.utils.dist_utils import local_rank
14
+
15
+ log = logging.getLogger()
16
+
17
+ _CLIP_SIZE = 384
18
+ _CLIP_FPS = 8.0
19
+
20
+ _SYNC_SIZE = 224
21
+ _SYNC_FPS = 25.0
22
+
23
+
24
+ class VGGSound(Dataset):
25
+
26
+ def __init__(
27
+ self,
28
+ root: Union[str, Path],
29
+ *,
30
+ tsv_path: Union[str, Path] = 'sets/vgg3-train.tsv',
31
+ sample_rate: int = 16_000,
32
+ duration_sec: float = 8.0,
33
+ audio_samples: Optional[int] = None,
34
+ normalize_audio: bool = False,
35
+ ):
36
+ self.root = Path(root)
37
+ self.normalize_audio = normalize_audio
38
+ if audio_samples is None:
39
+ self.audio_samples = int(sample_rate * duration_sec)
40
+ else:
41
+ self.audio_samples = audio_samples
42
+ effective_duration = audio_samples / sample_rate
43
+ # make sure the duration is close enough, within 15ms
44
+ assert abs(effective_duration - duration_sec) < 0.015, \
45
+ f'audio_samples {audio_samples} does not match duration_sec {duration_sec}'
46
+
47
+ print("Loading videos started")
48
+ videos = sorted(os.listdir(self.root))
49
+ videos = set([Path(v).stem for v in videos]) # remove extensions
50
+ print("Loading videos ended")
51
+ self.labels = {}
52
+ self.videos = []
53
+ missing_videos = []
54
+
55
+ # read the tsv for subset information
56
+ df_list = pd.read_csv(tsv_path, sep='\t', dtype={'id': str}).to_dict('records')
57
+ for record in df_list:
58
+ id = record['id']
59
+ label = record['label']
60
+ if id in videos:
61
+ self.labels[id] = label
62
+ self.videos.append(id)
63
+ else:
64
+ missing_videos.append(id)
65
+
66
+ if local_rank == 0:
67
+ log.info(f'{len(videos)} videos found in {root}')
68
+ log.info(f'{len(self.videos)} videos found in {tsv_path}')
69
+ log.info(f'{len(missing_videos)} videos missing in {root}')
70
+
71
+ self.sample_rate = sample_rate
72
+ self.duration_sec = duration_sec
73
+
74
+ self.expected_audio_length = audio_samples
75
+ self.clip_expected_length = int(_CLIP_FPS * self.duration_sec)
76
+ self.sync_expected_length = int(_SYNC_FPS * self.duration_sec)
77
+
78
+ self.clip_transform = v2.Compose([
79
+ v2.Resize((_CLIP_SIZE, _CLIP_SIZE), interpolation=v2.InterpolationMode.BICUBIC),
80
+ v2.ToImage(),
81
+ v2.ToDtype(torch.float32, scale=True),
82
+ ])
83
+
84
+ self.sync_transform = v2.Compose([
85
+ v2.Resize(_SYNC_SIZE, interpolation=v2.InterpolationMode.BICUBIC),
86
+ v2.CenterCrop(_SYNC_SIZE),
87
+ v2.ToImage(),
88
+ v2.ToDtype(torch.float32, scale=True),
89
+ v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
90
+ ])
91
+
92
+ self.resampler = {}
93
+
94
+ def sample(self, idx: int) -> dict[str, torch.Tensor]:
95
+ video_id = self.videos[idx]
96
+ label = self.labels[video_id]
97
+
98
+ reader = StreamingMediaDecoder(self.root / (video_id + '.mp4'))
99
+ reader.add_basic_video_stream(
100
+ frames_per_chunk=int(_CLIP_FPS * self.duration_sec),
101
+ frame_rate=_CLIP_FPS,
102
+ format='rgb24',
103
+ )
104
+ reader.add_basic_video_stream(
105
+ frames_per_chunk=int(_SYNC_FPS * self.duration_sec),
106
+ frame_rate=_SYNC_FPS,
107
+ format='rgb24',
108
+ )
109
+ reader.add_basic_audio_stream(frames_per_chunk=2**30, )
110
+
111
+ reader.fill_buffer()
112
+ data_chunk = reader.pop_chunks()
113
+
114
+ clip_chunk = data_chunk[0]
115
+ sync_chunk = data_chunk[1]
116
+ audio_chunk = data_chunk[2]
117
+
118
+ if clip_chunk is None:
119
+ raise RuntimeError(f'CLIP video returned None {video_id}')
120
+ if clip_chunk.shape[0] < self.clip_expected_length:
121
+ raise RuntimeError(
122
+ f'CLIP video too short {video_id}, expected {self.clip_expected_length}, got {clip_chunk.shape[0]}'
123
+ )
124
+
125
+ if sync_chunk is None:
126
+ raise RuntimeError(f'Sync video returned None {video_id}')
127
+ if sync_chunk.shape[0] < self.sync_expected_length:
128
+ raise RuntimeError(
129
+ f'Sync video too short {video_id}, expected {self.sync_expected_length}, got {sync_chunk.shape[0]}'
130
+ )
131
+
132
+ # process audio
133
+ sample_rate = int(reader.get_out_stream_info(2).sample_rate)
134
+ audio_chunk = audio_chunk.transpose(0, 1)
135
+ audio_chunk = audio_chunk.mean(dim=0) # mono
136
+ if self.normalize_audio:
137
+ abs_max = audio_chunk.abs().max()
138
+ audio_chunk = audio_chunk / abs_max * 0.95
139
+ if abs_max <= 1e-6:
140
+ raise RuntimeError(f'Audio is silent {video_id}')
141
+
142
+ # resample
143
+ if sample_rate == self.sample_rate:
144
+ audio_chunk = audio_chunk
145
+ else:
146
+ if sample_rate not in self.resampler:
147
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
148
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
149
+ sample_rate,
150
+ self.sample_rate,
151
+ lowpass_filter_width=64,
152
+ rolloff=0.9475937167399596,
153
+ resampling_method='sinc_interp_kaiser',
154
+ beta=14.769656459379492,
155
+ )
156
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
157
+
158
+ if audio_chunk.shape[0] < self.expected_audio_length:
159
+ raise RuntimeError(f'Audio too short {video_id}')
160
+ audio_chunk = audio_chunk[:self.expected_audio_length]
161
+
162
+ # truncate the video
163
+ clip_chunk = clip_chunk[:self.clip_expected_length]
164
+ if clip_chunk.shape[0] != self.clip_expected_length:
165
+ raise RuntimeError(f'CLIP video wrong length {video_id}, '
166
+ f'expected {self.clip_expected_length}, '
167
+ f'got {clip_chunk.shape[0]}')
168
+ clip_chunk = self.clip_transform(clip_chunk)
169
+
170
+ sync_chunk = sync_chunk[:self.sync_expected_length]
171
+ if sync_chunk.shape[0] != self.sync_expected_length:
172
+ raise RuntimeError(f'Sync video wrong length {video_id}, '
173
+ f'expected {self.sync_expected_length}, '
174
+ f'got {sync_chunk.shape[0]}')
175
+ sync_chunk = self.sync_transform(sync_chunk)
176
+
177
+ data = {
178
+ 'id': video_id,
179
+ 'caption': label,
180
+ 'audio': audio_chunk,
181
+ 'clip_video': clip_chunk,
182
+ 'sync_video': sync_chunk,
183
+ }
184
+
185
+ return data
186
+
187
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
188
+ try:
189
+ return self.sample(idx)
190
+ except Exception as e:
191
+ log.error(f'Error loading video {self.videos[idx]}: {e}')
192
+ return None
193
+
194
+ def __len__(self):
195
+ return len(self.labels)
meanaudio/data/extraction/wav_dataset.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ from pathlib import Path
4
+ from typing import Union
5
+
6
+ import open_clip
7
+ import pandas as pd
8
+ import torch
9
+ import torchaudio
10
+ from torch.utils.data.dataset import Dataset
11
+ import torch.nn.functional as F
12
+
13
+ log = logging.getLogger()
14
+
15
+
16
+ class WavTextClipsDataset(Dataset):
17
+
18
+ def __init__(
19
+ self,
20
+ root: Union[str, Path],
21
+ *,
22
+ captions_tsv: Union[str, Path],
23
+ clips_tsv: Union[str, Path],
24
+ sample_rate: int,
25
+ num_samples: int,
26
+ duration: int = 10,
27
+ normalize_audio: bool = False,
28
+ reject_silent: bool = False,
29
+ tokenizer_id: str = 'ViT-H-14-378-quickgelu',
30
+ multi_caption: bool = False
31
+ ):
32
+ self.root = Path(root)
33
+ self.sample_rate = sample_rate
34
+ self.num_samples = num_samples
35
+ self.normalize_audio = normalize_audio
36
+ self.reject_silent = reject_silent
37
+ self.duration = duration
38
+ self.tokenizer = open_clip.get_tokenizer(tokenizer_id) # only for clip, for t5 and clap we will get caption embeddings outside
39
+
40
+ audios = sorted(os.listdir(self.root))
41
+ audios = set([
42
+ Path(audio).stem for audio in audios # file name w/o extension
43
+ if audio.endswith('.wav') or audio.endswith('.flac')
44
+ ])
45
+ self.captions = {}
46
+
47
+ # read the caption tsv
48
+ df_list = pd.read_csv(captions_tsv, sep='\t', dtype={'id': str}).to_dict('records')
49
+ for record in df_list:
50
+ id = record['id'] # file name
51
+ caption = record['caption']
52
+ if not multi_caption:
53
+ self.captions[id] = caption # captions: {name(no partition index): caption} !Only ONE caption will be selected for an audio clip
54
+ else:
55
+ if id not in self.captions.keys():
56
+ self.captions[id] = [caption]
57
+ else:
58
+ self.captions[id].append(caption)
59
+
60
+ # read the clip tsv
61
+ df_list = pd.read_csv(clips_tsv, sep='\t', dtype={
62
+ 'id': str,
63
+ 'name': str
64
+ }).to_dict('records')
65
+ self.clips = []
66
+ for record in df_list: # partition
67
+ name = record['name']
68
+ if name not in self.captions:
69
+ log.warning(f'Audio {name} not found in {captions_tsv}')
70
+ continue
71
+
72
+ if not multi_caption:
73
+ record['caption'] = self.captions[name]
74
+ self.clips.append(record) # add caption to partition csv
75
+ else:
76
+ for caption in self.captions[name]:
77
+ r = record.copy()
78
+ r['caption'] = caption
79
+ self.clips.append(r) # add caption to partition csv
80
+
81
+ log.info(f'Found {len(self.clips)} audio files in {self.root}')
82
+
83
+ self.resampler = {}
84
+
85
+ def __getitem__(self, idx: int) -> torch.Tensor:
86
+ try:
87
+ clip = self.clips[idx]
88
+ audio_name = clip['name']
89
+ audio_id = clip['id']
90
+ caption = clip['caption']
91
+ start_sample = clip['start_sample']
92
+ end_sample = clip['end_sample']
93
+
94
+ audio_path = self.root / f'{audio_name}.flac'
95
+ if not audio_path.exists():
96
+ audio_path = self.root / f'{audio_name}.wav'
97
+ assert audio_path.exists()
98
+
99
+ audio_chunk, sample_rate = torchaudio.load(audio_path)
100
+ audio_chunk = audio_chunk.mean(dim=0) # mono
101
+ abs_max = audio_chunk.abs().max()
102
+ if self.normalize_audio:
103
+ audio_chunk = audio_chunk / abs_max * 0.95
104
+
105
+ if self.reject_silent and abs_max < 1e-6:
106
+ log.warning(f'Rejecting silent audio')
107
+ return None
108
+ if audio_chunk.size(0) < end_sample:
109
+ audio_chunk = F.pad(
110
+ audio_chunk,
111
+ (0, end_sample - audio_chunk.size(0)),
112
+ mode='constant',
113
+ value=0
114
+ )
115
+ else:
116
+ audio_chunk = audio_chunk[start_sample:end_sample]
117
+
118
+ # resample
119
+ if sample_rate == self.sample_rate:
120
+ audio_chunk = audio_chunk
121
+ else:
122
+ if sample_rate not in self.resampler:
123
+ # https://pytorch.org/audio/stable/tutorials/audio_resampling_tutorial.html#kaiser-best
124
+ self.resampler[sample_rate] = torchaudio.transforms.Resample(
125
+ sample_rate,
126
+ self.sample_rate,
127
+ lowpass_filter_width=64,
128
+ rolloff=0.9475937167399596,
129
+ resampling_method='sinc_interp_kaiser',
130
+ beta=14.769656459379492,
131
+ )
132
+ audio_chunk = self.resampler[sample_rate](audio_chunk)
133
+
134
+ if audio_chunk.shape[0] < self.num_samples:
135
+ raise ValueError('Audio is too short')
136
+ audio_chunk = audio_chunk[:self.num_samples]
137
+
138
+ tokens = self.tokenizer([caption])[0]
139
+
140
+ output = {
141
+ 'waveform': audio_chunk,
142
+ 'id': audio_id,
143
+ 'caption': caption,
144
+ 'tokens': tokens,
145
+ }
146
+
147
+ return output
148
+ except Exception as e:
149
+ log.error(f'Error reading {audio_path}: {e}')
150
+ return None
151
+
152
+ def __len__(self):
153
+ return len(self.clips)
meanaudio/data/mm_dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+
3
+ import torch
4
+ from torch.utils.data.dataset import Dataset
5
+
6
+
7
+ # modified from https://pytorch.org/docs/stable/_modules/torch/utils/data/dataset.html#ConcatDataset
8
+ class MultiModalDataset(Dataset):
9
+ datasets: list[Dataset]
10
+ cumulative_sizes: list[int]
11
+
12
+ @staticmethod
13
+ def cumsum(sequence):
14
+ r, s = [], 0
15
+ for e in sequence:
16
+ l = len(e)
17
+ r.append(l + s)
18
+ s += l
19
+ return r
20
+
21
+ def __init__(self, video_datasets: list[Dataset], audio_datasets: list[Dataset]):
22
+ super().__init__()
23
+ self.video_datasets = list(video_datasets) if video_datasets else []
24
+ self.audio_datasets = list(audio_datasets) if audio_datasets else []
25
+ self.datasets = self.video_datasets + self.audio_datasets
26
+
27
+ self.cumulative_sizes = self.cumsum(self.datasets)
28
+
29
+ def __len__(self):
30
+ return self.cumulative_sizes[-1]
31
+
32
+ def __getitem__(self, idx):
33
+ if idx < 0:
34
+ if -idx > len(self):
35
+ raise ValueError("absolute value of index should not exceed dataset length")
36
+ idx = len(self) + idx
37
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) # which dataset idx falls into
38
+ if dataset_idx == 0:
39
+ sample_idx = idx
40
+ else:
41
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
42
+ return self.datasets[dataset_idx][sample_idx]
43
+
44
+ def compute_latent_stats(self) -> tuple[torch.Tensor, torch.Tensor]:
45
+ if self.video_datasets == []:
46
+ raise NotImplementedError(f'This function should not be called for audio-text dataset',
47
+ 'Please load latents stats manually instead')
48
+ return self.audio_datasets[0].compute_latent_stats() # audio-text training
49
+ else:
50
+ return self.video_datasets[0].compute_latent_stats() # video-text training
meanaudio/data/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ from tensordict import MemoryMappedTensor
11
+ from torch.utils.data import DataLoader
12
+ from torch.utils.data.dataset import Dataset
13
+ from tqdm import tqdm
14
+
15
+ from meanaudio.utils.dist_utils import local_rank, world_size
16
+
17
+ scratch_path = Path(os.environ['SLURM_SCRATCH'] if 'SLURM_SCRATCH' in os.environ else '/dev/shm')
18
+ shm_path = Path('/dev/shm')
19
+
20
+ log = logging.getLogger()
21
+
22
+
23
+ def reseed(seed):
24
+ random.seed(seed)
25
+ torch.manual_seed(seed)
26
+
27
+
28
+ def local_scatter_torch(obj: Optional[Any]):
29
+ if world_size == 1:
30
+ # Just one worker. Do nothing.
31
+ return obj
32
+
33
+ array = [obj] * world_size
34
+ target_array = [None]
35
+ if local_rank == 0:
36
+ dist.scatter_object_list(target_array, scatter_object_input_list=array, src=0)
37
+ else:
38
+ dist.scatter_object_list(target_array, scatter_object_input_list=None, src=0)
39
+ return target_array[0]
40
+
41
+
42
+ class ShardDataset(Dataset):
43
+
44
+ def __init__(self, root):
45
+ self.root = root
46
+ self.shards = sorted(os.listdir(root))
47
+
48
+ def __len__(self):
49
+ return len(self.shards)
50
+
51
+ def __getitem__(self, idx):
52
+ return torch.load(os.path.join(self.root, self.shards[idx]), weights_only=True)
53
+
54
+
55
+ def get_tmp_dir(in_memory: bool) -> Path:
56
+ return shm_path if in_memory else scratch_path
57
+
58
+
59
+ def load_shards_and_share(data_path: Union[str, Path], ids: list[int],
60
+ in_memory: bool) -> MemoryMappedTensor:
61
+ if local_rank == 0:
62
+ with tempfile.NamedTemporaryFile(prefix='shared-tensor-', dir=get_tmp_dir(in_memory)) as f:
63
+ log.info(f'Loading shards from {data_path} into {f.name}...')
64
+ data = load_shards(data_path, ids=ids, tmp_file_path=f.name)
65
+ data = share_tensor_to_all(data)
66
+ torch.distributed.barrier()
67
+ f.close() # why does the context manager not close the file for me?
68
+ else:
69
+ log.info('Waiting for the data to be shared with me...')
70
+ data = share_tensor_to_all(None)
71
+ torch.distributed.barrier()
72
+
73
+ return data
74
+
75
+
76
+ def load_shards(
77
+ data_path: Union[str, Path],
78
+ ids: list[int],
79
+ *,
80
+ tmp_file_path: str,
81
+ ) -> Union[torch.Tensor, dict[str, torch.Tensor]]:
82
+
83
+ id_set = set(ids)
84
+ shards = sorted(os.listdir(data_path))
85
+ log.info(f'Found {len(shards)} shards in {data_path}.')
86
+ first_shard = torch.load(os.path.join(data_path, shards[0]), weights_only=True)
87
+
88
+ log.info(f'Rank {local_rank} created file {tmp_file_path}')
89
+ first_item = next(iter(first_shard.values()))
90
+ log.info(f'First item shape: {first_item.shape}')
91
+ mm_tensor = MemoryMappedTensor.empty(shape=(len(ids), *first_item.shape),
92
+ dtype=torch.float32,
93
+ filename=tmp_file_path,
94
+ existsok=True)
95
+ total_count = 0
96
+ used_index = set()
97
+ id_indexing = {i: idx for idx, i in enumerate(ids)}
98
+ # faster with no workers; otherwise we need to set_sharing_strategy('file_system')
99
+ loader = DataLoader(ShardDataset(data_path), batch_size=1, num_workers=0)
100
+ for data in tqdm(loader, desc='Loading shards'):
101
+ for i, v in data.items():
102
+ if i not in id_set:
103
+ continue
104
+
105
+ # tensor_index = ids.index(i)
106
+ tensor_index = id_indexing[i]
107
+ if tensor_index in used_index:
108
+ raise ValueError(f'Duplicate id {i} found in {data_path}.')
109
+ used_index.add(tensor_index)
110
+ mm_tensor[tensor_index] = v
111
+ total_count += 1
112
+
113
+ assert total_count == len(ids), f'Expected {len(ids)} tensors, got {total_count}.'
114
+ log.info(f'Loaded {total_count} tensors from {data_path}.')
115
+
116
+ return mm_tensor
117
+
118
+
119
+ def share_tensor_to_all(x: Optional[MemoryMappedTensor]) -> MemoryMappedTensor:
120
+ """
121
+ x: the tensor to be shared; None if local_rank != 0
122
+ return: the shared tensor
123
+ """
124
+
125
+ # there is no need to share your stuff with anyone if you are alone; must be in memory
126
+ if world_size == 1:
127
+ return x
128
+
129
+ if local_rank == 0:
130
+ assert x is not None, 'x must not be None if local_rank == 0'
131
+ else:
132
+ assert x is None, 'x must be None if local_rank != 0'
133
+
134
+ if local_rank == 0:
135
+ filename = x.filename
136
+ meta_information = (filename, x.shape, x.dtype)
137
+ else:
138
+ meta_information = None
139
+
140
+ filename, data_shape, data_type = local_scatter_torch(meta_information)
141
+ if local_rank == 0:
142
+ data = x
143
+ else:
144
+ data = MemoryMappedTensor.from_filename(filename=filename,
145
+ dtype=data_type,
146
+ shape=data_shape)
147
+
148
+ return data
meanaudio/eval_utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import logging
3
+ from pathlib import Path
4
+ from typing import Optional
5
+
6
+ import numpy as np
7
+ import torch
8
+ from colorlog import ColoredFormatter
9
+ from PIL import Image
10
+ from torchvision.transforms import v2
11
+
12
+ from meanaudio.data.av_utils import ImageInfo, VideoInfo, read_frames, reencode_with_audio
13
+ from meanaudio.model.flow_matching import FlowMatching
14
+ from meanaudio.model.mean_flow import MeanFlow
15
+ from meanaudio.model.networks import MeanAudio, FluxAudio
16
+ from meanaudio.model.sequence_config import CONFIG_16K, CONFIG_44K, SequenceConfig
17
+ from meanaudio.model.utils.features_utils import FeaturesUtils
18
+ from meanaudio.utils.download_utils import download_model_if_needed
19
+
20
+ log = logging.getLogger()
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class ModelConfig:
25
+ model_name: str
26
+ model_path: Path
27
+ vae_path: Path
28
+ bigvgan_16k_path: Optional[Path]
29
+ mode: str
30
+
31
+ @property
32
+ def seq_cfg(self) -> SequenceConfig:
33
+ if self.mode == '16k':
34
+ return CONFIG_16K # get sequence config when calling cfg.seq_cfgs
35
+ elif self.mode == '44k':
36
+ return CONFIG_44K
37
+
38
+ def download_if_needed(self):
39
+ raise NotImplementedError("Downloading models is not supported")
40
+ download_model_if_needed(self.model_path)
41
+ download_model_if_needed(self.vae_path)
42
+ if self.bigvgan_16k_path is not None:
43
+ download_model_if_needed(self.bigvgan_16k_path)
44
+
45
+
46
+ fluxaudio_fm = ModelConfig(model_name='fluxaudio_fm',
47
+ model_path=Path('./weights/fluxaudio_fm.pth'),
48
+ vae_path=Path('./weights/v1-16.pth'),
49
+ bigvgan_16k_path=Path('./weights/best_netG.pt'),
50
+ mode='16k')
51
+ meanaudio_mf = ModelConfig(model_name='meanaudio_mf',
52
+ model_path=Path('./weights/meanaudio_mf.pth'),
53
+ vae_path=Path('./weights/v1-16.pth'),
54
+ bigvgan_16k_path=Path('./weights/best_netG.pt'),
55
+ mode='16k')
56
+
57
+ all_model_cfg: dict[str, ModelConfig] = {
58
+ 'fluxaudio_fm': fluxaudio_fm,
59
+ 'meanaudio_mf': meanaudio_mf,
60
+ }
61
+
62
+
63
+ def generate_fm(
64
+ text: Optional[list[str]],
65
+ *,
66
+ negative_text: Optional[list[str]] = None,
67
+ feature_utils: FeaturesUtils,
68
+ net: FluxAudio,
69
+ fm: FlowMatching,
70
+ rng: torch.Generator,
71
+ cfg_strength: float,
72
+ ) -> torch.Tensor:
73
+ # generate audio with vanilla flow matching
74
+
75
+ device = feature_utils.device
76
+ dtype = feature_utils.dtype
77
+
78
+ bs = len(text)
79
+
80
+ if text is not None:
81
+ text_features, text_features_c = feature_utils.encode_text(text)
82
+ else:
83
+ text_features, text_features_c = net.get_empty_string_sequence(bs)
84
+
85
+ if negative_text is not None:
86
+ assert len(negative_text) == bs
87
+ negative_text_features = feature_utils.encode_text(negative_text)
88
+ else:
89
+ negative_text_features = net.get_empty_string_sequence(bs)
90
+
91
+ x0 = torch.randn(bs,
92
+ net.latent_seq_len,
93
+ net.latent_dim,
94
+ device=device,
95
+ dtype=dtype,
96
+ generator=rng)
97
+ preprocessed_conditions = net.preprocess_conditions(text_features, text_features_c)
98
+ empty_conditions = net.get_empty_conditions(
99
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
100
+
101
+ cfg_ode_wrapper = lambda t, x: net.ode_wrapper(t, x, preprocessed_conditions, empty_conditions,
102
+ cfg_strength)
103
+ x1 = fm.to_data(cfg_ode_wrapper, x0)
104
+ x1 = net.unnormalize(x1)
105
+ spec = feature_utils.decode(x1)
106
+ audio = feature_utils.vocode(spec)
107
+ return audio
108
+
109
+
110
+ def generate_mf(
111
+ text: Optional[list[str]],
112
+ *,
113
+ negative_text: Optional[list[str]] = None,
114
+ feature_utils: FeaturesUtils,
115
+ net: MeanAudio,
116
+ mf: MeanFlow,
117
+ rng: torch.Generator,
118
+ cfg_strength: float,
119
+ ) -> torch.Tensor:
120
+ # generate audio with mean flow
121
+ device = feature_utils.device
122
+ dtype = feature_utils.dtype
123
+
124
+ bs = len(text)
125
+
126
+ if text is not None:
127
+ text_features, text_features_c = feature_utils.encode_text(text)
128
+ else:
129
+ text_features, text_features_c = net.get_empty_string_sequence(bs)
130
+
131
+ if negative_text is not None:
132
+ assert len(negative_text) == bs
133
+ negative_text_features = feature_utils.encode_text(negative_text)
134
+ else:
135
+ negative_text_features = net.get_empty_string_sequence(bs)
136
+
137
+ x0 = torch.randn(bs,
138
+ net.latent_seq_len,
139
+ net.latent_dim,
140
+ device=device,
141
+ dtype=dtype,
142
+ generator=rng)
143
+ preprocessed_conditions = net.preprocess_conditions(text_features, text_features_c)
144
+ empty_conditions = net.get_empty_conditions(
145
+ bs, negative_text_features=negative_text_features if negative_text is not None else None)
146
+
147
+ cfg_ode_wrapper = lambda t, r, x: net.ode_wrapper(t, r, x, preprocessed_conditions, empty_conditions,
148
+ cfg_strength)
149
+ x1 = mf.to_data(cfg_ode_wrapper, x0)
150
+ x1 = net.unnormalize(x1)
151
+ spec = feature_utils.decode(x1)
152
+ audio = feature_utils.vocode(spec)
153
+ return audio
154
+
155
+
156
+ LOGFORMAT = "[%(log_color)s%(levelname)-8s%(reset)s]: %(log_color)s%(message)s%(reset)s"
157
+
158
+
159
+ def setup_eval_logging(log_level: int = logging.INFO):
160
+ logging.root.setLevel(log_level) # set up root logger <=> logging.getLogger().setLevel(log_level)
161
+ formatter = ColoredFormatter(LOGFORMAT)
162
+ stream = logging.StreamHandler() # to Console
163
+ stream.setLevel(log_level)
164
+ stream.setFormatter(formatter)
165
+ log = logging.getLogger()
166
+ log.setLevel(log_level)
167
+ log.addHandler(stream)
meanaudio/ext/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
meanaudio/ext/autoencoder/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .autoencoder import AutoEncoderModule
meanaudio/ext/autoencoder/autoencoder.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal, Optional
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from meanaudio.ext.autoencoder.vae import VAE, get_my_vae
7
+ from meanaudio.ext.bigvgan import BigVGAN
8
+ from meanaudio.ext.bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
9
+ from meanaudio.model.utils.distributions import DiagonalGaussianDistribution
10
+
11
+
12
+ class AutoEncoderModule(nn.Module):
13
+
14
+ def __init__(self,
15
+ *,
16
+ vae_ckpt_path,
17
+ vocoder_ckpt_path: Optional[str] = None,
18
+ mode: Literal['16k', '44k'],
19
+ need_vae_encoder: bool = True):
20
+ super().__init__()
21
+ self.vae: VAE = get_my_vae(mode).eval()
22
+ vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
23
+ self.vae.load_state_dict(vae_state_dict)
24
+ self.vae.remove_weight_norm()
25
+
26
+ if mode == '16k':
27
+ assert vocoder_ckpt_path is not None
28
+ self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
29
+ elif mode == '44k':
30
+ self.vocoder = BigVGANv2.from_pretrained('nvidia/bigvgan_v2_44khz_128band_512x',
31
+ use_cuda_kernel=False)
32
+ self.vocoder.remove_weight_norm()
33
+ else:
34
+ raise ValueError(f'Unknown mode: {mode}')
35
+
36
+ for param in self.parameters():
37
+ param.requires_grad = False
38
+
39
+ if not need_vae_encoder:
40
+ del self.vae.encoder
41
+
42
+ @torch.inference_mode()
43
+ def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
44
+ return self.vae.encode(x)
45
+
46
+ @torch.inference_mode()
47
+ def decode(self, z: torch.Tensor) -> torch.Tensor:
48
+ return self.vae.decode(z)
49
+
50
+ @torch.inference_mode()
51
+ def vocode(self, spec: torch.Tensor) -> torch.Tensor:
52
+ return self.vocoder(spec)
meanaudio/ext/autoencoder/edm2_utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # This work is licensed under a Creative Commons
4
+ # Attribution-NonCommercial-ShareAlike 4.0 International License.
5
+ # You should have received a copy of the license along with this
6
+ # work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/
7
+ """Improved diffusion model architecture proposed in the paper
8
+ "Analyzing and Improving the Training Dynamics of Diffusion Models"."""
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ #----------------------------------------------------------------------------
14
+ # Variant of constant() that inherits dtype and device from the given
15
+ # reference tensor by default.
16
+
17
+ _constant_cache = dict()
18
+
19
+
20
+ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
21
+ value = np.asarray(value)
22
+ if shape is not None:
23
+ shape = tuple(shape)
24
+ if dtype is None:
25
+ dtype = torch.get_default_dtype()
26
+ if device is None:
27
+ device = torch.device('cpu')
28
+ if memory_format is None:
29
+ memory_format = torch.contiguous_format
30
+
31
+ key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format)
32
+ tensor = _constant_cache.get(key, None)
33
+ if tensor is None:
34
+ tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device)
35
+ if shape is not None:
36
+ tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
37
+ tensor = tensor.contiguous(memory_format=memory_format)
38
+ _constant_cache[key] = tensor
39
+ return tensor
40
+
41
+
42
+ def const_like(ref, value, shape=None, dtype=None, device=None, memory_format=None):
43
+ if dtype is None:
44
+ dtype = ref.dtype
45
+ if device is None:
46
+ device = ref.device
47
+ return constant(value, shape=shape, dtype=dtype, device=device, memory_format=memory_format)
48
+
49
+
50
+ #----------------------------------------------------------------------------
51
+ # Normalize given tensor to unit magnitude with respect to the given
52
+ # dimensions. Default = all dimensions except the first.
53
+
54
+
55
+ def normalize(x, dim=None, eps=1e-4):
56
+ if dim is None:
57
+ dim = list(range(1, x.ndim))
58
+ norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
59
+ norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
60
+ return x / norm.to(x.dtype)
61
+
62
+
63
+ class Normalize(torch.nn.Module):
64
+
65
+ def __init__(self, dim=None, eps=1e-4):
66
+ super().__init__()
67
+ self.dim = dim
68
+ self.eps = eps
69
+
70
+ def forward(self, x):
71
+ return normalize(x, dim=self.dim, eps=self.eps)
72
+
73
+
74
+ #----------------------------------------------------------------------------
75
+ # Upsample or downsample the given tensor with the given filter,
76
+ # or keep it as is.
77
+
78
+
79
+ def resample(x, f=[1, 1], mode='keep'):
80
+ if mode == 'keep':
81
+ return x
82
+ f = np.float32(f)
83
+ assert f.ndim == 1 and len(f) % 2 == 0
84
+ pad = (len(f) - 1) // 2
85
+ f = f / f.sum()
86
+ f = np.outer(f, f)[np.newaxis, np.newaxis, :, :]
87
+ f = const_like(x, f)
88
+ c = x.shape[1]
89
+ if mode == 'down':
90
+ return torch.nn.functional.conv2d(x,
91
+ f.tile([c, 1, 1, 1]),
92
+ groups=c,
93
+ stride=2,
94
+ padding=(pad, ))
95
+ assert mode == 'up'
96
+ return torch.nn.functional.conv_transpose2d(x, (f * 4).tile([c, 1, 1, 1]),
97
+ groups=c,
98
+ stride=2,
99
+ padding=(pad, ))
100
+
101
+
102
+ #----------------------------------------------------------------------------
103
+ # Magnitude-preserving SiLU (Equation 81).
104
+
105
+
106
+ def mp_silu(x):
107
+ return torch.nn.functional.silu(x) / 0.596
108
+
109
+
110
+ class MPSiLU(torch.nn.Module):
111
+
112
+ def forward(self, x):
113
+ return mp_silu(x)
114
+
115
+
116
+ #----------------------------------------------------------------------------
117
+ # Magnitude-preserving sum (Equation 88).
118
+
119
+
120
+ def mp_sum(a, b, t=0.5):
121
+ return a.lerp(b, t) / np.sqrt((1 - t)**2 + t**2)
122
+
123
+
124
+ #----------------------------------------------------------------------------
125
+ # Magnitude-preserving concatenation (Equation 103).
126
+
127
+
128
+ def mp_cat(a, b, dim=1, t=0.5):
129
+ Na = a.shape[dim]
130
+ Nb = b.shape[dim]
131
+ C = np.sqrt((Na + Nb) / ((1 - t)**2 + t**2))
132
+ wa = C / np.sqrt(Na) * (1 - t)
133
+ wb = C / np.sqrt(Nb) * t
134
+ return torch.cat([wa * a, wb * b], dim=dim)
135
+
136
+
137
+ #----------------------------------------------------------------------------
138
+ # Magnitude-preserving convolution or fully-connected layer (Equation 47)
139
+ # with force weight normalization (Equation 66).
140
+
141
+
142
+ class MPConv1D(torch.nn.Module):
143
+
144
+ def __init__(self, in_channels, out_channels, kernel_size):
145
+ super().__init__()
146
+ self.out_channels = out_channels
147
+ self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, kernel_size))
148
+
149
+ self.weight_norm_removed = False
150
+
151
+ def forward(self, x, gain=1):
152
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
153
+
154
+ w = self.weight * gain
155
+ if w.ndim == 2:
156
+ return x @ w.t()
157
+ assert w.ndim == 3
158
+ return torch.nn.functional.conv1d(x, w, padding=(w.shape[-1] // 2, ))
159
+
160
+ def remove_weight_norm(self):
161
+ w = self.weight.to(torch.float32)
162
+ w = normalize(w) # traditional weight normalization
163
+ w = w / np.sqrt(w[0].numel())
164
+ w = w.to(self.weight.dtype)
165
+ self.weight.data.copy_(w)
166
+
167
+ self.weight_norm_removed = True
168
+ return self
meanaudio/ext/autoencoder/vae.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from meanaudio.ext.autoencoder.edm2_utils import MPConv1D
8
+ from meanaudio.ext.autoencoder.vae_modules import (AttnBlock1D, Downsample1D, ResnetBlock1D,
9
+ Upsample1D, nonlinearity)
10
+ from meanaudio.model.utils.distributions import DiagonalGaussianDistribution
11
+
12
+ log = logging.getLogger()
13
+
14
+ DATA_MEAN_80D = [
15
+ -1.6058, -1.3676, -1.2520, -1.2453, -1.2078, -1.2224, -1.2419, -1.2439, -1.2922, -1.2927,
16
+ -1.3170, -1.3543, -1.3401, -1.3836, -1.3907, -1.3912, -1.4313, -1.4152, -1.4527, -1.4728,
17
+ -1.4568, -1.5101, -1.5051, -1.5172, -1.5623, -1.5373, -1.5746, -1.5687, -1.6032, -1.6131,
18
+ -1.6081, -1.6331, -1.6489, -1.6489, -1.6700, -1.6738, -1.6953, -1.6969, -1.7048, -1.7280,
19
+ -1.7361, -1.7495, -1.7658, -1.7814, -1.7889, -1.8064, -1.8221, -1.8377, -1.8417, -1.8643,
20
+ -1.8857, -1.8929, -1.9173, -1.9379, -1.9531, -1.9673, -1.9824, -2.0042, -2.0215, -2.0436,
21
+ -2.0766, -2.1064, -2.1418, -2.1855, -2.2319, -2.2767, -2.3161, -2.3572, -2.3954, -2.4282,
22
+ -2.4659, -2.5072, -2.5552, -2.6074, -2.6584, -2.7107, -2.7634, -2.8266, -2.8981, -2.9673
23
+ ]
24
+
25
+ DATA_STD_80D = [
26
+ 1.0291, 1.0411, 1.0043, 0.9820, 0.9677, 0.9543, 0.9450, 0.9392, 0.9343, 0.9297, 0.9276, 0.9263,
27
+ 0.9242, 0.9254, 0.9232, 0.9281, 0.9263, 0.9315, 0.9274, 0.9247, 0.9277, 0.9199, 0.9188, 0.9194,
28
+ 0.9160, 0.9161, 0.9146, 0.9161, 0.9100, 0.9095, 0.9145, 0.9076, 0.9066, 0.9095, 0.9032, 0.9043,
29
+ 0.9038, 0.9011, 0.9019, 0.9010, 0.8984, 0.8983, 0.8986, 0.8961, 0.8962, 0.8978, 0.8962, 0.8973,
30
+ 0.8993, 0.8976, 0.8995, 0.9016, 0.8982, 0.8972, 0.8974, 0.8949, 0.8940, 0.8947, 0.8936, 0.8939,
31
+ 0.8951, 0.8956, 0.9017, 0.9167, 0.9436, 0.9690, 1.0003, 1.0225, 1.0381, 1.0491, 1.0545, 1.0604,
32
+ 1.0761, 1.0929, 1.1089, 1.1196, 1.1176, 1.1156, 1.1117, 1.1070
33
+ ]
34
+
35
+ DATA_MEAN_128D = [
36
+ -3.3462, -2.6723, -2.4893, -2.3143, -2.2664, -2.3317, -2.1802, -2.4006, -2.2357, -2.4597,
37
+ -2.3717, -2.4690, -2.5142, -2.4919, -2.6610, -2.5047, -2.7483, -2.5926, -2.7462, -2.7033,
38
+ -2.7386, -2.8112, -2.7502, -2.9594, -2.7473, -3.0035, -2.8891, -2.9922, -2.9856, -3.0157,
39
+ -3.1191, -2.9893, -3.1718, -3.0745, -3.1879, -3.2310, -3.1424, -3.2296, -3.2791, -3.2782,
40
+ -3.2756, -3.3134, -3.3509, -3.3750, -3.3951, -3.3698, -3.4505, -3.4509, -3.5089, -3.4647,
41
+ -3.5536, -3.5788, -3.5867, -3.6036, -3.6400, -3.6747, -3.7072, -3.7279, -3.7283, -3.7795,
42
+ -3.8259, -3.8447, -3.8663, -3.9182, -3.9605, -3.9861, -4.0105, -4.0373, -4.0762, -4.1121,
43
+ -4.1488, -4.1874, -4.2461, -4.3170, -4.3639, -4.4452, -4.5282, -4.6297, -4.7019, -4.7960,
44
+ -4.8700, -4.9507, -5.0303, -5.0866, -5.1634, -5.2342, -5.3242, -5.4053, -5.4927, -5.5712,
45
+ -5.6464, -5.7052, -5.7619, -5.8410, -5.9188, -6.0103, -6.0955, -6.1673, -6.2362, -6.3120,
46
+ -6.3926, -6.4797, -6.5565, -6.6511, -6.8130, -6.9961, -7.1275, -7.2457, -7.3576, -7.4663,
47
+ -7.6136, -7.7469, -7.8815, -8.0132, -8.1515, -8.3071, -8.4722, -8.7418, -9.3975, -9.6628,
48
+ -9.7671, -9.8863, -9.9992, -10.0860, -10.1709, -10.5418, -11.2795, -11.3861
49
+ ]
50
+
51
+ DATA_STD_128D = [
52
+ 2.3804, 2.4368, 2.3772, 2.3145, 2.2803, 2.2510, 2.2316, 2.2083, 2.1996, 2.1835, 2.1769, 2.1659,
53
+ 2.1631, 2.1618, 2.1540, 2.1606, 2.1571, 2.1567, 2.1612, 2.1579, 2.1679, 2.1683, 2.1634, 2.1557,
54
+ 2.1668, 2.1518, 2.1415, 2.1449, 2.1406, 2.1350, 2.1313, 2.1415, 2.1281, 2.1352, 2.1219, 2.1182,
55
+ 2.1327, 2.1195, 2.1137, 2.1080, 2.1179, 2.1036, 2.1087, 2.1036, 2.1015, 2.1068, 2.0975, 2.0991,
56
+ 2.0902, 2.1015, 2.0857, 2.0920, 2.0893, 2.0897, 2.0910, 2.0881, 2.0925, 2.0873, 2.0960, 2.0900,
57
+ 2.0957, 2.0958, 2.0978, 2.0936, 2.0886, 2.0905, 2.0845, 2.0855, 2.0796, 2.0840, 2.0813, 2.0817,
58
+ 2.0838, 2.0840, 2.0917, 2.1061, 2.1431, 2.1976, 2.2482, 2.3055, 2.3700, 2.4088, 2.4372, 2.4609,
59
+ 2.4731, 2.4847, 2.5072, 2.5451, 2.5772, 2.6147, 2.6529, 2.6596, 2.6645, 2.6726, 2.6803, 2.6812,
60
+ 2.6899, 2.6916, 2.6931, 2.6998, 2.7062, 2.7262, 2.7222, 2.7158, 2.7041, 2.7485, 2.7491, 2.7451,
61
+ 2.7485, 2.7233, 2.7297, 2.7233, 2.7145, 2.6958, 2.6788, 2.6439, 2.6007, 2.4786, 2.2469, 2.1877,
62
+ 2.1392, 2.0717, 2.0107, 1.9676, 1.9140, 1.7102, 0.9101, 0.7164
63
+ ]
64
+
65
+
66
+ class VAE(nn.Module):
67
+
68
+ def __init__(
69
+ self,
70
+ *,
71
+ data_dim: int,
72
+ embed_dim: int,
73
+ hidden_dim: int,
74
+ ):
75
+ super().__init__()
76
+
77
+ if data_dim == 80:
78
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_80D, dtype=torch.float32))
79
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_80D, dtype=torch.float32))
80
+ elif data_dim == 128:
81
+ self.data_mean = nn.Buffer(torch.tensor(DATA_MEAN_128D, dtype=torch.float32))
82
+ self.data_std = nn.Buffer(torch.tensor(DATA_STD_128D, dtype=torch.float32))
83
+
84
+ self.data_mean = self.data_mean.view(1, -1, 1)
85
+ self.data_std = self.data_std.view(1, -1, 1)
86
+
87
+ self.encoder = Encoder1D(
88
+ dim=hidden_dim,
89
+ ch_mult=(1, 2, 4),
90
+ num_res_blocks=2,
91
+ attn_layers=[3],
92
+ down_layers=[0],
93
+ in_dim=data_dim,
94
+ embed_dim=embed_dim,
95
+ )
96
+ self.decoder = Decoder1D(
97
+ dim=hidden_dim,
98
+ ch_mult=(1, 2, 4),
99
+ num_res_blocks=2,
100
+ attn_layers=[3],
101
+ down_layers=[0],
102
+ in_dim=data_dim,
103
+ out_dim=data_dim,
104
+ embed_dim=embed_dim,
105
+ )
106
+
107
+ self.embed_dim = embed_dim
108
+ # self.quant_conv = nn.Conv1d(2 * embed_dim, 2 * embed_dim, 1)
109
+ # self.post_quant_conv = nn.Conv1d(embed_dim, embed_dim, 1)
110
+
111
+ self.initialize_weights()
112
+
113
+ def initialize_weights(self):
114
+ pass
115
+
116
+ def encode(self, x: torch.Tensor, normalize: bool = True) -> DiagonalGaussianDistribution:
117
+ if normalize:
118
+ x = self.normalize(x)
119
+ moments = self.encoder(x)
120
+ posterior = DiagonalGaussianDistribution(moments)
121
+ return posterior
122
+
123
+ def decode(self, z: torch.Tensor, unnormalize: bool = True) -> torch.Tensor:
124
+ dec = self.decoder(z)
125
+ if unnormalize:
126
+ dec = self.unnormalize(dec)
127
+ return dec
128
+
129
+ def normalize(self, x: torch.Tensor) -> torch.Tensor:
130
+ return (x - self.data_mean) / self.data_std
131
+
132
+ def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
133
+ return x * self.data_std + self.data_mean
134
+
135
+ def forward(
136
+ self,
137
+ x: torch.Tensor,
138
+ sample_posterior: bool = True,
139
+ rng: Optional[torch.Generator] = None,
140
+ normalize: bool = True,
141
+ unnormalize: bool = True,
142
+ ) -> tuple[torch.Tensor, DiagonalGaussianDistribution]:
143
+
144
+ posterior = self.encode(x, normalize=normalize)
145
+ if sample_posterior:
146
+ z = posterior.sample(rng)
147
+ else:
148
+ z = posterior.mode()
149
+ dec = self.decode(z, unnormalize=unnormalize)
150
+ return dec, posterior
151
+
152
+ def load_weights(self, src_dict) -> None:
153
+ self.load_state_dict(src_dict, strict=True)
154
+
155
+ @property
156
+ def device(self) -> torch.device:
157
+ return next(self.parameters()).device
158
+
159
+ def get_last_layer(self):
160
+ return self.decoder.conv_out.weight
161
+
162
+ def remove_weight_norm(self):
163
+ for name, m in self.named_modules():
164
+ if isinstance(m, MPConv1D):
165
+ m.remove_weight_norm()
166
+ log.debug(f"Removed weight norm from {name}")
167
+ return self
168
+
169
+
170
+ class Encoder1D(nn.Module):
171
+
172
+ def __init__(self,
173
+ *,
174
+ dim: int,
175
+ ch_mult: tuple[int] = (1, 2, 4, 8),
176
+ num_res_blocks: int,
177
+ attn_layers: list[int] = [],
178
+ down_layers: list[int] = [],
179
+ resamp_with_conv: bool = True,
180
+ in_dim: int,
181
+ embed_dim: int,
182
+ double_z: bool = True,
183
+ kernel_size: int = 3,
184
+ clip_act: float = 256.0):
185
+ super().__init__()
186
+ self.dim = dim
187
+ self.num_layers = len(ch_mult)
188
+ self.num_res_blocks = num_res_blocks
189
+ self.in_channels = in_dim
190
+ self.clip_act = clip_act
191
+ self.down_layers = down_layers
192
+ self.attn_layers = attn_layers
193
+ self.conv_in = MPConv1D(in_dim, self.dim, kernel_size=kernel_size)
194
+
195
+ in_ch_mult = (1, ) + tuple(ch_mult)
196
+ self.in_ch_mult = in_ch_mult
197
+ # downsampling
198
+ self.down = nn.ModuleList()
199
+ for i_level in range(self.num_layers):
200
+ block = nn.ModuleList()
201
+ attn = nn.ModuleList()
202
+ block_in = dim * in_ch_mult[i_level]
203
+ block_out = dim * ch_mult[i_level]
204
+ for i_block in range(self.num_res_blocks):
205
+ block.append(
206
+ ResnetBlock1D(in_dim=block_in,
207
+ out_dim=block_out,
208
+ kernel_size=kernel_size,
209
+ use_norm=True))
210
+ block_in = block_out
211
+ if i_level in attn_layers:
212
+ attn.append(AttnBlock1D(block_in))
213
+ down = nn.Module()
214
+ down.block = block
215
+ down.attn = attn
216
+ if i_level in down_layers:
217
+ down.downsample = Downsample1D(block_in, resamp_with_conv)
218
+ self.down.append(down)
219
+
220
+ # middle
221
+ self.mid = nn.Module()
222
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in,
223
+ out_dim=block_in,
224
+ kernel_size=kernel_size,
225
+ use_norm=True)
226
+ self.mid.attn_1 = AttnBlock1D(block_in)
227
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in,
228
+ out_dim=block_in,
229
+ kernel_size=kernel_size,
230
+ use_norm=True)
231
+
232
+ # end
233
+ self.conv_out = MPConv1D(block_in,
234
+ 2 * embed_dim if double_z else embed_dim,
235
+ kernel_size=kernel_size)
236
+
237
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
238
+
239
+ def forward(self, x):
240
+
241
+ # downsampling
242
+ hs = [self.conv_in(x)]
243
+ for i_level in range(self.num_layers):
244
+ for i_block in range(self.num_res_blocks):
245
+ h = self.down[i_level].block[i_block](hs[-1])
246
+ if len(self.down[i_level].attn) > 0:
247
+ h = self.down[i_level].attn[i_block](h)
248
+ h = h.clamp(-self.clip_act, self.clip_act)
249
+ hs.append(h)
250
+ if i_level in self.down_layers:
251
+ hs.append(self.down[i_level].downsample(hs[-1]))
252
+
253
+ # middle
254
+ h = hs[-1]
255
+ h = self.mid.block_1(h)
256
+ h = self.mid.attn_1(h)
257
+ h = self.mid.block_2(h)
258
+ h = h.clamp(-self.clip_act, self.clip_act)
259
+
260
+ # end
261
+ h = nonlinearity(h)
262
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
263
+ return h
264
+
265
+
266
+ class Decoder1D(nn.Module):
267
+
268
+ def __init__(self,
269
+ *,
270
+ dim: int,
271
+ out_dim: int,
272
+ ch_mult: tuple[int] = (1, 2, 4, 8),
273
+ num_res_blocks: int,
274
+ attn_layers: list[int] = [],
275
+ down_layers: list[int] = [],
276
+ kernel_size: int = 3,
277
+ resamp_with_conv: bool = True,
278
+ in_dim: int,
279
+ embed_dim: int,
280
+ clip_act: float = 256.0):
281
+ super().__init__()
282
+ self.ch = dim
283
+ self.num_layers = len(ch_mult)
284
+ self.num_res_blocks = num_res_blocks
285
+ self.in_channels = in_dim
286
+ self.clip_act = clip_act
287
+ self.down_layers = [i + 1 for i in down_layers] # each downlayer add one
288
+
289
+ # compute in_ch_mult, block_in and curr_res at lowest res
290
+ block_in = dim * ch_mult[self.num_layers - 1]
291
+
292
+ # z to block_in
293
+ self.conv_in = MPConv1D(embed_dim, block_in, kernel_size=kernel_size)
294
+
295
+ # middle
296
+ self.mid = nn.Module()
297
+ self.mid.block_1 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
298
+ self.mid.attn_1 = AttnBlock1D(block_in)
299
+ self.mid.block_2 = ResnetBlock1D(in_dim=block_in, out_dim=block_in, use_norm=True)
300
+
301
+ # upsampling
302
+ self.up = nn.ModuleList()
303
+ for i_level in reversed(range(self.num_layers)):
304
+ block = nn.ModuleList()
305
+ attn = nn.ModuleList()
306
+ block_out = dim * ch_mult[i_level]
307
+ for i_block in range(self.num_res_blocks + 1):
308
+ block.append(ResnetBlock1D(in_dim=block_in, out_dim=block_out, use_norm=True))
309
+ block_in = block_out
310
+ if i_level in attn_layers:
311
+ attn.append(AttnBlock1D(block_in))
312
+ up = nn.Module()
313
+ up.block = block
314
+ up.attn = attn
315
+ if i_level in self.down_layers:
316
+ up.upsample = Upsample1D(block_in, resamp_with_conv)
317
+ self.up.insert(0, up) # prepend to get consistent order
318
+
319
+ # end
320
+ self.conv_out = MPConv1D(block_in, out_dim, kernel_size=kernel_size)
321
+ self.learnable_gain = nn.Parameter(torch.zeros([]))
322
+
323
+ def forward(self, z):
324
+ # z to block_in
325
+ h = self.conv_in(z)
326
+
327
+ # middle
328
+ h = self.mid.block_1(h)
329
+ h = self.mid.attn_1(h)
330
+ h = self.mid.block_2(h)
331
+ h = h.clamp(-self.clip_act, self.clip_act)
332
+
333
+ # upsampling
334
+ for i_level in reversed(range(self.num_layers)):
335
+ for i_block in range(self.num_res_blocks + 1):
336
+ h = self.up[i_level].block[i_block](h)
337
+ if len(self.up[i_level].attn) > 0:
338
+ h = self.up[i_level].attn[i_block](h)
339
+ h = h.clamp(-self.clip_act, self.clip_act)
340
+ if i_level in self.down_layers:
341
+ h = self.up[i_level].upsample(h)
342
+
343
+ h = nonlinearity(h)
344
+ h = self.conv_out(h, gain=(self.learnable_gain + 1))
345
+ return h
346
+
347
+
348
+ def VAE_16k(**kwargs) -> VAE:
349
+ return VAE(data_dim=80, embed_dim=20, hidden_dim=384, **kwargs)
350
+
351
+
352
+ def VAE_44k(**kwargs) -> VAE:
353
+ return VAE(data_dim=128, embed_dim=40, hidden_dim=512, **kwargs)
354
+
355
+
356
+ def get_my_vae(name: str, **kwargs) -> VAE:
357
+ if name == '16k':
358
+ return VAE_16k(**kwargs)
359
+ if name == '44k':
360
+ return VAE_44k(**kwargs)
361
+ raise ValueError(f'Unknown model: {name}')
362
+
363
+
364
+ if __name__ == '__main__':
365
+ network = get_my_vae('standard')
366
+
367
+ # print the number of parameters in terms of millions
368
+ num_params = sum(p.numel() for p in network.parameters()) / 1e6
369
+ print(f'Number of parameters: {num_params:.2f}M')
meanaudio/ext/autoencoder/vae_modules.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import rearrange
5
+
6
+ from meanaudio.ext.autoencoder.edm2_utils import (MPConv1D, mp_silu, mp_sum, normalize)
7
+
8
+
9
+ def nonlinearity(x):
10
+ # swish
11
+ return mp_silu(x)
12
+
13
+
14
+ class ResnetBlock1D(nn.Module):
15
+
16
+ def __init__(self, *, in_dim, out_dim=None, conv_shortcut=False, kernel_size=3, use_norm=True):
17
+ super().__init__()
18
+ self.in_dim = in_dim
19
+ out_dim = in_dim if out_dim is None else out_dim
20
+ self.out_dim = out_dim
21
+ self.use_conv_shortcut = conv_shortcut
22
+ self.use_norm = use_norm
23
+
24
+ self.conv1 = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
25
+ self.conv2 = MPConv1D(out_dim, out_dim, kernel_size=kernel_size)
26
+ if self.in_dim != self.out_dim:
27
+ if self.use_conv_shortcut:
28
+ self.conv_shortcut = MPConv1D(in_dim, out_dim, kernel_size=kernel_size)
29
+ else:
30
+ self.nin_shortcut = MPConv1D(in_dim, out_dim, kernel_size=1)
31
+
32
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
33
+
34
+ # pixel norm
35
+ if self.use_norm:
36
+ x = normalize(x, dim=1)
37
+
38
+ h = x
39
+ h = nonlinearity(h)
40
+ h = self.conv1(h)
41
+
42
+ h = nonlinearity(h)
43
+ h = self.conv2(h)
44
+
45
+ if self.in_dim != self.out_dim:
46
+ if self.use_conv_shortcut:
47
+ x = self.conv_shortcut(x)
48
+ else:
49
+ x = self.nin_shortcut(x)
50
+
51
+ return mp_sum(x, h, t=0.3)
52
+
53
+
54
+ class AttnBlock1D(nn.Module):
55
+
56
+ def __init__(self, in_channels, num_heads=1):
57
+ super().__init__()
58
+ self.in_channels = in_channels
59
+
60
+ self.num_heads = num_heads
61
+ self.qkv = MPConv1D(in_channels, in_channels * 3, kernel_size=1)
62
+ self.proj_out = MPConv1D(in_channels, in_channels, kernel_size=1)
63
+
64
+ def forward(self, x):
65
+ h = x
66
+ y = self.qkv(h)
67
+ y = y.reshape(y.shape[0], self.num_heads, -1, 3, y.shape[-1])
68
+ q, k, v = normalize(y, dim=2).unbind(3)
69
+
70
+ q = rearrange(q, 'b h c l -> b h l c')
71
+ k = rearrange(k, 'b h c l -> b h l c')
72
+ v = rearrange(v, 'b h c l -> b h l c')
73
+
74
+ h = F.scaled_dot_product_attention(q, k, v)
75
+ h = rearrange(h, 'b h l c -> b (h c) l')
76
+
77
+ h = self.proj_out(h)
78
+
79
+ return mp_sum(x, h, t=0.3)
80
+
81
+
82
+ class Upsample1D(nn.Module):
83
+
84
+ def __init__(self, in_channels, with_conv):
85
+ super().__init__()
86
+ self.with_conv = with_conv
87
+ if self.with_conv:
88
+ self.conv = MPConv1D(in_channels, in_channels, kernel_size=3)
89
+
90
+ def forward(self, x):
91
+ x = F.interpolate(x, scale_factor=2.0, mode='nearest-exact') # support 3D tensor(B,C,T)
92
+ if self.with_conv:
93
+ x = self.conv(x)
94
+ return x
95
+
96
+
97
+ class Downsample1D(nn.Module):
98
+
99
+ def __init__(self, in_channels, with_conv):
100
+ super().__init__()
101
+ self.with_conv = with_conv
102
+ if self.with_conv:
103
+ # no asymmetric padding in torch conv, must do it ourselves
104
+ self.conv1 = MPConv1D(in_channels, in_channels, kernel_size=1)
105
+ self.conv2 = MPConv1D(in_channels, in_channels, kernel_size=1)
106
+
107
+ def forward(self, x):
108
+
109
+ if self.with_conv:
110
+ x = self.conv1(x)
111
+
112
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
113
+
114
+ if self.with_conv:
115
+ x = self.conv2(x)
116
+
117
+ return x
meanaudio/ext/bigvgan/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 NVIDIA CORPORATION.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
meanaudio/ext/bigvgan/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .bigvgan import BigVGAN
meanaudio/ext/bigvgan/activations.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn, sin, pow
6
+ from torch.nn import Parameter
7
+
8
+
9
+ class Snake(nn.Module):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super(Snake, self).__init__()
35
+ self.in_features = in_features
36
+
37
+ # initialize alpha
38
+ self.alpha_logscale = alpha_logscale
39
+ if self.alpha_logscale: # log scale alphas initialized to zeros
40
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
41
+ else: # linear scale alphas initialized to ones
42
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
43
+
44
+ self.alpha.requires_grad = alpha_trainable
45
+
46
+ self.no_div_by_zero = 0.000000001
47
+
48
+ def forward(self, x):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
55
+ if self.alpha_logscale:
56
+ alpha = torch.exp(alpha)
57
+ x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta(nn.Module):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super(SnakeBeta, self).__init__()
91
+ self.in_features = in_features
92
+
93
+ # initialize alpha
94
+ self.alpha_logscale = alpha_logscale
95
+ if self.alpha_logscale: # log scale alphas initialized to zeros
96
+ self.alpha = Parameter(torch.zeros(in_features) * alpha)
97
+ self.beta = Parameter(torch.zeros(in_features) * alpha)
98
+ else: # linear scale alphas initialized to ones
99
+ self.alpha = Parameter(torch.ones(in_features) * alpha)
100
+ self.beta = Parameter(torch.ones(in_features) * alpha)
101
+
102
+ self.alpha.requires_grad = alpha_trainable
103
+ self.beta.requires_grad = alpha_trainable
104
+
105
+ self.no_div_by_zero = 0.000000001
106
+
107
+ def forward(self, x):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
114
+ beta = self.beta.unsqueeze(0).unsqueeze(-1)
115
+ if self.alpha_logscale:
116
+ alpha = torch.exp(alpha)
117
+ beta = torch.exp(beta)
118
+ x = x + (1.0 / (beta + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
119
+
120
+ return x
meanaudio/ext/bigvgan/alias_free_torch/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ from .filter import *
5
+ from .resample import *
6
+ from .act import *
meanaudio/ext/bigvgan/alias_free_torch/act.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from .resample import UpSample1d, DownSample1d
6
+
7
+
8
+ class Activation1d(nn.Module):
9
+ def __init__(self,
10
+ activation,
11
+ up_ratio: int = 2,
12
+ down_ratio: int = 2,
13
+ up_kernel_size: int = 12,
14
+ down_kernel_size: int = 12):
15
+ super().__init__()
16
+ self.up_ratio = up_ratio
17
+ self.down_ratio = down_ratio
18
+ self.act = activation
19
+ self.upsample = UpSample1d(up_ratio, up_kernel_size)
20
+ self.downsample = DownSample1d(down_ratio, down_kernel_size)
21
+
22
+ # x: [B,C,T]
23
+ def forward(self, x):
24
+ x = self.upsample(x)
25
+ x = self.act(x)
26
+ x = self.downsample(x)
27
+
28
+ return x
meanaudio/ext/bigvgan/alias_free_torch/filter.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import math
8
+
9
+ if 'sinc' in dir(torch):
10
+ sinc = torch.sinc
11
+ else:
12
+ # This code is adopted from adefossez's julius.core.sinc under the MIT License
13
+ # https://adefossez.github.io/julius/julius/core.html
14
+ # LICENSE is in incl_licenses directory.
15
+ def sinc(x: torch.Tensor):
16
+ """
17
+ Implementation of sinc, i.e. sin(pi * x) / (pi * x)
18
+ __Warning__: Different to julius.sinc, the input is multiplied by `pi`!
19
+ """
20
+ return torch.where(x == 0,
21
+ torch.tensor(1., device=x.device, dtype=x.dtype),
22
+ torch.sin(math.pi * x) / math.pi / x)
23
+
24
+
25
+ # This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
26
+ # https://adefossez.github.io/julius/julius/lowpass.html
27
+ # LICENSE is in incl_licenses directory.
28
+ def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
29
+ even = (kernel_size % 2 == 0)
30
+ half_size = kernel_size // 2
31
+
32
+ #For kaiser window
33
+ delta_f = 4 * half_width
34
+ A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
35
+ if A > 50.:
36
+ beta = 0.1102 * (A - 8.7)
37
+ elif A >= 21.:
38
+ beta = 0.5842 * (A - 21)**0.4 + 0.07886 * (A - 21.)
39
+ else:
40
+ beta = 0.
41
+ window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
42
+
43
+ # ratio = 0.5/cutoff -> 2 * cutoff = 1 / ratio
44
+ if even:
45
+ time = (torch.arange(-half_size, half_size) + 0.5)
46
+ else:
47
+ time = torch.arange(kernel_size) - half_size
48
+ if cutoff == 0:
49
+ filter_ = torch.zeros_like(time)
50
+ else:
51
+ filter_ = 2 * cutoff * window * sinc(2 * cutoff * time)
52
+ # Normalize filter to have sum = 1, otherwise we will have a small leakage
53
+ # of the constant component in the input signal.
54
+ filter_ /= filter_.sum()
55
+ filter = filter_.view(1, 1, kernel_size)
56
+
57
+ return filter
58
+
59
+
60
+ class LowPassFilter1d(nn.Module):
61
+ def __init__(self,
62
+ cutoff=0.5,
63
+ half_width=0.6,
64
+ stride: int = 1,
65
+ padding: bool = True,
66
+ padding_mode: str = 'replicate',
67
+ kernel_size: int = 12):
68
+ # kernel_size should be even number for stylegan3 setup,
69
+ # in this implementation, odd number is also possible.
70
+ super().__init__()
71
+ if cutoff < -0.:
72
+ raise ValueError("Minimum cutoff must be larger than zero.")
73
+ if cutoff > 0.5:
74
+ raise ValueError("A cutoff above 0.5 does not make sense.")
75
+ self.kernel_size = kernel_size
76
+ self.even = (kernel_size % 2 == 0)
77
+ self.pad_left = kernel_size // 2 - int(self.even)
78
+ self.pad_right = kernel_size // 2
79
+ self.stride = stride
80
+ self.padding = padding
81
+ self.padding_mode = padding_mode
82
+ filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
83
+ self.register_buffer("filter", filter)
84
+
85
+ #input [B, C, T]
86
+ def forward(self, x):
87
+ _, C, _ = x.shape
88
+
89
+ if self.padding:
90
+ x = F.pad(x, (self.pad_left, self.pad_right),
91
+ mode=self.padding_mode)
92
+ out = F.conv1d(x, self.filter.expand(C, -1, -1),
93
+ stride=self.stride, groups=C)
94
+
95
+ return out
meanaudio/ext/bigvgan/alias_free_torch/resample.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/junjun3518/alias-free-torch under the Apache License 2.0
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch.nn as nn
5
+ from torch.nn import functional as F
6
+ from .filter import LowPassFilter1d
7
+ from .filter import kaiser_sinc_filter1d
8
+
9
+
10
+ class UpSample1d(nn.Module):
11
+ def __init__(self, ratio=2, kernel_size=None):
12
+ super().__init__()
13
+ self.ratio = ratio
14
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
15
+ self.stride = ratio
16
+ self.pad = self.kernel_size // ratio - 1
17
+ self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
18
+ self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
19
+ filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio,
20
+ half_width=0.6 / ratio,
21
+ kernel_size=self.kernel_size)
22
+ self.register_buffer("filter", filter)
23
+
24
+ # x: [B, C, T]
25
+ def forward(self, x):
26
+ _, C, _ = x.shape
27
+
28
+ x = F.pad(x, (self.pad, self.pad), mode='replicate')
29
+ x = self.ratio * F.conv_transpose1d(
30
+ x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
31
+ x = x[..., self.pad_left:-self.pad_right]
32
+
33
+ return x
34
+
35
+
36
+ class DownSample1d(nn.Module):
37
+ def __init__(self, ratio=2, kernel_size=None):
38
+ super().__init__()
39
+ self.ratio = ratio
40
+ self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
41
+ self.lowpass = LowPassFilter1d(cutoff=0.5 / ratio,
42
+ half_width=0.6 / ratio,
43
+ stride=ratio,
44
+ kernel_size=self.kernel_size)
45
+
46
+ def forward(self, x):
47
+ xx = self.lowpass(x)
48
+
49
+ return xx
meanaudio/ext/bigvgan/bigvgan.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from omegaconf import OmegaConf
6
+
7
+ from meanaudio.ext.bigvgan.models import BigVGANVocoder
8
+
9
+ _bigvgan_vocoder_path = Path(__file__).parent / 'bigvgan_vocoder.yml'
10
+
11
+
12
+ class BigVGAN(nn.Module):
13
+
14
+ def __init__(self, ckpt_path, config_path=_bigvgan_vocoder_path):
15
+ super().__init__()
16
+ vocoder_cfg = OmegaConf.load(config_path)
17
+ self.vocoder = BigVGANVocoder(vocoder_cfg).eval()
18
+ vocoder_ckpt = torch.load(ckpt_path, map_location='cpu', weights_only=True)['generator']
19
+ self.vocoder.load_state_dict(vocoder_ckpt)
20
+
21
+ self.weight_norm_removed = False
22
+ self.remove_weight_norm()
23
+
24
+ @torch.inference_mode()
25
+ def forward(self, x):
26
+ assert self.weight_norm_removed, 'call remove_weight_norm() before inference'
27
+ return self.vocoder(x)
28
+
29
+ def remove_weight_norm(self):
30
+ self.vocoder.remove_weight_norm()
31
+ self.weight_norm_removed = True
32
+ return self
meanaudio/ext/bigvgan/bigvgan_vocoder.yml ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ resblock: '1'
2
+ num_gpus: 0
3
+ batch_size: 64
4
+ num_mels: 80
5
+ learning_rate: 0.0001
6
+ adam_b1: 0.8
7
+ adam_b2: 0.99
8
+ lr_decay: 0.999
9
+ seed: 1234
10
+ upsample_rates:
11
+ - 4
12
+ - 4
13
+ - 2
14
+ - 2
15
+ - 2
16
+ - 2
17
+ upsample_kernel_sizes:
18
+ - 8
19
+ - 8
20
+ - 4
21
+ - 4
22
+ - 4
23
+ - 4
24
+ upsample_initial_channel: 1536
25
+ resblock_kernel_sizes:
26
+ - 3
27
+ - 7
28
+ - 11
29
+ resblock_dilation_sizes:
30
+ - - 1
31
+ - 3
32
+ - 5
33
+ - - 1
34
+ - 3
35
+ - 5
36
+ - - 1
37
+ - 3
38
+ - 5
39
+ activation: snakebeta
40
+ snake_logscale: true
41
+ resolutions:
42
+ - - 1024
43
+ - 120
44
+ - 600
45
+ - - 2048
46
+ - 240
47
+ - 1200
48
+ - - 512
49
+ - 50
50
+ - 240
51
+ mpd_reshapes:
52
+ - 2
53
+ - 3
54
+ - 5
55
+ - 7
56
+ - 11
57
+ use_spectral_norm: false
58
+ discriminator_channel_mult: 1
59
+ num_workers: 4
60
+ dist_config:
61
+ dist_backend: nccl
62
+ dist_url: tcp://localhost:54341
63
+ world_size: 1
meanaudio/ext/bigvgan/env.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import os
5
+ import shutil
6
+
7
+
8
+ class AttrDict(dict):
9
+ def __init__(self, *args, **kwargs):
10
+ super(AttrDict, self).__init__(*args, **kwargs)
11
+ self.__dict__ = self
12
+
13
+
14
+ def build_env(config, config_name, path):
15
+ t_path = os.path.join(path, config_name)
16
+ if config != t_path:
17
+ os.makedirs(path, exist_ok=True)
18
+ shutil.copyfile(config, os.path.join(path, config_name))
meanaudio/ext/bigvgan/incl_licenses/LICENSE_1 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Jungil Kong
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
meanaudio/ext/bigvgan/incl_licenses/LICENSE_2 ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2020 Edward Dixon
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
meanaudio/ext/bigvgan/incl_licenses/LICENSE_3 ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
meanaudio/ext/bigvgan/incl_licenses/LICENSE_4 ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2019, Seungwon Park 박승원
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.