yhzx233 commited on
Commit
ea174b0
·
1 Parent(s): 8f33b55

feat: app.py

Browse files
XY_Tokenizer/.gitignore ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .env
106
+ .venv
107
+ env/
108
+ venv/
109
+ ENV/
110
+ env.bak/
111
+ venv.bak/
112
+
113
+ # Spyder project settings
114
+ .spyderproject
115
+ .spyproject
116
+
117
+ # Rope project settings
118
+ .ropeproject
119
+
120
+ # mkdocs documentation
121
+ /site
122
+
123
+ # mypy
124
+ .mypy_cache/
125
+ .dmypy.json
126
+ dmypy.json
127
+ __pycache__/
128
+
129
+ # Pyre type checker
130
+ .pyre/
131
+
132
+ # windows folder
133
+ *.ini
134
+
135
+ # models
136
+ *.pkl
137
+
138
+ *.wav
139
+ *.flac
140
+ *.mp3
141
+
142
+ # others
143
+ temp/
144
+
145
+ exp/
146
+
147
+ slurmlogs/
148
+ slurmlogs_*/
149
+
150
+ dev/
151
+ .vscode/
152
+
153
+ config/debug
154
+ .vscode/
155
+
156
+
157
+ submit_debug*
158
+
159
+ random_rep_for_v2.13
160
+
161
+ exp_eval
162
+
163
+ data/**/*.txt
164
+
165
+ tokenize_data/tokenize_result/
166
+
167
+
168
+ *.png
169
+
170
+ reconstruct_evaluation_backup/
171
+
172
+ semantic_evaluation/scripts/en
173
+
174
+ reconstruct_evaluation/scripts
175
+
176
+ *.jsonl
177
+ *.json
178
+ scripts/debug
179
+
180
+ *.hostfile
181
+ .deepspeed_env
182
+
183
+ *.idx
184
+
185
+ backup*
186
+
187
+ *.ckpt
188
+
189
+ # Project specific
190
+ output_wavs/
191
+ *.pt
192
+ *.pth
193
+ output.log
XY_Tokenizer/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XY Tokenizer
2
+
3
+ XY Tokenizer is a speech codec that simultaneously models both semantic and acoustic aspects of speech, converting audio into discrete tokens and decoding them back to high-quality audio. It achieves efficient speech representation at only 1kbps with RVQ8 quantization at 12.5Hz frame rate.
4
+
5
+ ## Features
6
+
7
+ - **Dual-channel modeling**: Simultaneously captures semantic meaning and acoustic details
8
+ - **Efficient representation**: 1kbps bitrate with RVQ8 quantization at 12.5Hz
9
+ - **High-quality audio tokenization**: Convert speech to discrete tokens and back with minimal quality loss
10
+ - **Long audio support**: Process audio files longer than 30 seconds using chunking with overlap
11
+ - **Batch processing**: Efficiently process multiple audio files in batches
12
+ - **24kHz output**: Generate high-quality 24kHz audio output
13
+
14
+ ## Installation
15
+
16
+ ```bash
17
+ # Create and activate conda environment
18
+ conda create -n xy_tokenizer python=3.10 -y && conda activate xy_tokenizer
19
+
20
+ # Install dependencies
21
+ pip install -r requirements.txt
22
+ ```
23
+
24
+ ## Usage
25
+
26
+ ### Basic Inference
27
+
28
+ To tokenize audio files and reconstruct them:
29
+
30
+ ```bash
31
+ python inference.py \
32
+ --config_path ./config/xy_tokenizer_config.yaml \
33
+ --checkpoint_path ./weights/xy_tokenizer.ckpt \
34
+ --input_dir ./input_wavs/ \
35
+ --output_dir ./output_wavs/
36
+ ```
37
+
38
+ ### Parameters
39
+
40
+ - `--config_path`: Path to the model configuration file
41
+ - `--checkpoint_path`: Path to the pre-trained model checkpoint
42
+ - `--input_dir`: Directory containing input WAV files
43
+ - `--output_dir`: Directory to save reconstructed audio files
44
+ - `--device`: Device to run inference on (default: "cuda")
45
+ - `--debug`, `--debug_ip`, `--debug_port`: Debugging options (disabled by default)
46
+
47
+ ## Project Structure
48
+
49
+ - `xy_tokenizer/`: Core model implementation
50
+ - `model.py`: Main XY_Tokenizer model class
51
+ - `nn/`: Neural network components
52
+ - `config/`: Configuration files
53
+ - `utils/`: Utility functions
54
+ - `weights/`: Pre-trained model weights
55
+ - `input_wavs/`: Directory for input audio files
56
+ - `output_wavs/`: Directory for output audio files
57
+
58
+ ## Model Architecture
59
+
60
+ XY Tokenizer uses a dual-channel architecture that simultaneously models:
61
+ 1. **Semantic Channel**: Captures high-level semantic information and linguistic content
62
+ 2. **Acoustic Channel**: Preserves detailed acoustic features including speaker characteristics and prosody
63
+
64
+ The model processes audio through several stages:
65
+ 1. Feature extraction (mel-spectrogram)
66
+ 2. Parallel semantic and acoustic encoding
67
+ 3. Residual Vector Quantization (RVQ8) at 12.5Hz frame rate (1kbps)
68
+ 4. Decoding and waveform generation
69
+
70
+ ## License
71
+
72
+ [Specify your license here]
XY_Tokenizer/config/xy_tokenizer_config.yaml ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ input_sample_rate: &input_sample_rate 16000
2
+ output_sample_rate: &output_sample_rate 24000
3
+
4
+ generator_params:
5
+ input_sample_rate: *input_sample_rate
6
+ output_sample_rate: *output_sample_rate
7
+
8
+ feature_extractor_kwargs:
9
+ chunk_length: 30
10
+ feature_size: 80
11
+ hop_length: 160
12
+ n_fft: 400
13
+ n_samples: 480000
14
+ nb_max_frames: 3000
15
+ padding_side: right
16
+ padding_value: 0.0
17
+ return_attention_mask: false
18
+ sampling_rate: *input_sample_rate
19
+
20
+ ## Codec Args
21
+
22
+ ## semantic channel
23
+ semantic_encoder_kwargs: # 100hz -> 50hz
24
+ num_mel_bins: 80
25
+ sampling_rate: *input_sample_rate
26
+ hop_length: 160
27
+ stride_size: 2
28
+ kernel_size: 3
29
+ d_model: 768
30
+ scale_embedding: false
31
+ max_audio_seconds: 30
32
+ encoder_layers: 12
33
+ encoder_attention_heads: 12
34
+ encoder_ffn_dim: 3072
35
+ activation_function: "gelu"
36
+
37
+ semantic_encoder_adapter_kwargs: # 50hz
38
+ input_dim: 768
39
+ output_dim: 768
40
+ d_model: 768
41
+ max_source_positions: 1500
42
+ encoder_layers: 4
43
+ encoder_attention_heads: 12
44
+ encoder_ffn_dim: 3072
45
+
46
+
47
+ ## acoustic channel
48
+ acoustic_encoder_kwargs: # 100hz -> 50hz
49
+ num_mel_bins: 80
50
+ sampling_rate: *input_sample_rate
51
+ hop_length: 160
52
+ stride_size: 2
53
+ kernel_size: 3
54
+ d_model: 768
55
+ scale_embedding: false
56
+ max_audio_seconds: 30
57
+ encoder_layers: 12
58
+ encoder_attention_heads: 12
59
+ encoder_ffn_dim: 3072
60
+ activation_function: "gelu"
61
+
62
+
63
+ ## semantic & acoustic shared parameters
64
+ pre_rvq_adapter_kwargs: # 50hz
65
+ input_dim: 1536
66
+ output_dim: 768
67
+ d_model: 768
68
+ max_source_positions: 1500
69
+ encoder_layers: 4
70
+ encoder_attention_heads: 12
71
+ encoder_ffn_dim: 3072
72
+
73
+ downsample_kwargs: # 50hz -> 12.5hz
74
+ d_model: 768
75
+ avg_pooler: 4
76
+
77
+ quantizer_kwargs: # 12.5hz
78
+ input_dim: 3072
79
+ rvq_dim: 512
80
+ output_dim: 3072
81
+ num_quantizers: 8
82
+ codebook_size: 1024
83
+ codebook_dim: 512
84
+ quantizer_dropout: 0.0
85
+ commitment: 1
86
+
87
+ post_rvq_adapter_kwargs: # 12.5hz
88
+ input_dim: 3072
89
+ output_dim: 3072
90
+ d_model: 768
91
+ max_source_positions: 375
92
+ encoder_layers: 4
93
+ encoder_attention_heads: 12
94
+ encoder_ffn_dim: 3072
95
+
96
+ upsample_kwargs: # 12.5hz -> 50hz
97
+ d_model: 768
98
+ stride: 4
99
+
100
+ ## acoustic channel
101
+ acoustic_decoder_kwargs: # 50hz -> 100hz
102
+ num_mel_bins: 80
103
+ sampling_rate: *input_sample_rate
104
+ hop_length: 160
105
+ stride_size: 2
106
+ kernel_size: 3
107
+ d_model: 768
108
+ scale_embedding: false
109
+ max_audio_seconds: 30
110
+ decoder_layers: 12
111
+ decoder_attention_heads: 12
112
+ decoder_ffn_dim: 3072
113
+ activation_function: "gelu"
114
+
115
+ vocos_kwargs: # 100hz -> 24khz
116
+ input_channels: 80
117
+ dim: 512
118
+ intermediate_dim: 4096
119
+ num_layers: 30
120
+ n_fft: 960
121
+ hop_size: 240
122
+ padding: "same"
XY_Tokenizer/inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import logging
4
+ import torch
5
+
6
+ from utils.helpers import set_logging, waiting_for_debug, load_audio, save_audio, find_audio_files
7
+ from xy_tokenizer.model import XY_Tokenizer
8
+
9
+ if __name__ == "__main__":
10
+ set_logging()
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument("--config_path", type=str, default="./config/xy_tokenizer_config.yaml")
14
+ parser.add_argument("--checkpoint_path", type=str, default="./weights/xy_tokenizer.ckpt")
15
+ parser.add_argument("--device", type=str, default="cuda")
16
+
17
+ parser.add_argument("--input_dir", type=str, required=True)
18
+ parser.add_argument("--output_dir", type=str, required=True)
19
+
20
+
21
+ parser.add_argument("--debug_ip", type=str)
22
+ parser.add_argument("--debug_port", type=int)
23
+ parser.add_argument("--debug", default=0, type=int, nargs="?",
24
+ help='whether debug or not',
25
+ )
26
+ args = parser.parse_args()
27
+ if args.debug == 1:
28
+ waiting_for_debug(args.debug_ip, args.debug_port)
29
+
30
+ device = torch.device(args.device)
31
+
32
+ ## Load codec model
33
+ generator = XY_Tokenizer.load_from_checkpoint(config_path=args.config_path, ckpt_path=args.checkpoint_path).to(device).eval()
34
+
35
+ ## Find audios
36
+ audio_paths = find_audio_files(input_dir=args.input_dir)
37
+
38
+ ## Create output directory if not exists
39
+ os.makedirs(args.output_dir, exist_ok=True)
40
+ logging.info(f"Processing {len(audio_paths)} audio files, output will be saved to {args.output_dir}")
41
+
42
+ with torch.no_grad():
43
+ ## Process audios in batches
44
+ batch_size = 8
45
+ for i in range(0, len(audio_paths), batch_size):
46
+ batch_paths = audio_paths[i:i + batch_size]
47
+ logging.info(f"Processing batch {i // batch_size + 1}/{len(audio_paths) // batch_size + 1}, files: {batch_paths}")
48
+
49
+ # Load audio files
50
+ wav_list = [load_audio(path, target_sample_rate=generator.input_sample_rate).squeeze().to(device) for path in batch_paths]
51
+ logging.info(f"Successfully loaded {len(wav_list)} audio files with lengths {[len(wav) for wav in wav_list]} samples")
52
+
53
+ # Encode
54
+ encode_result = generator.encode(wav_list, overlap_seconds=10)
55
+ codes_list = encode_result["codes_list"] # B * (nq, T)
56
+ logging.info(f"Encoding completed, code lengths: {[codes.shape[-1] for codes in codes_list]}")
57
+ logging.info(f"{codes_list = }")
58
+
59
+ # Decode
60
+ decode_result = generator.decode(codes_list, overlap_seconds=10)
61
+ syn_wav_list = decode_result["syn_wav_list"] # B * (T,)
62
+ logging.info(f"Decoding completed, generated waveform lengths: {[len(wav) for wav in syn_wav_list]} samples")
63
+
64
+ # Save generated audios
65
+ for path, syn_wav in zip(batch_paths, syn_wav_list):
66
+ output_path = os.path.join(args.output_dir, os.path.basename(path))
67
+ save_audio(output_path, syn_wav.cpu().reshape(1, -1), sample_rate=generator.output_sample_rate)
68
+ logging.info(f"Saved generated audio to {output_path}")
69
+
70
+
71
+ logging.info("All audio processing completed")
XY_Tokenizer/requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ beartype
2
+ tensorboard
3
+ numpy
4
+ torch
5
+ torchaudio
6
+ einops
7
+ scipy
8
+ huggingface-hub
9
+ soundfile
10
+ matplotlib
11
+ lion_pytorch
12
+ accelerate
13
+ debugpy
14
+ tensorboardX
15
+ librosa
16
+ pesq
17
+ tqdm
18
+ mir_eval
19
+ stopes
20
+ s3prl
21
+ onnxscript
22
+ jiwer
23
+ orjson
XY_Tokenizer/utils/helpers.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torchaudio
3
+ import os
4
+ import sys
5
+ import glob
6
+ import debugpy
7
+ import torch
8
+ import numpy as np
9
+ import re
10
+
11
+ def count_params_by_module(model_name, model):
12
+ logging.info(f"Counting num_parameters of {model_name}:")
13
+
14
+ param_stats = {}
15
+ total_params = 0 # Count total parameters
16
+ total_requires_grad_params = 0 # Count parameters with requires_grad=True
17
+ total_no_grad_params = 0 # Count parameters with requires_grad=False
18
+
19
+ for name, param in model.named_parameters():
20
+ module_name = name.split('.')[0]
21
+ if module_name not in param_stats:
22
+ param_stats[module_name] = {'total': 0, 'requires_grad': 0, 'no_grad': 0}
23
+
24
+ param_num = param.numel()
25
+ param_stats[module_name]['total'] += param_num
26
+ total_params += param_num
27
+
28
+ if param.requires_grad:
29
+ param_stats[module_name]['requires_grad'] += param_num
30
+ total_requires_grad_params += param_num
31
+ else:
32
+ param_stats[module_name]['no_grad'] += param_num
33
+ total_no_grad_params += param_num
34
+
35
+ # Calculate maximum width for each column
36
+ max_module_name_length = max(len(module) for module in param_stats)
37
+ max_param_length = max(len(f"{stats['total'] / 1e6:.2f}M") for stats in param_stats.values())
38
+
39
+ # Output parameter statistics for each module
40
+ for module, stats in param_stats.items():
41
+ logging.info(f"\t{module:<{max_module_name_length}}: "
42
+ f"Total: {stats['total'] / 1e6:<{max_param_length}.2f}M, "
43
+ f"Requires Grad: {stats['requires_grad'] / 1e6:<{max_param_length}.2f}M, "
44
+ f"No Grad: {stats['no_grad'] / 1e6:<{max_param_length}.2f}M")
45
+
46
+ # Output total parameter statistics
47
+ logging.info(f"\tTotal parameters: {total_params / 1e6:.2f}M parameters")
48
+ logging.info(f"\tRequires Grad parameters: {total_requires_grad_params / 1e6:.2f}M parameters")
49
+ logging.info(f"\tNo Grad parameters: {total_no_grad_params / 1e6:.2f}M parameters")
50
+ logging.info(f"################################################################")
51
+
52
+
53
+ def load_and_resample_audio(audio_path, target_sample_rate):
54
+ wav, raw_sample_rate = torchaudio.load(audio_path) # (1, T) tensor
55
+ if raw_sample_rate != target_sample_rate:
56
+ wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate) # tensor
57
+ return wav.squeeze()
58
+
59
+ def set_logging():
60
+ rank = os.environ.get("RANK", 0)
61
+ logging.basicConfig(
62
+ level=logging.INFO,
63
+ stream=sys.stdout,
64
+ format=f"%(asctime)s [RANK {rank}] (%(module)s:%(lineno)d) %(levelname)s : %(message)s",
65
+ )
66
+
67
+ def waiting_for_debug(ip, port):
68
+ rank = os.environ.get("RANK", "0")
69
+ debugpy.listen((ip, port)) # Replace localhost with cluster node IP
70
+ logging.info(f"[rank = {rank}] Waiting for debugger attach...")
71
+ debugpy.wait_for_client()
72
+ logging.info(f"[rank = {rank}] Debugger attached")
73
+
74
+ def load_audio(audio_path, target_sample_rate):
75
+ # Load audio file, wav shape: (channels, time)
76
+ wav, raw_sample_rate = torchaudio.load(audio_path)
77
+
78
+ # If multi-channel, convert to mono by averaging across channels
79
+ if wav.shape[0] > 1:
80
+ wav = torch.mean(wav, dim=0, keepdim=True) # Average across channels, keep channel dim
81
+
82
+ # Resample if necessary
83
+ if raw_sample_rate != target_sample_rate:
84
+ wav = torchaudio.functional.resample(wav, raw_sample_rate, target_sample_rate)
85
+
86
+ # Convert to numpy, add channel dimension, then back to tensor with desired shape
87
+ wav = np.expand_dims(wav.squeeze(0).numpy(), axis=1) # Shape: (time, 1)
88
+ wav = torch.tensor(wav).reshape(1, 1, -1) # Shape: (1, 1, time)
89
+
90
+ return wav
91
+
92
+ def save_audio(audio_outpath, audio_out, sample_rate):
93
+ torchaudio.save(
94
+ audio_outpath,
95
+ audio_out,
96
+ sample_rate=sample_rate,
97
+ encoding='PCM_S',
98
+ bits_per_sample=16
99
+ )
100
+ logging.info(f"Successfully saved audio at {audio_outpath}")
101
+
102
+ def find_audio_files(input_dir):
103
+ audio_extensions = ['*.flac', '*.mp3', '*.wav']
104
+ audios_input = []
105
+ for ext in audio_extensions:
106
+ audios_input.extend(glob.glob(os.path.join(input_dir, '**', ext), recursive=True))
107
+ logging.info(f"Found {len(audios_input)} audio files in {input_dir}")
108
+ return sorted(audios_input)
109
+
110
+ def normalize_text(text):
111
+ # Remove all punctuation (including English and Chinese punctuation)
112
+ text = re.sub(r'[^\w\s\u4e00-\u9fff]', '', text)
113
+ # Convert to lowercase (effective for English, no effect on Chinese)
114
+ text = text.lower()
115
+ # Remove extra spaces
116
+ text = ' '.join(text.split())
117
+ return text
XY_Tokenizer/xy_tokenizer/model.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import yaml
3
+ import logging
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ from .nn.feature_extractor import MelFeatureExtractor
10
+ from .nn.modules import OmniAudioEncoder, OmniAudioDecoder, ResidualDownConv, UpConv, Transformer, Vocos
11
+ from .nn.quantizer import ResidualVQ
12
+
13
+ class XY_Tokenizer(nn.Module):
14
+ def __init__(self, generator_params):
15
+ super().__init__()
16
+ # Basic parameters
17
+ self.input_sample_rate = generator_params['input_sample_rate']
18
+ self.output_sample_rate = generator_params['output_sample_rate']
19
+
20
+ self.encoder_downsample_rate = 1280
21
+ self.decoder_upsample_rate = 1920
22
+ self.code_dim = generator_params['quantizer_kwargs']['input_dim']
23
+
24
+ ## Codec part
25
+
26
+ ## Semantic channel
27
+ self.semantic_encoder = OmniAudioEncoder(**generator_params['semantic_encoder_kwargs'])
28
+
29
+ self.semantic_encoder_adapter = Transformer(**generator_params['semantic_encoder_adapter_kwargs'])
30
+
31
+ ## Acoustic channel
32
+ self.acoustic_encoder = OmniAudioEncoder(**generator_params['acoustic_encoder_kwargs'])
33
+
34
+ ## Semantic & acoustic shared parameters
35
+ self.pre_rvq_adapter = Transformer(**generator_params['pre_rvq_adapter_kwargs'])
36
+
37
+ self.downsample = ResidualDownConv(**generator_params['downsample_kwargs'])
38
+
39
+ self.quantizer = ResidualVQ(**generator_params['quantizer_kwargs'])
40
+ self.nq = generator_params['quantizer_kwargs']['num_quantizers']
41
+
42
+ self.post_rvq_adapter = Transformer(**generator_params['post_rvq_adapter_kwargs'])
43
+
44
+ ## Acoustic channel
45
+ self.upsample = UpConv(**generator_params['upsample_kwargs'])
46
+
47
+ self.acoustic_decoder = OmniAudioDecoder(**generator_params['acoustic_decoder_kwargs'])
48
+
49
+ self.enhanced_vocos = Vocos(**generator_params['vocos_kwargs'])
50
+
51
+ ## Feature extractor
52
+ self.feature_extractor = MelFeatureExtractor(**generator_params['feature_extractor_kwargs'])
53
+
54
+ @torch.inference_mode()
55
+ def inference_tokenize(self, x, input_lengths):
56
+ """
57
+ Input:
58
+ x: Waveform tensor # (B, 1, T), T <= 30s * sample_rate
59
+ input_lengths: Valid length for each sample # (B,)
60
+ Output:
61
+ dict: Contains the following key-value pairs
62
+ "zq": Quantized embeddings # (B, D, T)
63
+ "codes": Quantization codes # (nq, B, T)
64
+ "codes_lengths": Quantization code lengths # (B,)
65
+ """
66
+ list_x = [xi[:, :x_len].reshape(-1).cpu().numpy() for xi, x_len in zip(x, input_lengths)]
67
+ features = self.feature_extractor(
68
+ list_x,
69
+ sampling_rate=self.input_sample_rate,
70
+ return_tensors="pt",
71
+ return_attention_mask=True
72
+ )
73
+ input_mel = features['input_features'].to(x.device).to(x.dtype) # (B, D, 3000)
74
+ audio_attention_mask = features['attention_mask'].to(x.device) # (B, 3000)
75
+
76
+ # Get batch size and sequence length of the input
77
+ mel_output_length = torch.sum(audio_attention_mask, dim=-1).long() # (B,)
78
+
79
+ # Semantic channel
80
+ semantic_encoder_output, semantic_encoder_output_length = self.semantic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
81
+
82
+ semantic_encoder_adapter_output, semantic_encoder_adapter_output_length = self.semantic_encoder_adapter(semantic_encoder_output, semantic_encoder_output_length) # (B, D, T), 50hz
83
+
84
+ # Acoustic channel
85
+ acoustic_encoder_output, acoustic_encoder_output_length = self.acoustic_encoder(input_mel, mel_output_length) # (B, D, T), 100hz -> 50hz
86
+
87
+ # Semantic & acoustic mixing
88
+ concated_semantic_acoustic_channel = torch.concat([semantic_encoder_adapter_output, acoustic_encoder_output], dim=1) # (B, D, T)
89
+ concated_semantic_acoustic_channel_length = acoustic_encoder_output_length
90
+
91
+ pre_rvq_adapter_output, pre_rvq_adapter_output_length = self.pre_rvq_adapter(concated_semantic_acoustic_channel, concated_semantic_acoustic_channel_length) # (B, D, T), 50hz
92
+
93
+ downsample_output, downsample_output_length = self.downsample(pre_rvq_adapter_output, pre_rvq_adapter_output_length) # (B, D, T), 50hz -> 12.5hz
94
+
95
+ zq, codes, vq_loss, _, quantizer_output_length = self.quantizer(downsample_output, downsample_output_length) # (B, D, T), (nq, B, T), (nq,), (nq, B, D, T), (B,)
96
+
97
+ return {
98
+ "zq": zq, # (B, D, T)
99
+ "codes": codes, # (nq, B, T)
100
+ "codes_lengths": quantizer_output_length # (B,)
101
+ }
102
+
103
+ @torch.inference_mode()
104
+ def inference_detokenize(self, codes, codes_lengths):
105
+ """
106
+ Input:
107
+ codes: Quantization codes # (nq, B, T)
108
+ codes_lengths: Quantization code lengths for each sample # (B,)
109
+ Output:
110
+ dict: Contains the following key-value pairs
111
+ "y": Synthesized audio waveform # (B, 1, T)
112
+ "output_length": Output lengths # (B,)
113
+ """
114
+ zq = self.quantizer.decode_codes(codes) # (B, D, T)
115
+
116
+ post_rvq_adapter_output, post_rvq_adapter_output_length = self.post_rvq_adapter(zq, codes_lengths) # (B, D, T), 12.5hz
117
+
118
+ # Acoustic channel
119
+ upsample_output, upsample_output_length = self.upsample(post_rvq_adapter_output, post_rvq_adapter_output_length) # (B, D, T), 12.5hz -> 50hz
120
+
121
+ acoustic_decoder_output, acoustic_decoder_output_length = self.acoustic_decoder(upsample_output, upsample_output_length) # (B, D, T), 50hz -> 100hz
122
+
123
+ y, vocos_output_length = self.enhanced_vocos(acoustic_decoder_output, acoustic_decoder_output_length) # (B, 1, T), 100hz -> 16khz
124
+
125
+ return {
126
+ "y": y, # (B, 1, T)
127
+ "output_length": vocos_output_length, # (B,)
128
+ }
129
+
130
+ @torch.inference_mode()
131
+ def encode(self, wav_list, overlap_seconds=10, device=torch.device("cuda")):
132
+ """
133
+ Input:
134
+ wav_list: List of audio waveforms, each with potentially different length, may exceed 30 seconds # B * (T,)
135
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
136
+ Output:
137
+ dict: Contains the following key-value pairs
138
+ "codes_list": List of quantization codes # B * (nq, T)
139
+ """
140
+ duration_seconds = 30 - overlap_seconds
141
+ chunk_size = int(30 * self.input_sample_rate) # Maximum samples per chunk
142
+ duration_size = int(duration_seconds * self.input_sample_rate) # Valid output samples per chunk
143
+ code_duration_length = duration_size // self.encoder_downsample_rate # Valid code length per chunk
144
+
145
+ # Get maximum waveform length
146
+ max_length = max(len(wav) for wav in wav_list)
147
+ batch_size = len(wav_list)
148
+ wav_tensor = torch.zeros(batch_size, 1, max_length, device=device)
149
+ input_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
150
+ for i, wav in enumerate(wav_list):
151
+ wav_tensor[i, 0, :len(wav)] = wav
152
+ input_lengths[i] = len(wav) # (B,)
153
+
154
+ # Calculate number of chunks needed
155
+ max_chunks = (max_length + duration_size - 1) // duration_size
156
+ codes_list = []
157
+
158
+ # Process the entire batch in chunks
159
+ for chunk_idx in range(max_chunks):
160
+ start = chunk_idx * duration_size
161
+ end = min(start + chunk_size, max_length)
162
+ chunk = wav_tensor[:, :, start:end] # (B, 1, T')
163
+ chunk_lengths = torch.clamp(input_lengths - start, 0, end - start) # (B,)
164
+
165
+ # Skip empty chunks
166
+ if chunk_lengths.max() == 0:
167
+ continue
168
+
169
+ # Encode
170
+ result = self.inference_tokenize(chunk, chunk_lengths) # {"zq": (B, D, T'), "codes": (nq, B, T'), "codes_lengths": (B,)}
171
+ chunk_codes = result["codes"] # (nq, B, T')
172
+ chunk_code_lengths = result["codes_lengths"] # (B,)
173
+
174
+ # Extract valid portion
175
+ valid_code_lengths = torch.clamp(chunk_code_lengths, 0, code_duration_length) # (B,)
176
+ valid_chunk_codes = torch.zeros(self.nq, batch_size, code_duration_length, device=device, dtype=chunk_codes.dtype)
177
+ for b in range(batch_size):
178
+ if valid_code_lengths[b] > 0:
179
+ valid_chunk_codes[:, b, :valid_code_lengths[b]] = chunk_codes[:, b, :valid_code_lengths[b]] # (nq, B, valid_code_length)
180
+
181
+ codes_list.append(valid_chunk_codes) # (nq, B, valid_code_length)
182
+
183
+ # Concatenate all chunks
184
+ if codes_list:
185
+ codes_tensor = torch.cat(codes_list, dim=-1) # (nq, B, T_total)
186
+ codes_list = [codes_tensor[:, i, :input_lengths[i] // self.encoder_downsample_rate] for i in range(batch_size)] # B * (nq, T)
187
+ else:
188
+ codes_list = [torch.zeros(self.nq, 0, device=device, dtype=torch.long) for _ in range(batch_size)] # B * (nq, 0)
189
+
190
+ return {
191
+ "codes_list": codes_list # B * (nq, T)
192
+ }
193
+
194
+ @torch.inference_mode()
195
+ def decode(self, codes_list, overlap_seconds=10, device=torch.device("cuda")):
196
+ """
197
+ Input:
198
+ codes_list: List of quantization codes # B * (nq, T)
199
+ overlap_seconds: Overlap in seconds, process 30 seconds at a time, keeping (30 - overlap_seconds) seconds of valid output
200
+ Output:
201
+ dict: Contains the following key-value pairs
202
+ "syn_wav_list": List of synthesized audio waveforms # B * (T,)
203
+ """
204
+ duration_seconds = 30 - overlap_seconds
205
+ chunk_code_length = int(30 * self.input_sample_rate // self.encoder_downsample_rate) # Maximum code length per chunk
206
+ duration_code_length = int(duration_seconds * self.input_sample_rate // self.encoder_downsample_rate) # Valid code length per chunk
207
+ duration_wav_length = duration_code_length * self.decoder_upsample_rate # Valid waveform length per chunk
208
+
209
+ # Get maximum code length
210
+ max_code_length = max(codes.shape[-1] for codes in codes_list)
211
+ batch_size = len(codes_list)
212
+ codes_tensor = torch.zeros(self.nq, batch_size, max_code_length, device=device, dtype=torch.long)
213
+ code_lengths = torch.zeros(batch_size, dtype=torch.long, device=device)
214
+ for i, codes in enumerate(codes_list):
215
+ codes_tensor[:, i, :codes.shape[-1]] = codes.to(device)
216
+ code_lengths[i] = codes.shape[-1] # (B,)
217
+
218
+ # Calculate number of chunks needed
219
+ max_chunks = (max_code_length + duration_code_length - 1) // duration_code_length
220
+ wav_list = []
221
+
222
+ # Process the entire batch in chunks
223
+ for chunk_idx in range(max_chunks):
224
+ start = chunk_idx * duration_code_length
225
+ end = min(start + chunk_code_length, max_code_length)
226
+ chunk_codes = codes_tensor[:, :, start:end] # (nq, B, T')
227
+ chunk_code_lengths = torch.clamp(code_lengths - start, 0, end - start) # (B,)
228
+
229
+ # Skip empty chunks
230
+ if chunk_code_lengths.max() == 0:
231
+ continue
232
+
233
+ # Decode
234
+ result = self.inference_detokenize(chunk_codes, chunk_code_lengths) # {"y": (B, 1, T'), "output_length": (B,)}
235
+ chunk_wav = result["y"] # (B, 1, T')
236
+ chunk_wav_lengths = result["output_length"] # (B,)
237
+
238
+ # Extract valid portion
239
+ valid_wav_lengths = torch.clamp(chunk_wav_lengths, 0, duration_wav_length) # (B,)
240
+ valid_chunk_wav = torch.zeros(batch_size, 1, duration_wav_length, device=device)
241
+ for b in range(batch_size):
242
+ if valid_wav_lengths[b] > 0:
243
+ valid_chunk_wav[b, :, :valid_wav_lengths[b]] = chunk_wav[b, :, :valid_wav_lengths[b]] # (B, 1, valid_wav_length)
244
+
245
+ wav_list.append(valid_chunk_wav) # (B, 1, valid_wav_length)
246
+
247
+ # Concatenate all chunks
248
+ if wav_list:
249
+ wav_tensor = torch.cat(wav_list, dim=-1) # (B, 1, T_total)
250
+ syn_wav_list = [wav_tensor[i, 0, :code_lengths[i] * self.decoder_upsample_rate] for i in range(batch_size)] # B * (T,)
251
+ else:
252
+ syn_wav_list = [torch.zeros(0, device=device) for _ in range(batch_size)] # B * (0,)
253
+
254
+ return {
255
+ "syn_wav_list": syn_wav_list # B * (T,)
256
+ }
257
+
258
+ @classmethod
259
+ def load_from_checkpoint(cls, config_path: str, ckpt_path: str):
260
+ # Load model from configuration file and checkpoint
261
+ logging.info(f"Loading model from {config_path} and {ckpt_path}")
262
+
263
+ # Load configuration
264
+ with open(config_path, 'r') as f:
265
+ config = yaml.safe_load(f)
266
+
267
+ # Create model instance
268
+ model = cls(config['generator_params'])
269
+
270
+ # Load checkpoint
271
+ checkpoint = torch.load(ckpt_path, map_location='cpu')
272
+
273
+ # Check if checkpoint contains 'generator' key
274
+ if 'generator' in checkpoint:
275
+ model.load_state_dict(checkpoint['generator'])
276
+ else:
277
+ model.load_state_dict(checkpoint)
278
+
279
+ return model
XY_Tokenizer/xy_tokenizer/nn/feature_extractor.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+
4
+ from typing import Union, List, Optional
5
+ from transformers.feature_extraction_sequence_utils import SequenceFeatureExtractor
6
+ from transformers.feature_extraction_utils import BatchFeature
7
+ from transformers.utils import TensorType, logging
8
+ from transformers.utils.import_utils import is_torch_available
9
+ from transformers.audio_utils import mel_filter_bank, spectrogram, window_function
10
+
11
+ class MelFeatureExtractor(SequenceFeatureExtractor):
12
+ model_input_names = ["input_features"]
13
+
14
+ def __init__(
15
+ self,
16
+ feature_size=80,
17
+ sampling_rate=16000,
18
+ hop_length=160,
19
+ chunk_length=30,
20
+ n_fft=400,
21
+ padding_value=0.0,
22
+ dither=0.0,
23
+ return_attention_mask=False,
24
+ max_frequency=None,
25
+ **kwargs,
26
+ ):
27
+ super().__init__(
28
+ feature_size=feature_size,
29
+ sampling_rate=sampling_rate,
30
+ padding_value=padding_value,
31
+ return_attention_mask=return_attention_mask,
32
+ **kwargs,
33
+ )
34
+ self.n_fft = n_fft
35
+ self.hop_length = hop_length
36
+ self.chunk_length = chunk_length
37
+ self.n_samples = chunk_length * sampling_rate
38
+ self.nb_max_frames = self.n_samples // hop_length
39
+ self.sampling_rate = sampling_rate
40
+ self.dither = dither
41
+ self.max_frequency = max_frequency if max_frequency is not None else sampling_rate / 2
42
+ self.mel_filters = mel_filter_bank(
43
+ num_frequency_bins=1 + n_fft // 2,
44
+ num_mel_filters=feature_size,
45
+ min_frequency=0.0,
46
+ max_frequency=self.max_frequency,
47
+ sampling_rate=sampling_rate,
48
+ norm="slaney",
49
+ mel_scale="slaney",
50
+ )
51
+
52
+ def _np_extract_fbank_features(self, waveform_batch: np.array, device: str) -> np.ndarray:
53
+ if device != "cpu":
54
+ raise ValueError(
55
+ f"Got device `{device}` for feature extraction, but feature extraction on CUDA accelerator "
56
+ "devices requires torch, which is not installed. Either set `device='cpu'`, or "
57
+ "install torch according to the official instructions: https://pytorch.org/get-started/locally/"
58
+ )
59
+ log_spec_batch = []
60
+ for waveform in waveform_batch:
61
+ log_spec = spectrogram(
62
+ waveform,
63
+ window_function(self.n_fft, "hann"),
64
+ frame_length=self.n_fft,
65
+ hop_length=self.hop_length,
66
+ power=2.0,
67
+ dither=self.dither,
68
+ mel_filters=self.mel_filters,
69
+ log_mel="log10",
70
+ )
71
+ log_spec = log_spec[:, :-1]
72
+ log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
73
+ log_spec = (log_spec + 4.0) / 4.0
74
+ log_spec_batch.append(log_spec)
75
+ log_spec_batch = np.array(log_spec_batch)
76
+ return log_spec_batch
77
+
78
+ def _torch_extract_fbank_features(self, waveform: np.array, device: str = "cpu") -> np.ndarray:
79
+ """
80
+ Compute the log-mel spectrogram of the audio using PyTorch's GPU-accelerated STFT implementation with batching,
81
+ yielding results similar to cpu computing with 1e-5 tolerance.
82
+ """
83
+ waveform = torch.from_numpy(waveform).to(device, torch.float32)
84
+ window = torch.hann_window(self.n_fft, device=device)
85
+
86
+ if self.dither != 0.0:
87
+ waveform += self.dither * torch.randn(waveform.shape, dtype=waveform.dtype, device=waveform.device)
88
+
89
+ stft = torch.stft(waveform, self.n_fft, self.hop_length, window=window, return_complex=True)
90
+ magnitudes = stft[..., :-1].abs() ** 2
91
+
92
+ mel_filters = torch.from_numpy(self.mel_filters).to(device, torch.float32)
93
+ mel_spec = mel_filters.T @ magnitudes
94
+
95
+ log_spec = torch.clamp(mel_spec, min=1e-10).log10()
96
+ if waveform.dim() == 2:
97
+ max_val = log_spec.max(dim=2, keepdim=True)[0].max(dim=1, keepdim=True)[0]
98
+ log_spec = torch.maximum(log_spec, max_val - 8.0)
99
+ else:
100
+ log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
101
+ log_spec = (log_spec + 4.0) / 4.0
102
+ if device != "cpu":
103
+ log_spec = log_spec.detach().cpu()
104
+ return log_spec.numpy()
105
+
106
+ @staticmethod
107
+ def zero_mean_unit_var_norm(
108
+ input_values: List[np.ndarray], attention_mask: List[np.ndarray], padding_value: float = 0.0
109
+ ) -> List[np.ndarray]:
110
+ """
111
+ Every array in the list is normalized to have zero mean and unit variance
112
+ """
113
+ if attention_mask is not None:
114
+ attention_mask = np.array(attention_mask, np.int32)
115
+ normed_input_values = []
116
+
117
+ for vector, length in zip(input_values, attention_mask.sum(-1)):
118
+ normed_slice = (vector - vector[:length].mean()) / np.sqrt(vector[:length].var() + 1e-7)
119
+ if length < normed_slice.shape[0]:
120
+ normed_slice[length:] = padding_value
121
+
122
+ normed_input_values.append(normed_slice)
123
+ else:
124
+ normed_input_values = [(x - x.mean()) / np.sqrt(x.var() + 1e-7) for x in input_values]
125
+
126
+ return normed_input_values
127
+
128
+ def __call__(
129
+ self,
130
+ raw_speech: Union[np.ndarray, List[float], List[np.ndarray], List[List[float]]],
131
+ truncation: bool = True,
132
+ pad_to_multiple_of: Optional[int] = None,
133
+ return_tensors: Optional[Union[str, TensorType]] = None,
134
+ return_attention_mask: Optional[bool] = None,
135
+ padding: Optional[str] = "max_length",
136
+ max_length: Optional[int] = None,
137
+ sampling_rate: Optional[int] = None,
138
+ do_normalize: Optional[bool] = None,
139
+ device: Optional[str] = "cpu",
140
+ return_token_timestamps: Optional[bool] = None,
141
+ **kwargs,
142
+ ) -> BatchFeature:
143
+ """
144
+ Main method to featurize and prepare for the model one or several sequence(s). Implementation uses PyTorch for
145
+ the STFT computation if available, otherwise a slower NumPy based one.
146
+
147
+ Args:
148
+ raw_speech (`np.ndarray`, `List[float]`, `List[np.ndarray]`, `List[List[float]]`):
149
+ The sequence or batch of sequences to be padded. Each sequence can be a numpy array, a list of float
150
+ values, a list of numpy arrays or a list of list of float values. Must be mono channel audio, not
151
+ stereo, i.e. single float per timestep.
152
+ truncation (`bool`, *optional*, default to `True`):
153
+ Activates truncation to cut input sequences longer than *max_length* to *max_length*.
154
+ pad_to_multiple_of (`int`, *optional*, defaults to None):
155
+ If set will pad the sequence to a multiple of the provided value.
156
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
157
+ If set, will return tensors instead of list of python integers. Acceptable values are:
158
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
159
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
160
+ - `'np'`: Return Numpy `np.ndarray` objects.
161
+ sampling_rate (`int`, *optional*):
162
+ The sampling rate at which the `raw_speech` input was sampled. If provided, it is checked against
163
+ the extractor's sampling rate.
164
+ padding_value (`float`, *optional*, defaults to 0.0):
165
+ The value that is used to fill the padding values / vectors.
166
+ do_normalize (`bool`, *optional*, defaults to `False`):
167
+ Whether or not to zero-mean unit-variance normalize the input.
168
+ device (`str`, *optional*, defaults to `'cpu'`):
169
+ Specifies the device for computation of the log-mel spectrogram.
170
+ return_token_timestamps (`bool`, *optional*, defaults to `None`):
171
+ Whether or not to return the number of frames of the input raw_speech.
172
+ """
173
+ if sampling_rate is not None and sampling_rate != self.sampling_rate:
174
+ logger.warning(
175
+ f"The provided `raw_speech` input was sampled at {sampling_rate}Hz, but the feature extractor "
176
+ f"is configured for {self.sampling_rate}Hz. You should resample the audio to match the "
177
+ f"extractor's sampling rate to ensure correct feature extraction."
178
+ )
179
+
180
+ is_batched_numpy = isinstance(raw_speech, np.ndarray) and len(raw_speech.shape) > 1
181
+ if is_batched_numpy and len(raw_speech.shape) > 2:
182
+ raise ValueError(f"Only mono-channel audio is supported for input to {self}")
183
+ is_batched = is_batched_numpy or (
184
+ isinstance(raw_speech, (list, tuple)) and (isinstance(raw_speech[0], (np.ndarray, tuple, list)))
185
+ )
186
+
187
+ if is_batched:
188
+ raw_speech = [np.asarray([speech], dtype=np.float32).T for speech in raw_speech]
189
+ elif not is_batched and not isinstance(raw_speech, np.ndarray):
190
+ raw_speech = np.asarray(raw_speech, dtype=np.float32)
191
+ elif isinstance(raw_speech, np.ndarray) and raw_speech.dtype is np.dtype(np.float64):
192
+ raw_speech = raw_speech.astype(np.float32)
193
+
194
+ if not is_batched:
195
+ raw_speech = [np.asarray([raw_speech]).T]
196
+
197
+ batched_speech = BatchFeature({"input_features": raw_speech})
198
+
199
+ padded_inputs = self.pad(
200
+ batched_speech,
201
+ padding=padding,
202
+ max_length=max_length if max_length else self.n_samples,
203
+ truncation=truncation,
204
+ pad_to_multiple_of=pad_to_multiple_of,
205
+ return_attention_mask=return_attention_mask or do_normalize,
206
+ )
207
+
208
+ if do_normalize:
209
+ padded_inputs["input_features"] = self.zero_mean_unit_var_norm(
210
+ padded_inputs["input_features"],
211
+ attention_mask=padded_inputs["attention_mask"],
212
+ padding_value=self.padding_value,
213
+ )
214
+ padded_inputs["input_features"] = np.stack(padded_inputs["input_features"], axis=0)
215
+
216
+ input_features = padded_inputs.get("input_features").transpose(2, 0, 1)
217
+
218
+ extract_fbank_features = (
219
+ self._torch_extract_fbank_features if is_torch_available() else self._np_extract_fbank_features
220
+ )
221
+ input_features = extract_fbank_features(input_features[0], device)
222
+
223
+ if isinstance(input_features[0], List):
224
+ padded_inputs["input_features"] = [np.asarray(feature, dtype=np.float32) for feature in input_features]
225
+ else:
226
+ padded_inputs["input_features"] = input_features
227
+
228
+ if return_attention_mask:
229
+ padded_inputs["attention_mask"] = padded_inputs["attention_mask"][:, :: self.hop_length]
230
+
231
+ if return_token_timestamps is not None:
232
+ padded_inputs["num_frames"] = [len(raw_speech_i) // self.hop_length for raw_speech_i in raw_speech]
233
+
234
+ if return_tensors is not None:
235
+ padded_inputs = padded_inputs.convert_to_tensors(return_tensors)
236
+
237
+ return padded_inputs
XY_Tokenizer/xy_tokenizer/nn/modules.py ADDED
@@ -0,0 +1,1480 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.distributed
3
+ import numpy as np
4
+ import logging
5
+ import math
6
+ import copy
7
+ import numpy as np
8
+ import scipy
9
+ import torch
10
+ import librosa
11
+
12
+ from typing import Optional, Tuple
13
+ from torch import nn, view_as_real, view_as_complex
14
+ from torch import nn
15
+ from torch.nn import functional as F
16
+ from torch.nn.utils import weight_norm, remove_weight_norm
17
+ from torchaudio.functional.functional import _hz_to_mel, _mel_to_hz
18
+ from transformers.activations import ACT2FN
19
+ from dataclasses import dataclass
20
+ from transformers.modeling_outputs import ModelOutput
21
+ from transformers import WhisperModel
22
+
23
+
24
+ # Define function to generate positional embeddings using sine and cosine functions to represent sequence position information
25
+ def sinusoids(length, channels, max_timescale=10000):
26
+ """Returns sinusoidal waves for positional embedding"""
27
+ assert channels % 2 == 0
28
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
29
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
30
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
31
+ return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
32
+
33
+ # Generate sequence mask to distinguish valid sequence and padding parts
34
+ def get_sequence_mask(inputs, inputs_length):
35
+ if inputs.dim() == 3:
36
+ bsz, tgt_len, _ = inputs.size()
37
+ else:
38
+ bsz, tgt_len = inputs_length.shape[0], torch.max(inputs_length)
39
+ sequence_mask = torch.arange(0, tgt_len).to(inputs.device)
40
+ sequence_mask = torch.lt(sequence_mask, inputs_length.reshape(bsz, 1)).view(bsz, tgt_len, 1)
41
+ return sequence_mask
42
+
43
+ # Define RMSNorm layer for normalizing hidden states and stabilizing training process
44
+ class RMSNorm(nn.Module):
45
+ def __init__(self, hidden_size, eps=1e-6):
46
+ super().__init__()
47
+ self.weight = nn.Parameter(torch.ones(hidden_size))
48
+ self.variance_epsilon = eps
49
+
50
+ def forward(self, hidden_states):
51
+ variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
52
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
53
+ if self.weight.dtype in [torch.float16, torch.bfloat16]:
54
+ hidden_states = hidden_states.to(self.weight.dtype)
55
+ return self.weight * hidden_states
56
+
57
+ # Modified variable-length attention mechanism, supporting FP32 with unified interface
58
+ class VarLenAttention(nn.Module):
59
+ def __init__(self, embed_dim, num_heads, causal=False, dropout=0.0):
60
+ """
61
+ Initialize variable-length attention module.
62
+
63
+ Parameters:
64
+ embed_dim (int): Embedding dimension (model's hidden dimension)
65
+ num_heads (int): Number of attention heads
66
+ causal (bool): Whether to enable causal attention (only attend to current and previous positions)
67
+ dropout (float): Attention dropout probability
68
+ """
69
+ super().__init__()
70
+ self.embed_dim = embed_dim
71
+ self.num_heads = num_heads
72
+ self.head_dim = embed_dim // num_heads
73
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
74
+ self.causal = causal
75
+ self.dropout = nn.Dropout(dropout)
76
+ self.scaling = self.head_dim ** -0.5 # Scaling factor
77
+
78
+ # Linear projection layers for Q, K, V and output
79
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=False)
80
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
81
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
82
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
83
+
84
+ def _create_attention_mask(self, seq_len, max_len, device, dtype):
85
+ """
86
+ Create attention mask supporting variable-length sequences and causality.
87
+
88
+ Parameters:
89
+ seq_len (torch.Tensor): Sequence length for each sample, shape [bsz]
90
+ max_len (int): Maximum sequence length in the batch
91
+ device: Device for tensor creation
92
+ dtype: Data type for mask values
93
+
94
+ Returns:
95
+ mask (torch.Tensor): Attention mask, shape [bsz, 1, max_len, max_len], invalid positions set to minimum value
96
+ """
97
+ bsz = seq_len.size(0)
98
+ # Initialize mask as 1 (valid positions)
99
+ mask = torch.ones(bsz, 1, max_len, max_len, device=device, dtype=dtype)
100
+
101
+ # Generate sequence indices
102
+ seq_indices = torch.arange(max_len, device=device).unsqueeze(0) # [1, max_len]
103
+ seq_len_expanded = seq_len.unsqueeze(1) # [bsz, 1]
104
+
105
+ # Mark valid positions (less than seq_len)
106
+ valid_mask = seq_indices < seq_len_expanded.unsqueeze(-1) # [bsz, 1, max_len]
107
+ mask = mask * (valid_mask.unsqueeze(2) & valid_mask.unsqueeze(3)).to(dtype) # [bsz, 1, max_len, max_len]
108
+
109
+ # If causal attention, add upper triangular mask
110
+ if self.causal:
111
+ causal_mask = torch.triu(torch.ones(max_len, max_len, device=device, dtype=torch.bool), diagonal=1)
112
+ mask = mask * (~causal_mask.unsqueeze(0).unsqueeze(1)).to(dtype) # Keep only lower triangular part
113
+
114
+ # Set invalid positions (0) to dtype's minimum value
115
+ mask = mask + (1.0 - mask) * torch.finfo(dtype).min # Valid positions unchanged, invalid positions to minimum value
116
+ return mask
117
+
118
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
119
+ """
120
+ Forward propagation, input and output are [bsz, max_len, embed_dim].
121
+
122
+ Parameters:
123
+ hidden_states (torch.Tensor): Input hidden states, shape [bsz, max_len, embed_dim]
124
+ seq_len (torch.Tensor): Sequence length for each sample, shape [bsz]
125
+
126
+ Returns:
127
+ attn_output (torch.Tensor): Attention output, shape [bsz, max_len, embed_dim]
128
+ """
129
+ bsz, max_len, _ = hidden_states.size()
130
+
131
+ # Project to Q, K, V
132
+ query = self.q_proj(hidden_states) * self.scaling # [bsz, max_len, embed_dim]
133
+ key = self.k_proj(hidden_states) # [bsz, max_len, embed_dim]
134
+ value = self.v_proj(hidden_states) # [bsz, max_len, embed_dim]
135
+
136
+ # Reshape to multi-head form
137
+ query = query.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
138
+ key = key.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
139
+ value = value.view(bsz, max_len, self.num_heads, self.head_dim).transpose(1, 2) # [bsz, num_heads, max_len, head_dim]
140
+
141
+ # Calculate attention scores
142
+ attn_scores = torch.matmul(query, key.transpose(-1, -2)) # [bsz, num_heads, max_len, max_len]
143
+
144
+ # Generate attention mask
145
+ attn_mask = self._create_attention_mask(seq_len, max_len, hidden_states.device, attn_scores.dtype) # [bsz, 1, max_len, max_len]
146
+ # Apply mask (additive form, consistent with HubertEncoder)
147
+ attn_scores = attn_scores + attn_mask # Invalid positions set to very small value
148
+
149
+ # Softmax calculate attention weights
150
+ attn_weights = F.softmax(attn_scores, dim=-1) # [bsz, num_heads, max_len, max_len]
151
+ attn_weights = self.dropout(attn_weights)
152
+
153
+ # Calculate attention output
154
+ attn_output = torch.matmul(attn_weights, value) # [bsz, num_heads, max_len, head_dim]
155
+ attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, max_len, self.embed_dim) # [bsz, max_len, embed_dim]
156
+
157
+ # Output projection
158
+ attn_output = self.out_proj(attn_output) # [bsz, max_len, embed_dim]
159
+
160
+ return attn_output
161
+
162
+ # Define Transformer layer containing attention mechanism and feedforward network for feature extraction and transformation
163
+ class OmniWhisperTransformerLayer(nn.Module):
164
+ def __init__(self, activation_function="gelu", d_model=1280, attention_heads=20, ffn_dim=5120, causal=False, ln_type="LayerNorm", attn_type="varlen"):
165
+ super().__init__()
166
+ self.embed_dim = d_model
167
+ # Only keep varlen attention mechanism
168
+ if attn_type != "varlen":
169
+ raise ValueError(f"Unknown attn_type: {attn_type}. Only 'varlen' is supported.")
170
+ self.self_attn = VarLenAttention(self.embed_dim, attention_heads, causal)
171
+ if ln_type == "LayerNorm":
172
+ self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim)
173
+ elif ln_type == "RMSNorm":
174
+ self.self_attn_layer_norm = RMSNorm(self.embed_dim)
175
+ else:
176
+ raise ValueError(f"Unknown ln_type: {ln_type}")
177
+ self.activation_fn = ACT2FN[activation_function]
178
+ self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
179
+ self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
180
+ if ln_type == "LayerNorm":
181
+ self.final_layer_norm = nn.LayerNorm(self.embed_dim)
182
+ elif ln_type == "RMSNorm":
183
+ self.final_layer_norm = RMSNorm(self.embed_dim)
184
+ else:
185
+ raise ValueError(f"Unknown ln_type: {ln_type}")
186
+
187
+ def forward(self, hidden_states: torch.Tensor, seq_len: torch.Tensor) -> torch.Tensor:
188
+ residual = hidden_states # [bsz, max_len, embed_dim]
189
+ hidden_states = self.self_attn_layer_norm(hidden_states)
190
+ # from torch.cuda.amp import autocast
191
+ # print(f"{residual.dtype = }")
192
+ # print(f"Autocast enabled: {torch.is_autocast_enabled():}")
193
+ # print(f"after layernorm {hidden_states.dtype = }")
194
+ hidden_states = self.self_attn(hidden_states, seq_len) # [bsz, max_len, embed_dim]
195
+ hidden_states = residual + hidden_states
196
+ residual = hidden_states
197
+ hidden_states = self.final_layer_norm(hidden_states)
198
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
199
+ hidden_states = self.fc2(hidden_states)
200
+ hidden_states = residual + hidden_states
201
+ if (hidden_states.dtype == torch.float16 or hidden_states.dtype == torch.bfloat16) and \
202
+ (torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any()):
203
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
204
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
205
+ return hidden_states
206
+
207
+ # Define audio encoder to convert input audio features to hidden state representation
208
+ class OmniAudioEncoder(nn.Module):
209
+ def __init__(
210
+ self,
211
+ num_mel_bins=128, # Input feature Mel band number, usually the dimension of Mel spectrogram
212
+ sampling_rate=16000, # Audio sampling rate, unit Hz
213
+ hop_length=160, # Frame shift length (sample number) when calculating Mel spectrogram
214
+ stride_size=2, # Convolution layer step, used for downsampling
215
+ kernel_size=3, # Convolution kernel size, controlling receptive field
216
+ d_model=1280, # Model's hidden state dimension (embedding dimension)
217
+ scale_embedding=True, # Whether to scale embedding (usually used for stabilizing training)
218
+ max_audio_seconds=30, # Maximum audio duration supported (seconds)
219
+ encoder_layers=32, # Transformer encoder layer number
220
+ encoder_attention_heads=20, # Attention head number for each Transformer layer
221
+ encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
222
+ activation_function="gelu", # Activation function type, default GELU
223
+ attn_type="varlen" # New parameter, select attention mechanism type
224
+ ):
225
+ super().__init__()
226
+ # Calculate maximum sequence length: Convert sampling rate to frame number after considering downsampling step
227
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
228
+ # Embedding scaling factor, if enabled sqrt(d_model), otherwise 1.0
229
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
230
+ self.num_mel_bins = num_mel_bins # Save Mel band number
231
+ self.d_model = d_model # Save hidden state dimension
232
+ self.stride_size = stride_size
233
+
234
+ # First convolution layer: Convert Mel spectrogram features (num_mel_bins) to hidden dimension (d_model)
235
+ self.conv1 = nn.Conv1d(num_mel_bins, d_model, kernel_size=kernel_size, padding=1)
236
+ # Second convolution layer: Apply downsampling with stride_size
237
+ self.conv2 = nn.Conv1d(d_model, d_model, kernel_size=kernel_size, stride=stride_size, padding=1)
238
+
239
+ # Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model)
240
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
241
+
242
+ # Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network
243
+ self.layers = nn.ModuleList([
244
+ OmniWhisperTransformerLayer(
245
+ activation_function=activation_function,
246
+ d_model=d_model,
247
+ attention_heads=encoder_attention_heads,
248
+ ffn_dim=encoder_ffn_dim,
249
+ causal=False, # Encoder does not need causal attention
250
+ attn_type=attn_type # Pass attention type
251
+ ) for _ in range(encoder_layers)
252
+ ])
253
+
254
+ # Last layer normalization for stable output
255
+ self.layer_norm = nn.LayerNorm(d_model)
256
+
257
+ def forward(self, input_features, input_length, output_hidden_states=False):
258
+ """
259
+ Forward propagation function to convert input audio features to hidden state representation
260
+
261
+ Parameters:
262
+ input_features (torch.Tensor): Input Mel spectrogram features, shape [bsz, num_mel_bins, seq_len]
263
+ input_length (torch.Tensor): Input sequence length for each sample, shape [bsz]
264
+ output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False
265
+
266
+ Returns:
267
+ if output_hidden_states is False:
268
+ hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, d_model, tgt_len]
269
+ output_length (torch.Tensor): Output sequence length for each sample, shape [bsz]
270
+ else:
271
+ hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, d_model, tgt_len]
272
+ output_length (torch.Tensor): Output sequence length for each sample, shape [bsz]
273
+ hidden_states_all_layers (tuple): Tuple containing hidden states for each layer, including initial input
274
+ """
275
+ # Ensure input feature data type consistent with convolution layer weights
276
+ input_features = input_features.to(self.conv1.weight.dtype) # (B, D, T)
277
+
278
+ # First layer convolution + GELU activation, Convert Mel spectrogram to hidden states
279
+ inputs_embeds = nn.functional.gelu(self.conv1(input_features)) # (B, D, T)
280
+
281
+ # Second layer convolution + GELU activation, Apply downsampling with stride_size
282
+ inputs_embeds = nn.functional.gelu(self.conv2(inputs_embeds)) # (B, D, T)
283
+
284
+ # Calculate output length: Result after downsampling with stride_size
285
+ output_length = (input_length // self.stride_size).long() # (B,)
286
+
287
+ # Adjust dimension order to [bsz, seq_len, d_model] for Transformer input
288
+ hidden_states = inputs_embeds.permute(0, 2, 1) # (B, T, D)
289
+
290
+ # Get batch size and target sequence length
291
+ bsz, tgt_len, _ = hidden_states.size()
292
+
293
+ # According to current sequence length, take or use complete positional embedding
294
+ if tgt_len < self.positional_embedding.shape[0]:
295
+ current_positional_embedding = self.positional_embedding[:tgt_len]
296
+ else:
297
+ current_positional_embedding = self.positional_embedding
298
+
299
+ # Add input embedding to positional embedding, convert to float to avoid precision issues
300
+ hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype)
301
+
302
+ # Generate sequence mask for processing variable-length sequence
303
+ attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1]
304
+
305
+ # Initialize hidden states list for storing output for each layer (if needed)
306
+ hidden_states_all_layers = () if output_hidden_states else None
307
+
308
+ # Process hidden states through Transformer encoder layer by layer
309
+ for encoder_layer in self.layers:
310
+ if output_hidden_states:
311
+ hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
312
+ hidden_states = encoder_layer(hidden_states, output_length) # [bsz, tgt_len, d_model]
313
+
314
+ # Normalize hidden states
315
+ hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model]
316
+ if output_hidden_states:
317
+ hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
318
+
319
+ # Use mask to zero out padding parts and ensure output only retains valid data
320
+ hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model]
321
+ hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, tgt_len]
322
+
323
+ if not output_hidden_states:
324
+ return hidden_states, output_length
325
+ else:
326
+ return hidden_states, output_length, hidden_states_all_layers
327
+
328
+ # Define audio decoder to convert hidden states to Mel spectrogram
329
+ class OmniAudioDecoder(nn.Module):
330
+ def __init__(
331
+ self,
332
+ num_mel_bins=128,
333
+ sampling_rate=16000,
334
+ hop_length=160,
335
+ stride_size=2,
336
+ kernel_size=3,
337
+ d_model=1280,
338
+ scale_embedding=True,
339
+ max_audio_seconds=30,
340
+ decoder_layers=32,
341
+ decoder_attention_heads=20,
342
+ decoder_ffn_dim=5120,
343
+ activation_function="gelu",
344
+ attn_type="varlen" # New parameter, select attention mechanism type
345
+ ):
346
+ super().__init__()
347
+ self.max_source_positions = (max_audio_seconds * sampling_rate // hop_length) // stride_size
348
+ self.embed_scale = math.sqrt(d_model) if scale_embedding else 1.0
349
+ self.num_mel_bins = num_mel_bins
350
+ self.d_model = d_model
351
+ self.stride_size = stride_size
352
+
353
+ # Correct transpose convolution layer to ensure output length close to stride_size times
354
+ self.deconv1 = nn.ConvTranspose1d(
355
+ d_model,
356
+ d_model,
357
+ kernel_size=kernel_size,
358
+ stride=stride_size,
359
+ padding=0, # Do not fill input side
360
+ output_padding=0 # Can be adjusted to precisely control length
361
+ )
362
+ self.deconv2 = nn.ConvTranspose1d(
363
+ d_model,
364
+ num_mel_bins,
365
+ kernel_size=kernel_size,
366
+ stride=1, # Only convert channels, do not change length
367
+ padding=0
368
+ )
369
+
370
+ # Positional embedding remains consistent
371
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model)) # (T, D)
372
+
373
+ # Transformer decoder layer
374
+ self.layers = nn.ModuleList([
375
+ OmniWhisperTransformerLayer(
376
+ activation_function=activation_function,
377
+ d_model=d_model,
378
+ attention_heads=decoder_attention_heads,
379
+ ffn_dim=decoder_ffn_dim,
380
+ causal=False, # Decoder uses causal attention
381
+ attn_type=attn_type # Pass attention type
382
+ ) for _ in range(decoder_layers)
383
+ ])
384
+ self.layer_norm = nn.LayerNorm(d_model)
385
+
386
+ def forward(self, hidden_states, input_length): # (B, D, T)
387
+ # Input is hidden state output from encoder
388
+ hidden_states = hidden_states.transpose(1, 2) # (B, T, D)
389
+ bsz, tgt_len, _ = hidden_states.size()
390
+
391
+ # Add positional embedding
392
+ if tgt_len < self.positional_embedding.shape[0]:
393
+ current_positional_embedding = self.positional_embedding[:tgt_len] # (T, D)
394
+ else:
395
+ current_positional_embedding = self.positional_embedding
396
+ hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # (B, T, D)
397
+
398
+ # Generate sequence mask
399
+ attention_mask = get_sequence_mask(hidden_states, input_length) # [bsz, tgt_len, 1]
400
+
401
+ # Process through decoder layer
402
+ for decoder_layer in self.layers:
403
+ hidden_states = decoder_layer(hidden_states, input_length) # [bsz, tgt_len, d_model]
404
+
405
+ # Final layer normalization
406
+ hidden_states = self.layer_norm(hidden_states) # [bsz, tgt_len, d_model]
407
+
408
+ # Use mask to zero out padding parts
409
+ hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, tgt_len, d_model]
410
+
411
+ # Process through transpose convolution layer to reconstruct audio features
412
+ hidden_states = hidden_states.permute(0, 2, 1) # (B, D, T)
413
+ output_features = nn.functional.gelu(self.deconv1(hidden_states)) # (B, D, T)
414
+ output_features = nn.functional.gelu(self.deconv2(output_features)) # (B, D, T)
415
+
416
+ # If strictly stride_size times length is needed, can trim extra parts
417
+ expected_length = tgt_len * self.stride_size
418
+ if output_features.size(2) > expected_length:
419
+ output_features = output_features[:, :, :expected_length]
420
+
421
+ output_length = input_length * self.stride_size
422
+ # Output shape: [bsz, num_mel_bins, seq_len]
423
+ return output_features, output_length
424
+
425
+ # The following part remains unchanged
426
+ class ResidualDownConv(nn.Module):
427
+ def __init__(self, d_model=1280, avg_pooler=4):
428
+ """
429
+ Downsampling module containing residual connection and convolution operation
430
+
431
+ Parameters:
432
+ d_model (int): Input and output hidden dimension
433
+ avg_pooler (int): Downsampling factor (convolution step)
434
+ """
435
+ super().__init__()
436
+ self.d_model = d_model
437
+ self.avg_pooler = avg_pooler
438
+ self.intermediate_dim = d_model * avg_pooler
439
+
440
+ # Convolution layer for downsampling
441
+ self.gate_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
442
+ self.up_proj = nn.Conv1d(d_model, self.intermediate_dim, avg_pooler, avg_pooler, bias=False)
443
+
444
+ # Downsampled linear projection
445
+ self.down_proj = nn.Linear(self.intermediate_dim, self.intermediate_dim, bias=False)
446
+
447
+ # Activation function and layer normalization
448
+ self.act_fn = ACT2FN['silu']
449
+ self.layer_norm = nn.LayerNorm(self.intermediate_dim)
450
+
451
+ def forward(self, x, input_length):
452
+ """
453
+ Forward propagation, execute downsampling and residual processing
454
+
455
+ Parameters:
456
+ x (torch.Tensor): Input tensor, shape [B, D, T]
457
+
458
+ Returns:
459
+ res (torch.Tensor): Downsampled feature, shape [B, intermediate_dim, seq_len // avg_pooler]
460
+ valid_mask (torch.Tensor): Valid sequence mask
461
+ """
462
+ output_length = input_length // self.avg_pooler
463
+ x = x.transpose(1, 2) # (B, T, D)
464
+ batch_size, seq_len, _ = x.shape # (B, T, D)
465
+ if seq_len % self.avg_pooler != 0:
466
+ pad_size = self.avg_pooler - seq_len % self.avg_pooler
467
+ x = F.pad(x, (0, pad_size), "constant", 0)
468
+
469
+ xt = x.permute(0, 2, 1) # (B, D, T)
470
+ g = self.gate_proj(xt).permute(0, 2, 1) # (B, T, D)
471
+ u = self.up_proj(xt).permute(0, 2, 1) # (B, T, D)
472
+ x = x.reshape(batch_size, -1, self.intermediate_dim) # (B, T, D)
473
+
474
+ c = self.down_proj(self.act_fn(g) * u) # (B, T, D)
475
+ res = self.layer_norm(c + x) # (B, T, D)
476
+ res = res.transpose(1, 2) # (B, D, T)
477
+ return res, output_length # (B, D, T)
478
+
479
+
480
+ class UpConv(nn.Module):
481
+ def __init__(self, d_model=1280, stride=4):
482
+ """
483
+ Simple upsampling module using transpose convolution
484
+
485
+ Parameters:
486
+ d_model (int): Input and output hidden dimension
487
+ stride (int): Upsampling factor (transpose convolution step)
488
+ """
489
+ super().__init__()
490
+ self.d_model = d_model
491
+ self.stride = stride
492
+
493
+ # Simple transpose convolution layer to keep channel number consistent
494
+ self.up_conv = nn.ConvTranspose1d(
495
+ self.stride * d_model,
496
+ d_model,
497
+ kernel_size=stride,
498
+ stride=stride,
499
+ bias=False
500
+ )
501
+
502
+ def forward(self, x, input_length):
503
+ """
504
+ Forward propagation, execute upsampling
505
+
506
+ Parameters:
507
+ x (torch.Tensor): Input tensor, shape [B, D * stride, T]
508
+
509
+ Returns:
510
+ res (torch.Tensor): Upsampled feature, shape [B, D, T * stride]
511
+ """
512
+ # Directly apply transpose convolution
513
+ res = self.up_conv(x)
514
+ output_length = input_length * self.stride
515
+ return res, output_length
516
+
517
+
518
+ # Define Transformer encoder containing multiple Transformer layers for feature extraction and transformation
519
+ class Transformer(nn.Module):
520
+ def __init__(
521
+ self,
522
+ input_dim=1280, # Input feature dimension
523
+ d_model=1280, # Model's hidden state dimension (embedding dimension)
524
+ output_dim=1280, # Output feature dimension
525
+ max_source_positions=1500, # Maximum sequence length for positional embedding
526
+ encoder_layers=32, # Transformer encoder layer number
527
+ encoder_attention_heads=20, # Attention head number for each Transformer layer
528
+ encoder_ffn_dim=5120, # Intermediate dimension for feedforward network
529
+ activation_function="gelu", # Activation function type, default GELU
530
+ attn_type="varlen" # Attention mechanism type
531
+ ):
532
+ super().__init__()
533
+ self.input_dim = input_dim # Save input dimension
534
+ self.d_model = d_model # Save hidden state dimension
535
+ self.output_dim = output_dim # Save output dimension
536
+ self.max_source_positions = max_source_positions # Save maximum sequence length
537
+
538
+ # If input dimension and model dimension are not consistent, add input projection layer
539
+ if input_dim != d_model:
540
+ self.proj = nn.Linear(input_dim, d_model, bias=True)
541
+ else:
542
+ self.proj = None # No need for input projection layer
543
+
544
+ # Register positional embedding buffer, using sine function to generate, shape (max_source_positions, d_model)
545
+ self.register_buffer("positional_embedding", sinusoids(self.max_source_positions, d_model))
546
+
547
+ # Create Transformer encoder layer list, each layer contains attention mechanism and feedforward network
548
+ self.layers = nn.ModuleList([
549
+ OmniWhisperTransformerLayer(
550
+ activation_function=activation_function,
551
+ d_model=d_model,
552
+ attention_heads=encoder_attention_heads,
553
+ ffn_dim=encoder_ffn_dim,
554
+ causal=False, # Encoder does not need causal attention
555
+ attn_type=attn_type # Pass attention type
556
+ ) for _ in range(encoder_layers)
557
+ ])
558
+
559
+ # Last layer normalization for stable output
560
+ self.layer_norm = nn.LayerNorm(d_model)
561
+
562
+ # If output dimension and model dimension are not consistent, add output projection layer
563
+ if output_dim != d_model:
564
+ self.out_proj = nn.Linear(d_model, output_dim, bias=True)
565
+ else:
566
+ self.out_proj = None # No need for output projection layer
567
+
568
+ def forward(self, input_features: torch.Tensor, input_length: torch.Tensor, output_hidden_states: bool = False):
569
+ """
570
+ Forward propagation function to convert input features through Transformer layer to hidden state representation
571
+
572
+ Parameters:
573
+ input_features (torch.Tensor): Input features, shape [bsz, input_dim, seq_len] (B, input_dim, T)
574
+ input_length (torch.Tensor): Input sequence length for each sample, shape [bsz]
575
+ output_hidden_states (bool, optional): Whether to return hidden states for each layer, default False
576
+
577
+ Returns:
578
+ if output_hidden_states is False:
579
+ hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, output_dim, seq_len] (B, output_dim, T)
580
+ output_length (torch.Tensor): Output sequence length for each sample, shape [bsz]
581
+ else:
582
+ hidden_states (torch.Tensor): Encoded hidden states, shape [bsz, output_dim, seq_len] (B, output_dim, T)
583
+ output_length (torch.Tensor): Output sequence length for each sample, shape [bsz]
584
+ hidden_states_all_layers (tuple): Tuple containing hidden states for each layer, each shape [bsz, seq_len, d_model]
585
+ """
586
+ # Output length is the same as input length, Transformer does not change sequence length
587
+ output_length = input_length.long() # [bsz]
588
+
589
+ # If there is input projection layer, map input features from input_dim to d_model
590
+ if self.proj is not None:
591
+ hidden_states = self.proj(input_features.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, d_model, seq_len] (B, D, T)
592
+ else:
593
+ hidden_states = input_features # [bsz, d_model, seq_len] (B, D, T)
594
+
595
+ # Adjust input dimension order to [bsz, seq_len, d_model] for Transformer input
596
+ hidden_states = hidden_states.permute(0, 2, 1) # [bsz, seq_len, d_model] (B, T, D)
597
+
598
+ # Get batch size and target sequence length
599
+ bsz, tgt_len, _ = hidden_states.size()
600
+
601
+ # According to current sequence length, take or use complete positional embedding
602
+ if tgt_len < self.positional_embedding.shape[0]:
603
+ current_positional_embedding = self.positional_embedding[:tgt_len] # [tgt_len, d_model]
604
+ else:
605
+ current_positional_embedding = self.positional_embedding # [max_source_positions, d_model]
606
+
607
+ # Add input features to positional embedding, convert to float to avoid precision issues
608
+ hidden_states = (hidden_states.to(torch.float32) + current_positional_embedding).to(hidden_states.dtype) # [bsz, seq_len, d_model]
609
+
610
+ # Generate sequence mask for processing variable-length sequence
611
+ attention_mask = get_sequence_mask(hidden_states, output_length) # [bsz, tgt_len, 1]
612
+
613
+ # Initialize hidden states list for storing output for each layer (if needed)
614
+ hidden_states_all_layers = () if output_hidden_states else None
615
+
616
+ # Process hidden states through Transformer encoder layer by layer
617
+ for encoder_layer in self.layers:
618
+ if output_hidden_states:
619
+ hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
620
+ hidden_states = encoder_layer(hidden_states, output_length) # [bsz, seq_len, d_model]
621
+
622
+ # Normalize hidden states
623
+ hidden_states = self.layer_norm(hidden_states) # [bsz, seq_len, d_model]
624
+ if output_hidden_states:
625
+ hidden_states_all_layers = hidden_states_all_layers + (hidden_states,)
626
+
627
+ # Use mask to zero out padding parts and ensure output only retains valid data
628
+ hidden_states = torch.where(attention_mask, hidden_states, 0) # [bsz, seq_len, d_model]
629
+
630
+ # Adjust dimension order to [bsz, d_model, seq_len]
631
+ hidden_states = hidden_states.transpose(1, 2) # [bsz, d_model, seq_len] (B, D, T)
632
+
633
+ # If there is output projection layer, map hidden states from d_model to output_dim
634
+ if self.out_proj is not None:
635
+ hidden_states = self.out_proj(hidden_states.permute(0, 2, 1)).permute(0, 2, 1) # [bsz, output_dim, seq_len] (B, output_dim, T)
636
+
637
+ if not output_hidden_states:
638
+ return hidden_states, output_length
639
+ else:
640
+ return hidden_states, output_length, hidden_states_all_layers
641
+
642
+
643
+ def safe_log(x: torch.Tensor, clip_val: float = 1e-7) -> torch.Tensor:
644
+ """
645
+ Computes the element-wise logarithm of the input tensor with clipping to avoid near-zero values.
646
+
647
+ Args:
648
+ x (Tensor): Input tensor.
649
+ clip_val (float, optional): Minimum value to clip the input tensor. Defaults to 1e-7.
650
+
651
+ Returns:
652
+ Tensor: Element-wise logarithm of the input tensor with clipping applied.
653
+ """
654
+ return torch.log(torch.clip(x, min=clip_val))
655
+
656
+
657
+ def symlog(x: torch.Tensor) -> torch.Tensor:
658
+ return torch.sign(x) * torch.log1p(x.abs())
659
+
660
+
661
+ def symexp(x: torch.Tensor) -> torch.Tensor:
662
+ return torch.sign(x) * (torch.exp(x.abs()) - 1)
663
+
664
+
665
+ class STFT(nn.Module):
666
+ def __init__(
667
+ self,
668
+ n_fft: int,
669
+ hop_length: int,
670
+ win_length: int,
671
+ center=True,
672
+ ):
673
+ super().__init__()
674
+ self.center = center
675
+ self.n_fft = n_fft
676
+ self.hop_length = hop_length
677
+ self.win_length = win_length
678
+ window = torch.hann_window(win_length)
679
+ self.register_buffer("window", window)
680
+
681
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
682
+ # x: (B, T * hop_length)
683
+
684
+ if not self.center:
685
+ pad = self.win_length - self.hop_length
686
+ x = torch.nn.functional.pad(x, (pad // 2, pad // 2), mode="reflect")
687
+
688
+ stft_spec = torch.stft(
689
+ x,
690
+ self.n_fft,
691
+ hop_length=self.hop_length,
692
+ win_length=self.win_length,
693
+ window=self.window,
694
+ center=self.center,
695
+ return_complex=False,
696
+ ) # (B, n_fft // 2 + 1, T, 2)
697
+
698
+ rea = stft_spec[:, :, :, 0] # (B, n_fft // 2 + 1, T, 2)
699
+ imag = stft_spec[:, :, :, 1] # (B, n_fft // 2 + 1, T, 2)
700
+
701
+ log_mag = torch.log(
702
+ torch.abs(torch.sqrt(torch.pow(rea, 2) + torch.pow(imag, 2))) + 1e-5
703
+ ) # (B, n_fft // 2 + 1, T)
704
+ phase = torch.atan2(imag, rea) # (B, n_fft // 2 + 1, T)
705
+
706
+ return log_mag, phase
707
+
708
+
709
+ class ISTFT(nn.Module):
710
+ """
711
+ Custom implementation of ISTFT since torch.istft doesn't allow custom padding (other than `center=True`) with
712
+ windowing. This is because the NOLA (Nonzero Overlap Add) check fails at the edges.
713
+ See issue: https://github.com/pytorch/pytorch/issues/62323
714
+ Specifically, in the context of neural vocoding we are interested in "same" padding analogous to CNNs.
715
+ The NOLA constraint is met as we trim padded samples anyway.
716
+
717
+ Args:
718
+ n_fft (int): Size of Fourier transform.
719
+ hop_length (int): The distance between neighboring sliding window frames.
720
+ win_length (int): The size of window frame and STFT filter.
721
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
722
+ """
723
+
724
+ def __init__(
725
+ self, n_fft: int, hop_length: int, win_length: int, padding: str = "same"
726
+ ):
727
+ super().__init__()
728
+ if padding not in ["center", "same"]:
729
+ raise ValueError("Padding must be 'center' or 'same'.")
730
+ self.padding = padding
731
+ self.n_fft = n_fft
732
+ self.hop_length = hop_length
733
+ self.win_length = win_length
734
+ window = torch.hann_window(win_length)
735
+ self.register_buffer("window", window)
736
+
737
+ def forward(self, spec: torch.Tensor) -> torch.Tensor:
738
+ """
739
+ Compute the Inverse Short Time Fourier Transform (ISTFT) of a complex spectrogram.
740
+
741
+ Args:
742
+ spec (Tensor): Input complex spectrogram of shape (B, N, T), where B is the batch size,
743
+ N is the number of frequency bins, and T is the number of time frames.
744
+
745
+ Returns:
746
+ Tensor: Reconstructed time-domain signal of shape (B, L), where L is the length of the output signal.
747
+ """
748
+ if self.padding == "center":
749
+ # Fallback to pytorch native implementation
750
+ return torch.istft(
751
+ spec,
752
+ self.n_fft,
753
+ self.hop_length,
754
+ self.win_length,
755
+ self.window,
756
+ center=True,
757
+ )
758
+ elif self.padding == "same":
759
+ pad = (self.win_length - self.hop_length) // 2
760
+ else:
761
+ raise ValueError("Padding must be 'center' or 'same'.")
762
+
763
+ assert spec.dim() == 3, "Expected a 3D tensor as input"
764
+ B, N, T = spec.shape
765
+
766
+ # Inverse FFT
767
+ ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward")
768
+ ifft = ifft * self.window[None, :, None]
769
+
770
+ # Overlap and Add
771
+ output_size = (T - 1) * self.hop_length + self.win_length
772
+ y = torch.nn.functional.fold(
773
+ ifft,
774
+ output_size=(1, output_size),
775
+ kernel_size=(1, self.win_length),
776
+ stride=(1, self.hop_length),
777
+ )[:, 0, 0, pad:-pad]
778
+
779
+ # Window envelope
780
+ window_sq = self.window.square().expand(1, T, -1).transpose(1, 2)
781
+ window_envelope = torch.nn.functional.fold(
782
+ window_sq,
783
+ output_size=(1, output_size),
784
+ kernel_size=(1, self.win_length),
785
+ stride=(1, self.hop_length),
786
+ ).squeeze()[pad:-pad]
787
+
788
+ # Normalize
789
+ assert (window_envelope > 1e-11).all()
790
+ y = y / window_envelope
791
+
792
+ return y
793
+
794
+
795
+ class MDCT(nn.Module):
796
+ """
797
+ Modified Discrete Cosine Transform (MDCT) module.
798
+
799
+ Args:
800
+ frame_len (int): Length of the MDCT frame.
801
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
802
+ """
803
+
804
+ def __init__(self, frame_len: int, padding: str = "same"):
805
+ super().__init__()
806
+ if padding not in ["center", "same"]:
807
+ raise ValueError("Padding must be 'center' or 'same'.")
808
+ self.padding = padding
809
+ self.frame_len = frame_len
810
+ N = frame_len // 2
811
+ n0 = (N + 1) / 2
812
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
813
+ self.register_buffer("window", window)
814
+
815
+ pre_twiddle = torch.exp(-1j * torch.pi * torch.arange(frame_len) / frame_len)
816
+ post_twiddle = torch.exp(-1j * torch.pi * n0 * (torch.arange(N) + 0.5) / N)
817
+ # view_as_real: NCCL Backend does not support ComplexFloat data type
818
+ # https://github.com/pytorch/pytorch/issues/71613
819
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
820
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
821
+
822
+ def forward(self, audio: torch.Tensor) -> torch.Tensor:
823
+ """
824
+ Apply the Modified Discrete Cosine Transform (MDCT) to the input audio.
825
+
826
+ Args:
827
+ audio (Tensor): Input audio waveform of shape (B, T), where B is the batch size
828
+ and T is the length of the audio.
829
+
830
+ Returns:
831
+ Tensor: MDCT coefficients of shape (B, L, N), where L is the number of output frames
832
+ and N is the number of frequency bins.
833
+ """
834
+ if self.padding == "center":
835
+ audio = torch.nn.functional.pad(
836
+ audio, (self.frame_len // 2, self.frame_len // 2)
837
+ )
838
+ elif self.padding == "same":
839
+ # hop_length is 1/2 frame_len
840
+ audio = torch.nn.functional.pad(
841
+ audio, (self.frame_len // 4, self.frame_len // 4)
842
+ )
843
+ else:
844
+ raise ValueError("Padding must be 'center' or 'same'.")
845
+
846
+ x = audio.unfold(-1, self.frame_len, self.frame_len // 2)
847
+ N = self.frame_len // 2
848
+ x = x * self.window.expand(x.shape)
849
+ X = torch.fft.fft(
850
+ x * view_as_complex(self.pre_twiddle).expand(x.shape), dim=-1
851
+ )[..., :N]
852
+ res = X * view_as_complex(self.post_twiddle).expand(X.shape) * np.sqrt(1 / N)
853
+ return torch.real(res) * np.sqrt(2)
854
+
855
+
856
+ class IMDCT(nn.Module):
857
+ """
858
+ Inverse Modified Discrete Cosine Transform (IMDCT) module.
859
+
860
+ Args:
861
+ frame_len (int): Length of the MDCT frame.
862
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
863
+ """
864
+
865
+ def __init__(self, frame_len: int, padding: str = "same"):
866
+ super().__init__()
867
+ if padding not in ["center", "same"]:
868
+ raise ValueError("Padding must be 'center' or 'same'.")
869
+ self.padding = padding
870
+ self.frame_len = frame_len
871
+ N = frame_len // 2
872
+ n0 = (N + 1) / 2
873
+ window = torch.from_numpy(scipy.signal.cosine(frame_len)).float()
874
+ self.register_buffer("window", window)
875
+
876
+ pre_twiddle = torch.exp(1j * torch.pi * n0 * torch.arange(N * 2) / N)
877
+ post_twiddle = torch.exp(1j * torch.pi * (torch.arange(N * 2) + n0) / (N * 2))
878
+ self.register_buffer("pre_twiddle", view_as_real(pre_twiddle))
879
+ self.register_buffer("post_twiddle", view_as_real(post_twiddle))
880
+
881
+ def forward(self, X: torch.Tensor) -> torch.Tensor:
882
+ """
883
+ Apply the Inverse Modified Discrete Cosine Transform (IMDCT) to the input MDCT coefficients.
884
+
885
+ Args:
886
+ X (Tensor): Input MDCT coefficients of shape (B, L, N), where B is the batch size,
887
+ L is the number of frames, and N is the number of frequency bins.
888
+
889
+ Returns:
890
+ Tensor: Reconstructed audio waveform of shape (B, T), where T is the length of the audio.
891
+ """
892
+ B, L, N = X.shape
893
+ Y = torch.zeros((B, L, N * 2), dtype=X.dtype, device=X.device)
894
+ Y[..., :N] = X
895
+ Y[..., N:] = -1 * torch.conj(torch.flip(X, dims=(-1,)))
896
+ y = torch.fft.ifft(
897
+ Y * view_as_complex(self.pre_twiddle).expand(Y.shape), dim=-1
898
+ )
899
+ y = (
900
+ torch.real(y * view_as_complex(self.post_twiddle).expand(y.shape))
901
+ * np.sqrt(N)
902
+ * np.sqrt(2)
903
+ )
904
+ result = y * self.window.expand(y.shape)
905
+ output_size = (1, (L + 1) * N)
906
+ audio = torch.nn.functional.fold(
907
+ result.transpose(1, 2),
908
+ output_size=output_size,
909
+ kernel_size=(1, self.frame_len),
910
+ stride=(1, self.frame_len // 2),
911
+ )[:, 0, 0, :]
912
+
913
+ if self.padding == "center":
914
+ pad = self.frame_len // 2
915
+ elif self.padding == "same":
916
+ pad = self.frame_len // 4
917
+ else:
918
+ raise ValueError("Padding must be 'center' or 'same'.")
919
+
920
+ audio = audio[:, pad:-pad]
921
+ return audio
922
+
923
+
924
+ class FourierHead(nn.Module):
925
+ """Base class for inverse fourier modules."""
926
+
927
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
928
+ """
929
+ Args:
930
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
931
+ L is the sequence length, and H denotes the model dimension.
932
+
933
+ Returns:
934
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
935
+ """
936
+ raise NotImplementedError("Subclasses must implement the forward method.")
937
+
938
+
939
+ class ISTFTHead(FourierHead):
940
+ """
941
+ ISTFT Head module for predicting STFT complex coefficients.
942
+
943
+ Args:
944
+ dim (int): Hidden dimension of the model.
945
+ n_fft (int): Size of Fourier transform.
946
+ hop_length (int): The distance between neighboring sliding window frames, which should align with
947
+ the resolution of the input features.
948
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
949
+ """
950
+
951
+ def __init__(self, dim: int, n_fft: int, hop_length: int, padding: str = "same"):
952
+ super().__init__()
953
+ out_dim = n_fft + 2
954
+ self.out = torch.nn.Linear(dim, out_dim)
955
+ self.istft = ISTFT(
956
+ n_fft=n_fft, hop_length=hop_length, win_length=n_fft, padding=padding
957
+ )
958
+
959
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
960
+ """
961
+ Forward pass of the ISTFTHead module.
962
+
963
+ Args:
964
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
965
+ L is the sequence length, and H denotes the model dimension.
966
+
967
+ Returns:
968
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
969
+ """
970
+ x = self.out(x).transpose(1, 2)
971
+ mag, p = x.chunk(2, dim=1)
972
+ mag = torch.exp(mag)
973
+ mag = torch.clip(
974
+ mag, max=1e2
975
+ ) # safeguard to prevent excessively large magnitudes
976
+ # wrapping happens here. These two lines produce real and imaginary value
977
+ x = torch.cos(p)
978
+ y = torch.sin(p)
979
+ # recalculating phase here does not produce anything new
980
+ # only costs time
981
+ # phase = torch.atan2(y, x)
982
+ # S = mag * torch.exp(phase * 1j)
983
+ # better directly produce the complex value
984
+ original_dtype = x.dtype
985
+ S = mag.float() * (x.float() + 1j * y.float())
986
+ audio = self.istft(S)
987
+ audio = audio.to(original_dtype)
988
+ return audio
989
+
990
+
991
+ class IMDCTSymExpHead(FourierHead):
992
+ """
993
+ IMDCT Head module for predicting MDCT coefficients with symmetric exponential function
994
+
995
+ Args:
996
+ dim (int): Hidden dimension of the model.
997
+ mdct_frame_len (int): Length of the MDCT frame.
998
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
999
+ sample_rate (int, optional): The sample rate of the audio. If provided, the last layer will be initialized
1000
+ based on perceptual scaling. Defaults to None.
1001
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
1002
+ """
1003
+
1004
+ def __init__(
1005
+ self,
1006
+ dim: int,
1007
+ mdct_frame_len: int,
1008
+ padding: str = "same",
1009
+ sample_rate: Optional[int] = None,
1010
+ clip_audio: bool = False,
1011
+ ):
1012
+ super().__init__()
1013
+ out_dim = mdct_frame_len // 2
1014
+ self.out = nn.Linear(dim, out_dim)
1015
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
1016
+ self.clip_audio = clip_audio
1017
+
1018
+ if sample_rate is not None:
1019
+ # optionally init the last layer following mel-scale
1020
+ m_max = _hz_to_mel(sample_rate // 2)
1021
+ m_pts = torch.linspace(0, m_max, out_dim)
1022
+ f_pts = _mel_to_hz(m_pts)
1023
+ scale = 1 - (f_pts / f_pts.max())
1024
+
1025
+ with torch.no_grad():
1026
+ self.out.weight.mul_(scale.view(-1, 1))
1027
+
1028
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1029
+ """
1030
+ Forward pass of the IMDCTSymExpHead module.
1031
+
1032
+ Args:
1033
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
1034
+ L is the sequence length, and H denotes the model dimension.
1035
+
1036
+ Returns:
1037
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
1038
+ """
1039
+ x = self.out(x)
1040
+ x = symexp(x)
1041
+ x = torch.clip(
1042
+ x, min=-1e2, max=1e2
1043
+ ) # safeguard to prevent excessively large magnitudes
1044
+ audio = self.imdct(x)
1045
+ if self.clip_audio:
1046
+ audio = torch.clip(x, min=-1.0, max=1.0)
1047
+
1048
+ return audio
1049
+
1050
+
1051
+ class IMDCTCosHead(FourierHead):
1052
+ """
1053
+ IMDCT Head module for predicting MDCT coefficients with parametrizing MDCT = exp(m) · cos(p)
1054
+
1055
+ Args:
1056
+ dim (int): Hidden dimension of the model.
1057
+ mdct_frame_len (int): Length of the MDCT frame.
1058
+ padding (str, optional): Type of padding. Options are "center" or "same". Defaults to "same".
1059
+ clip_audio (bool, optional): Whether to clip the audio output within the range of [-1.0, 1.0]. Defaults to False.
1060
+ """
1061
+
1062
+ def __init__(
1063
+ self,
1064
+ dim: int,
1065
+ mdct_frame_len: int,
1066
+ padding: str = "same",
1067
+ clip_audio: bool = False,
1068
+ ):
1069
+ super().__init__()
1070
+ self.clip_audio = clip_audio
1071
+ self.out = nn.Linear(dim, mdct_frame_len)
1072
+ self.imdct = IMDCT(frame_len=mdct_frame_len, padding=padding)
1073
+
1074
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1075
+ """
1076
+ Forward pass of the IMDCTCosHead module.
1077
+
1078
+ Args:
1079
+ x (Tensor): Input tensor of shape (B, L, H), where B is the batch size,
1080
+ L is the sequence length, and H denotes the model dimension.
1081
+
1082
+ Returns:
1083
+ Tensor: Reconstructed time-domain audio signal of shape (B, T), where T is the length of the output signal.
1084
+ """
1085
+ x = self.out(x)
1086
+ m, p = x.chunk(2, dim=2)
1087
+ m = torch.exp(m).clip(
1088
+ max=1e2
1089
+ ) # safeguard to prevent excessively large magnitudes
1090
+ audio = self.imdct(m * torch.cos(p))
1091
+ if self.clip_audio:
1092
+ audio = torch.clip(x, min=-1.0, max=1.0)
1093
+ return audio
1094
+
1095
+
1096
+ class ConvNeXtBlock(nn.Module):
1097
+ """ConvNeXt Block adapted from https://github.com/facebookresearch/ConvNeXt to 1D audio signal.
1098
+
1099
+ Args:
1100
+ dim (int): Number of input channels.
1101
+ intermediate_dim (int): Dimensionality of the intermediate layer.
1102
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
1103
+ Defaults to None.
1104
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
1105
+ None means non-conditional LayerNorm. Defaults to None.
1106
+ """
1107
+
1108
+ def __init__(
1109
+ self,
1110
+ dim: int,
1111
+ intermediate_dim: int,
1112
+ layer_scale_init_value: float,
1113
+ adanorm_num_embeddings: Optional[int] = None,
1114
+ ):
1115
+ super().__init__()
1116
+ self.dwconv = nn.Conv1d(
1117
+ dim, dim, kernel_size=7, padding=3, groups=dim
1118
+ ) # depthwise conv
1119
+ self.adanorm = adanorm_num_embeddings is not None
1120
+ if adanorm_num_embeddings:
1121
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
1122
+ else:
1123
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
1124
+ self.pwconv1 = nn.Linear(
1125
+ dim, intermediate_dim
1126
+ ) # pointwise/1x1 convs, implemented with linear layers
1127
+ self.act = nn.GELU()
1128
+ self.pwconv2 = nn.Linear(intermediate_dim, dim)
1129
+ self.gamma = (
1130
+ nn.Parameter(layer_scale_init_value * torch.ones(dim), requires_grad=True)
1131
+ if layer_scale_init_value > 0
1132
+ else None
1133
+ )
1134
+
1135
+ def forward(
1136
+ self, x: torch.Tensor, cond_embedding_id: Optional[torch.Tensor] = None
1137
+ ) -> torch.Tensor:
1138
+ residual = x
1139
+ x = self.dwconv(x)
1140
+ x = x.transpose(1, 2) # (B, C, T) -> (B, T, C)
1141
+ if self.adanorm:
1142
+ assert cond_embedding_id is not None
1143
+ x = self.norm(x, cond_embedding_id)
1144
+ else:
1145
+ x = self.norm(x)
1146
+ x = self.pwconv1(x)
1147
+ x = self.act(x)
1148
+ x = self.pwconv2(x)
1149
+ if self.gamma is not None:
1150
+ x = self.gamma * x
1151
+ x = x.transpose(1, 2) # (B, T, C) -> (B, C, T)
1152
+
1153
+ x = residual + x
1154
+ return x
1155
+
1156
+
1157
+ class AdaLayerNorm(nn.Module):
1158
+ """
1159
+ Adaptive Layer Normalization module with learnable embeddings per `num_embeddings` classes
1160
+
1161
+ Args:
1162
+ num_embeddings (int): Number of embeddings.
1163
+ embedding_dim (int): Dimension of the embeddings.
1164
+ """
1165
+
1166
+ def __init__(self, num_embeddings: int, embedding_dim: int, eps: float = 1e-6):
1167
+ super().__init__()
1168
+ self.eps = eps
1169
+ self.dim = embedding_dim
1170
+ self.scale = nn.Embedding(
1171
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
1172
+ )
1173
+ self.shift = nn.Embedding(
1174
+ num_embeddings=num_embeddings, embedding_dim=embedding_dim
1175
+ )
1176
+ torch.nn.init.ones_(self.scale.weight)
1177
+ torch.nn.init.zeros_(self.shift.weight)
1178
+
1179
+ def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor) -> torch.Tensor:
1180
+ scale = self.scale(cond_embedding_id)
1181
+ shift = self.shift(cond_embedding_id)
1182
+ x = nn.functional.layer_norm(x, (self.dim,), eps=self.eps)
1183
+ x = x * scale + shift
1184
+ return x
1185
+
1186
+
1187
+ class ResBlock1(nn.Module):
1188
+ """
1189
+ ResBlock adapted from HiFi-GAN V1 (https://github.com/jik876/hifi-gan) with dilated 1D convolutions,
1190
+ but without upsampling layers.
1191
+
1192
+ Args:
1193
+ dim (int): Number of input channels.
1194
+ kernel_size (int, optional): Size of the convolutional kernel. Defaults to 3.
1195
+ dilation (tuple[int], optional): Dilation factors for the dilated convolutions.
1196
+ Defaults to (1, 3, 5).
1197
+ lrelu_slope (float, optional): Negative slope of the LeakyReLU activation function.
1198
+ Defaults to 0.1.
1199
+ layer_scale_init_value (float, optional): Initial value for the layer scale. None means no scaling.
1200
+ Defaults to None.
1201
+ """
1202
+
1203
+ def __init__(
1204
+ self,
1205
+ dim: int,
1206
+ kernel_size: int = 3,
1207
+ dilation: Tuple[int, int, int] = (1, 3, 5),
1208
+ lrelu_slope: float = 0.1,
1209
+ layer_scale_init_value: Optional[float] = None,
1210
+ ):
1211
+ super().__init__()
1212
+ self.lrelu_slope = lrelu_slope
1213
+ self.convs1 = nn.ModuleList(
1214
+ [
1215
+ weight_norm(
1216
+ nn.Conv1d(
1217
+ dim,
1218
+ dim,
1219
+ kernel_size,
1220
+ 1,
1221
+ dilation=dilation[0],
1222
+ padding=self.get_padding(kernel_size, dilation[0]),
1223
+ )
1224
+ ),
1225
+ weight_norm(
1226
+ nn.Conv1d(
1227
+ dim,
1228
+ dim,
1229
+ kernel_size,
1230
+ 1,
1231
+ dilation=dilation[1],
1232
+ padding=self.get_padding(kernel_size, dilation[1]),
1233
+ )
1234
+ ),
1235
+ weight_norm(
1236
+ nn.Conv1d(
1237
+ dim,
1238
+ dim,
1239
+ kernel_size,
1240
+ 1,
1241
+ dilation=dilation[2],
1242
+ padding=self.get_padding(kernel_size, dilation[2]),
1243
+ )
1244
+ ),
1245
+ ]
1246
+ )
1247
+
1248
+ self.convs2 = nn.ModuleList(
1249
+ [
1250
+ weight_norm(
1251
+ nn.Conv1d(
1252
+ dim,
1253
+ dim,
1254
+ kernel_size,
1255
+ 1,
1256
+ dilation=1,
1257
+ padding=self.get_padding(kernel_size, 1),
1258
+ )
1259
+ ),
1260
+ weight_norm(
1261
+ nn.Conv1d(
1262
+ dim,
1263
+ dim,
1264
+ kernel_size,
1265
+ 1,
1266
+ dilation=1,
1267
+ padding=self.get_padding(kernel_size, 1),
1268
+ )
1269
+ ),
1270
+ weight_norm(
1271
+ nn.Conv1d(
1272
+ dim,
1273
+ dim,
1274
+ kernel_size,
1275
+ 1,
1276
+ dilation=1,
1277
+ padding=self.get_padding(kernel_size, 1),
1278
+ )
1279
+ ),
1280
+ ]
1281
+ )
1282
+
1283
+ self.gamma = nn.ParameterList(
1284
+ [
1285
+ (
1286
+ nn.Parameter(
1287
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
1288
+ )
1289
+ if layer_scale_init_value is not None
1290
+ else None
1291
+ ),
1292
+ (
1293
+ nn.Parameter(
1294
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
1295
+ )
1296
+ if layer_scale_init_value is not None
1297
+ else None
1298
+ ),
1299
+ (
1300
+ nn.Parameter(
1301
+ layer_scale_init_value * torch.ones(dim, 1), requires_grad=True
1302
+ )
1303
+ if layer_scale_init_value is not None
1304
+ else None
1305
+ ),
1306
+ ]
1307
+ )
1308
+
1309
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
1310
+ for c1, c2, gamma in zip(self.convs1, self.convs2, self.gamma):
1311
+ xt = torch.nn.functional.leaky_relu(x, negative_slope=self.lrelu_slope)
1312
+ xt = c1(xt)
1313
+ xt = torch.nn.functional.leaky_relu(xt, negative_slope=self.lrelu_slope)
1314
+ xt = c2(xt)
1315
+ if gamma is not None:
1316
+ xt = gamma * xt
1317
+ x = xt + x
1318
+ return x
1319
+
1320
+ def remove_weight_norm(self):
1321
+ for l in self.convs1:
1322
+ remove_weight_norm(l)
1323
+ for l in self.convs2:
1324
+ remove_weight_norm(l)
1325
+
1326
+ @staticmethod
1327
+ def get_padding(kernel_size: int, dilation: int = 1) -> int:
1328
+ return int((kernel_size * dilation - dilation) / 2)
1329
+
1330
+
1331
+ class Backbone(nn.Module):
1332
+ """Base class for the generator's backbone. It preserves the same temporal resolution across all layers."""
1333
+
1334
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
1335
+ """
1336
+ Args:
1337
+ x (Tensor): Input tensor of shape (B, C, L), where B is the batch size,
1338
+ C denotes output features, and L is the sequence length.
1339
+
1340
+ Returns:
1341
+ Tensor: Output of shape (B, L, H), where B is the batch size, L is the sequence length,
1342
+ and H denotes the model dimension.
1343
+ """
1344
+ raise NotImplementedError("Subclasses must implement the forward method.")
1345
+
1346
+
1347
+ class VocosBackbone(Backbone):
1348
+ """
1349
+ Vocos backbone module built with ConvNeXt blocks. Supports additional conditioning with Adaptive Layer Normalization
1350
+
1351
+ Args:
1352
+ input_channels (int): Number of input features channels.
1353
+ dim (int): Hidden dimension of the model.
1354
+ intermediate_dim (int): Intermediate dimension used in ConvNeXtBlock.
1355
+ num_layers (int): Number of ConvNeXtBlock layers.
1356
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to `1 / num_layers`.
1357
+ adanorm_num_embeddings (int, optional): Number of embeddings for AdaLayerNorm.
1358
+ None means non-conditional model. Defaults to None.
1359
+ """
1360
+
1361
+ def __init__(
1362
+ self,
1363
+ input_channels: int,
1364
+ dim: int,
1365
+ intermediate_dim: int,
1366
+ num_layers: int,
1367
+ layer_scale_init_value: Optional[float] = None,
1368
+ adanorm_num_embeddings: Optional[int] = None,
1369
+ ):
1370
+ super().__init__()
1371
+ self.input_channels = input_channels
1372
+ self.embed = nn.Conv1d(input_channels, dim, kernel_size=7, padding=3)
1373
+ self.adanorm = adanorm_num_embeddings is not None
1374
+ if adanorm_num_embeddings:
1375
+ self.norm = AdaLayerNorm(adanorm_num_embeddings, dim, eps=1e-6)
1376
+ else:
1377
+ self.norm = nn.LayerNorm(dim, eps=1e-6)
1378
+ layer_scale_init_value = layer_scale_init_value or 1 / num_layers
1379
+ self.convnext = nn.ModuleList(
1380
+ [
1381
+ ConvNeXtBlock(
1382
+ dim=dim,
1383
+ intermediate_dim=intermediate_dim,
1384
+ layer_scale_init_value=layer_scale_init_value,
1385
+ adanorm_num_embeddings=adanorm_num_embeddings,
1386
+ )
1387
+ for _ in range(num_layers)
1388
+ ]
1389
+ )
1390
+ self.final_layer_norm = nn.LayerNorm(dim, eps=1e-6)
1391
+ self.apply(self._init_weights)
1392
+
1393
+ def _init_weights(self, m):
1394
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
1395
+ nn.init.trunc_normal_(m.weight, std=0.02)
1396
+ nn.init.constant_(m.bias, 0)
1397
+
1398
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
1399
+ bandwidth_id = kwargs.get("bandwidth_id", None)
1400
+ x = self.embed(x)
1401
+ if self.adanorm:
1402
+ assert bandwidth_id is not None
1403
+ x = self.norm(x.transpose(1, 2), cond_embedding_id=bandwidth_id)
1404
+ else:
1405
+ x = self.norm(x.transpose(1, 2))
1406
+ x = x.transpose(1, 2)
1407
+ for conv_block in self.convnext:
1408
+ x = conv_block(x, cond_embedding_id=bandwidth_id)
1409
+ x = self.final_layer_norm(x.transpose(1, 2))
1410
+ return x
1411
+
1412
+
1413
+ class VocosResNetBackbone(Backbone):
1414
+ """
1415
+ Vocos backbone module built with ResBlocks.
1416
+
1417
+ Args:
1418
+ input_channels (int): Number of input features channels.
1419
+ dim (int): Hidden dimension of the model.
1420
+ num_blocks (int): Number of ResBlock1 blocks.
1421
+ layer_scale_init_value (float, optional): Initial value for layer scaling. Defaults to None.
1422
+ """
1423
+
1424
+ def __init__(
1425
+ self,
1426
+ input_channels,
1427
+ dim,
1428
+ num_blocks,
1429
+ layer_scale_init_value=None,
1430
+ ):
1431
+ super().__init__()
1432
+ self.input_channels = input_channels
1433
+ self.embed = weight_norm(
1434
+ nn.Conv1d(input_channels, dim, kernel_size=3, padding=1)
1435
+ )
1436
+ layer_scale_init_value = layer_scale_init_value or 1 / num_blocks / 3
1437
+ self.resnet = nn.Sequential(
1438
+ *[
1439
+ ResBlock1(dim=dim, layer_scale_init_value=layer_scale_init_value)
1440
+ for _ in range(num_blocks)
1441
+ ]
1442
+ )
1443
+
1444
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
1445
+ x = self.embed(x)
1446
+ x = self.resnet(x)
1447
+ x = x.transpose(1, 2)
1448
+ return x
1449
+
1450
+
1451
+ class Vocos(nn.Module):
1452
+ def __init__(
1453
+ self,
1454
+ input_channels: int = 128,
1455
+ dim: int = 512,
1456
+ intermediate_dim: int = 4096,
1457
+ num_layers: int = 30,
1458
+ n_fft: int = 640,
1459
+ hop_size: int = 160,
1460
+ padding: str = "same",
1461
+ adanorm_num_embeddings=None,
1462
+ ):
1463
+ super().__init__()
1464
+
1465
+ self.backbone = VocosBackbone(
1466
+ input_channels=input_channels,
1467
+ dim=dim,
1468
+ intermediate_dim=intermediate_dim,
1469
+ num_layers=num_layers,
1470
+ adanorm_num_embeddings=adanorm_num_embeddings,
1471
+ )
1472
+ self.head = ISTFTHead(dim, n_fft, hop_size, padding)
1473
+ self.hop_size = hop_size
1474
+
1475
+ def forward(self, x, input_length):
1476
+ x = self.backbone(x)
1477
+ x = self.head(x)
1478
+ output_length = input_length * self.hop_size
1479
+ return x[:, None, :], output_length
1480
+
XY_Tokenizer/xy_tokenizer/nn/quantizer.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.distributed as dist
6
+
7
+ from einops import rearrange
8
+ from torch.nn.utils import weight_norm
9
+
10
+ def WNConv1d(*args, **kwargs):
11
+ return weight_norm(nn.Conv1d(*args, **kwargs))
12
+
13
+ def WNConvTranspose1d(*args, **kwargs):
14
+ return weight_norm(nn.ConvTranspose1d(*args, **kwargs))
15
+
16
+ def sample_vectors(samples, num):
17
+ # samples: (N, D), num_samples: N, feature dim: D
18
+ num_samples, device = samples.shape[0], samples.device
19
+ if num_samples >= num:
20
+ indices = torch.randperm(num_samples, device=device)[:num]
21
+ else:
22
+ indices = torch.randint(0, num_samples, (num,), device=device)
23
+ return samples[indices].float() # (num, D), ensure fp32
24
+
25
+ def kmeans(samples, num_clusters, num_iters=10):
26
+ # samples: (N, D), N samples with D dimensions
27
+ dim, dtype = samples.shape[-1], torch.float32 # Force fp32
28
+ means = sample_vectors(samples, num_clusters).float() # (num_clusters, D), ensure fp32
29
+
30
+ for _ in range(num_iters):
31
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) - # (N, 1), ensure fp32
32
+ 2 * samples.float() @ means.t() + # (N, num_clusters), ensure fp32
33
+ means.t().float().pow(2).sum(0, keepdim=True)) # (1, num_clusters), ensure fp32
34
+ # dists: (N, num_clusters)
35
+ buckets = dists.max(dim=-1).indices # (N)
36
+ bins = torch.bincount(buckets, minlength=num_clusters) # (num_clusters)
37
+ zero_mask = bins == 0 # (num_clusters)
38
+ bins_min_clamped = bins.masked_fill(zero_mask, 1) # (num_clusters)
39
+
40
+ new_means = buckets.new_zeros(num_clusters, dim, dtype=torch.float32) # (num_clusters, D), ensure fp32
41
+ new_means.scatter_add_(0, buckets.unsqueeze(1).expand(-1, dim), samples.float()) # (num_clusters, D), ensure fp32
42
+ new_means = new_means / bins_min_clamped[..., None] # (num_clusters, D)
43
+ means = torch.where(zero_mask[..., None], means, new_means) # (num_clusters, D)
44
+
45
+ # Final cluster assignments for returning cluster sizes
46
+ dists = -(samples.float().pow(2).sum(1, keepdim=True) -
47
+ 2 * samples.float() @ means.t() +
48
+ means.t().float().pow(2).sum(0, keepdim=True)) # (N, num_clusters), ensure fp32
49
+ buckets = dists.max(dim=-1).indices # (N)
50
+ bins = torch.bincount(buckets, minlength=num_clusters).float() # (num_clusters), ensure fp32
51
+
52
+ return means, bins # (num_clusters, D), (num_clusters)
53
+
54
+ class VectorQuantize(nn.Module):
55
+ def __init__(
56
+ self,
57
+ input_dim,
58
+ codebook_size,
59
+ codebook_dim,
60
+ commitment=1.0,
61
+ decay=0.99, # EMA decay
62
+ epsilon=1e-5, # Laplace smoothing epsilon
63
+ threshold_ema_dead=2, # Dead code threshold
64
+ kmeans_init=True, # Use kmeans initialization
65
+ kmeans_iters=10, # Kmeans iterations
66
+ ):
67
+ super().__init__()
68
+ self.input_dim = input_dim
69
+ self.codebook_size = codebook_size
70
+ self.codebook_dim = codebook_dim
71
+ self.commitment = commitment
72
+ self.decay = decay
73
+ self.epsilon = epsilon
74
+ self.threshold_ema_dead = threshold_ema_dead
75
+ self.kmeans_init = kmeans_init
76
+ self.kmeans_iters = kmeans_iters
77
+
78
+ if self.input_dim != self.codebook_dim:
79
+ self.in_project = WNConv1d(self.input_dim, self.codebook_dim, kernel_size=1) # (B, D, T) -> (B, D', T)
80
+ self.out_project = WNConv1d(self.codebook_dim, self.input_dim, kernel_size=1) # (B, D', T) -> (B, D, T)
81
+ else:
82
+ self.in_project = nn.Identity()
83
+ self.out_project = nn.Identity()
84
+
85
+ # Initialize codebook and EMA buffers
86
+ init_fn = torch.zeros if kmeans_init else lambda x, y: torch.randn(x, y)
87
+ self.register_buffer("codebook", init_fn(codebook_size, codebook_dim).float()) # (codebook_size, D'), ensure fp32
88
+ self.register_buffer("inited", torch.tensor([not kmeans_init], dtype=torch.bool)) # (1)
89
+ self.register_buffer("cluster_size", torch.zeros(codebook_size).float()) # (codebook_size), ensure fp32
90
+ self.register_buffer("embed_avg", self.codebook.clone().float()) # (codebook_size, D'), ensure fp32
91
+
92
+ def ema_update(self, encodings, embed_onehot):
93
+ # encodings: (B*T, D'), embed_onehot: (B*T, codebook_size)
94
+ """Update codebook using EMA"""
95
+ encodings = encodings.float() # Ensure fp32
96
+ embed_onehot = embed_onehot.float() # Ensure fp32
97
+ cluster_size_new = embed_onehot.sum(0) # (codebook_size)
98
+ embed_sum = encodings.t() @ embed_onehot # (D', codebook_size)
99
+
100
+ # Distributed reduction
101
+ if dist.is_initialized():
102
+ dist.all_reduce(cluster_size_new, op=dist.ReduceOp.SUM)
103
+ dist.all_reduce(embed_sum, op=dist.ReduceOp.SUM)
104
+
105
+ ema_inplace(self.cluster_size, cluster_size_new, self.decay) # (codebook_size)
106
+ ema_inplace(self.embed_avg, embed_sum.t(), self.decay) # (codebook_size, D')
107
+
108
+ # Laplace smoothing
109
+ cluster_size = (self.cluster_size + self.epsilon) / (self.cluster_size.sum() + self.codebook_size * self.epsilon) # (codebook_size)
110
+ cluster_size = cluster_size * self.cluster_size.sum() # (codebook_size)
111
+ self.codebook.copy_(self.embed_avg / cluster_size.unsqueeze(1)) # (codebook_size, D')
112
+
113
+ def replace_dead_codes(self, encodings):
114
+ # encodings: (B*T, D')
115
+ """Replace dead codes with random samples from current batch"""
116
+ if self.threshold_ema_dead == 0:
117
+ return
118
+
119
+ dead_mask = self.cluster_size < self.threshold_ema_dead # (codebook_size)
120
+ if dead_mask.any():
121
+ if dist.is_initialized() and dist.get_rank() == 0:
122
+ samples = sample_vectors(encodings.float(), self.codebook_size) # (codebook_size, D'), ensure fp32
123
+ else:
124
+ samples = torch.zeros_like(self.codebook).float() # Placeholder, ensure fp32
125
+
126
+ # Broadcast samples
127
+ if dist.is_initialized():
128
+ dist.broadcast(samples, src=0)
129
+
130
+ self.codebook[dead_mask] = samples[:dead_mask.sum()].to(self.codebook.dtype) # Update dead codes
131
+
132
+ def init_codebook(self, encodings):
133
+ # encodings: (B*T, D')
134
+ """Initialize codebook with k-means and update cluster_size"""
135
+ if self.inited.item():
136
+ return
137
+
138
+ if dist.is_initialized() and dist.get_rank() == 0:
139
+ embed, cluster_sizes = kmeans(encodings.float(), self.codebook_size, self.kmeans_iters) # (codebook_size, D'), (codebook_size), ensure fp32
140
+ else:
141
+ embed = torch.zeros(self.codebook_size, self.codebook_dim, device=encodings.device).float() # ensure fp32
142
+ cluster_sizes = torch.zeros(self.codebook_size, device=encodings.device, dtype=torch.float32) # ensure fp32
143
+
144
+ # Broadcast results
145
+ if dist.is_initialized():
146
+ dist.broadcast(embed, src=0)
147
+ dist.broadcast(cluster_sizes, src=0)
148
+
149
+ self.codebook.copy_(embed) # (codebook_size, D')
150
+ self.embed_avg.copy_(embed.clone()) # (codebook_size, D')
151
+ self.cluster_size.copy_(cluster_sizes.float()) # (codebook_size)
152
+ self.inited.fill_(True)
153
+
154
+ def forward(self, z): # z: (B, D, T)
155
+ # logging.info(f"{self.cluster_size = }, {self.codebook = }, {self.embed_avg = }, {self.inited = }")
156
+ z = z.float() # Ensure fp32
157
+ z_e = self.in_project(z).float() # (B, D', T), ensure fp32
158
+
159
+ # Rearrange for quantization
160
+ encodings = rearrange(z_e, "b d t -> (b t) d").float() # (B*T, D'), ensure fp32
161
+
162
+ # Initialize codebook if needed
163
+ if self.kmeans_init and not self.inited.item():
164
+ self.init_codebook(encodings)
165
+
166
+ # Quantization
167
+ dist = (encodings.pow(2).sum(1, keepdim=True) - # (B*T, 1)
168
+ 2 * encodings @ self.codebook.float().t() + # (B*T, codebook_size)
169
+ self.codebook.float().pow(2).sum(1, keepdim=True).t()) # (1, codebook_size)
170
+ # dist: (B*T, codebook_size)
171
+
172
+ indices = (-dist).max(1)[1] # (B*T)
173
+ indices = rearrange(indices, "(b t) -> b t", b=z.size(0)) # (B, T)
174
+
175
+ # Get quantized vectors
176
+ z_q = self.decode_code(indices).float() # (B, D', T), ensure fp32
177
+
178
+ # Commitment loss
179
+ commit_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) * self.commitment # (B)
180
+
181
+ # EMA updates and dead code replacement during training
182
+ if self.training and torch.is_grad_enabled():
183
+ embed_onehot = F.one_hot(indices.view(-1), self.codebook_size).float() # (B*T, codebook_size), ensure fp32
184
+ self.ema_update(encodings, embed_onehot)
185
+ self.replace_dead_codes(encodings)
186
+
187
+ # Straight-through estimator
188
+ z_q = z_e + (z_q - z_e).detach() # (B, D', T)
189
+ z_q = self.out_project(z_q).float() # (B, D, T), ensure fp32
190
+
191
+ return z_q, commit_loss, torch.tensor(0.0, device=z.device, dtype=torch.float32), indices, z # (B, D, T), (B), scalar, (B, T), (B, D', T)
192
+
193
+ def decode_code(self, embed_id): # embed_id: (B, T)
194
+ return F.embedding(embed_id, self.codebook).transpose(1, 2).float() # (B, D', T), ensure fp32
195
+
196
+ class ResidualVQ(nn.Module):
197
+ def __init__(
198
+ self,
199
+ input_dim: int = 1280, # Input dimension, unrelated to RVQ
200
+ rvq_dim = None, # RVQ dimension. If different from input_dim/output_dim, will add input_dim->rvq_dim/rvq_dim->output_dim projection
201
+ output_dim: int = None, # Output dimension, unrelated to RVQ
202
+ num_quantizers: int = 32,
203
+ codebook_size: int = 1024,
204
+ codebook_dim: int = 8, # Dimension of each codebook. If different from rvq_dim, will add rvq_dim->codebook_dim and codebook_dim->rvq_dim projections
205
+ quantizer_dropout: float = 0.5,
206
+ decay=0.99,
207
+ epsilon=1e-5,
208
+ threshold_ema_dead=2,
209
+ kmeans_init=True,
210
+ kmeans_iters=10,
211
+ skip_rvq_ratio: float = 0.0, # New parameter: probability of skipping RVQ
212
+ **kwargs,
213
+ ):
214
+ super().__init__()
215
+ self.input_dim = input_dim
216
+
217
+ self.num_quantizers = num_quantizers
218
+ self.codebook_size = codebook_size
219
+ self.codebook_dim = codebook_dim
220
+ self.quantizer_dropout = quantizer_dropout
221
+ self.skip_rvq_ratio = skip_rvq_ratio # Store skip probability
222
+ self.rvq_dim = rvq_dim
223
+
224
+ self.input_proj = WNConv1d(input_dim, rvq_dim, kernel_size=1) if input_dim != rvq_dim else nn.Identity()
225
+ self.output_proj = WNConv1d(rvq_dim, output_dim, kernel_size=1) if rvq_dim != output_dim else nn.Identity()
226
+
227
+ self.quantizers = nn.ModuleList(
228
+ [
229
+ VectorQuantize(
230
+ input_dim=rvq_dim,
231
+ codebook_size=codebook_size,
232
+ codebook_dim=codebook_dim,
233
+ decay=decay,
234
+ epsilon=epsilon,
235
+ threshold_ema_dead=threshold_ema_dead,
236
+ kmeans_init=kmeans_init,
237
+ kmeans_iters=kmeans_iters,
238
+ **kwargs,
239
+ )
240
+ for _ in range(num_quantizers)
241
+ ]
242
+ )
243
+
244
+ def forward(self, z, input_length, n_quantizers: int = None): # z: (B, D, T), input_length: (B)
245
+ z = self.input_proj(z)
246
+
247
+ with torch.autocast('cuda', enabled = False):
248
+ batch_size, _, max_time = z.shape
249
+ mask = torch.arange(max_time, device=z.device).expand(batch_size, max_time) < input_length.unsqueeze(1) # (B, T)
250
+
251
+ quantized_out = torch.zeros_like(z, dtype=torch.float32) # (B, D, T), ensure fp32
252
+ residual = z.clone().float() # (B, D, T), ensure fp32
253
+
254
+ all_commit_losses = []
255
+ all_indices = []
256
+ all_quantized = []
257
+
258
+ n_quantizers = n_quantizers or self.num_quantizers
259
+
260
+ # Randomly decide whether to skip RVQ during training
261
+ skip_mask = None
262
+ if self.training and torch.is_grad_enabled() and self.skip_rvq_ratio > 0:
263
+ # Generate random mask with skip_rvq_ratio probability
264
+ skip_mask = torch.rand(batch_size, device=z.device) < self.skip_rvq_ratio # (B,)
265
+ # If all samples are skipped, force the first sample to be unskipped
266
+ if skip_mask.all():
267
+ skip_mask[0] = False # Ensure at least one sample (index 0) is not skipped
268
+
269
+ if self.training and torch.is_grad_enabled():
270
+ n_quantizers_tensor = torch.ones((z.shape[0],), dtype=torch.float32, device=z.device) * self.num_quantizers + 1 # (B)
271
+ dropout = torch.randint(1, self.num_quantizers + 1, (z.shape[0],), dtype=torch.float32, device=z.device) # (B)
272
+ n_dropout = int(z.shape[0] * self.quantizer_dropout)
273
+ n_quantizers_tensor[:n_dropout] = dropout[:n_dropout] # (B)
274
+ else:
275
+ n_quantizers_tensor = torch.full((z.shape[0],), n_quantizers, dtype=torch.float32, device=z.device) # (B)
276
+
277
+ for i, quantizer in enumerate(self.quantizers):
278
+ if not self.training and i >= n_quantizers:
279
+ break
280
+
281
+ masked_residual = residual * mask.unsqueeze(1) # (B, D, T)
282
+
283
+ # If skipping RVQ, directly use input value
284
+ if self.training and skip_mask is not None and skip_mask.any():
285
+ z_q_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
286
+ commit_loss_i = torch.zeros(batch_size, device=z.device, dtype=torch.float32) # (B), ensure fp32
287
+ indices_i = torch.zeros(batch_size, max_time, device=z.device, dtype=torch.long) # (B, T)
288
+ z_e_i = torch.zeros_like(masked_residual, dtype=torch.float32) # (B, D, T), ensure fp32
289
+
290
+ # Quantize non-skipped samples
291
+ non_skipped_mask = ~skip_mask # (B)
292
+ if non_skipped_mask.any():
293
+ z_q_i_non_skipped, commit_loss_i_non_skipped, _, indices_i_non_skipped, z_e_i_non_skipped = quantizer(
294
+ masked_residual[non_skipped_mask].float() # Ensure fp32
295
+ )
296
+ z_q_i[non_skipped_mask] = z_q_i_non_skipped
297
+ commit_loss_i[non_skipped_mask] = commit_loss_i_non_skipped
298
+ indices_i[non_skipped_mask] = indices_i_non_skipped
299
+ z_e_i[non_skipped_mask] = z_e_i_non_skipped
300
+ else:
301
+ z_q_i, commit_loss_i, _, indices_i, z_e_i = quantizer(masked_residual.float()) # (B, D, T), (B), scalar, (B, T), (B, D', T), ensure fp32
302
+
303
+ quantizer_mask = (torch.full((z.shape[0],), i, device=z.device, dtype=torch.float32) < n_quantizers_tensor) # (B)
304
+ update_mask = (mask & quantizer_mask.unsqueeze(-1)).unsqueeze(1) # (B, 1, T)
305
+
306
+ # If skipping, output is directly the input
307
+ if skip_mask is not None:
308
+ skip_mask_expanded = skip_mask.unsqueeze(1).unsqueeze(2) # (B, 1, 1)
309
+ z_q_i = torch.where(skip_mask_expanded, masked_residual, z_q_i) # (B, D, T)
310
+ commit_loss_i = torch.where(skip_mask, torch.zeros_like(commit_loss_i), commit_loss_i) # (B)
311
+
312
+ quantized_out = quantized_out + z_q_i * update_mask # (B, D, T)
313
+
314
+ residual_fp32 = residual.to(dtype=torch.float32) # (B, D, T)
315
+ z_q_i_fp32 = z_q_i.to(dtype=torch.float32) # (B, D, T)
316
+ residual_fp32 = residual_fp32 - z_q_i_fp32 * update_mask # (B, D, T)
317
+ residual = residual_fp32.to(dtype=torch.float32) # (B, D, T), ensure fp32
318
+
319
+ valid_mask = mask & quantizer_mask.unsqueeze(-1) # (B, T)
320
+ if valid_mask.any():
321
+ commit_loss_i = (commit_loss_i * quantizer_mask).sum() / quantizer_mask.sum() # scalar
322
+ else:
323
+ commit_loss_i = torch.tensor(0.0, device=z.device, dtype=torch.float32) # scalar, ensure fp32
324
+
325
+ all_commit_losses.append(commit_loss_i) # scalar
326
+ all_indices.append(indices_i) # (B, T)
327
+ all_quantized.append(z_q_i) # (B, D, T)
328
+
329
+ all_commit_losses = torch.stack(all_commit_losses) # (N)
330
+ all_indices = torch.stack(all_indices) # (N, B, T)
331
+ all_quantized = torch.stack(all_quantized) # (N, B, D, T)
332
+
333
+ output_length = input_length # (B)
334
+
335
+ quantized_out = self.output_proj(quantized_out)
336
+
337
+ return (
338
+ quantized_out, # (B, D, T)
339
+ all_indices, # (N, B, T)
340
+ all_commit_losses,# (N)
341
+ all_quantized, # (N, B, D, T)
342
+ output_length, # (B)
343
+ )
344
+
345
+ def decode_codes(self, codes): # codes: (nq, B, T)
346
+ """Decode codes from multiple quantizers to embeddings.
347
+
348
+ Args:
349
+ codes: Tensor of shape (nq, B, T) containing code indices for each quantizer.
350
+
351
+ Returns:
352
+ emb: Tensor of shape (B, D, T) representing the decoded embeddings.
353
+ """
354
+ nq, B, T = codes.shape
355
+ device = codes.device
356
+ emb = torch.zeros(B, self.rvq_dim, T, device=device, dtype=torch.float32) # (B, D, T)
357
+
358
+ for i, quantizer in enumerate(self.quantizers[:nq]):
359
+ code_i = codes[i] # (B, T)
360
+ quantized_i = quantizer.decode_code(code_i) # (B, D', T)
361
+ emb += quantized_i # Accumulate quantized embeddings
362
+
363
+ emb = self.output_proj(emb) # (B, D, T), apply output projection
364
+ return emb # (B, D, T)
365
+
366
+
367
+ def ema_inplace(moving_avg, new, decay):
368
+ # moving_avg: (codebook_size) or (codebook_size, D'), new: same as moving_avg
369
+ """Update exponential moving average in-place"""
370
+ moving_avg.data.mul_(decay).add_(new.float(), alpha=(1 - decay)) # ensure fp32
app.py CHANGED
@@ -1,7 +1,495 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ import torchaudio
4
+ import tempfile
5
+ import json
6
+ import os
7
+ from typing import Optional, Tuple
8
 
9
+ from generation_utils import load_model, process_batch
 
10
 
11
+ def load_examples_from_jsonl():
12
+ """
13
+ Load examples from examples/examples.jsonl and convert to ROLE_EXAMPLES format
14
+ """
15
+ examples = []
16
+ jsonl_path = "examples/examples.jsonl"
17
+
18
+ if not os.path.exists(jsonl_path):
19
+ print(f"Warning: {jsonl_path} not found")
20
+ return []
21
+
22
+ with open(jsonl_path, 'r', encoding='utf-8') as f:
23
+ for line in f:
24
+ line = line.strip()
25
+
26
+ data = json.loads(line)
27
+
28
+ # Extract required fields
29
+ text = data.get('text', '')
30
+ base_path = data.get('base_path', 'examples')
31
+
32
+ # Check if this is a role-based example (has speaker1 and speaker2 audio)
33
+ if 'prompt_audio_speaker1' in data and 'prompt_audio_speaker2' in data:
34
+ # Role mode example
35
+ audio_mode = "Role"
36
+ prompt_audio_1 = os.path.join(base_path, data['prompt_audio_speaker1'])
37
+ prompt_text_1 = data.get('prompt_text_speaker1', '')
38
+ prompt_audio_2 = os.path.join(base_path, data['prompt_audio_speaker2'])
39
+ prompt_text_2 = data.get('prompt_text_speaker2', '')
40
+ use_normalize = True
41
+
42
+ example = [text, audio_mode, prompt_audio_1, prompt_text_1, prompt_audio_2, prompt_text_2, use_normalize]
43
+ examples.append(example)
44
+
45
+ print(f"Loaded {len(examples)} examples from {jsonl_path}")
46
+ return examples
47
+
48
+ # Load examples from JSONL file
49
+ ROLE_EXAMPLES = load_examples_from_jsonl()
50
+
51
+ # Language configuration
52
+ LANGUAGES = {
53
+ "English": {
54
+ "title": "MOSS-TTSD🪐 Dialogue Generation",
55
+ "script_input": "### Script Input",
56
+ "text_to_synthesize": "Text to Synthesize",
57
+ "text_placeholder": "Text to be synthesized, format: [S1]Role1 text[S2]Role2 text",
58
+ "use_normalize": "Use text normalization",
59
+ "normalize_info": "Recommended to enable, improves handling of numbers, punctuation, etc.",
60
+ "audio_input_mode": "### Audio Input Mode",
61
+ "select_input_mode": "Select input mode",
62
+ "mode_info": "Single Audio: Upload one audio with [S1][S2] text; Role Audio: Upload separate audio for Role1 and Role2",
63
+ "drag_drop_audio": "Drag and drop audio here - or - click to upload",
64
+ "prompt_text": "Prompt Text",
65
+ "prompt_placeholder": "Format: [S1]Role1 text[S2]Role2 text",
66
+ "role1_audio": "**Role1 Audio**",
67
+ "role1_audio_file": "Role1 Audio File",
68
+ "role1_text": "Role1 Text",
69
+ "role1_placeholder": "Role1 text content",
70
+ "role2_audio": "**Role2 Audio**",
71
+ "role2_audio_file": "Role2 Audio File",
72
+ "role2_text": "Role2 Text",
73
+ "role2_placeholder": "Role2 text content",
74
+ "generate_audio": "Generate Audio",
75
+ "generated_audio": "Generated Audio",
76
+ "status_info": "Status Information",
77
+ "examples": "### Examples",
78
+ "examples_desc": "Click on examples below to auto-fill the form",
79
+ "role_headers": ["Text to Synthesize", "Input Mode", "Role1 Audio File", "Role1 Text", "Role2 Audio File", "Role2 Text", "Use Normalize"]
80
+ },
81
+ "中文": {
82
+ "title": "MOSS-TTSD🪐 对话语音生成",
83
+ "script_input": "### 文本输入",
84
+ "text_to_synthesize": "要合成的文本",
85
+ "text_placeholder": "要合成的文本,格式:[S1]角色1文本[S2]角色2文本",
86
+ "use_normalize": "使用文本规范化",
87
+ "normalize_info": "建议启用,改善数字、标点符号等的处理",
88
+ "audio_input_mode": "### 音频输入模式",
89
+ "select_input_mode": "选择输入模式",
90
+ "mode_info": "单音频:上传一个包含[S1][S2]文本的音频;角色音频:分别为角色1和角色2上传音频",
91
+ "drag_drop_audio": "拖拽音频文件到此处 - 或 - 点击上传",
92
+ "prompt_text": "提示文本",
93
+ "prompt_placeholder": "格式:[S1]角色1文本[S2]角色2文本",
94
+ "role1_audio": "**角色1音频**",
95
+ "role1_audio_file": "角色1音频文件",
96
+ "role1_text": "角色1文本",
97
+ "role1_placeholder": "角色1文本内容",
98
+ "role2_audio": "**角色2音频**",
99
+ "role2_audio_file": "角色2音频文件",
100
+ "role2_text": "角色2文本",
101
+ "role2_placeholder": "角色2文本内容",
102
+ "generate_audio": "生成音频",
103
+ "generated_audio": "生成的音频",
104
+ "status_info": "状态信息",
105
+ "examples": "### 示例",
106
+ "examples_desc": "点击下方示例自动填充表单",
107
+ "role_headers": ["要合成的文本", "输入模式", "角色1音频文件", "角色1文本", "角色2音频文件", "角色2文本", "使用规范化"]
108
+ }
109
+ }
110
+
111
+ # Model configuration
112
+ SYSTEM_PROMPT = "You are a speech synthesizer that generates natural, realistic, and human-like conversational audio from dialogue text."
113
+ MODEL_PATH = "fnlp/MOSS-TTSD-v0"
114
+ SPT_CONFIG_PATH = "XY_Tokenizer/config/xy_tokenizer_config.yaml"
115
+ SPT_CHECKPOINT_PATH = "XY_Tokenizer/weights/xy_tokenizer.ckpt"
116
+ MAX_CHANNELS = 8
117
+
118
+ from huggingface_hub import hf_hub_download
119
+
120
+ ckpt_path = hf_hub_download(
121
+ repo_id="fnlp/XY_Tokenizer_TTSD_V0",
122
+ filename="xy_tokenizer.ckpt",
123
+ cache_dir="XY_Tokenizer/weights"
124
+ )
125
+
126
+ print("Checkpoint downloaded to:", ckpt_path)
127
+
128
+ # Global variables for caching loaded models
129
+ tokenizer = None
130
+ model = None
131
+ spt = None
132
+ device = None
133
+
134
+ def initialize_model():
135
+ """Initialize model (load only on first call)"""
136
+ global tokenizer, model, spt, device
137
+
138
+ if tokenizer is None:
139
+ print("Initializing model...")
140
+ device = "cuda" if torch.cuda.is_available() else "cpu"
141
+ tokenizer, model, spt = load_model(MODEL_PATH, SPT_CONFIG_PATH, SPT_CHECKPOINT_PATH)
142
+ spt = spt.to(device)
143
+ model = model.to(device)
144
+ print("Model initialization completed!")
145
+
146
+ return tokenizer, model, spt, device
147
+
148
+ def process_single_audio_generation(
149
+ text_input: str,
150
+ audio_mode: str,
151
+ prompt_text_single: str,
152
+ prompt_audio_single: Optional[str] = None,
153
+ prompt_text_1: str = "",
154
+ prompt_audio_1: Optional[str] = None,
155
+ prompt_text_2: str = "",
156
+ prompt_audio_2: Optional[str] = None,
157
+ use_normalize: bool = True
158
+ ) -> Tuple[Optional[str], str]:
159
+ """
160
+ Process single audio generation request
161
+
162
+ Args:
163
+ text_input: Text to synthesize
164
+ prompt_text_single: Prompt text for single audio
165
+ prompt_audio_single: Single audio file path
166
+ prompt_text_1: Role1 text
167
+ prompt_audio_1: Role1 audio file path
168
+ prompt_text_2: Role2 text
169
+ prompt_audio_2: Role2 audio file path
170
+ use_normalize: Whether to use text normalization
171
+
172
+ Returns:
173
+ Generated audio file path and status information
174
+ """
175
+ try:
176
+ # Initialize model
177
+ tokenizer, model, spt, device = initialize_model()
178
+
179
+ # Build input item
180
+ item = {
181
+ "text": text_input,
182
+ }
183
+
184
+ # Handle different audio input modes (mutually exclusive)
185
+ if audio_mode == "Single":
186
+ # Use single audio mode
187
+ item["prompt_audio"] = prompt_audio_single
188
+ item["prompt_text"] = prompt_text_single
189
+ elif audio_mode == "Role" and prompt_audio_1 and prompt_audio_2:
190
+ # Use role audio mode (requires both audio files)
191
+ item["prompt_audio_speaker1"] = prompt_audio_1
192
+ item["prompt_text_speaker1"] = prompt_text_1 if prompt_text_1 else ""
193
+ item["prompt_audio_speaker2"] = prompt_audio_2
194
+ item["prompt_text_speaker2"] = prompt_text_2 if prompt_text_2 else ""
195
+ elif audio_mode == "Role" and prompt_audio_1:
196
+ # Only Role 1 audio provided, treat as single audio
197
+ print("Only Role 1 audio provided, treating as single audio.")
198
+ item["prompt_audio"] = prompt_audio_1
199
+ item["prompt_text"] = prompt_text_1 if prompt_text_1 else ""
200
+ elif audio_mode == "Role" and prompt_audio_2:
201
+ # Only Role 2 audio provided, treat as single audio
202
+ print("Only Role 2 audio provided, treating as single audio.")
203
+ item["prompt_audio"] = prompt_audio_2
204
+ item["prompt_text"] = prompt_text_2 if prompt_text_2 else ""
205
+ else:
206
+ return None, "Error: Please select a mode and provide corresponding audio files\n- Single Audio Mode: Provide one audio file and corresponding text\n- Role Mode: Provide audio files for Role1 and Role2"
207
+
208
+ # Set random seed to ensure reproducible results
209
+ import accelerate
210
+ accelerate.utils.set_seed(42)
211
+
212
+ # Process batch (single item)
213
+ actual_texts_data, audio_results = process_batch(
214
+ batch_items=[item],
215
+ tokenizer=tokenizer,
216
+ model=model,
217
+ spt=spt,
218
+ device=device,
219
+ system_prompt=SYSTEM_PROMPT,
220
+ start_idx=0,
221
+ use_normalize=use_normalize
222
+ )
223
+
224
+ # Check results
225
+ if not audio_results or audio_results[0] is None:
226
+ return None, "Error: Audio generation failed"
227
+
228
+ audio_result = audio_results[0]
229
+
230
+ # Create temporary output file
231
+ output_path = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
232
+
233
+ # Save audio
234
+ torchaudio.save(output_path, audio_result["audio_data"], audio_result["sample_rate"])
235
+
236
+ # Build status information (using English since this is server-side output)
237
+ status_info = f"""
238
+ ✅ Generation successful!
239
+ 📊 Audio Information:
240
+ - Sample Rate: {audio_result["sample_rate"]} Hz
241
+ - Audio Length: {audio_result["audio_data"].shape[-1] / audio_result["sample_rate"]:.2f} seconds
242
+ - Channels: {audio_result["audio_data"].shape[0]}
243
+
244
+ 📝 Text Processing Information:
245
+ - Original Text: {actual_texts_data[0]['original_text'][:100]}...
246
+ - Final Text: {actual_texts_data[0]['final_text'][:100]}...
247
+ - Use Normalize: {actual_texts_data[0]['use_normalize']}
248
+ """
249
+
250
+ return output_path, status_info
251
+
252
+ except Exception as e:
253
+ import traceback
254
+ error_msg = f"Error: Audio generation failed: {str(e)}\n\nDetails:\n{traceback.format_exc()}"
255
+ return None, error_msg
256
+
257
+ # Create Gradio interface
258
+ def create_gradio_interface() -> gr.Blocks:
259
+ with gr.Blocks(title="MOSS-TTSD🪐 Dialogue Generation", theme=gr.themes.Soft()) as demo:
260
+
261
+ # Language selection at the top
262
+ with gr.Row():
263
+ language_selector = gr.Radio(
264
+ choices=["English", "中文"],
265
+ value="English",
266
+ label="Language / 语言",
267
+ info="Select interface language / 选择界面语言"
268
+ )
269
+
270
+ # Title and header (will be updated based on language)
271
+ title_md = gr.Markdown("# MOSS-TTSD🪐 Dialogue Generation")
272
+ github_md = gr.Markdown("### [Github](https://github.com/OpenMOSS/MOSS-TTSD)")
273
+
274
+ with gr.Row():
275
+ # Left input area
276
+ with gr.Column(scale=1):
277
+ script_input_md = gr.Markdown("### Script Input")
278
+
279
+ text_input = gr.Textbox(
280
+ label="Text to Synthesize",
281
+ placeholder="Text to be synthesized, format: [S1]Role1 text[S2]Role2 text",
282
+ lines=6,
283
+ )
284
+
285
+ use_normalize_single = gr.Checkbox(
286
+ label="Use text normalization",
287
+ value=True,
288
+ info="Recommended to enable, improves handling of numbers, punctuation, etc."
289
+ )
290
+
291
+ # Right audio input area
292
+ with gr.Column(scale=1):
293
+ audio_input_mode_md = gr.Markdown("### Audio Input Mode")
294
+
295
+ # Audio input mode selection
296
+ audio_mode = gr.Radio(
297
+ choices=["Single", "Role"],
298
+ value="Single",
299
+ label="Select input mode",
300
+ info="Single Audio: Upload one audio with [S1][S2] text; Role Audio: Upload separate audio for Role1 and Role2"
301
+ )
302
+
303
+ # Single audio mode
304
+ with gr.Group(visible=True) as single_mode_group:
305
+ prompt_audio_single = gr.File(
306
+ label="Drag and drop audio here - or - click to upload",
307
+ file_types=["audio"],
308
+ type="filepath"
309
+ )
310
+ prompt_text_single = gr.Textbox(
311
+ label="Prompt Text",
312
+ placeholder="Format: [S1]Role1 text[S2]Role2 text",
313
+ lines=3,
314
+ )
315
+
316
+ # Role audio mode
317
+ with gr.Group(visible=False) as role_mode_group:
318
+ with gr.Row():
319
+ with gr.Column():
320
+ role1_audio_md = gr.Markdown("**Role1 Audio**")
321
+ prompt_audio_1 = gr.File(
322
+ label="Role1 Audio File",
323
+ file_types=["audio"],
324
+ type="filepath"
325
+ )
326
+ prompt_text_1 = gr.Textbox(
327
+ label="Role1 Text",
328
+ placeholder="Role1 text content",
329
+ lines=2
330
+ )
331
+
332
+ with gr.Column():
333
+ role2_audio_md = gr.Markdown("**Role2 Audio**")
334
+ prompt_audio_2 = gr.File(
335
+ label="Role2 Audio File",
336
+ file_types=["audio"],
337
+ type="filepath"
338
+ )
339
+ prompt_text_2 = gr.Textbox(
340
+ label="Role2 Text",
341
+ placeholder="Role2 text content",
342
+ lines=2
343
+ )
344
+
345
+ # Generate button
346
+ with gr.Row():
347
+ generate_btn = gr.Button("Generate Audio", variant="primary", size="lg")
348
+
349
+ # Output area
350
+ with gr.Row():
351
+ with gr.Column():
352
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
353
+ status_info = gr.Textbox(
354
+ label="Status Information",
355
+ lines=10,
356
+ interactive=False
357
+ )
358
+
359
+ # Examples area
360
+ with gr.Row():
361
+ with gr.Column():
362
+ examples_md = gr.Markdown("### Examples")
363
+ examples_desc_md = gr.Markdown("Click on examples below to auto-fill the form")
364
+
365
+ role_examples = gr.Examples(
366
+ examples=ROLE_EXAMPLES,
367
+ inputs=[text_input, audio_mode, prompt_audio_1, prompt_text_1, prompt_audio_2, prompt_text_2, use_normalize_single],
368
+ )
369
+
370
+ # Event handlers
371
+
372
+ # Language change event
373
+ def update_language(lang):
374
+ """Update interface language"""
375
+ texts = LANGUAGES[lang]
376
+
377
+ # Update demo title
378
+ demo.title = texts["title"]
379
+
380
+ return (
381
+ gr.Markdown(f"# {texts['title']}"), # title_md
382
+ texts["script_input"], # script_input_md
383
+ gr.Textbox(
384
+ label=texts["text_to_synthesize"],
385
+ placeholder=texts["text_placeholder"],
386
+ lines=6,
387
+ ), # text_input
388
+ gr.Checkbox(
389
+ label=texts["use_normalize"],
390
+ value=True,
391
+ info=texts["normalize_info"]
392
+ ), # use_normalize_single
393
+ texts["audio_input_mode"], # audio_input_mode_md
394
+ gr.Radio(
395
+ choices=["Single", "Role"],
396
+ value="Single",
397
+ label=texts["select_input_mode"],
398
+ info=texts["mode_info"]
399
+ ), # audio_mode
400
+ gr.File(
401
+ label=texts["drag_drop_audio"],
402
+ file_types=["audio"],
403
+ type="filepath"
404
+ ), # prompt_audio_single
405
+ gr.Textbox(
406
+ label=texts["prompt_text"],
407
+ placeholder=texts["prompt_placeholder"],
408
+ lines=3,
409
+ ), # prompt_text_single
410
+ texts["role1_audio"], # role1_audio_md
411
+ gr.File(
412
+ label=texts["role1_audio_file"],
413
+ file_types=["audio"],
414
+ type="filepath"
415
+ ), # prompt_audio_1
416
+ gr.Textbox(
417
+ label=texts["role1_text"],
418
+ placeholder=texts["role1_placeholder"],
419
+ lines=2
420
+ ), # prompt_text_1
421
+ texts["role2_audio"], # role2_audio_md
422
+ gr.File(
423
+ label=texts["role2_audio_file"],
424
+ file_types=["audio"],
425
+ type="filepath"
426
+ ), # prompt_audio_2
427
+ gr.Textbox(
428
+ label=texts["role2_text"],
429
+ placeholder=texts["role2_placeholder"],
430
+ lines=2
431
+ ), # prompt_text_2
432
+ gr.Button(texts["generate_audio"], variant="primary", size="lg"), # generate_btn
433
+ gr.Audio(label=texts["generated_audio"], type="filepath"), # output_audio
434
+ gr.Textbox(
435
+ label=texts["status_info"],
436
+ lines=10,
437
+ interactive=False
438
+ ), # status_info
439
+ texts["examples"], # examples_md
440
+ texts["examples_desc"], # examples_desc_md
441
+ gr.Dataset(headers=texts["role_headers"])
442
+ )
443
+
444
+ language_selector.change(
445
+ fn=update_language,
446
+ inputs=[language_selector],
447
+ outputs=[
448
+ title_md, script_input_md, text_input, use_normalize_single,
449
+ audio_input_mode_md, audio_mode, prompt_audio_single, prompt_text_single,
450
+ role1_audio_md, prompt_audio_1, prompt_text_1,
451
+ role2_audio_md, prompt_audio_2, prompt_text_2,
452
+ generate_btn, output_audio, status_info,
453
+ examples_md, examples_desc_md, role_examples.dataset,
454
+ ]
455
+ )
456
+
457
+ # Audio mode toggle event
458
+ def toggle_audio_mode(mode):
459
+ if mode == "Single":
460
+ return gr.Group(visible=True), gr.Group(visible=False)
461
+ else:
462
+ return gr.Group(visible=False), gr.Group(visible=True)
463
+
464
+ audio_mode.change(
465
+ fn=toggle_audio_mode,
466
+ inputs=[audio_mode],
467
+ outputs=[single_mode_group, role_mode_group]
468
+ )
469
+
470
+ # Audio generation event
471
+ generate_btn.click(
472
+ fn=process_single_audio_generation,
473
+ inputs=[
474
+ text_input,
475
+ audio_mode,
476
+ prompt_text_single,
477
+ prompt_audio_single,
478
+ prompt_text_1,
479
+ prompt_audio_1,
480
+ prompt_text_2,
481
+ prompt_audio_2,
482
+ use_normalize_single
483
+ ],
484
+ outputs=[output_audio, status_info],
485
+ show_progress=True
486
+ )
487
+
488
+ return demo
489
+
490
+ # Main function
491
+ if __name__ == "__main__":
492
+ demo = create_gradio_interface()
493
+
494
+ # Launch interface
495
+ demo.launch()
generation_utils.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ import torch
5
+ import torchaudio
6
+ import numpy as np
7
+
8
+ from transformers import AutoTokenizer
9
+ from modeling_asteroid import AsteroidTTSInstruct
10
+ from XY_Tokenizer.xy_tokenizer.model import XY_Tokenizer
11
+
12
+ MAX_CHANNELS = 8
13
+ SILENCE_DURATION = 5.0 # Fixed silence duration: 5 seconds
14
+
15
+ def load_model(model_path, spt_config_path, spt_checkpoint_path):
16
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
17
+
18
+ model = AsteroidTTSInstruct.from_pretrained(model_path, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2")
19
+
20
+ spt = XY_Tokenizer.load_from_checkpoint(config_path=spt_config_path, ckpt_path=spt_checkpoint_path)
21
+
22
+ model.eval()
23
+ spt.eval()
24
+ return tokenizer, model, spt
25
+
26
+
27
+ def process_jsonl_item(item):
28
+ """Process JSONL data items and extract audio and text information according to the new format"""
29
+ base_path = item.get("base_path", "")
30
+ text = item.get("text", "")
31
+
32
+ # Process prompt audio and text
33
+ if "prompt_audio" in item and "prompt_text" in item:
34
+ print("Using prompt_audio and prompt_text directly from item.")
35
+ # If prompt_audio and prompt_text exist, use them directly
36
+ prompt_audio = item["prompt_audio"]
37
+ prompt_text = item["prompt_text"]
38
+
39
+ # Only perform path joining when prompt_audio is a string path
40
+ if isinstance(prompt_audio, str) and base_path and prompt_audio:
41
+ prompt_audio = os.path.join(base_path, prompt_audio)
42
+ else:
43
+ print("Using speaker1 and speaker2 information for prompt audio and text.")
44
+ # Otherwise, merge speaker1 and speaker2 information
45
+ prompt_audio_speaker1 = item.get("prompt_audio_speaker1", "")
46
+ prompt_text_speaker1 = item.get("prompt_text_speaker1", "")
47
+ prompt_audio_speaker2 = item.get("prompt_audio_speaker2", "")
48
+ prompt_text_speaker2 = item.get("prompt_text_speaker2", "")
49
+
50
+ # Process audio: if it's a string path, perform path joining; if it's a tuple, use directly
51
+ if isinstance(prompt_audio_speaker1, str):
52
+ speaker1_audio = os.path.join(base_path, prompt_audio_speaker1) if base_path and prompt_audio_speaker1 else prompt_audio_speaker1
53
+ else:
54
+ speaker1_audio = prompt_audio_speaker1 # Use tuple directly
55
+
56
+ if isinstance(prompt_audio_speaker2, str):
57
+ speaker2_audio = os.path.join(base_path, prompt_audio_speaker2) if base_path and prompt_audio_speaker2 else prompt_audio_speaker2
58
+ else:
59
+ speaker2_audio = prompt_audio_speaker2 # Use tuple directly
60
+
61
+ prompt_audio = {
62
+ "speaker1": speaker1_audio,
63
+ "speaker2": speaker2_audio
64
+ }
65
+
66
+ # Merge text
67
+ prompt_text = ""
68
+ if prompt_text_speaker1:
69
+ prompt_text += f"[S1]{prompt_text_speaker1}"
70
+ if prompt_text_speaker2:
71
+ prompt_text += f"[S2]{prompt_text_speaker2}"
72
+ prompt_text = prompt_text.strip()
73
+
74
+ return {
75
+ "text": text,
76
+ "prompt_text": prompt_text,
77
+ "prompt_audio": prompt_audio
78
+ }
79
+
80
+
81
+ def load_audio_data(prompt_audio, target_sample_rate=16000):
82
+ """Load audio data and return processed audio tensor
83
+
84
+ Args:
85
+ prompt_audio: Can be in the following formats:
86
+ - String: audio file path
87
+ - Tuple: (wav, sr) result from torchaudio.load
88
+ - Dict: {"speaker1": path_or_tuple, "speaker2": path_or_tuple}
89
+ """
90
+ if prompt_audio is None:
91
+ return None
92
+
93
+ try:
94
+ # Check if prompt_audio is a dictionary (containing speaker1 and speaker2)
95
+ if isinstance(prompt_audio, dict) and "speaker1" in prompt_audio and "speaker2" in prompt_audio:
96
+ # Process audio from both speakers separately
97
+ wav1, sr1 = _load_single_audio(prompt_audio["speaker1"])
98
+ wav2, sr2 = _load_single_audio(prompt_audio["speaker2"])
99
+ # Merge audio from both speakers
100
+ wav = merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate)
101
+ if wav is None:
102
+ return None
103
+ else:
104
+ # Single audio
105
+ wav, sr = _load_single_audio(prompt_audio)
106
+ # Resample to 16k
107
+ if sr != target_sample_rate:
108
+ wav = torchaudio.functional.resample(wav, sr, target_sample_rate)
109
+ # Ensure mono channel
110
+ if wav.shape[0] > 1:
111
+ wav = wav.mean(dim=0, keepdim=True) # Convert multi-channel to mono
112
+ if len(wav.shape) == 1:
113
+ wav = wav.unsqueeze(0)
114
+
115
+ return wav
116
+ except Exception as e:
117
+ print(f"Error loading audio data: {e}")
118
+ import traceback
119
+ traceback.print_exc()
120
+ return None
121
+
122
+
123
+ def _load_single_audio(audio_input):
124
+ """Load single audio, supports file path or (wav, sr) tuple
125
+
126
+ Args:
127
+ audio_input: String (file path) or tuple (wav, sr)
128
+
129
+ Returns:
130
+ tuple: (wav, sr)
131
+ """
132
+ if isinstance(audio_input, tuple) and len(audio_input) == 2:
133
+ # Already a (wav, sr) tuple
134
+ wav, sr = audio_input
135
+ return wav, sr
136
+ elif isinstance(audio_input, str):
137
+ # Is a file path, needs to be loaded
138
+ wav, sr = torchaudio.load(audio_input)
139
+ return wav, sr
140
+ else:
141
+ raise ValueError(f"Unsupported audio input format: {type(audio_input)}")
142
+
143
+
144
+ def merge_speaker_audios(wav1, sr1, wav2, sr2, target_sample_rate=16000):
145
+ """Merge audio data from two speakers"""
146
+ try:
147
+ # Process first audio
148
+ if sr1 != target_sample_rate:
149
+ wav1 = torchaudio.functional.resample(wav1, sr1, target_sample_rate)
150
+ # Ensure mono channel
151
+ if wav1.shape[0] > 1:
152
+ wav1 = wav1.mean(dim=0, keepdim=True) # Convert multi-channel to mono
153
+ if len(wav1.shape) == 1:
154
+ wav1 = wav1.unsqueeze(0)
155
+
156
+ # Process second audio
157
+ if sr2 != target_sample_rate:
158
+ wav2 = torchaudio.functional.resample(wav2, sr2, target_sample_rate)
159
+ # Ensure mono channel
160
+ if wav2.shape[0] > 1:
161
+ wav2 = wav2.mean(dim=0, keepdim=True) # Convert multi-channel to mono
162
+ if len(wav2.shape) == 1:
163
+ wav2 = wav2.unsqueeze(0)
164
+
165
+ # Concatenate audio
166
+ merged_wav = torch.cat([wav1, wav2], dim=1)
167
+ return merged_wav
168
+ except Exception as e:
169
+ print(f"Error merging audio: {e}")
170
+ return None
171
+
172
+
173
+ def process_inputs(tokenizer, spt, prompt, text, device, audio_data=None, max_channels=8, pad_token=1024):
174
+ seq = f"<|begin_of_style|>{prompt}<|end_of_style|>\n<|begin_of_text|>{text}<|end_of_text|>\n<|begin_of_speech|>"
175
+ inputs1 = np.array(tokenizer.encode(seq))
176
+ input_ids = np.full((inputs1.shape[0], max_channels), pad_token)
177
+ input_ids[:, 0] = inputs1
178
+
179
+ if audio_data is not None:
180
+ try:
181
+ # audio_data should now be a processed audio tensor
182
+ wav = audio_data
183
+
184
+ # Add fixed 5-second silence at the end of audio (using 16k sample rate)
185
+ silence_samples = int(SILENCE_DURATION * 16000)
186
+ silence = torch.zeros(wav.shape[0], silence_samples)
187
+ wav = torch.cat([wav, silence], dim=1)
188
+
189
+ with torch.no_grad():
190
+ # Use SPT encoding
191
+ encode_result = spt.encode([wav.squeeze().to(device)])
192
+ audio_token = encode_result["codes_list"][0].permute(1, 0).cpu().numpy() # Adjust dimension order
193
+
194
+ # similar to DAC encoding adjustment
195
+ audio_token[:, 0] = audio_token[:, 0] + 151665 # Keep this line if offset is needed, otherwise delete
196
+ input_ids = np.concatenate([input_ids, audio_token])[:-60]
197
+ except Exception as e:
198
+ print(f"Error processing audio data: {e}")
199
+ import traceback
200
+ traceback.print_exc()
201
+ # If error occurs, still return input without audio
202
+
203
+ return input_ids
204
+
205
+
206
+ def shifting_inputs(input_ids, tokenizer, pad_token=1024, max_channels=8):
207
+ seq_len = input_ids.shape[0]
208
+ new_seq_len = seq_len + max_channels - 1
209
+ shifted_input_ids = np.full((new_seq_len, max_channels), pad_token, dtype=np.int64)
210
+ shifted_input_ids[:, 0] = np.full(new_seq_len, tokenizer.pad_token_id, dtype=np.int64)
211
+ for i in range(max_channels):
212
+ shifted_input_ids[i : (seq_len + i), i] = input_ids[:, i]
213
+ return shifted_input_ids
214
+
215
+
216
+ def rpadding(input_ids, channels, tokenizer):
217
+ attention_masks = [np.ones(inputs.shape[0]) for inputs in input_ids]
218
+ max_length = max(ids.shape[0] for ids in input_ids)
219
+ padded_input_ids, padded_attns = [], []
220
+
221
+ for ids, attn in zip(input_ids, attention_masks):
222
+ pad_len = max_length - ids.shape[0]
223
+ input_pad = np.full((pad_len, channels), 1024)
224
+ input_pad[:, 0] = tokenizer.pad_token_id
225
+ padded_input_ids.append(np.concatenate([input_pad, ids]))
226
+ attn_pad = np.zeros(pad_len)
227
+ padded_attns.append(np.concatenate([attn_pad, attn]))
228
+
229
+ input_ids = torch.tensor(np.stack(padded_input_ids))
230
+ attention_mask = torch.tensor(np.stack(padded_attns))
231
+
232
+ return input_ids, attention_mask
233
+
234
+
235
+ def find_max_valid_positions(C: torch.Tensor, invalid_value=1024) -> torch.Tensor:
236
+ values = C[:, :, 1]
237
+ mask = (values != invalid_value)
238
+ reversed_mask = mask.flip(dims=[1])
239
+ reversed_indices = torch.argmax(reversed_mask.int(), dim=1)
240
+ seq_len = C.size(1)
241
+ original_indices = seq_len - 1 - reversed_indices
242
+ has_valid = mask.any(dim=1)
243
+ original_indices = torch.where(has_valid, original_indices, -1)
244
+ return original_indices
245
+
246
+
247
+ def normalize_text(text: str) -> str:
248
+ """
249
+ Normalize multi-speaker script.
250
+
251
+ 1. Don't preserve line breaks.
252
+ 2. Remove brackets for non-speaker tags (if [] doesn't contain S1/S2...Sx format, remove the brackets themselves).
253
+ 3. Remove decorative symbols: 【】《》()『』「」"-“” .
254
+ 4. Internal punctuation !;:、 → ,;only allow ? and ,。
255
+ 5. Multiple 。 keep only the last one, others → ,。
256
+ 6. Replace consecutive "哈" (>=2) with "(笑)".
257
+ 7. Auto-recognize [S1] / [S2] … tags; if missing, treat as whole segment.
258
+ """
259
+ # Replace [1], [2] etc. format with [S1], [S2] etc. format
260
+ text = re.sub(r'\[(\d+)\]', r'[S\1]', text)
261
+
262
+ # Remove decorative characters
263
+ remove_chars = "【】《》()『』「」""\"-“”"
264
+
265
+
266
+ # Remove brackets for non-speaker tags (keep content, only remove brackets themselves)
267
+ text = re.sub(r'\[(?!S\d+\])([^\]]*)\]', r'\1', text)
268
+
269
+ # Use positive lookahead to split text by speaker tags (tags themselves are still preserved)
270
+ segments = re.split(r'(?=\[S\d+\])', text.replace("\n", " "))
271
+ normalized_lines = []
272
+
273
+ for seg in segments:
274
+ seg = seg.strip()
275
+ if not seg:
276
+ continue
277
+
278
+ # Extract tags
279
+ m = re.match(r'^(\[S\d+\])\s*(.*)', seg)
280
+ tag, content = m.groups() if m else ('', seg)
281
+
282
+ # Remove irrelevant symbols
283
+ content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
284
+
285
+ # Handle consecutive "哈" characters: replace 2 or more with "(笑)"
286
+ content = re.sub(r'哈{2,}', '(笑)', content)
287
+
288
+ # First handle multi-character punctuation marks
289
+ content = content.replace('——', ',')
290
+ content = content.replace('……', ',')
291
+
292
+ # Handle single-character internal punctuation marks
293
+ internal_punct_map = str.maketrans({
294
+ '!': ',', '!': ',',
295
+ ';': ',', ';': ',',
296
+ ':': ',', ':': ',',
297
+ '、': ',',
298
+ '?': ',', '?': ','
299
+ })
300
+ content = content.translate(internal_punct_map)
301
+ content = content.strip()
302
+
303
+ # Keep only the final period
304
+ if len(content) > 1:
305
+ last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
306
+ body = content[:-1].replace('。', ',')
307
+ content = body + last_ch
308
+
309
+ normalized_lines.append(f"{tag}{content}".strip())
310
+
311
+ return "".join(normalized_lines)
312
+
313
+
314
+ def process_batch(batch_items, tokenizer, model, spt, device, system_prompt, start_idx, use_normalize=False):
315
+ """Process a batch of data items and generate audio, return audio data and metadata"""
316
+ try:
317
+ # Prepare batch data
318
+ batch_size = len(batch_items)
319
+ texts = []
320
+ prompts = [system_prompt] * batch_size
321
+ prompt_audios = []
322
+ actual_texts_data = [] # Store actual text data used
323
+
324
+ print(f"Processing {batch_size} samples starting from index {start_idx}...")
325
+
326
+ # Extract text and audio from each sample
327
+ for i, item in enumerate(batch_items):
328
+ # Use new processing function
329
+ processed_item = process_jsonl_item(item)
330
+
331
+ text = processed_item["text"]
332
+ prompt_text = processed_item["prompt_text"]
333
+
334
+ # Merge text
335
+ full_text = prompt_text + text
336
+ original_full_text = full_text # Save original text
337
+
338
+ # Apply text normalization based on parameter
339
+ if use_normalize:
340
+ full_text = normalize_text(full_text)
341
+
342
+ # Replace speaker tags
343
+ final_text = full_text.replace("[S1]", "<speaker1>").replace("[S2]", "<speaker2>")
344
+ texts.append(final_text)
345
+
346
+ # Save actual text information used
347
+ actual_texts_data.append({
348
+ "index": start_idx + i,
349
+ "original_text": original_full_text,
350
+ "normalized_text": normalize_text(original_full_text) if use_normalize else None,
351
+ "final_text": final_text,
352
+ "use_normalize": use_normalize
353
+ })
354
+
355
+ # Get reference audio
356
+ prompt_audios.append(processed_item["prompt_audio"])
357
+
358
+ # Process inputs
359
+ input_ids_list = []
360
+ for i, (text, prompt, audio_path) in enumerate(zip(texts, prompts, prompt_audios)):
361
+ # Load audio data here
362
+ audio_data = load_audio_data(audio_path) if audio_path else None
363
+ inputs = process_inputs(tokenizer, spt, prompt, text, device, audio_data)
364
+ inputs = shifting_inputs(inputs, tokenizer)
365
+ input_ids_list.append(inputs)
366
+
367
+ # Pad batch inputs
368
+ input_ids, attention_mask = rpadding(input_ids_list, MAX_CHANNELS, tokenizer)
369
+
370
+ # Batch generation
371
+ print(f"Starting batch audio generation...")
372
+ start = input_ids.shape[1] - MAX_CHANNELS + 1
373
+
374
+ # Move inputs to GPU
375
+ input_ids = input_ids.to(device)
376
+ attention_mask = attention_mask.to(device)
377
+
378
+ # Generate model outputs
379
+ outputs = model.generate(
380
+ input_ids=input_ids,
381
+ attention_mask=attention_mask,
382
+ )
383
+ print(f"Original outputs shape: {outputs.shape}")
384
+ print(f"Start value: {start}")
385
+ print(f"Shape after slicing: {outputs[:, start:].shape}")
386
+ print(f"MAX_CHANNELS: {MAX_CHANNELS}")
387
+ print(f"Calculated seq_len: {outputs.shape[1] - MAX_CHANNELS + 1}")
388
+ # Process outputs
389
+ outputs = outputs[:, start:]
390
+ seq_len = outputs.shape[1] - MAX_CHANNELS + 1
391
+ speech_ids = torch.full((outputs.shape[0], seq_len, MAX_CHANNELS), 0).to(device)
392
+
393
+
394
+ # Adjust output format
395
+ for j in range(MAX_CHANNELS):
396
+ speech_ids[..., j] = outputs[:, j : seq_len + j, j]
397
+ if j == 0:
398
+ speech_ids[..., j] = speech_ids[..., j] - 151665
399
+
400
+ # Find valid positions for each sample
401
+ li = find_max_valid_positions(speech_ids)
402
+
403
+ # Store audio result data
404
+ audio_results = []
405
+
406
+ # Process batch sample results individually
407
+ for i in range(batch_size):
408
+ try:
409
+ # Extract valid speech tokens
410
+ end_idx = li[i] + 1
411
+ if end_idx <= 0:
412
+ print(f"Sample {start_idx + i} has no valid speech tokens")
413
+ audio_results.append(None)
414
+ continue
415
+
416
+ this_speech_id = speech_ids[i, :end_idx]
417
+ print(f"Speech token shape for sample {start_idx + i}: {this_speech_id.shape}")
418
+
419
+ # Decode generated audio
420
+ with torch.no_grad():
421
+ codes_list = [this_speech_id.permute(1, 0)] # Convert to SPT expected format
422
+ decode_result = spt.decode(codes_list, overlap_seconds=10)
423
+ audio_result = decode_result["syn_wav_list"][0].cpu().detach()
424
+
425
+ if audio_result.ndim == 1: # If 1D [samples]
426
+ audio_result = audio_result.unsqueeze(0) # Convert to 2D [1, samples]
427
+
428
+ # Save audio data instead of file path
429
+ audio_results.append({
430
+ "audio_data": audio_result,
431
+ "sample_rate": spt.output_sample_rate,
432
+ "index": start_idx + i
433
+ })
434
+ print(f"Audio generation completed: sample {start_idx + i}")
435
+
436
+ except Exception as e:
437
+ print(f"Error processing sample {start_idx + i}: {str(e)}")
438
+ import traceback
439
+ traceback.print_exc()
440
+ audio_results.append(None)
441
+
442
+ # Clean up GPU memory
443
+ torch.cuda.empty_cache()
444
+
445
+ # Return text data and audio data
446
+ return actual_texts_data, audio_results
447
+
448
+ except Exception as e:
449
+ print(f"Error during batch processing: {str(e)}")
450
+ import traceback
451
+ traceback.print_exc()
452
+ return [], [None] * len(batch_items)
modeling_asteroid.py ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from dataclasses import dataclass
4
+ from transformers.utils import ModelOutput
5
+ from transformers.cache_utils import Cache
6
+ from typing import Optional, List, Tuple, Union
7
+ from transformers.loss.loss_utils import ForCausalLMLoss
8
+ from transformers.generation.streamers import BaseStreamer
9
+ from transformers.modeling_outputs import BaseModelOutputWithPast
10
+ from transformers.generation.configuration_utils import GenerationConfig
11
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
12
+ from transformers import PreTrainedModel, GenerationMixin, Qwen3Config, Qwen3Model
13
+ from transformers.generation.logits_process import LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
14
+
15
+
16
+ class AsteroidTTSConfig(Qwen3Config):
17
+ def __init__(self,
18
+ channels = 8,
19
+ speech_pad_token = 1024,
20
+ speech_vocab_size = 1025,
21
+ speech_token_range = [],
22
+ **kwargs):
23
+ super().__init__(**kwargs)
24
+ self.channels = channels
25
+ self.speech_pad_token = speech_pad_token
26
+ self.speech_vocab_size = speech_vocab_size
27
+ self.speech_token_range = speech_token_range
28
+
29
+
30
+ @dataclass
31
+ class AsteroidTTSOutputWithPast(ModelOutput):
32
+ loss: Optional[torch.FloatTensor] = None
33
+ logits: torch.FloatTensor = None
34
+ loss_all: Optional[Tuple[torch.FloatTensor]] = None
35
+ logits_all: Optional[Tuple[torch.FloatTensor]] = None
36
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
37
+ hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
38
+ attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
39
+
40
+
41
+ @dataclass
42
+ class GenerateDecoderOnlyOutput(ModelOutput):
43
+ sequences: torch.LongTensor = None
44
+ scores: Optional[Tuple[torch.FloatTensor]] = None
45
+ logits: Optional[Tuple[torch.FloatTensor]] = None
46
+ attentions: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
47
+ hidden_states: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
48
+ past_key_values: Optional[Tuple[Tuple[Tuple[torch.FloatTensor]]]] = None
49
+
50
+
51
+ class CustomMixin(GenerationMixin):
52
+ def _sample(
53
+ self,
54
+ input_ids: torch.LongTensor,
55
+ logits_processor: LogitsProcessorList,
56
+ stopping_criteria: StoppingCriteriaList,
57
+ generation_config: GenerationConfig,
58
+ synced_gpus: bool,
59
+ streamer: Optional["BaseStreamer"],
60
+ **model_kwargs,
61
+ ) -> Union[GenerateDecoderOnlyOutput, torch.LongTensor]:
62
+ # 提取配置参数
63
+ speech_pad_idx = self.config.speech_pad_token
64
+
65
+ eos_token_id = generation_config.eos_token_id
66
+ output_attentions = generation_config.output_attentions
67
+ output_hidden_states = generation_config.output_hidden_states
68
+ output_scores = generation_config.output_scores
69
+ output_logits = generation_config.output_logits
70
+ return_dict_in_generate = generation_config.return_dict_in_generate
71
+ max_length = generation_config.max_length
72
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
73
+ do_sample = generation_config.do_sample
74
+
75
+ # 初始化输出元组
76
+ scores = () if (return_dict_in_generate and output_scores) else None
77
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
78
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
79
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
80
+
81
+ # 初始化跟踪变量
82
+ batch_size, cur_len, channels = input_ids.shape # channels = 8
83
+ this_peer_finished = False
84
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
85
+ needs_additional_steps = -1 * torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
86
+ tf_inputs = input_ids[:]
87
+ input_ids = input_ids[:, :-(channels - 1)]
88
+ model_kwargs["attention_mask"] = model_kwargs["attention_mask"][:, :-(channels - 1)]
89
+ base_length = input_ids.shape[1]
90
+ model_kwargs = self._get_initial_cache_position(cur_len, input_ids.device, model_kwargs)
91
+
92
+ # 定义logits processor
93
+ if generation_config.do_samples is not None:
94
+ do_samples = generation_config.do_samples
95
+ realprocessor = [LogitsProcessorList() for _ in range(channels)]
96
+ for i, layer_config in enumerate(generation_config.layers):
97
+ if layer_config.get("repetition_penalty") is not None:
98
+ realprocessor[i].append(RepetitionPenaltyLogitsProcessor(penalty=layer_config.get("repetition_penalty")))
99
+ if layer_config.get("temperature") is not None:
100
+ realprocessor[i].append(TemperatureLogitsWarper(temperature=layer_config.get("temperature")))
101
+ if layer_config.get("top_k") is not None:
102
+ realprocessor[i].append(TopKLogitsWarper(top_k=layer_config.get("top_k")))
103
+ if layer_config.get("top_p") is not None:
104
+ realprocessor[i].append(TopPLogitsWarper(top_p=layer_config.get("top_p")))
105
+ else:
106
+ do_samples = [do_sample for _ in range(channels)]
107
+ realprocessor = [logits_processor for _ in range(channels)]
108
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
109
+ # 准备模型输入
110
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
111
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
112
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
113
+ # 前向传递
114
+ outputs = self(**model_inputs, return_dict=True)
115
+ model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs)
116
+
117
+ if synced_gpus and this_peer_finished:
118
+ continue
119
+
120
+ # 获取下一个 token 的 logits
121
+ next_token_logits = [logits[:, -1, :].clone().float().to(input_ids.device) for logits in outputs.logits_all]
122
+ for i, channel_logits in enumerate(next_token_logits):
123
+ if i != 0 and input_ids.shape[1] + 1 > tf_inputs.shape[1] - 7 + i:
124
+ channel_logits[:, 1024] = - torch.inf
125
+ if i == 0 and input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
126
+ channel_logits[:, 152694] = - torch.inf
127
+ next_token_scores = [realprocessor[i](input_ids[..., i], logits) for i, logits in enumerate(next_token_logits)]
128
+ # 生成下一个 token
129
+ next_tokens = []
130
+ for i, channel_score in enumerate(next_token_scores):
131
+ if do_samples[i]:
132
+ channel_ntk = torch.multinomial(nn.functional.softmax(channel_score, dim=-1), num_samples=1).squeeze(1)
133
+ elif not do_samples[i]:
134
+ channel_ntk = torch.argmax(channel_score, dim=-1)
135
+ next_tokens.append(channel_ntk)
136
+ next_tokens = torch.stack(next_tokens, dim=-1) # [batch_size, channels]
137
+ # 额外步骤逻辑
138
+ indices = (~self.is_speech_token(next_tokens[:, 0])) & (needs_additional_steps < 0)
139
+ needs_additional_steps[indices] = channels - 1 # 对于 8 个通道,需要 6 步
140
+
141
+ if input_ids.shape[1] + 1 <= tf_inputs.shape[1]:
142
+ i = input_ids.shape[1] + 1 - base_length
143
+ next_tokens[:, i:] = tf_inputs[:, input_ids.shape[1], i:]
144
+
145
+ # 在额外步骤中替换 token
146
+ mask = (needs_additional_steps > 0) & (needs_additional_steps < 7)
147
+ if mask.any().item():
148
+ next_tokens[mask, 0] = self.config.eos_token_id
149
+ for i in range(1, channels):
150
+ mask_i = mask & (needs_additional_steps < channels - i)
151
+ next_tokens[mask_i, i] = speech_pad_idx
152
+
153
+ if has_eos_stopping_criteria:
154
+ for i in range(channels):
155
+ pddp = self.config.eos_token_id if i == 0 else speech_pad_idx
156
+ next_tokens[:, i] = next_tokens[:, i] * unfinished_sequences + pddp * (1 - unfinished_sequences)
157
+
158
+ input_ids = torch.cat([input_ids, next_tokens[:, None, :]], dim=1)
159
+ if streamer is not None:
160
+ streamer.put(next_tokens[:, 0].cpu())
161
+
162
+ # 更新 unfinished_sequences
163
+ needs_additional_steps = torch.where(needs_additional_steps > 0, needs_additional_steps - 1, needs_additional_steps)
164
+ stopping = stopping_criteria(input_ids[..., 0], scores) | (needs_additional_steps == 0)
165
+ unfinished_sequences = unfinished_sequences & ~stopping
166
+ unfinished_sequences = unfinished_sequences | (needs_additional_steps > 0)
167
+ this_peer_finished = unfinished_sequences.max() == 0
168
+
169
+ if return_dict_in_generate:
170
+ if output_scores:
171
+ scores += (next_token_scores,)
172
+ if output_logits:
173
+ raw_logits += (next_token_logits,)
174
+ if output_attentions:
175
+ decoder_attentions += (outputs.attentions,)
176
+ if output_hidden_states:
177
+ decoder_hidden_states += (outputs.hidden_states,)
178
+
179
+ cur_len += 1
180
+ del outputs
181
+
182
+ if streamer is not None:
183
+ streamer.end()
184
+
185
+ if return_dict_in_generate:
186
+ return GenerateDecoderOnlyOutput(
187
+ sequences=input_ids,
188
+ scores=scores,
189
+ logits=raw_logits,
190
+ attentions=decoder_attentions,
191
+ hidden_states=decoder_hidden_states,
192
+ past_key_values=model_kwargs.get("past_key_values"),
193
+ )
194
+ else:
195
+ return input_ids
196
+
197
+
198
+ class AsteroidTTSPretrainedModel(PreTrainedModel):
199
+ config_class = AsteroidTTSConfig
200
+ base_model_prefix = "model"
201
+ supports_gradient_checkpointing = True
202
+ _no_split_modules = ["Qwen3DecoderLayer"]
203
+ _skip_keys_device_placement = ["past_key_values"]
204
+ _supports_flash_attn_2 = True
205
+ _supports_sdpa = True
206
+ _supports_flex_attn = True
207
+ _supports_cache_class = True
208
+ _supports_quantized_cache = True
209
+ _supports_static_cache = True
210
+ _supports_attention_backend = True
211
+
212
+
213
+ class AsteroidTTSModel(AsteroidTTSPretrainedModel):
214
+ def __init__(self, config: AsteroidTTSConfig):
215
+ super().__init__(config)
216
+ self.text_pad_idx = config.pad_token_id
217
+ self.speech_pad_idx = config.speech_pad_token
218
+ self.embedding_list = nn.ModuleList([])
219
+ self.embedding_list.append(nn.Embedding(config.vocab_size, config.hidden_size, self.text_pad_idx))
220
+ # Channels 1 to channels-1: Speech tokens only
221
+ for _ in range(1, config.channels):
222
+ self.embedding_list.append(nn.Embedding(config.speech_vocab_size, config.hidden_size, self.speech_pad_idx))
223
+
224
+ self.language_model = Qwen3Model(config)
225
+ self.post_init()
226
+
227
+ def get_input_embeddings(self):
228
+ return self.embedding_list[0]
229
+
230
+ def set_input_embeddings(self, value: nn.Embedding):
231
+ self.embedding_list[0] = value
232
+
233
+ def _prepare_multi_modal_inputs(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
234
+ """
235
+ Prepares multi-modal embeddings from input_ids of shape (batch_size, channels, sequence_length).
236
+ For channel 0: text + speech tokens, for channels 1 to channels-1: speech tokens padded with speech_pad_token.
237
+ """
238
+ batch_size, seq_length, channels = input_ids.shape
239
+ if channels != self.config.channels:
240
+ raise ValueError(f"Expected {self.config.channels} channels, got {channels}")
241
+
242
+ inputs_embeds = torch.zeros(batch_size, seq_length, self.config.hidden_size, device=input_ids.device, dtype=self.embedding_list[0].weight.dtype)
243
+ for i in range(channels):
244
+ embed_layer = self.embedding_list[i]
245
+ channel_input = input_ids[...,i]
246
+ inputs_embeds += embed_layer(channel_input)
247
+
248
+ return inputs_embeds
249
+
250
+ def forward(
251
+ self,
252
+ input_ids: torch.LongTensor = None, # Shape: (batch_size, channels, sequence_length)
253
+ attention_mask: Optional[torch.Tensor] = None,
254
+ position_ids: Optional[torch.LongTensor] = None,
255
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
256
+ inputs_embeds: Optional[torch.FloatTensor] = None,
257
+ use_cache: Optional[bool] = None,
258
+ output_attentions: Optional[bool] = None,
259
+ output_hidden_states: Optional[bool] = None,
260
+ return_dict: Optional[bool] = None,
261
+ cache_position: Optional[torch.LongTensor] = None,
262
+ **kwargs,
263
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
264
+
265
+ if (input_ids is None) ^ (inputs_embeds is not None):
266
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
267
+
268
+ if input_ids is not None:
269
+ inputs_embeds = self._prepare_multi_modal_inputs(input_ids)
270
+
271
+ outputs = self.language_model(
272
+ input_ids=None,
273
+ attention_mask=attention_mask,
274
+ position_ids=position_ids,
275
+ past_key_values=past_key_values,
276
+ inputs_embeds=inputs_embeds,
277
+ use_cache=use_cache,
278
+ output_attentions=output_attentions,
279
+ output_hidden_states=output_hidden_states,
280
+ return_dict=return_dict,
281
+ cache_position=cache_position,
282
+ )
283
+ return outputs
284
+
285
+
286
+ class AsteroidTTSInstruct(AsteroidTTSPretrainedModel, CustomMixin):
287
+ _tied_weights_keys = []
288
+ _tp_plan = {"lm_head": "colwise_rep"}
289
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
290
+
291
+ def __init__(self, config: AsteroidTTSConfig):
292
+ super().__init__(config)
293
+ self.model = AsteroidTTSModel(config)
294
+ self.channels = config.channels
295
+ self.weights = [1 for _ in range(self.channels)]
296
+ self._tied_weights_keys = [f"lm_heads.{i}.weight" for i in range(self.channels)]
297
+ self.vocab_size = config.vocab_size
298
+ self.lm_heads = nn.ModuleList([])
299
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.vocab_size, bias=False))
300
+ for _ in range(1, config.channels):
301
+ self.lm_heads.append(nn.Linear(config.hidden_size, config.speech_vocab_size, bias=False))
302
+ self.post_init()
303
+
304
+ def get_input_embeddings(self):
305
+ return self.model.embedding_list[0]
306
+
307
+ def can_generate(self):
308
+ return True
309
+
310
+ def is_speech_token(self, tokens):
311
+ return (tokens >= self.config.speech_token_range[0]) & (tokens < self.config.speech_token_range[1])
312
+
313
+ def tie_weights(self):
314
+ for i in range(self.config.channels):
315
+ self._tie_or_clone_weights(self.lm_heads[i], self.model.embedding_list[i])
316
+
317
+ def set_input_embeddings(self, value):
318
+ self.model.embedding_list[0] = value
319
+
320
+ def get_output_embeddings(self):
321
+ return self.lm_heads[0]
322
+
323
+ def set_output_embeddings(self, new_embeddings):
324
+ self.lm_heads[0] = new_embeddings
325
+
326
+ def set_decoder(self, decoder):
327
+ self.model = decoder
328
+
329
+ def get_decoder(self):
330
+ return self.model
331
+
332
+ def set_weights(self, weights):
333
+ self.weights = weights
334
+
335
+ def forward(
336
+ self,
337
+ input_ids: torch.LongTensor = None,
338
+ attention_mask: Optional[torch.Tensor] = None,
339
+ position_ids: Optional[torch.LongTensor] = None,
340
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
341
+ inputs_embeds: Optional[torch.FloatTensor] = None,
342
+ labels: Optional[torch.LongTensor] = None,
343
+ use_cache: Optional[bool] = None,
344
+ output_attentions: Optional[bool] = None,
345
+ output_hidden_states: Optional[bool] = None,
346
+ return_dict: Optional[bool] = None,
347
+ cache_position: Optional[torch.LongTensor] = None,
348
+ **kwargs,
349
+ ) -> Union[Tuple, AsteroidTTSOutputWithPast]:
350
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
351
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
352
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
353
+
354
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
355
+ outputs = self.model(
356
+ input_ids=input_ids,
357
+ attention_mask=attention_mask,
358
+ position_ids=position_ids,
359
+ past_key_values=past_key_values,
360
+ inputs_embeds=inputs_embeds,
361
+ use_cache=use_cache,
362
+ output_attentions=output_attentions,
363
+ output_hidden_states=output_hidden_states,
364
+ return_dict=return_dict,
365
+ cache_position=cache_position,
366
+ **kwargs,
367
+ )
368
+
369
+ hidden_states = outputs[0]
370
+ logits_all = [lm_head(hidden_states) for lm_head in self.lm_heads]
371
+
372
+ loss_all = torch.empty(self.channels, device=input_ids.device if not input_ids is None else inputs_embeds.device)
373
+
374
+ if labels is not None:
375
+ for i in range(self.config.channels):
376
+ vocab_size = self.config.vocab_size if i == 0 else self.config.speech_vocab_size
377
+ loss_all[i] = ForCausalLMLoss(logits_all[i], labels[..., i], vocab_size)
378
+
379
+ # total_weight = sum(self.weights)
380
+ # normalized_weights = [w / total_weight for w in self.weights]
381
+ normalized_weights = self.weights
382
+
383
+ total_loss = 0
384
+ for w, loss in zip(normalized_weights, loss_all):
385
+ total_loss += w * loss
386
+
387
+ if not return_dict:
388
+ output = (logits_all,) + outputs[1:]
389
+ return (total_loss, loss_all, ) + output if loss is not None else output
390
+
391
+ return AsteroidTTSOutputWithPast(
392
+ loss=total_loss,
393
+ logits=logits_all[0],
394
+ loss_all=loss_all,
395
+ logits_all=logits_all,
396
+ past_key_values=outputs.past_key_values,
397
+ hidden_states=outputs.hidden_states,
398
+ attentions=outputs.attentions,
399
+ )
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchaudio>=2.0.0
3
+ transformers>=4.30.0
4
+ gradio>=4.0.0
5
+ numpy>=1.21.0
6
+ accelerate>=0.20.0
7
+ PyPDF2
8
+ beautifulsoup4
9
+ soundfile
10
+ librosa
11
+ tqdm
12
+ requests
13
+ openai
14
+ PyYAML
15
+ einops
16
+ huggingface_hub